From 2859ff800ebab4ae54a16c56086c96021eecccd5 Mon Sep 17 00:00:00 2001 From: juancwu Date: Wed, 29 Apr 2026 01:33:14 +0000 Subject: [PATCH] add KeyRing for key rotation with thread-safe access --- keys.go | 140 ++++++++++++++++++++++++++++++ keys_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 keys.go create mode 100644 keys_test.go diff --git a/keys.go b/keys.go new file mode 100644 index 0000000..7b70621 --- /dev/null +++ b/keys.go @@ -0,0 +1,140 @@ +package ficha + +import ( + "crypto/rand" + "errors" + "fmt" + "strings" + "sync" +) + +// GenerateKey returns a fresh 32-byte key from crypto/rand. +func GenerateKey() ([]byte, error) { + k := make([]byte, keySize) + if _, err := rand.Read(k); err != nil { + return nil, fmt.Errorf("ficha: generate key: %w", err) + } + return k, nil +} + +// KeyRing holds the active signing key plus older keys still valid +// for decryption during rotation. Safe for concurrent use. +type KeyRing struct { + mu sync.RWMutex + active string + keys map[string][]byte +} + +// NewKeyRing creates a KeyRing with one key, marked active. +func NewKeyRing(activeID string, activeKey []byte) (*KeyRing, error) { + if err := validateKeyID(activeID); err != nil { + return nil, err + } + if len(activeKey) != keySize { + return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(activeKey)) + } + + kr := &KeyRing{ + active: activeID, + keys: make(map[string][]byte), + } + kr.keys[activeID] = append([]byte(nil), activeKey...) + return kr, nil +} + +// Add registers a new key without changing the active one. +// Use this to introduce a new key before promoting it. +func (kr *KeyRing) Add(id string, key []byte) error { + if err := validateKeyID(id); err != nil { + return err + } + if len(key) != keySize { + return fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key)) + } + + kr.mu.Lock() + defer kr.mu.Unlock() + + if _, exists := kr.keys[id]; exists { + return fmt.Errorf("ficha: key id %q already exists", id) + } + kr.keys[id] = append([]byte(nil), key...) + return nil +} + +// SetActive promotes an already-registered key to active. +// New tokens will be issued under this key; existing tokens +// signed with previous keys remain valid until their expiry. +func (kr *KeyRing) SetActive(id string) error { + kr.mu.Lock() + defer kr.mu.Unlock() + + if _, ok := kr.keys[id]; !ok { + return ErrUnknownKey + } + kr.active = id + return nil +} + +// Remove deletes a key from the ring. Tokens signed with it +// will no longer validate. Cannot remove the active key. +func (kr *KeyRing) Remove(id string) error { + kr.mu.Lock() + defer kr.mu.Unlock() + + if id == kr.active { + return errors.New("ficha: cannot remove active key") + } + if _, ok := kr.keys[id]; !ok { + return ErrUnknownKey + } + delete(kr.keys, id) + return nil +} + +// Active returns the current active key ID and key bytes. +// The returned key slice is a copy and safe to retain. +func (kr *KeyRing) Active() (id string, key []byte) { + kr.mu.RLock() + defer kr.mu.RUnlock() + + k := kr.keys[kr.active] + return kr.active, append([]byte(nil), k...) +} + +// Get returns the key bytes for the given ID, or false if unknown. +// The returned key slice is a copy. +func (kr *KeyRing) Get(id string) ([]byte, bool) { + kr.mu.RLock() + defer kr.mu.RUnlock() + + k, ok := kr.keys[id] + if !ok { + return nil, false + } + return append([]byte(nil), k...), true +} + +// IDs returns all known key IDs in unspecified order. +func (kr *KeyRing) IDs() []string { + kr.mu.RLock() + defer kr.mu.RUnlock() + + ids := make([]string, 0, len(kr.keys)) + for id := range kr.keys { + ids = append(ids, id) + } + return ids +} + +// validateKeyID enforces constraints needed by the wire format: +// non-empty, no dots (used as separator), no whitespace. +func validateKeyID(id string) error { + if id == "" { + return errors.New("ficha: key id cannot be empty") + } + if strings.ContainsAny(id, ". \t\r\n") { + return errors.New("ficha: key id cannot contain dots or whitespace") + } + return nil +} diff --git a/keys_test.go b/keys_test.go new file mode 100644 index 0000000..a8fc6c9 --- /dev/null +++ b/keys_test.go @@ -0,0 +1,238 @@ +package ficha + +import ( + "bytes" + "crypto/rand" + "errors" + "sync" + "testing" +) + +func newTestKey(t *testing.T) []byte { + t.Helper() + k := make([]byte, keySize) + if _, err := rand.Read(k); err != nil { + t.Fatalf("rand: %v", err) + } + return k +} + +func TestNewKeyRing(t *testing.T) { + k := newTestKey(t) + kr, err := NewKeyRing("k1", k) + if err != nil { + t.Fatalf("NewKeyRing: %v", err) + } + + id, got := kr.Active() + if id != "k1" { + t.Errorf("active id: got %q, want %q", id, "k1") + } + if !bytes.Equal(got, k) { + t.Errorf("active key bytes mismatch") + } +} + +func TestNewKeyRingValidation(t *testing.T) { + tests := []struct { + name string + id string + key []byte + }{ + {"empty id", "", make([]byte, keySize)}, + {"id with dot", "k.1", make([]byte, keySize)}, + {"id with space", "k 1", make([]byte, keySize)}, + {"id with tab", "k\t1", make([]byte, keySize)}, + {"key too short", "k1", make([]byte, 16)}, + {"key too long", "k1", make([]byte, 64)}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if _, err := NewKeyRing(tc.id, tc.key); err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} + +func TestKeyRingAddAndGet(t *testing.T) { + k1 := newTestKey(t) + k2 := newTestKey(t) + + kr, err := NewKeyRing("k1", k1) + if err != nil { + t.Fatalf("NewKeyRing: %v", err) + } + + if err := kr.Add("k2", k2); err != nil { + t.Fatalf("Add: %v", err) + } + + got, ok := kr.Get("k2") + if !ok { + t.Fatal("Get(k2): not found") + } + if !bytes.Equal(got, k2) { + t.Error("Get(k2): bytes mismatch") + } + + // Active should still be k1. + if id, _ := kr.Active(); id != "k1" { + t.Errorf("active should still be k1, got %q", id) + } +} + +func TestKeyRingAddDuplicate(t *testing.T) { + kr, _ := NewKeyRing("k1", newTestKey(t)) + if err := kr.Add("k1", newTestKey(t)); err == nil { + t.Error("expected error adding duplicate id, got nil") + } +} + +func TestKeyRingSetActive(t *testing.T) { + kr, _ := NewKeyRing("k1", newTestKey(t)) + k2 := newTestKey(t) + if err := kr.Add("k2", k2); err != nil { + t.Fatalf("Add: %v", err) + } + + if err := kr.SetActive("k2"); err != nil { + t.Fatalf("SetActive: %v", err) + } + + id, key := kr.Active() + if id != "k2" { + t.Errorf("active id: got %q, want %q", id, "k2") + } + if !bytes.Equal(key, k2) { + t.Error("active key bytes mismatch") + } +} + +func TestKeyRingSetActiveUnknown(t *testing.T) { + kr, _ := NewKeyRing("k1", newTestKey(t)) + if err := kr.SetActive("nope"); !errors.Is(err, ErrUnknownKey) { + t.Errorf("expected ErrUnknownKey, got %v", err) + } +} + +func TestKeyRingRemove(t *testing.T) { + kr, _ := NewKeyRing("k1", newTestKey(t)) + if err := kr.Add("k2", newTestKey(t)); err != nil { + t.Fatalf("Add: %v", err) + } + + // Cannot remove active. + if err := kr.Remove("k1"); err == nil { + t.Error("expected error removing active key, got nil") + } + + // Can remove non-active. + if err := kr.Remove("k2"); err != nil { + t.Fatalf("Remove: %v", err) + } + if _, ok := kr.Get("k2"); ok { + t.Error("k2 should be gone") + } + + // Removing unknown. + if err := kr.Remove("ghost"); !errors.Is(err, ErrUnknownKey) { + t.Errorf("expected ErrUnknownKey, got %v", err) + } +} + +func TestKeyRingDefensiveCopy(t *testing.T) { + original := newTestKey(t) + kr, err := NewKeyRing("k1", original) + if err != nil { + t.Fatalf("NewKeyRing: %v", err) + } + + // Mutate the caller's slice — the ring should be unaffected. + original[0] ^= 0xFF + + _, got := kr.Active() + if got[0] == original[0] { + t.Error("KeyRing did not defensively copy on construction") + } + + // Mutate the returned slice — the ring should be unaffected. + got[0] ^= 0xFF + _, again := kr.Active() + if again[0] == got[0] { + t.Error("KeyRing did not defensively copy on Active()") + } +} + +func TestKeyRingIDs(t *testing.T) { + kr, _ := NewKeyRing("k1", newTestKey(t)) + if err := kr.Add("k2", newTestKey(t)); err != nil { + t.Fatalf("Add: %v", err) + } + if err := kr.Add("k3", newTestKey(t)); err != nil { + t.Fatalf("Add: %v", err) + } + + ids := kr.IDs() + if len(ids) != 3 { + t.Errorf("IDs length: got %d, want 3", len(ids)) + } + + want := map[string]bool{"k1": true, "k2": true, "k3": true} + for _, id := range ids { + if !want[id] { + t.Errorf("unexpected id: %q", id) + } + } +} + +func TestKeyRingConcurrent(t *testing.T) { + // Hammer the ring from many goroutines to flush out race issues. + // Run with: go test -race + kr, _ := NewKeyRing("k1", newTestKey(t)) + if err := kr.Add("k2", newTestKey(t)); err != nil { + t.Fatalf("Add: %v", err) + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + kr.Active() + kr.Get("k2") + kr.IDs() + } + }() + } + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 50; j++ { + _ = kr.SetActive("k1") + _ = kr.SetActive("k2") + } + }() + } + wg.Wait() +} + +func TestGenerateKey(t *testing.T) { + k1, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + if len(k1) != keySize { + t.Errorf("length: got %d, want %d", len(k1), keySize) + } + + k2, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + if bytes.Equal(k1, k2) { + t.Error("two generated keys should not be equal") + } +}