add payload codec with permissions and data compartments
This commit is contained in:
parent
5a7dfddc38
commit
71483982ca
2 changed files with 291 additions and 0 deletions
244
codec_test.go
Normal file
244
codec_test.go
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPayloadRoundtripFull(t *testing.T) {
|
||||
type meta struct {
|
||||
UserID string `json:"user_id"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
}
|
||||
consumerData, err := json.Marshal(meta{UserID: "u_123", TenantID: "t_42"})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal consumer data: %v", err)
|
||||
}
|
||||
|
||||
original := payload{
|
||||
ID: "tok_abc",
|
||||
Iat: 1_700_000_000,
|
||||
Exp: 1_700_003_600,
|
||||
Permissions: []string{"orders:read", "orders:write", "admin"},
|
||||
Data: consumerData,
|
||||
}
|
||||
|
||||
encoded, err := encodePayload(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodePayload: %v", err)
|
||||
}
|
||||
if len(encoded) == 0 {
|
||||
t.Fatal("encodePayload returned empty bytes")
|
||||
}
|
||||
|
||||
decoded, err := decodePayload(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodePayload: %v", err)
|
||||
}
|
||||
|
||||
if decoded.ID != original.ID {
|
||||
t.Errorf("ID: got %q, want %q", decoded.ID, original.ID)
|
||||
}
|
||||
if decoded.Iat != original.Iat {
|
||||
t.Errorf("Iat: got %d, want %d", decoded.Iat, original.Iat)
|
||||
}
|
||||
if decoded.Exp != original.Exp {
|
||||
t.Errorf("Exp: got %d, want %d", decoded.Exp, original.Exp)
|
||||
}
|
||||
|
||||
if len(decoded.Permissions) != len(original.Permissions) {
|
||||
t.Fatalf("Permissions length: got %d, want %d",
|
||||
len(decoded.Permissions), len(original.Permissions))
|
||||
}
|
||||
for i, p := range original.Permissions {
|
||||
if decoded.Permissions[i] != p {
|
||||
t.Errorf("Permissions[%d]: got %q, want %q", i, decoded.Permissions[i], p)
|
||||
}
|
||||
}
|
||||
|
||||
var got meta
|
||||
if err := json.Unmarshal(decoded.Data, &got); err != nil {
|
||||
t.Fatalf("unmarshal consumer data: %v", err)
|
||||
}
|
||||
if got.UserID != "u_123" || got.TenantID != "t_42" {
|
||||
t.Errorf("Data: got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadRoundtripPermissionsOnly(t *testing.T) {
|
||||
original := payload{
|
||||
ID: "tok_perms",
|
||||
Iat: 1_700_000_000,
|
||||
Exp: 1_700_003_600,
|
||||
Permissions: []string{"read", "write"},
|
||||
}
|
||||
|
||||
encoded, err := encodePayload(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodePayload: %v", err)
|
||||
}
|
||||
|
||||
// omitempty should keep "data" out of the encoded JSON.
|
||||
if contains := bytesContain(encoded, `"data"`); contains {
|
||||
t.Errorf("encoded payload should omit empty Data field, got: %s", encoded)
|
||||
}
|
||||
|
||||
decoded, err := decodePayload(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodePayload: %v", err)
|
||||
}
|
||||
if len(decoded.Permissions) != 2 {
|
||||
t.Errorf("Permissions length: got %d, want 2", len(decoded.Permissions))
|
||||
}
|
||||
if decoded.Data != nil {
|
||||
t.Errorf("Data should be nil, got %s", decoded.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadRoundtripDataOnly(t *testing.T) {
|
||||
type meta struct {
|
||||
Note string `json:"note"`
|
||||
}
|
||||
consumerData, _ := json.Marshal(meta{Note: "hello"})
|
||||
|
||||
original := payload{
|
||||
ID: "tok_data",
|
||||
Iat: 1_700_000_000,
|
||||
Exp: 1_700_003_600,
|
||||
Data: consumerData,
|
||||
}
|
||||
|
||||
encoded, err := encodePayload(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodePayload: %v", err)
|
||||
}
|
||||
|
||||
// omitempty should keep "perms" out of the encoded JSON.
|
||||
if contains := bytesContain(encoded, `"perms"`); contains {
|
||||
t.Errorf("encoded payload should omit empty Permissions field, got: %s", encoded)
|
||||
}
|
||||
|
||||
decoded, err := decodePayload(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodePayload: %v", err)
|
||||
}
|
||||
if len(decoded.Permissions) != 0 {
|
||||
t.Errorf("Permissions should be empty, got %v", decoded.Permissions)
|
||||
}
|
||||
if len(decoded.Data) == 0 {
|
||||
t.Error("Data should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadRoundtripMinimal(t *testing.T) {
|
||||
// A token with no permissions and no data is unusual but valid —
|
||||
// e.g., an "I am authenticated" token that just proves identity.
|
||||
original := payload{
|
||||
ID: "tok_minimal",
|
||||
Iat: 1_700_000_000,
|
||||
Exp: 1_700_003_600,
|
||||
}
|
||||
|
||||
encoded, err := encodePayload(original)
|
||||
if err != nil {
|
||||
t.Fatalf("encodePayload: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := decodePayload(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodePayload: %v", err)
|
||||
}
|
||||
if decoded.ID != original.ID {
|
||||
t.Errorf("ID: got %q, want %q", decoded.ID, original.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exp int64
|
||||
now time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "future expiry is not expired",
|
||||
exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
now: time.Now(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "past expiry is expired",
|
||||
exp: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
now: time.Now(),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exactly at expiry is expired (conservative)",
|
||||
exp: 1_700_000_000,
|
||||
now: time.Unix(1_700_000_000, 0),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "one second before expiry is not expired",
|
||||
exp: 1_700_000_000,
|
||||
now: time.Unix(1_699_999_999, 0),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := payload{Exp: tc.exp}
|
||||
if got := p.expired(tc.now); got != tc.want {
|
||||
t.Errorf("expired() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodePayloadMalformed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
}{
|
||||
{"empty", []byte{}},
|
||||
{"not json", []byte("hello world")},
|
||||
{"truncated", []byte(`{"id":"abc","iat":`)},
|
||||
{"wrong type for iat", []byte(`{"id":"abc","iat":"notanumber","exp":1}`)},
|
||||
{"perms wrong type", []byte(`{"id":"abc","iat":1,"exp":2,"perms":"notalist"}`)},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if _, err := decodePayload(tc.input); err == nil {
|
||||
t.Errorf("expected error for input %q, got nil", tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// bytesContain is a small helper for checking JSON contents in tests.
|
||||
func bytesContain(haystack []byte, needle string) bool {
|
||||
return len(needle) > 0 && bytesIndex(haystack, []byte(needle)) >= 0
|
||||
}
|
||||
|
||||
func bytesIndex(s, sep []byte) int {
|
||||
n := len(sep)
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
for i := 0; i+n <= len(s); i++ {
|
||||
match := true
|
||||
for j := range n {
|
||||
if s[i+j] != sep[j] {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue