diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..63897d9 --- /dev/null +++ b/codec.go @@ -0,0 +1,47 @@ +package ficha + +import ( + "encoding/json" + "time" +) + +// payload is the internal structure that gets encrypted into a token. +// It has two compartments for consumer use: +// +// - Permissions: a flat list of string identifiers that ficha understands +// and provides check methods for (Has, HasAll, HasAny, etc.). +// Consumers choose any string convention they like (e.g. "orders:read", +// "admin", "team:42:write"). +// +// - Data: a freeform JSON blob that ficha treats as opaque bytes. +// Consumers marshal/unmarshal it into their own types. +// +// The payload itself is unexported because consumers never construct one directly. +type payload struct { + ID string `json:"id"` + Iat int64 `json:"iat"` + Exp int64 `json:"exp"` + Permissions []string `json:"perms,omitempty"` + Data json.RawMessage `json:"data,omitempty"` +} + +// encodePayload serializes a payload to bytes ready for encryption. +func encodePayload(p payload) ([]byte, error) { + return json.Marshal(p) +} + +// decodePayload deserializes bytes (post-decryption) back into a payload. +func decodePayload(b []byte) (payload, error) { + var p payload + if err := json.Unmarshal(b, &p); err != nil { + return payload{}, err + } + return p, nil +} + +// expired reports whether the payload's expiry has passed at the given time. +// Uses >= so a token expiring exactly at `now` is considered expired +// (the conservative choice). +func (p payload) expired(now time.Time) bool { + return now.Unix() >= p.Exp +} diff --git a/codec_test.go b/codec_test.go new file mode 100644 index 0000000..6e8560b --- /dev/null +++ b/codec_test.go @@ -0,0 +1,244 @@ +package ficha + +import ( + "encoding/json" + "testing" + "time" +) + +func TestPayloadRoundtripFull(t *testing.T) { + type meta struct { + UserID string `json:"user_id"` + TenantID string `json:"tenant_id"` + } + consumerData, err := json.Marshal(meta{UserID: "u_123", TenantID: "t_42"}) + if err != nil { + t.Fatalf("marshal consumer data: %v", err) + } + + original := payload{ + ID: "tok_abc", + Iat: 1_700_000_000, + Exp: 1_700_003_600, + Permissions: []string{"orders:read", "orders:write", "admin"}, + Data: consumerData, + } + + encoded, err := encodePayload(original) + if err != nil { + t.Fatalf("encodePayload: %v", err) + } + if len(encoded) == 0 { + t.Fatal("encodePayload returned empty bytes") + } + + decoded, err := decodePayload(encoded) + if err != nil { + t.Fatalf("decodePayload: %v", err) + } + + if decoded.ID != original.ID { + t.Errorf("ID: got %q, want %q", decoded.ID, original.ID) + } + if decoded.Iat != original.Iat { + t.Errorf("Iat: got %d, want %d", decoded.Iat, original.Iat) + } + if decoded.Exp != original.Exp { + t.Errorf("Exp: got %d, want %d", decoded.Exp, original.Exp) + } + + if len(decoded.Permissions) != len(original.Permissions) { + t.Fatalf("Permissions length: got %d, want %d", + len(decoded.Permissions), len(original.Permissions)) + } + for i, p := range original.Permissions { + if decoded.Permissions[i] != p { + t.Errorf("Permissions[%d]: got %q, want %q", i, decoded.Permissions[i], p) + } + } + + var got meta + if err := json.Unmarshal(decoded.Data, &got); err != nil { + t.Fatalf("unmarshal consumer data: %v", err) + } + if got.UserID != "u_123" || got.TenantID != "t_42" { + t.Errorf("Data: got %+v", got) + } +} + +func TestPayloadRoundtripPermissionsOnly(t *testing.T) { + original := payload{ + ID: "tok_perms", + Iat: 1_700_000_000, + Exp: 1_700_003_600, + Permissions: []string{"read", "write"}, + } + + encoded, err := encodePayload(original) + if err != nil { + t.Fatalf("encodePayload: %v", err) + } + + // omitempty should keep "data" out of the encoded JSON. + if contains := bytesContain(encoded, `"data"`); contains { + t.Errorf("encoded payload should omit empty Data field, got: %s", encoded) + } + + decoded, err := decodePayload(encoded) + if err != nil { + t.Fatalf("decodePayload: %v", err) + } + if len(decoded.Permissions) != 2 { + t.Errorf("Permissions length: got %d, want 2", len(decoded.Permissions)) + } + if decoded.Data != nil { + t.Errorf("Data should be nil, got %s", decoded.Data) + } +} + +func TestPayloadRoundtripDataOnly(t *testing.T) { + type meta struct { + Note string `json:"note"` + } + consumerData, _ := json.Marshal(meta{Note: "hello"}) + + original := payload{ + ID: "tok_data", + Iat: 1_700_000_000, + Exp: 1_700_003_600, + Data: consumerData, + } + + encoded, err := encodePayload(original) + if err != nil { + t.Fatalf("encodePayload: %v", err) + } + + // omitempty should keep "perms" out of the encoded JSON. + if contains := bytesContain(encoded, `"perms"`); contains { + t.Errorf("encoded payload should omit empty Permissions field, got: %s", encoded) + } + + decoded, err := decodePayload(encoded) + if err != nil { + t.Fatalf("decodePayload: %v", err) + } + if len(decoded.Permissions) != 0 { + t.Errorf("Permissions should be empty, got %v", decoded.Permissions) + } + if len(decoded.Data) == 0 { + t.Error("Data should not be empty") + } +} + +func TestPayloadRoundtripMinimal(t *testing.T) { + // A token with no permissions and no data is unusual but valid — + // e.g., an "I am authenticated" token that just proves identity. + original := payload{ + ID: "tok_minimal", + Iat: 1_700_000_000, + Exp: 1_700_003_600, + } + + encoded, err := encodePayload(original) + if err != nil { + t.Fatalf("encodePayload: %v", err) + } + + decoded, err := decodePayload(encoded) + if err != nil { + t.Fatalf("decodePayload: %v", err) + } + if decoded.ID != original.ID { + t.Errorf("ID: got %q, want %q", decoded.ID, original.ID) + } +} + +func TestPayloadExpired(t *testing.T) { + tests := []struct { + name string + exp int64 + now time.Time + want bool + }{ + { + name: "future expiry is not expired", + exp: time.Now().Add(1 * time.Hour).Unix(), + now: time.Now(), + want: false, + }, + { + name: "past expiry is expired", + exp: time.Now().Add(-1 * time.Hour).Unix(), + now: time.Now(), + want: true, + }, + { + name: "exactly at expiry is expired (conservative)", + exp: 1_700_000_000, + now: time.Unix(1_700_000_000, 0), + want: true, + }, + { + name: "one second before expiry is not expired", + exp: 1_700_000_000, + now: time.Unix(1_699_999_999, 0), + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := payload{Exp: tc.exp} + if got := p.expired(tc.now); got != tc.want { + t.Errorf("expired() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestDecodePayloadMalformed(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"empty", []byte{}}, + {"not json", []byte("hello world")}, + {"truncated", []byte(`{"id":"abc","iat":`)}, + {"wrong type for iat", []byte(`{"id":"abc","iat":"notanumber","exp":1}`)}, + {"perms wrong type", []byte(`{"id":"abc","iat":1,"exp":2,"perms":"notalist"}`)}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if _, err := decodePayload(tc.input); err == nil { + t.Errorf("expected error for input %q, got nil", tc.input) + } + }) + } +} + +// bytesContain is a small helper for checking JSON contents in tests. +func bytesContain(haystack []byte, needle string) bool { + return len(needle) > 0 && bytesIndex(haystack, []byte(needle)) >= 0 +} + +func bytesIndex(s, sep []byte) int { + n := len(sep) + if n == 0 { + return 0 + } + for i := 0; i+n <= len(s); i++ { + match := true + for j := range n { + if s[i+j] != sep[j] { + match = false + break + } + } + if match { + return i + } + } + return -1 +}