From 5288df0b9e1237c97f4556f11fd42dc62d4a2c41 Mon Sep 17 00:00:00 2001 From: juancwu Date: Wed, 29 Apr 2026 01:40:35 +0000 Subject: [PATCH] add RevocationStore interface and in-memory implementation --- revocation.go | 86 +++++++++++++++++++++++++++ revocation_test.go | 145 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 revocation.go create mode 100644 revocation_test.go diff --git a/revocation.go b/revocation.go new file mode 100644 index 0000000..10c63c2 --- /dev/null +++ b/revocation.go @@ -0,0 +1,86 @@ +package ficha + +import ( + "context" + "sync" + "time" +) + +// RevocationStore is the interface ficha uses to track revoked tokens. +// Implementations are responsible for storage; ficha provides only the +// in-memory reference implementation below. +// +// Implementations must be safe for concurrent use. +type RevocationStore interface { + // IsRevoked reports whether tokenID has been revoked. + IsRevoked(ctx context.Context, tokenID string) (bool, error) + + // Revoke marks tokenID as revoked. The until parameter is the + // token's natural expiry — implementations may discard the entry + // after that time, since expired tokens fail validation anyway. + Revoke(ctx context.Context, tokenID string, until time.Time) error +} + +// MemoryRevocationStore is an in-memory RevocationStore suitable for +// tests, single-process deployments, or as a reference implementation. +// Not suitable for production multi-server use — entries are not shared. +type MemoryRevocationStore struct { + mu sync.RWMutex + revoked map[string]time.Time + now func() time.Time +} + +// NewMemoryRevocationStore returns an empty in-memory store. +func NewMemoryRevocationStore() *MemoryRevocationStore { + return &MemoryRevocationStore{ + revoked: make(map[string]time.Time), + now: time.Now, + } +} + +func (m *MemoryRevocationStore) IsRevoked(_ context.Context, tokenID string) (bool, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + until, ok := m.revoked[tokenID] + if !ok { + return false, nil + } + if !m.now().Before(until) { + // Past expiry — would fail validation regardless. Treat as not revoked. + return false, nil + } + return true, nil +} + +func (m *MemoryRevocationStore) Revoke(_ context.Context, tokenID string, until time.Time) error { + m.mu.Lock() + defer m.mu.Unlock() + m.revoked[tokenID] = until + return nil +} + +// Cleanup removes expired entries. Call periodically to bound memory use. +// Returns the number of entries removed. +func (m *MemoryRevocationStore) Cleanup() int { + m.mu.Lock() + defer m.mu.Unlock() + + now := m.now() + removed := 0 + for id, until := range m.revoked { + if !now.Before(until) { + delete(m.revoked, id) + removed++ + } + } + return removed +} + +// Len returns the current number of tracked entries (including any not +// yet cleaned up). Mainly useful for tests and metrics. +func (m *MemoryRevocationStore) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.revoked) +} diff --git a/revocation_test.go b/revocation_test.go new file mode 100644 index 0000000..b6bffeb --- /dev/null +++ b/revocation_test.go @@ -0,0 +1,145 @@ +package ficha + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestMemoryRevocationStoreEmpty(t *testing.T) { + s := NewMemoryRevocationStore() + revoked, err := s.IsRevoked(context.Background(), "anything") + if err != nil { + t.Fatalf("IsRevoked: %v", err) + } + if revoked { + t.Error("empty store should not report any token as revoked") + } + if s.Len() != 0 { + t.Errorf("Len: got %d, want 0", s.Len()) + } +} + +func TestMemoryRevocationStoreRevokeAndCheck(t *testing.T) { + s := NewMemoryRevocationStore() + ctx := context.Background() + until := time.Now().Add(1 * time.Hour) + + if err := s.Revoke(ctx, "tok_abc", until); err != nil { + t.Fatalf("Revoke: %v", err) + } + + revoked, err := s.IsRevoked(ctx, "tok_abc") + if err != nil { + t.Fatalf("IsRevoked: %v", err) + } + if !revoked { + t.Error("token should be revoked") + } + + // Different ID should not be flagged. + revoked, _ = s.IsRevoked(ctx, "tok_other") + if revoked { + t.Error("unrelated token should not be revoked") + } +} + +func TestMemoryRevocationStoreExpiredEntry(t *testing.T) { + s := NewMemoryRevocationStore() + // Inject a controllable clock. + clock := time.Unix(1_700_000_000, 0) + s.now = func() time.Time { return clock } + + ctx := context.Background() + until := time.Unix(1_700_000_500, 0) + if err := s.Revoke(ctx, "tok_abc", until); err != nil { + t.Fatalf("Revoke: %v", err) + } + + // Before expiry: revoked. + revoked, _ := s.IsRevoked(ctx, "tok_abc") + if !revoked { + t.Error("should be revoked before expiry") + } + + // Advance clock past expiry. + clock = time.Unix(1_700_000_600, 0) + + revoked, _ = s.IsRevoked(ctx, "tok_abc") + if revoked { + t.Error("expired entry should not be reported as revoked") + } +} + +func TestMemoryRevocationStoreCleanup(t *testing.T) { + s := NewMemoryRevocationStore() + clock := time.Unix(1_700_000_000, 0) + s.now = func() time.Time { return clock } + + ctx := context.Background() + _ = s.Revoke(ctx, "expired1", time.Unix(1_700_000_100, 0)) + _ = s.Revoke(ctx, "expired2", time.Unix(1_700_000_200, 0)) + _ = s.Revoke(ctx, "stillvalid", time.Unix(1_700_000_999, 0)) + + if s.Len() != 3 { + t.Errorf("Len before cleanup: got %d, want 3", s.Len()) + } + + // Move clock past the first two expiries but not the third. + clock = time.Unix(1_700_000_500, 0) + + removed := s.Cleanup() + if removed != 2 { + t.Errorf("Cleanup removed: got %d, want 2", removed) + } + if s.Len() != 1 { + t.Errorf("Len after cleanup: got %d, want 1", s.Len()) + } + + // The remaining one is still flagged. + revoked, _ := s.IsRevoked(ctx, "stillvalid") + if !revoked { + t.Error("non-expired entry should survive cleanup") + } +} + +func TestMemoryRevocationStoreReRevoke(t *testing.T) { + // Revoking the same ID twice should be idempotent. + s := NewMemoryRevocationStore() + ctx := context.Background() + until := time.Now().Add(1 * time.Hour) + + if err := s.Revoke(ctx, "tok", until); err != nil { + t.Fatalf("Revoke 1: %v", err) + } + if err := s.Revoke(ctx, "tok", until.Add(1*time.Hour)); err != nil { + t.Fatalf("Revoke 2: %v", err) + } + if s.Len() != 1 { + t.Errorf("Len: got %d, want 1", s.Len()) + } +} + +func TestMemoryRevocationStoreConcurrent(t *testing.T) { + s := NewMemoryRevocationStore() + ctx := context.Background() + until := time.Now().Add(1 * time.Hour) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + id := "tok_" + string(rune('a'+i%26)) + for j := 0; j < 100; j++ { + _ = s.Revoke(ctx, id, until) + _, _ = s.IsRevoked(ctx, id) + } + }(i) + } + wg.Wait() +} + +// Compile-time check: MemoryRevocationStore satisfies RevocationStore. +var _ RevocationStore = (*MemoryRevocationStore)(nil)