244 lines
5.7 KiB
Go
244 lines
5.7 KiB
Go
package ficha
|
|
|
|
import (
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestPayloadRoundtripFull(t *testing.T) {
|
|
type meta struct {
|
|
UserID string `json:"user_id"`
|
|
TenantID string `json:"tenant_id"`
|
|
}
|
|
consumerData, err := json.Marshal(meta{UserID: "u_123", TenantID: "t_42"})
|
|
if err != nil {
|
|
t.Fatalf("marshal consumer data: %v", err)
|
|
}
|
|
|
|
original := payload{
|
|
ID: "tok_abc",
|
|
Iat: 1_700_000_000,
|
|
Exp: 1_700_003_600,
|
|
Permissions: []string{"orders:read", "orders:write", "admin"},
|
|
Data: consumerData,
|
|
}
|
|
|
|
encoded, err := encodePayload(original)
|
|
if err != nil {
|
|
t.Fatalf("encodePayload: %v", err)
|
|
}
|
|
if len(encoded) == 0 {
|
|
t.Fatal("encodePayload returned empty bytes")
|
|
}
|
|
|
|
decoded, err := decodePayload(encoded)
|
|
if err != nil {
|
|
t.Fatalf("decodePayload: %v", err)
|
|
}
|
|
|
|
if decoded.ID != original.ID {
|
|
t.Errorf("ID: got %q, want %q", decoded.ID, original.ID)
|
|
}
|
|
if decoded.Iat != original.Iat {
|
|
t.Errorf("Iat: got %d, want %d", decoded.Iat, original.Iat)
|
|
}
|
|
if decoded.Exp != original.Exp {
|
|
t.Errorf("Exp: got %d, want %d", decoded.Exp, original.Exp)
|
|
}
|
|
|
|
if len(decoded.Permissions) != len(original.Permissions) {
|
|
t.Fatalf("Permissions length: got %d, want %d",
|
|
len(decoded.Permissions), len(original.Permissions))
|
|
}
|
|
for i, p := range original.Permissions {
|
|
if decoded.Permissions[i] != p {
|
|
t.Errorf("Permissions[%d]: got %q, want %q", i, decoded.Permissions[i], p)
|
|
}
|
|
}
|
|
|
|
var got meta
|
|
if err := json.Unmarshal(decoded.Data, &got); err != nil {
|
|
t.Fatalf("unmarshal consumer data: %v", err)
|
|
}
|
|
if got.UserID != "u_123" || got.TenantID != "t_42" {
|
|
t.Errorf("Data: got %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestPayloadRoundtripPermissionsOnly(t *testing.T) {
|
|
original := payload{
|
|
ID: "tok_perms",
|
|
Iat: 1_700_000_000,
|
|
Exp: 1_700_003_600,
|
|
Permissions: []string{"read", "write"},
|
|
}
|
|
|
|
encoded, err := encodePayload(original)
|
|
if err != nil {
|
|
t.Fatalf("encodePayload: %v", err)
|
|
}
|
|
|
|
// omitempty should keep "data" out of the encoded JSON.
|
|
if contains := bytesContain(encoded, `"data"`); contains {
|
|
t.Errorf("encoded payload should omit empty Data field, got: %s", encoded)
|
|
}
|
|
|
|
decoded, err := decodePayload(encoded)
|
|
if err != nil {
|
|
t.Fatalf("decodePayload: %v", err)
|
|
}
|
|
if len(decoded.Permissions) != 2 {
|
|
t.Errorf("Permissions length: got %d, want 2", len(decoded.Permissions))
|
|
}
|
|
if decoded.Data != nil {
|
|
t.Errorf("Data should be nil, got %s", decoded.Data)
|
|
}
|
|
}
|
|
|
|
func TestPayloadRoundtripDataOnly(t *testing.T) {
|
|
type meta struct {
|
|
Note string `json:"note"`
|
|
}
|
|
consumerData, _ := json.Marshal(meta{Note: "hello"})
|
|
|
|
original := payload{
|
|
ID: "tok_data",
|
|
Iat: 1_700_000_000,
|
|
Exp: 1_700_003_600,
|
|
Data: consumerData,
|
|
}
|
|
|
|
encoded, err := encodePayload(original)
|
|
if err != nil {
|
|
t.Fatalf("encodePayload: %v", err)
|
|
}
|
|
|
|
// omitempty should keep "perms" out of the encoded JSON.
|
|
if contains := bytesContain(encoded, `"perms"`); contains {
|
|
t.Errorf("encoded payload should omit empty Permissions field, got: %s", encoded)
|
|
}
|
|
|
|
decoded, err := decodePayload(encoded)
|
|
if err != nil {
|
|
t.Fatalf("decodePayload: %v", err)
|
|
}
|
|
if len(decoded.Permissions) != 0 {
|
|
t.Errorf("Permissions should be empty, got %v", decoded.Permissions)
|
|
}
|
|
if len(decoded.Data) == 0 {
|
|
t.Error("Data should not be empty")
|
|
}
|
|
}
|
|
|
|
func TestPayloadRoundtripMinimal(t *testing.T) {
|
|
// A token with no permissions and no data is unusual but valid —
|
|
// e.g., an "I am authenticated" token that just proves identity.
|
|
original := payload{
|
|
ID: "tok_minimal",
|
|
Iat: 1_700_000_000,
|
|
Exp: 1_700_003_600,
|
|
}
|
|
|
|
encoded, err := encodePayload(original)
|
|
if err != nil {
|
|
t.Fatalf("encodePayload: %v", err)
|
|
}
|
|
|
|
decoded, err := decodePayload(encoded)
|
|
if err != nil {
|
|
t.Fatalf("decodePayload: %v", err)
|
|
}
|
|
if decoded.ID != original.ID {
|
|
t.Errorf("ID: got %q, want %q", decoded.ID, original.ID)
|
|
}
|
|
}
|
|
|
|
func TestPayloadExpired(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
exp int64
|
|
now time.Time
|
|
want bool
|
|
}{
|
|
{
|
|
name: "future expiry is not expired",
|
|
exp: time.Now().Add(1 * time.Hour).Unix(),
|
|
now: time.Now(),
|
|
want: false,
|
|
},
|
|
{
|
|
name: "past expiry is expired",
|
|
exp: time.Now().Add(-1 * time.Hour).Unix(),
|
|
now: time.Now(),
|
|
want: true,
|
|
},
|
|
{
|
|
name: "exactly at expiry is expired (conservative)",
|
|
exp: 1_700_000_000,
|
|
now: time.Unix(1_700_000_000, 0),
|
|
want: true,
|
|
},
|
|
{
|
|
name: "one second before expiry is not expired",
|
|
exp: 1_700_000_000,
|
|
now: time.Unix(1_699_999_999, 0),
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
p := payload{Exp: tc.exp}
|
|
if got := p.expired(tc.now); got != tc.want {
|
|
t.Errorf("expired() = %v, want %v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDecodePayloadMalformed(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
}{
|
|
{"empty", []byte{}},
|
|
{"not json", []byte("hello world")},
|
|
{"truncated", []byte(`{"id":"abc","iat":`)},
|
|
{"wrong type for iat", []byte(`{"id":"abc","iat":"notanumber","exp":1}`)},
|
|
{"perms wrong type", []byte(`{"id":"abc","iat":1,"exp":2,"perms":"notalist"}`)},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if _, err := decodePayload(tc.input); err == nil {
|
|
t.Errorf("expected error for input %q, got nil", tc.input)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// bytesContain is a small helper for checking JSON contents in tests.
|
|
func bytesContain(haystack []byte, needle string) bool {
|
|
return len(needle) > 0 && bytesIndex(haystack, []byte(needle)) >= 0
|
|
}
|
|
|
|
func bytesIndex(s, sep []byte) int {
|
|
n := len(sep)
|
|
if n == 0 {
|
|
return 0
|
|
}
|
|
for i := 0; i+n <= len(s); i++ {
|
|
match := true
|
|
for j := range n {
|
|
if s[i+j] != sep[j] {
|
|
match = false
|
|
break
|
|
}
|
|
}
|
|
if match {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|