add KeyRing for key rotation with thread-safe access
This commit is contained in:
parent
a6aad1a1d6
commit
2859ff800e
2 changed files with 378 additions and 0 deletions
140
keys.go
Normal file
140
keys.go
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GenerateKey returns a fresh 32-byte key from crypto/rand.
|
||||
func GenerateKey() ([]byte, error) {
|
||||
k := make([]byte, keySize)
|
||||
if _, err := rand.Read(k); err != nil {
|
||||
return nil, fmt.Errorf("ficha: generate key: %w", err)
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// KeyRing holds the active signing key plus older keys still valid
|
||||
// for decryption during rotation. Safe for concurrent use.
|
||||
type KeyRing struct {
|
||||
mu sync.RWMutex
|
||||
active string
|
||||
keys map[string][]byte
|
||||
}
|
||||
|
||||
// NewKeyRing creates a KeyRing with one key, marked active.
|
||||
func NewKeyRing(activeID string, activeKey []byte) (*KeyRing, error) {
|
||||
if err := validateKeyID(activeID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(activeKey) != keySize {
|
||||
return nil, fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(activeKey))
|
||||
}
|
||||
|
||||
kr := &KeyRing{
|
||||
active: activeID,
|
||||
keys: make(map[string][]byte),
|
||||
}
|
||||
kr.keys[activeID] = append([]byte(nil), activeKey...)
|
||||
return kr, nil
|
||||
}
|
||||
|
||||
// Add registers a new key without changing the active one.
|
||||
// Use this to introduce a new key before promoting it.
|
||||
func (kr *KeyRing) Add(id string, key []byte) error {
|
||||
if err := validateKeyID(id); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(key) != keySize {
|
||||
return fmt.Errorf("ficha: key must be %d bytes, got %d", keySize, len(key))
|
||||
}
|
||||
|
||||
kr.mu.Lock()
|
||||
defer kr.mu.Unlock()
|
||||
|
||||
if _, exists := kr.keys[id]; exists {
|
||||
return fmt.Errorf("ficha: key id %q already exists", id)
|
||||
}
|
||||
kr.keys[id] = append([]byte(nil), key...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetActive promotes an already-registered key to active.
|
||||
// New tokens will be issued under this key; existing tokens
|
||||
// signed with previous keys remain valid until their expiry.
|
||||
func (kr *KeyRing) SetActive(id string) error {
|
||||
kr.mu.Lock()
|
||||
defer kr.mu.Unlock()
|
||||
|
||||
if _, ok := kr.keys[id]; !ok {
|
||||
return ErrUnknownKey
|
||||
}
|
||||
kr.active = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes a key from the ring. Tokens signed with it
|
||||
// will no longer validate. Cannot remove the active key.
|
||||
func (kr *KeyRing) Remove(id string) error {
|
||||
kr.mu.Lock()
|
||||
defer kr.mu.Unlock()
|
||||
|
||||
if id == kr.active {
|
||||
return errors.New("ficha: cannot remove active key")
|
||||
}
|
||||
if _, ok := kr.keys[id]; !ok {
|
||||
return ErrUnknownKey
|
||||
}
|
||||
delete(kr.keys, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Active returns the current active key ID and key bytes.
|
||||
// The returned key slice is a copy and safe to retain.
|
||||
func (kr *KeyRing) Active() (id string, key []byte) {
|
||||
kr.mu.RLock()
|
||||
defer kr.mu.RUnlock()
|
||||
|
||||
k := kr.keys[kr.active]
|
||||
return kr.active, append([]byte(nil), k...)
|
||||
}
|
||||
|
||||
// Get returns the key bytes for the given ID, or false if unknown.
|
||||
// The returned key slice is a copy.
|
||||
func (kr *KeyRing) Get(id string) ([]byte, bool) {
|
||||
kr.mu.RLock()
|
||||
defer kr.mu.RUnlock()
|
||||
|
||||
k, ok := kr.keys[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return append([]byte(nil), k...), true
|
||||
}
|
||||
|
||||
// IDs returns all known key IDs in unspecified order.
|
||||
func (kr *KeyRing) IDs() []string {
|
||||
kr.mu.RLock()
|
||||
defer kr.mu.RUnlock()
|
||||
|
||||
ids := make([]string, 0, len(kr.keys))
|
||||
for id := range kr.keys {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// validateKeyID enforces constraints needed by the wire format:
|
||||
// non-empty, no dots (used as separator), no whitespace.
|
||||
func validateKeyID(id string) error {
|
||||
if id == "" {
|
||||
return errors.New("ficha: key id cannot be empty")
|
||||
}
|
||||
if strings.ContainsAny(id, ". \t\r\n") {
|
||||
return errors.New("ficha: key id cannot contain dots or whitespace")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
238
keys_test.go
Normal file
238
keys_test.go
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
package ficha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestKey(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
k := make([]byte, keySize)
|
||||
if _, err := rand.Read(k); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
func TestNewKeyRing(t *testing.T) {
|
||||
k := newTestKey(t)
|
||||
kr, err := NewKeyRing("k1", k)
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyRing: %v", err)
|
||||
}
|
||||
|
||||
id, got := kr.Active()
|
||||
if id != "k1" {
|
||||
t.Errorf("active id: got %q, want %q", id, "k1")
|
||||
}
|
||||
if !bytes.Equal(got, k) {
|
||||
t.Errorf("active key bytes mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewKeyRingValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
key []byte
|
||||
}{
|
||||
{"empty id", "", make([]byte, keySize)},
|
||||
{"id with dot", "k.1", make([]byte, keySize)},
|
||||
{"id with space", "k 1", make([]byte, keySize)},
|
||||
{"id with tab", "k\t1", make([]byte, keySize)},
|
||||
{"key too short", "k1", make([]byte, 16)},
|
||||
{"key too long", "k1", make([]byte, 64)},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if _, err := NewKeyRing(tc.id, tc.key); err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingAddAndGet(t *testing.T) {
|
||||
k1 := newTestKey(t)
|
||||
k2 := newTestKey(t)
|
||||
|
||||
kr, err := NewKeyRing("k1", k1)
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyRing: %v", err)
|
||||
}
|
||||
|
||||
if err := kr.Add("k2", k2); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
|
||||
got, ok := kr.Get("k2")
|
||||
if !ok {
|
||||
t.Fatal("Get(k2): not found")
|
||||
}
|
||||
if !bytes.Equal(got, k2) {
|
||||
t.Error("Get(k2): bytes mismatch")
|
||||
}
|
||||
|
||||
// Active should still be k1.
|
||||
if id, _ := kr.Active(); id != "k1" {
|
||||
t.Errorf("active should still be k1, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingAddDuplicate(t *testing.T) {
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
if err := kr.Add("k1", newTestKey(t)); err == nil {
|
||||
t.Error("expected error adding duplicate id, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingSetActive(t *testing.T) {
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
k2 := newTestKey(t)
|
||||
if err := kr.Add("k2", k2); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
|
||||
if err := kr.SetActive("k2"); err != nil {
|
||||
t.Fatalf("SetActive: %v", err)
|
||||
}
|
||||
|
||||
id, key := kr.Active()
|
||||
if id != "k2" {
|
||||
t.Errorf("active id: got %q, want %q", id, "k2")
|
||||
}
|
||||
if !bytes.Equal(key, k2) {
|
||||
t.Error("active key bytes mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingSetActiveUnknown(t *testing.T) {
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
if err := kr.SetActive("nope"); !errors.Is(err, ErrUnknownKey) {
|
||||
t.Errorf("expected ErrUnknownKey, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingRemove(t *testing.T) {
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
if err := kr.Add("k2", newTestKey(t)); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
|
||||
// Cannot remove active.
|
||||
if err := kr.Remove("k1"); err == nil {
|
||||
t.Error("expected error removing active key, got nil")
|
||||
}
|
||||
|
||||
// Can remove non-active.
|
||||
if err := kr.Remove("k2"); err != nil {
|
||||
t.Fatalf("Remove: %v", err)
|
||||
}
|
||||
if _, ok := kr.Get("k2"); ok {
|
||||
t.Error("k2 should be gone")
|
||||
}
|
||||
|
||||
// Removing unknown.
|
||||
if err := kr.Remove("ghost"); !errors.Is(err, ErrUnknownKey) {
|
||||
t.Errorf("expected ErrUnknownKey, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingDefensiveCopy(t *testing.T) {
|
||||
original := newTestKey(t)
|
||||
kr, err := NewKeyRing("k1", original)
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyRing: %v", err)
|
||||
}
|
||||
|
||||
// Mutate the caller's slice — the ring should be unaffected.
|
||||
original[0] ^= 0xFF
|
||||
|
||||
_, got := kr.Active()
|
||||
if got[0] == original[0] {
|
||||
t.Error("KeyRing did not defensively copy on construction")
|
||||
}
|
||||
|
||||
// Mutate the returned slice — the ring should be unaffected.
|
||||
got[0] ^= 0xFF
|
||||
_, again := kr.Active()
|
||||
if again[0] == got[0] {
|
||||
t.Error("KeyRing did not defensively copy on Active()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingIDs(t *testing.T) {
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
if err := kr.Add("k2", newTestKey(t)); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
if err := kr.Add("k3", newTestKey(t)); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
|
||||
ids := kr.IDs()
|
||||
if len(ids) != 3 {
|
||||
t.Errorf("IDs length: got %d, want 3", len(ids))
|
||||
}
|
||||
|
||||
want := map[string]bool{"k1": true, "k2": true, "k3": true}
|
||||
for _, id := range ids {
|
||||
if !want[id] {
|
||||
t.Errorf("unexpected id: %q", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRingConcurrent(t *testing.T) {
|
||||
// Hammer the ring from many goroutines to flush out race issues.
|
||||
// Run with: go test -race
|
||||
kr, _ := NewKeyRing("k1", newTestKey(t))
|
||||
if err := kr.Add("k2", newTestKey(t)); err != nil {
|
||||
t.Fatalf("Add: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
kr.Active()
|
||||
kr.Get("k2")
|
||||
kr.IDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 50; j++ {
|
||||
_ = kr.SetActive("k1")
|
||||
_ = kr.SetActive("k2")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
k1, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
if len(k1) != keySize {
|
||||
t.Errorf("length: got %d, want %d", len(k1), keySize)
|
||||
}
|
||||
|
||||
k2, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
if bytes.Equal(k1, k2) {
|
||||
t.Error("two generated keys should not be equal")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue