From 134393fbcad686b5fc2b5e922a3143a656cd26db Mon Sep 17 00:00:00 2001 From: juancwu Date: Sun, 26 Apr 2026 01:36:53 +0000 Subject: [PATCH] authkit initial --- README.md | 344 +++++++++- Taskfile.yml | 53 ++ authkit.go | 121 ++++ doc.go | 13 + errors.go | 18 + extractor.go | 64 ++ go.mod | 20 + go.sum | 36 + hasher/argon2id.go | 136 ++++ hasher/argon2id_test.go | 78 +++ jwt.go | 69 ++ jwt_test.go | 99 +++ memstore_test.go | 619 ++++++++++++++++++ middleware/authz.go | 71 ++ middleware/context.go | 39 ++ middleware/middleware.go | 138 ++++ models.go | 75 +++ principal.go | 63 ++ service_apikey.go | 92 +++ service_authz.go | 85 +++ service_jwt.go | 147 +++++ service_magic.go | 62 ++ service_reset.go | 69 ++ service_session.go | 154 +++++ service_test.go | 220 +++++++ service_user.go | 178 +++++ sqlstore/apikeys.go | 124 ++++ sqlstore/dialect.go | 118 ++++ sqlstore/dialect/postgres/errors.go | 33 + .../dialect/postgres/migrations/0001_init.sql | 99 +++ sqlstore/dialect/postgres/postgres.go | 272 ++++++++ sqlstore/migrate.go | 116 ++++ sqlstore/rbac.go | 301 +++++++++ sqlstore/scan.go | 86 +++ sqlstore/schema.go | 78 +++ sqlstore/sessions.go | 98 +++ sqlstore/sqlstore.go | 59 ++ sqlstore/sqlstore_test.go | 258 ++++++++ sqlstore/tokens.go | 92 +++ sqlstore/users.go | 186 ++++++ stores.go | 83 +++ tokens.go | 67 ++ tokens_test.go | 56 ++ 43 files changed, 5188 insertions(+), 1 deletion(-) create mode 100644 Taskfile.yml create mode 100644 authkit.go create mode 100644 doc.go create mode 100644 errors.go create mode 100644 extractor.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 hasher/argon2id.go create mode 100644 hasher/argon2id_test.go create mode 100644 jwt.go create mode 100644 jwt_test.go create mode 100644 memstore_test.go create mode 100644 middleware/authz.go create mode 100644 middleware/context.go create mode 100644 middleware/middleware.go create mode 100644 models.go create mode 100644 principal.go create mode 100644 service_apikey.go create mode 100644 service_authz.go create mode 100644 service_jwt.go create mode 100644 service_magic.go create mode 100644 service_reset.go create mode 100644 service_session.go create mode 100644 service_test.go create mode 100644 service_user.go create mode 100644 sqlstore/apikeys.go create mode 100644 sqlstore/dialect.go create mode 100644 sqlstore/dialect/postgres/errors.go create mode 100644 sqlstore/dialect/postgres/migrations/0001_init.sql create mode 100644 sqlstore/dialect/postgres/postgres.go create mode 100644 sqlstore/migrate.go create mode 100644 sqlstore/rbac.go create mode 100644 sqlstore/scan.go create mode 100644 sqlstore/schema.go create mode 100644 sqlstore/sessions.go create mode 100644 sqlstore/sqlstore.go create mode 100644 sqlstore/sqlstore_test.go create mode 100644 sqlstore/tokens.go create mode 100644 sqlstore/users.go create mode 100644 stores.go create mode 100644 tokens.go create mode 100644 tokens_test.go diff --git a/README.md b/README.md index 4658760..7051c83 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,345 @@ # authkit -Just a concoction of auth stuff in one place. \ No newline at end of file +A pragmatic authentication and authorization toolkit for Go web services. + +`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and +permissions, plus default `database/sql` Postgres implementations and +framework-neutral HTTP middleware. It supports both opaque server-side +sessions and JWT access tokens with rotating refresh tokens, hashes passwords +with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux) +or any `net/http` stack. + +## Install + +``` +go get git.juancwu.dev/juancwu/authkit +``` + +`authkit` itself depends only on `database/sql` and the Go standard library +plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring +your own driver: `pgx`, `lib/pq`, or anything else that registers a +`database/sql` driver. + +```go +import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" +``` + +PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and +`pgcrypto` so no extensions are required. + +## What's included + +**Authentication flows** +- Email + password registration and login (Argon2id PHC-encoded hashes) +- Opaque server-side sessions with sliding TTL bounded by an absolute cap +- JWT access tokens (HS256) with rotating refresh tokens and reuse detection +- Email verification, password reset, and magic-link passwordless login + +**Authorization** +- Roles and permissions with many-to-many wiring +- API keys with custom abilities for per-endpoint scoping +- A unified `Principal` type so middleware works the same regardless of which + authentication method ran + +**Storage** +- Interfaces for every store so callers can plug in their own backends +- Default Postgres implementation built on `*sql.DB` (`sqlstore` package) +- Override table names via `Schema` without forking — useful when authkit + lives alongside an existing application schema +- A `Dialect` abstraction so future MySQL / SQLite implementations slot in + without changes to store code +- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)` + helper that takes a session-scoped advisory lock + +**HTTP** +- `middleware.RequireSession`, `RequireJWT`, `RequireAPIKey`, `RequireAny` +- `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`, + `RequireAbility` +- `middleware.PrincipalFrom(ctx)` to read the authenticated principal in + handlers + +**Errors** +- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`, + `ErrTokenReused`, `ErrSessionInvalid`, `ErrAPIKeyInvalid`, + `ErrPermissionDenied`, ...) compatible with `errors.Is` +- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx) + for op tags + +## Out of scope (v1) + +MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching, +pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite +dialects (architecture supports them; only Postgres ships in v1), and +column-name overrides in `Schema` (table-name overrides only). + +## Quick start + +### 1. Open a database and run migrations + +```go +import ( + "database/sql" + + "git.juancwu.dev/juancwu/authkit/sqlstore" + pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" + + _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" +) + +db, err := sql.Open("pgx", os.Getenv("DATABASE_URL")) +if err != nil { /* ... */ } +defer db.Close() + +if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil { + log.Fatal(err) +} +``` + +`Migrate` is idempotent and safe to call from multiple processes — it takes +a session-scoped `pg_advisory_lock` to serialise rollouts. + +`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same +calls — the library only cares about `*sql.DB`. + +### 2. Wire the service + +```go +import ( + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" +) + +stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema()) +if err != nil { /* ... */ } + +auth := authkit.New(authkit.Deps{ + Users: stores.Users, + Sessions: stores.Sessions, + Tokens: stores.Tokens, + APIKeys: stores.APIKeys, + Roles: stores.Roles, + Permissions: stores.Permissions, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), +}, authkit.Config{ + JWTSecret: []byte(os.Getenv("JWT_SECRET")), + JWTIssuer: "myapp", + SessionCookieSecure: true, + SessionCookieHTTPOnly: true, +}) +``` + +`Config` zero values fall back to sensible defaults (24h idle / 30d absolute +session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h +password-reset, 15m magic-link). `JWTSecret` and the seven `Deps` fields are +required; `New` panics on a misconfiguration. + +### 3. Use the service + +```go +// Registration + password login +u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") +u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2") + +// Opaque session (cookie-friendly) +plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP) +http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt)) + +// JWT + rotating refresh +access, refresh, err := auth.IssueJWT(ctx, u.ID) +access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed + +// API key with abilities +plaintext, key, err := auth.IssueAPIKey(ctx, u.ID, "ci", + []string{"billing:read", "users:list"}, nil) + +// Email verification + password reset + magic link +tok, err := auth.RequestEmailVerification(ctx, u.ID) +_, err = auth.ConfirmEmail(ctx, tok) + +tok, err = auth.RequestPasswordReset(ctx, "alice@example.com") +err = auth.ConfirmPasswordReset(ctx, tok, "new-password") + +tok, err = auth.RequestMagicLink(ctx, "alice@example.com") +u, err = auth.ConsumeMagicLink(ctx, tok) +``` + +The plaintext returned by `IssueSession`, `IssueJWT`, `IssueAPIKey`, and the +token-minting flows is **show-once** — only its SHA-256 hash is stored. Show +it to the user immediately; you cannot recover it later. + +### 4. Wire middleware + +`authkit/middleware` returns standard `func(http.Handler) http.Handler` +values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any +`net/http` mux that accepts the same shape. + +```go +import ( + authkitmw "git.juancwu.dev/juancwu/authkit/middleware" + "git.juancwu.dev/juancwu/lightmux" +) + +mux := lightmux.New() + +cookieAuth := authkitmw.RequireSession(authkitmw.Options{ + Auth: auth, + Extractor: authkit.ChainExtractors( + authkit.CookieExtractor("authkit_session"), + authkit.BearerExtractor(), + ), +}) + +me := mux.Group("/me", cookieAuth) +me.Get("", func(w http.ResponseWriter, r *http.Request) { + p := authkitmw.MustPrincipal(r) + json.NewEncoder(w).Encode(p) +}) + +// RBAC: stack authz on top of any auth method +admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin")) + +// API-key-only route with a per-endpoint ability check +api := mux.Group("/api/v1", authkitmw.RequireAPIKey(authkitmw.Options{Auth: auth})) +api.Get("/billing", billingHandler, authkitmw.RequireAbility("billing:read")) +``` + +`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or +chain extractors) when reading session cookies. `Options.OnUnauth` and +`Options.OnForbidden` default to a JSON `401` / `403`; override them to match +your error envelope. + +## Custom table names + +Pass a non-default `Schema` to use your own table names. Identifiers must +match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and +`Migrate()` time, so SQL injection through the schema is impossible. + +```go +schema := sqlstore.DefaultSchema() +schema.Tables.Users = "accounts" +schema.Tables.APIKeys = "api_credentials" + +stores, _ := sqlstore.New(db, pgdialect.New(), schema) +``` + +The bundled migration files use the default `authkit_*` names. If you +override, you're responsible for matching DDL (most consumers with custom +naming already have their own DDL pipeline). + +Column-name overrides are not exposed in v1 — the column set is fixed for +each table. Adding column overrides later is purely additive. + +## How things work + +### Secret token format + +Sessions, refresh tokens, API keys, email-verify tokens, password-reset tokens, +and magic-link tokens all share one format: + +``` +plaintext = "_" + base64url(32 random bytes, no padding) +lookup = sha256(plaintext) +``` + +Plaintext is returned to the caller exactly once and never persisted; the +SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or +`Config.Random` for tests). + +### JWT revocation + +Access tokens carry `sv` (`session_version`) in their claims. When you call +`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version` +column increments and every outstanding access token fails the next +`AuthenticateJWT`. This is the only way to invalidate a JWT before its +`exp`. + +### Refresh token rotation + +Each `RefreshJWT` consumes the presented refresh token and issues a new one +on the same chain (the `chain_id` column on `authkit_tokens`). If a +*consumed* refresh token is ever presented again — a strong replay signal — +the entire chain is deleted via `TokenStore.DeleteByChain` and the call +returns `ErrTokenReused`. + +### Sliding session TTL + +Each authenticated request via `AuthenticateSession` slides `expires_at` to +`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`. +Long-lived idle sessions still hit the absolute boundary. + +### Schema and migrations + +`sqlstore.Migrate` applies every embedded `.sql` file under +`sqlstore/dialect/postgres/migrations/` whose version (filename without +`.sql`) is not in `authkit_schema_migrations`. The dialect's +`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises +concurrent migrators. Each migration owns its own transaction so future +migrations can use statements like `CREATE INDEX CONCURRENTLY`. + +Every default table is prefixed `authkit_` so the schema can live alongside +your application's own tables in a shared database. + +### Driver and dialect architecture + +The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour +lives behind a small `Dialect` interface: + +```go +type Dialect interface { + Name() string + BuildQueries(s Schema) Queries + Bootstrap(ctx context.Context, db *sql.DB) error + AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) + Migrations() fs.FS + IsUniqueViolation(err error) bool + Placeholder(n int) string + PlaceholderList(start, count int) string +} +``` + +v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new +implementation; no changes to store code. + +## Configuration reference + +| Field | Default | Notes | +|---|---|---| +| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request | +| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this | +| `SessionCookieName` | `authkit_session` | | +| `SessionCookieSameSite` | `Lax` | | +| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production | +| `JWTSecret` | — (required) | HS256 key | +| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them | +| `AccessTokenTTL` | 15m | | +| `RefreshTokenTTL` | 30d | | +| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | | +| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests | +| `Random` | `crypto/rand.Reader` | Override for deterministic tests | +| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit | + +## Implementing your own store + +Every store is a small interface with explicit semantics — see `stores.go`. +The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token +consumed and return it in a single statement (`UPDATE ... RETURNING` on +Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed. + +## Testing + +``` +go test ./... # unit tests, no DB +AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests +``` + +Unit tests cover token mint/parse, Argon2id encode/verify (including +`needsRehash` on parameter change), JWT issue/parse (incl. expired, +`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email +verification, password reset cascading session invalidation, magic-link +self-verification, API keys with abilities, and RBAC role-permission +resolution. Integration tests run the full `sqlstore` contract against a +real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set. + +## License + +MIT. See `LICENSE`. diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..b659f73 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,53 @@ +version: '3' + +tasks: + default: + desc: Run vet and unit tests + cmds: + - task: check + + install:tools: + desc: Install development tools (tparse for prettier test output) + cmds: + - go install github.com/mfridman/tparse@latest + + build: + desc: Build all packages + cmds: + - go build ./... + + test: + desc: Run unit tests with prettier output via tparse + cmds: + - set -o pipefail && go test ./... -json -cover | tparse -all + + test:race: + desc: Run unit tests under the race detector + cmds: + - set -o pipefail && go test ./... -race -json | tparse -all + + test:integration: + desc: Run sqlstore integration tests against a real Postgres (requires AUTHKIT_TEST_DATABASE_URL) + cmds: + - set -o pipefail && go test ./sqlstore/... -run Integration -count=1 -json | tparse -all + + vet: + desc: Run go vet + cmds: + - go vet ./... + + fmt: + desc: Format Go source files + cmds: + - gofmt -w . + + tidy: + desc: Tidy go.mod + cmds: + - go mod tidy + + check: + desc: Run vet and unit tests + cmds: + - task: vet + - task: test diff --git a/authkit.go b/authkit.go new file mode 100644 index 0000000..f23c2e7 --- /dev/null +++ b/authkit.go @@ -0,0 +1,121 @@ +package authkit + +import ( + "context" + "crypto/rand" + "io" + "net/http" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +// Deps bundles every backing store and the password hasher the Auth service +// depends on. All fields are required; New panics on a nil dep so misuse is +// caught at boot rather than under load. +type Deps struct { + Users UserStore + Sessions SessionStore + Tokens TokenStore + APIKeys APIKeyStore + Roles RoleStore + Permissions PermissionStore + Hasher Hasher +} + +// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material, +// and optional hooks. Any zero-valued duration is replaced with a sane +// default in New; required fields (notably JWTSecret) cause New to panic. +type Config struct { + // Session (opaque) cookies + DB-backed lifetime + SessionIdleTTL time.Duration + SessionAbsoluteTTL time.Duration + SessionCookieName string + SessionCookieDomain string + SessionCookiePath string + SessionCookieSecure bool + SessionCookieHTTPOnly bool + SessionCookieSameSite http.SameSite + + // JWT (HS256) + JWTSecret []byte + JWTIssuer string + JWTAudience string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + + // Single-use tokens + EmailVerifyTTL time.Duration + PasswordResetTTL time.Duration + MagicLinkTTL time.Duration + + // Hooks (optional) + Clock func() time.Time + Random io.Reader + LoginHook func(ctx context.Context, email string, success bool) error +} + +// Auth is the high-level service that composes the stores and hasher into the +// flows callers use: registration, login, sessions, JWTs, magic links, API +// keys, and authz checks. It is safe for concurrent use; method receivers +// never mutate Auth state after construction. +type Auth struct { + deps Deps + cfg Config +} + +// New validates Deps and Config, fills in defaults, and returns a ready +// service. It panics on missing deps or missing JWT secret rather than +// returning an error — these are programmer errors, not runtime ones. +func New(deps Deps, cfg Config) *Auth { + if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil || + deps.APIKeys == nil || deps.Roles == nil || deps.Permissions == nil || + deps.Hasher == nil { + panic(errx.New("authkit.New", "all Deps fields are required")) + } + if len(cfg.JWTSecret) == 0 { + panic(errx.New("authkit.New", "Config.JWTSecret is required")) + } + + if cfg.SessionIdleTTL == 0 { + cfg.SessionIdleTTL = 24 * time.Hour + } + if cfg.SessionAbsoluteTTL == 0 { + cfg.SessionAbsoluteTTL = 30 * 24 * time.Hour + } + if cfg.SessionCookieName == "" { + cfg.SessionCookieName = "authkit_session" + } + if cfg.SessionCookiePath == "" { + cfg.SessionCookiePath = "/" + } + if cfg.SessionCookieSameSite == 0 { + cfg.SessionCookieSameSite = http.SameSiteLaxMode + } + if cfg.AccessTokenTTL == 0 { + cfg.AccessTokenTTL = 15 * time.Minute + } + if cfg.RefreshTokenTTL == 0 { + cfg.RefreshTokenTTL = 30 * 24 * time.Hour + } + if cfg.EmailVerifyTTL == 0 { + cfg.EmailVerifyTTL = 48 * time.Hour + } + if cfg.PasswordResetTTL == 0 { + cfg.PasswordResetTTL = time.Hour + } + if cfg.MagicLinkTTL == 0 { + cfg.MagicLinkTTL = 15 * time.Minute + } + if cfg.Clock == nil { + cfg.Clock = func() time.Time { return time.Now().UTC() } + } + if cfg.Random == nil { + cfg.Random = rand.Reader + } + + return &Auth{deps: deps, cfg: cfg} +} + +// now returns the configured wall clock, defaulting to time.Now in UTC. +func (a *Auth) now() time.Time { return a.cfg.Clock() } diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3c49a23 --- /dev/null +++ b/doc.go @@ -0,0 +1,13 @@ +// Package authkit is an authentication and authorization toolkit for Go web +// services. It defines storage interfaces (UserStore, SessionStore, TokenStore, +// APIKeyStore, RoleStore, PermissionStore) and a high-level Auth service that +// composes them to support registration, password login, opaque server-side +// sessions, JWT access plus rotating refresh tokens, email verification, +// password resets, magic-link passwordless login, role-based access control, +// and API keys with custom abilities. +// +// Default Postgres implementations of every store live in the pgstore +// subpackage. Argon2id password hashing lives in hasher. Framework-neutral +// HTTP middleware (compatible with lightmux and any net/http stack) lives in +// middleware. +package authkit diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..a5acf10 --- /dev/null +++ b/errors.go @@ -0,0 +1,18 @@ +package authkit + +import "errors" + +var ( + ErrEmailTaken = errors.New("authkit: email already registered") + ErrUserNotFound = errors.New("authkit: user not found") + ErrInvalidCredentials = errors.New("authkit: invalid credentials") + ErrEmailNotVerified = errors.New("authkit: email not verified") + ErrTokenInvalid = errors.New("authkit: invalid or expired token") + ErrTokenReused = errors.New("authkit: token reuse detected") + ErrSessionInvalid = errors.New("authkit: invalid or expired session") + ErrAPIKeyInvalid = errors.New("authkit: invalid or expired api key") + ErrPermissionDenied = errors.New("authkit: permission denied") + ErrRoleNotFound = errors.New("authkit: role not found") + ErrPermissionNotFound = errors.New("authkit: permission not found") + ErrConfigInvalid = errors.New("authkit: invalid configuration") +) diff --git a/extractor.go b/extractor.go new file mode 100644 index 0000000..e1e0d73 --- /dev/null +++ b/extractor.go @@ -0,0 +1,64 @@ +package authkit + +import ( + "net/http" + "strings" +) + +// Extractor pulls a credential string out of an HTTP request. It returns +// (value, true) when a value was found, otherwise ("", false). +type Extractor func(r *http.Request) (string, bool) + +// BearerExtractor reads the value following "Bearer " in the Authorization +// header. Comparison is case-insensitive on the scheme. +func BearerExtractor() Extractor { + return func(r *http.Request) (string, bool) { + h := r.Header.Get("Authorization") + if h == "" { + return "", false + } + const prefix = "bearer " + if len(h) <= len(prefix) || !strings.EqualFold(h[:len(prefix)], prefix) { + return "", false + } + v := strings.TrimSpace(h[len(prefix):]) + if v == "" { + return "", false + } + return v, true + } +} + +// CookieExtractor reads the named cookie's value. +func CookieExtractor(name string) Extractor { + return func(r *http.Request) (string, bool) { + c, err := r.Cookie(name) + if err != nil || c.Value == "" { + return "", false + } + return c.Value, true + } +} + +// HeaderExtractor reads a custom header verbatim. +func HeaderExtractor(name string) Extractor { + return func(r *http.Request) (string, bool) { + v := strings.TrimSpace(r.Header.Get(name)) + if v == "" { + return "", false + } + return v, true + } +} + +// ChainExtractors tries each extractor in order, returning the first hit. +func ChainExtractors(es ...Extractor) Extractor { + return func(r *http.Request) (string, bool) { + for _, e := range es { + if v, ok := e(r); ok { + return v, true + } + } + return "", false + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ecd7a0c --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module git.juancwu.dev/juancwu/authkit + +go 1.26.2 + +require ( + git.juancwu.dev/juancwu/errx v0.1.0 + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.9.2 + golang.org/x/crypto v0.50.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..dc69ee1 --- /dev/null +++ b/go.sum @@ -0,0 +1,36 @@ +git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w= +git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hasher/argon2id.go b/hasher/argon2id.go new file mode 100644 index 0000000..5d5c57b --- /dev/null +++ b/hasher/argon2id.go @@ -0,0 +1,136 @@ +// Package hasher provides authkit.Hasher implementations. The default +// implementation, Argon2id, encodes hashes in the standard PHC string format +// (https://github.com/P-H-C/phc-string-format) so callers can introspect +// parameters and migrate. +package hasher + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "io" + "strings" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "golang.org/x/crypto/argon2" +) + +// Argon2idParams configures the Argon2id KDF. +type Argon2idParams struct { + Memory uint32 // KB; OWASP 2024 baseline: 19 MiB minimum, 64 MiB recommended + Iterations uint32 // time cost + Parallelism uint8 // lanes + SaltLen uint32 // bytes + KeyLen uint32 // bytes +} + +// DefaultArgon2idParams returns sensible defaults: 64 MiB memory, 3 +// iterations, 2 lanes, 16-byte salt, 32-byte key. Tune up Memory/Iterations +// over time and rely on Verify's needsRehash signal to migrate stored hashes. +func DefaultArgon2idParams() Argon2idParams { + return Argon2idParams{ + Memory: 64 * 1024, + Iterations: 3, + Parallelism: 2, + SaltLen: 16, + KeyLen: 32, + } +} + +type argon2idHasher struct { + params Argon2idParams + rng io.Reader +} + +// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the +// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand. +func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher { + if params == (Argon2idParams{}) { + params = DefaultArgon2idParams() + } + if rng == nil { + rng = rand.Reader + } + return &argon2idHasher{params: params, rng: rng} +} + +func (h *argon2idHasher) Hash(password string) (string, error) { + const op = "authkit.hasher.Argon2id.Hash" + if password == "" { + return "", errx.New(op, "password is empty") + } + salt := make([]byte, h.params.SaltLen) + if _, err := io.ReadFull(h.rng, salt); err != nil { + return "", errx.Wrap(op, err) + } + key := argon2.IDKey([]byte(password), salt, + h.params.Iterations, h.params.Memory, h.params.Parallelism, h.params.KeyLen) + return encodePHC(h.params, salt, key), nil +} + +func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) { + const op = "authkit.hasher.Argon2id.Verify" + got, salt, key, err := decodePHC(encoded) + if err != nil { + return false, false, errx.Wrap(op, err) + } + want := argon2.IDKey([]byte(password), salt, + got.Iterations, got.Memory, got.Parallelism, uint32(len(key))) + if subtle.ConstantTimeCompare(want, key) != 1 { + return false, false, nil + } + needsRehash := got.Memory != h.params.Memory || + got.Iterations != h.params.Iterations || + got.Parallelism != h.params.Parallelism || + uint32(len(salt)) != h.params.SaltLen || + uint32(len(key)) != h.params.KeyLen + return true, needsRehash, nil +} + +// PHC string format: +// +// $argon2id$v=19$m=,t=,p=$$ +// +// base64 here is the unpadded standard alphabet, per the spec. +func encodePHC(p Argon2idParams, salt, key []byte) string { + enc := base64.RawStdEncoding + return fmt.Sprintf( + "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, p.Memory, p.Iterations, p.Parallelism, + enc.EncodeToString(salt), enc.EncodeToString(key), + ) +} + +func decodePHC(s string) (Argon2idParams, []byte, []byte, error) { + parts := strings.Split(s, "$") + // Expect: ["", "argon2id", "v=19", "m=...,t=...,p=...", "", ""] + if len(parts) != 6 || parts[0] != "" || parts[1] != "argon2id" { + return Argon2idParams{}, nil, nil, errors.New("hasher: not an argon2id phc string") + } + var version int + if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad version segment: %w", err) + } + if version != argon2.Version { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: unsupported argon2 version %d", version) + } + var p Argon2idParams + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism); err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad params segment: %w", err) + } + enc := base64.RawStdEncoding + salt, err := enc.DecodeString(parts[4]) + if err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad salt: %w", err) + } + key, err := enc.DecodeString(parts[5]) + if err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad key: %w", err) + } + p.SaltLen = uint32(len(salt)) + p.KeyLen = uint32(len(key)) + return p, salt, key, nil +} diff --git a/hasher/argon2id_test.go b/hasher/argon2id_test.go new file mode 100644 index 0000000..008c453 --- /dev/null +++ b/hasher/argon2id_test.go @@ -0,0 +1,78 @@ +package hasher + +import ( + "strings" + "testing" +) + +func TestArgon2idHashVerifyRoundtrip(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + encoded, err := h.Hash("hunter2hunter2") + if err != nil { + t.Fatalf("Hash: %v", err) + } + if !strings.HasPrefix(encoded, "$argon2id$") { + t.Fatalf("encoded hash not in PHC form: %s", encoded) + } + ok, needsRehash, err := h.Verify("hunter2hunter2", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if !ok { + t.Fatalf("Verify rejected the original password") + } + if needsRehash { + t.Fatalf("Verify with default params should not signal rehash") + } +} + +func TestArgon2idVerifyWrongPassword(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + encoded, err := h.Hash("correct horse battery staple") + if err != nil { + t.Fatalf("Hash: %v", err) + } + ok, _, err := h.Verify("nope", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if ok { + t.Fatalf("Verify should reject wrong password") + } +} + +func TestArgon2idNeedsRehashOnParamChange(t *testing.T) { + // Hash with light params... + light := Argon2idParams{Memory: 8 * 1024, Iterations: 1, Parallelism: 1, SaltLen: 16, KeyLen: 32} + encoded, err := NewArgon2id(light, nil).Hash("hello world") + if err != nil { + t.Fatalf("Hash: %v", err) + } + // ...verify with stronger params should still match but flag rehash. + heavier := DefaultArgon2idParams() + ok, needsRehash, err := NewArgon2id(heavier, nil).Verify("hello world", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if !ok { + t.Fatalf("Verify rejected legitimate password across params") + } + if !needsRehash { + t.Fatalf("Verify should flag rehash when stored params differ from current") + } +} + +func TestArgon2idRejectsMalformed(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + cases := []string{ + "", + "not-a-phc", + "$argon2i$v=19$m=64,t=1,p=1$abc$def", + "$argon2id$v=99$m=64,t=1,p=1$YWJj$ZGVm", + } + for _, c := range cases { + if _, _, err := h.Verify("x", c); err == nil { + t.Fatalf("Verify should reject malformed encoding: %q", c) + } + } +} diff --git a/jwt.go b/jwt.go new file mode 100644 index 0000000..c9fb7b4 --- /dev/null +++ b/jwt.go @@ -0,0 +1,69 @@ +package authkit + +import ( + "git.juancwu.dev/juancwu/errx" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// accessClaims is the JWT shape issued by IssueJWT. The session_version +// field carries the User.SessionVersion at issue time so AuthenticateJWT +// can detect global revocations (logout-everywhere, password change). +type accessClaims struct { + jwt.RegisteredClaims + SessionVersion int `json:"sv"` + Method string `json:"m"` +} + +func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, error) { + const op = "authkit.signAccessToken" + now := a.now() + claims := accessClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID.String(), + Issuer: a.cfg.JWTIssuer, + Audience: jwt.ClaimStrings{a.cfg.JWTAudience}, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(a.cfg.AccessTokenTTL)), + ID: uuid.NewString(), + }, + SessionVersion: sessionVersion, + Method: string(AuthMethodJWT), + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString(a.cfg.JWTSecret) + if err != nil { + return "", errx.Wrap(op, err) + } + return signed, nil +} + +// parseAccessToken validates the signature and returns the parsed claims. +func (a *Auth) parseAccessToken(token string) (*accessClaims, error) { + const op = "authkit.parseAccessToken" + opts := []jwt.ParserOption{ + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithTimeFunc(a.cfg.Clock), + } + if a.cfg.JWTIssuer != "" { + opts = append(opts, jwt.WithIssuer(a.cfg.JWTIssuer)) + } + if a.cfg.JWTAudience != "" { + opts = append(opts, jwt.WithAudience(a.cfg.JWTAudience)) + } + parser := jwt.NewParser(opts...) + + parsed, err := parser.ParseWithClaims(token, &accessClaims{}, func(t *jwt.Token) (any, error) { + return a.cfg.JWTSecret, nil + }) + if err != nil { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + claims, ok := parsed.Claims.(*accessClaims) + if !ok || !parsed.Valid { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + return claims, nil +} diff --git a/jwt_test.go b/jwt_test.go new file mode 100644 index 0000000..5de34a2 --- /dev/null +++ b/jwt_test.go @@ -0,0 +1,99 @@ +package authkit + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestJWTIssueAndAuthenticate(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "alice@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, refresh, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if access == "" || refresh == "" { + t.Fatalf("empty tokens") + } + p, err := a.AuthenticateJWT(context.Background(), access) + if err != nil { + t.Fatalf("AuthenticateJWT: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("principal user id mismatch") + } + if p.Method != AuthMethodJWT { + t.Fatalf("principal method = %s, want jwt", p.Method) + } +} + +func TestJWTSessionVersionMismatchRejected(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "bob@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, _, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil { + t.Fatalf("RevokeAllUserSessions: %v", err) + } + if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err) + } +} + +func TestJWTRefreshRotationAndReuseDetection(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "carol@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + _, refresh1, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + _, refresh2, err := a.RefreshJWT(context.Background(), refresh1) + if err != nil { + t.Fatalf("first RefreshJWT: %v", err) + } + if refresh1 == refresh2 { + t.Fatalf("refresh token did not rotate") + } + + // Replaying refresh1 must surface ErrTokenReused and revoke the chain. + if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) { + t.Fatalf("expected ErrTokenReused on replay, got %v", err) + } + // After chain revocation, even refresh2 (the legitimate next one) must + // be rejected. + if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err) + } +} + +func TestJWTExpiredTokenRejected(t *testing.T) { + a := newTestAuth(t) + now := time.Now().UTC() + a.cfg.Clock = func() time.Time { return now } + u, err := a.Register(context.Background(), "dan@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, _, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + // Advance clock past TTL. + a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) } + if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err) + } +} diff --git a/memstore_test.go b/memstore_test.go new file mode 100644 index 0000000..4b16216 --- /dev/null +++ b/memstore_test.go @@ -0,0 +1,619 @@ +package authkit + +// In-memory store fakes used by service-level tests. Kept in a _test.go file +// so they don't ship in the public API. Each fake is intentionally minimal — +// it satisfies the interface and supports the flows the tests exercise; not +// every method is wired up (a few panic to flag accidental use). + +import ( + "bytes" + "context" + "errors" + "slices" + "sync" + "time" + + "github.com/google/uuid" +) + +type memUserStore struct { + mu sync.Mutex + byID map[uuid.UUID]*User + byEml map[string]uuid.UUID +} + +func newMemUserStore() *memUserStore { + return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}} +} + +func (s *memUserStore) CreateUser(_ context.Context, u *User) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.byEml[u.EmailNormalized]; exists { + return ErrEmailTaken + } + if u.ID == uuid.Nil { + u.ID = uuid.New() + } + cp := *u + s.byID[u.ID] = &cp + s.byEml[u.EmailNormalized] = u.ID + return nil +} + +func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return nil, ErrUserNotFound + } + cp := *u + return &cp, nil +} + +func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.byEml[ne] + if !ok { + return nil, ErrUserNotFound + } + u := *s.byID[id] + return &u, nil +} + +func (s *memUserStore) UpdateUser(_ context.Context, u *User) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.byID[u.ID]; !ok { + return ErrUserNotFound + } + cp := *u + s.byID[u.ID] = &cp + return nil +} +func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + delete(s.byID, id) + delete(s.byEml, u.EmailNormalized) + return nil +} + +func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.PasswordHash = h + return nil +} +func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.EmailVerifiedAt = &at + return nil +} +func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return 0, ErrUserNotFound + } + u.SessionVersion++ + return u.SessionVersion, nil +} +func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return 0, ErrUserNotFound + } + u.FailedLogins++ + return u.FailedLogins, nil +} +func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.FailedLogins = 0 + return nil +} + +type memSessionStore struct { + mu sync.Mutex + m map[string]*Session +} + +func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} } +func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *ses + s.m[string(ses.IDHash)] = &cp + return nil +} +func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[string(h)] + if !ok { + return nil, ErrSessionInvalid + } + cp := *v + return &cp, nil +} +func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[string(h)] + if !ok { + return ErrSessionInvalid + } + v.LastSeenAt = lastSeen + v.ExpiresAt = exp + return nil +} +func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, string(h)) + return nil +} +func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.m { + if v.UserID == uid { + delete(s.m, k) + } + } + return nil +} +func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, v := range s.m { + if !v.ExpiresAt.After(now) { + delete(s.m, k) + n++ + } + } + return n, nil +} + +type memTokenStore struct { + mu sync.Mutex + // Keyed by kind+hex(hash) so identical hashes across kinds don't collide. + m map[string]*Token +} + +func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} } +func tokKey(kind TokenKind, h []byte) string { + return string(kind) + ":" + string(h) +} +func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *t + s.m[tokKey(t.Kind, t.Hash)] = &cp + return nil +} +func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.m[tokKey(kind, h)] + if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) { + return nil, ErrTokenInvalid + } + t.ConsumedAt = &now + cp := *t + return &cp, nil +} +func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.m[tokKey(kind, h)] + if !ok { + return nil, ErrTokenInvalid + } + cp := *t + return &cp, nil +} +func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, t := range s.m { + if t.ChainID != nil && *t.ChainID == chainID { + delete(s.m, k) + n++ + } + } + return n, nil +} +func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, t := range s.m { + if !t.ExpiresAt.After(now) { + delete(s.m, k) + n++ + } + } + return n, nil +} + +type memAPIKeyStore struct { + mu sync.Mutex + m map[string]*APIKey +} + +func newMemAPIKeyStore() *memAPIKeyStore { return &memAPIKeyStore{m: map[string]*APIKey{}} } +func (s *memAPIKeyStore) CreateAPIKey(_ context.Context, k *APIKey) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *k + cp.Abilities = append([]string(nil), k.Abilities...) + s.m[string(k.IDHash)] = &cp + return nil +} +func (s *memAPIKeyStore) GetAPIKey(_ context.Context, h []byte) (*APIKey, error) { + s.mu.Lock() + defer s.mu.Unlock() + k, ok := s.m[string(h)] + if !ok { + return nil, ErrAPIKeyInvalid + } + cp := *k + cp.Abilities = append([]string(nil), k.Abilities...) + return &cp, nil +} +func (s *memAPIKeyStore) ListAPIKeysByOwner(_ context.Context, owner uuid.UUID) ([]*APIKey, error) { + s.mu.Lock() + defer s.mu.Unlock() + var out []*APIKey + for _, k := range s.m { + if k.OwnerID == owner { + cp := *k + out = append(out, &cp) + } + } + return out, nil +} +func (s *memAPIKeyStore) TouchAPIKey(_ context.Context, h []byte, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + if k, ok := s.m[string(h)]; ok { + k.LastUsedAt = &at + } + return nil +} +func (s *memAPIKeyStore) RevokeAPIKey(_ context.Context, h []byte, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + k, ok := s.m[string(h)] + if !ok { + return ErrAPIKeyInvalid + } + if k.RevokedAt != nil { + return ErrAPIKeyInvalid + } + k.RevokedAt = &at + return nil +} +func (s *memAPIKeyStore) RevokeAPIKeysByOwner(_ context.Context, owner uuid.UUID, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, k := range s.m { + if k.OwnerID == owner && k.RevokedAt == nil { + k.RevokedAt = &at + } + } + return nil +} + +type memRoleStore struct { + mu sync.Mutex + roles map[uuid.UUID]*Role + rolesByNm map[string]uuid.UUID + userRoles map[uuid.UUID]map[uuid.UUID]struct{} +} + +func newMemRoleStore() *memRoleStore { + return &memRoleStore{ + roles: map[uuid.UUID]*Role{}, + rolesByNm: map[string]uuid.UUID{}, + userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{}, + } +} +func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, dup := s.rolesByNm[r.Name]; dup { + return errors.New("role exists") + } + if r.ID == uuid.Nil { + r.ID = uuid.New() + } + cp := *r + s.roles[r.ID] = &cp + s.rolesByNm[r.Name] = r.ID + return nil +} +func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.roles[id] + if !ok { + return nil, ErrRoleNotFound + } + cp := *r + return &cp, nil +} +func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.rolesByNm[n] + if !ok { + return nil, ErrRoleNotFound + } + r := *s.roles[id] + return &r, nil +} +func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*Role, 0, len(s.roles)) + for _, r := range s.roles { + cp := *r + out = append(out, &cp) + } + return out, nil +} +func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.roles[id] + if !ok { + return ErrRoleNotFound + } + delete(s.roles, id) + delete(s.rolesByNm, r.Name) + for _, m := range s.userRoles { + delete(m, id) + } + return nil +} +func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + m, ok := s.userRoles[uid] + if !ok { + m = map[uuid.UUID]struct{}{} + s.userRoles[uid] = m + } + m[rid] = struct{}{} + return nil +} +func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + if m, ok := s.userRoles[uid]; ok { + delete(m, rid) + } + return nil +} +func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := []*Role{} + for rid := range s.userRoles[uid] { + if r, ok := s.roles[rid]; ok { + cp := *r + out = append(out, &cp) + } + } + return out, nil +} +func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for rid := range s.userRoles[uid] { + r := s.roles[rid] + if slices.Contains(names, r.Name) { + return true, nil + } + } + return false, nil +} + +type memPermStore struct { + mu sync.Mutex + perms map[uuid.UUID]*Permission + permsByNm map[string]uuid.UUID + rolePerms map[uuid.UUID]map[uuid.UUID]struct{} + roles *memRoleStore +} + +func newMemPermStore(rs *memRoleStore) *memPermStore { + return &memPermStore{ + perms: map[uuid.UUID]*Permission{}, + permsByNm: map[string]uuid.UUID{}, + rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{}, + roles: rs, + } +} +func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, dup := s.permsByNm[p.Name]; dup { + return errors.New("perm exists") + } + if p.ID == uuid.Nil { + p.ID = uuid.New() + } + cp := *p + s.perms[p.ID] = &cp + s.permsByNm[p.Name] = p.ID + return nil +} +func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + p, ok := s.perms[id] + if !ok { + return nil, ErrPermissionNotFound + } + cp := *p + return &cp, nil +} +func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.permsByNm[n] + if !ok { + return nil, ErrPermissionNotFound + } + p := *s.perms[id] + return &p, nil +} +func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*Permission, 0, len(s.perms)) + for _, p := range s.perms { + cp := *p + out = append(out, &cp) + } + return out, nil +} +func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + p, ok := s.perms[id] + if !ok { + return ErrPermissionNotFound + } + delete(s.perms, id) + delete(s.permsByNm, p.Name) + for _, m := range s.rolePerms { + delete(m, id) + } + return nil +} +func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + m, ok := s.rolePerms[rid] + if !ok { + m = map[uuid.UUID]struct{}{} + s.rolePerms[rid] = m + } + m[pid] = struct{}{} + return nil +} +func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + if m, ok := s.rolePerms[rid]; ok { + delete(m, pid) + } + return nil +} +func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := []*Permission{} + for pid := range s.rolePerms[rid] { + if p, ok := s.perms[pid]; ok { + cp := *p + out = append(out, &cp) + } + } + return out, nil +} +func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) { + s.roles.mu.Lock() + roleIDs := make([]uuid.UUID, 0) + for rid := range s.roles.userRoles[uid] { + roleIDs = append(roleIDs, rid) + } + s.roles.mu.Unlock() + + s.mu.Lock() + defer s.mu.Unlock() + seen := map[uuid.UUID]struct{}{} + out := []*Permission{} + for _, rid := range roleIDs { + for pid := range s.rolePerms[rid] { + if _, dup := seen[pid]; dup { + continue + } + seen[pid] = struct{}{} + if p, ok := s.perms[pid]; ok { + cp := *p + out = append(out, &cp) + } + } + } + return out, nil +} + +// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in +// the hasher package itself; we just need *something* callable here. +type stubHasher struct{} + +func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil } +func (stubHasher) Verify(p, encoded string) (bool, bool, error) { + want := "stub:" + p + if !bytes.Equal([]byte(want), []byte(encoded)) { + return false, false, nil + } + return true, false, nil +} + +// newTestAuth wires the fakes into Auth with deterministic config. +func newTestAuth(t interface{ Helper() }) *Auth { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + roles := newMemRoleStore() + return New(Deps{ + Users: newMemUserStore(), + Sessions: newMemSessionStore(), + Tokens: newMemTokenStore(), + APIKeys: newMemAPIKeyStore(), + Roles: roles, + Permissions: newMemPermStore(roles), + Hasher: stubHasher{}, + }, Config{ + JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"), + JWTIssuer: "authkit-test", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: 1 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + }) +} diff --git a/middleware/authz.go b/middleware/authz.go new file mode 100644 index 0000000..79d7297 --- /dev/null +++ b/middleware/authz.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// authzGuard wraps the common pattern of "look up the Principal, run a +// predicate, succeed or 403". onForbidden defaults to JSON 403. +func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler { + if onForbidden == nil { + onForbidden = defaultJSONError(http.StatusForbidden) + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p, ok := PrincipalFrom(r.Context()) + if !ok { + // No auth middleware ran upstream; treat as forbidden + // rather than crashing — composition is the caller's + // responsibility but a 403 is the safer default. + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + if !pred(p) { + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// RequireRole permits requests whose Principal holds the named role. +func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasRole(name) + }) +} + +// RequireAnyRole permits requests whose Principal holds at least one of the +// named roles. +func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasAnyRole(names...) + }) +} + +// RequirePermission permits requests whose Principal holds the named +// permission (resolved via roles at auth time). +func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasPermission(name) + }) +} + +// RequireAbility permits requests whose Principal carries the named ability. +// Abilities are populated only for API-key authentication; this middleware +// will reject session/JWT-authenticated requests by design. +func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasAbility(name) + }) +} + +func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) { + if len(s) == 0 { + return nil + } + return s[0] +} diff --git a/middleware/context.go b/middleware/context.go new file mode 100644 index 0000000..7de6da0 --- /dev/null +++ b/middleware/context.go @@ -0,0 +1,39 @@ +// Package middleware provides framework-neutral HTTP middleware for authkit. +// Every middleware function returns the standard func(http.Handler) +// http.Handler type, so it composes with lightmux's Use/Group/Handle as well +// as any net/http stack that uses the same signature. +package middleware + +import ( + "context" + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// principalKey is an unexported context key. Using a distinct empty struct +// type guarantees no collision with caller-defined keys. +type principalKey struct{} + +// withPrincipal stashes p on the request context for downstream handlers. +func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context { + return context.WithValue(ctx, principalKey{}, p) +} + +// PrincipalFrom retrieves the authenticated Principal placed by RequireSession, +// RequireJWT, or RequireAPIKey. The boolean is false if no auth middleware +// ran for this request. +func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) { + p, ok := ctx.Value(principalKey{}).(*authkit.Principal) + return p, ok +} + +// MustPrincipal panics if no Principal is on the context. Use only on +// handlers known to be behind a Require* middleware. +func MustPrincipal(r *http.Request) *authkit.Principal { + p, ok := PrincipalFrom(r.Context()) + if !ok { + panic("authkit/middleware: no principal on request context") + } + return p +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..1fbc4dd --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,138 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// Options configures auth middleware. Auth is required; the rest fall back +// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403 +// on authz failure. +type Options struct { + Auth *authkit.Auth + Extractor authkit.Extractor + OnUnauth func(w http.ResponseWriter, r *http.Request, err error) + OnForbidden func(w http.ResponseWriter, r *http.Request, err error) +} + +func (o Options) extractor() authkit.Extractor { + if o.Extractor != nil { + return o.Extractor + } + return authkit.BearerExtractor() +} + +func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) { + if o.OnUnauth != nil { + return o.OnUnauth + } + return defaultJSONError(http.StatusUnauthorized) +} + +func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) { + if o.OnForbidden != nil { + return o.OnForbidden + } + return defaultJSONError(http.StatusForbidden) +} + +func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) { + return func(w http.ResponseWriter, _ *http.Request, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": http.StatusText(status), + }) + } +} + +// RequireSession authenticates the request via an opaque session string. The +// extractor is consulted first; if no extractor is set the default Bearer +// extractor is used. For cookie-based session lookup, set +// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName). +func RequireSession(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateSession(r.Context(), raw) + }) +} + +// RequireJWT authenticates the request via an HS256 JWT. +func RequireJWT(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateJWT(r.Context(), raw) + }) +} + +// RequireAPIKey authenticates the request via an opaque API secret. +func RequireAPIKey(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateAPIKey(r.Context(), raw) + }) +} + +// RequireAny tries each method in order until one succeeds. Useful for routes +// that accept either a session cookie or an API key. +func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler { + if len(methods) == 0 { + methods = []authkit.AuthMethod{ + authkit.AuthMethodSession, + authkit.AuthMethodJWT, + authkit.AuthMethodAPIKey, + } + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, ok := opts.extractor()(r) + if !ok || raw == "" { + opts.onUnauth()(w, r, authkit.ErrSessionInvalid) + return + } + var ( + p *authkit.Principal + lastErr error + ) + for _, m := range methods { + switch m { + case authkit.AuthMethodSession: + p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw) + case authkit.AuthMethodJWT: + p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw) + case authkit.AuthMethodAPIKey: + p, lastErr = opts.Auth.AuthenticateAPIKey(r.Context(), raw) + } + if lastErr == nil && p != nil { + next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) + return + } + } + opts.onUnauth()(w, r, lastErr) + }) + } +} + +// requireWith is the shared scaffolding for the single-method Require* +// middlewares. +func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler { + if opts.Auth == nil { + panic("authkit/middleware: Options.Auth is required") + } + extractor := opts.extractor() + onUnauth := opts.onUnauth() + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, ok := extractor(r) + if !ok || raw == "" { + onUnauth(w, r, authkit.ErrSessionInvalid) + return + } + p, err := authn(r, raw) + if err != nil { + onUnauth(w, r, err) + return + } + next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) + }) + } +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..9b31616 --- /dev/null +++ b/models.go @@ -0,0 +1,75 @@ +package authkit + +import ( + "net/netip" + "time" + + "github.com/google/uuid" +) + +type User struct { + ID uuid.UUID + Email string + EmailNormalized string + EmailVerifiedAt *time.Time + PasswordHash string + SessionVersion int + FailedLogins int + LastLoginAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type Session struct { + IDHash []byte + UserID uuid.UUID + UserAgent string + IP netip.Addr + CreatedAt time.Time + LastSeenAt time.Time + ExpiresAt time.Time +} + +type TokenKind string + +const ( + TokenEmailVerify TokenKind = "email_verify" + TokenPasswordReset TokenKind = "password_reset" + TokenMagicLink TokenKind = "magic_link" + TokenRefresh TokenKind = "refresh" +) + +type Token struct { + Hash []byte + Kind TokenKind + UserID uuid.UUID + ChainID *string + ConsumedAt *time.Time + CreatedAt time.Time + ExpiresAt time.Time +} + +type APIKey struct { + IDHash []byte + OwnerID uuid.UUID + Name string + Abilities []string + LastUsedAt *time.Time + CreatedAt time.Time + ExpiresAt *time.Time + RevokedAt *time.Time +} + +type Role struct { + ID uuid.UUID + Name string + Description string + CreatedAt time.Time +} + +type Permission struct { + ID uuid.UUID + Name string + Description string + CreatedAt time.Time +} diff --git a/principal.go b/principal.go new file mode 100644 index 0000000..590dd9b --- /dev/null +++ b/principal.go @@ -0,0 +1,63 @@ +package authkit + +import ( + "time" + + "github.com/google/uuid" +) + +type AuthMethod string + +const ( + AuthMethodSession AuthMethod = "session" + AuthMethodJWT AuthMethod = "jwt" + AuthMethodAPIKey AuthMethod = "api_key" +) + +type Principal struct { + UserID uuid.UUID + Method AuthMethod + SessionID []byte + APIKeyID []byte + Roles []string + Permissions []string + Abilities []string + IssuedAt time.Time + ExpiresAt time.Time +} + +func (p *Principal) HasRole(name string) bool { + for _, r := range p.Roles { + if r == name { + return true + } + } + return false +} + +func (p *Principal) HasAnyRole(names ...string) bool { + for _, n := range names { + if p.HasRole(n) { + return true + } + } + return false +} + +func (p *Principal) HasPermission(name string) bool { + for _, perm := range p.Permissions { + if perm == name { + return true + } + } + return false +} + +func (p *Principal) HasAbility(name string) bool { + for _, a := range p.Abilities { + if a == name { + return true + } + } + return false +} diff --git a/service_apikey.go b/service_apikey.go new file mode 100644 index 0000000..045e063 --- /dev/null +++ b/service_apikey.go @@ -0,0 +1,92 @@ +package authkit + +import ( + "context" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueAPIKey mints a fresh API secret with the given abilities and an +// optional TTL. The plaintext is returned to the caller (show-once) and the +// SHA-256 lookup hash is stored. Pass ttl=nil for a non-expiring key. +func (a *Auth) IssueAPIKey(ctx context.Context, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *APIKey, error) { + const op = "authkit.Auth.IssueAPIKey" + plaintext, hash, err := mintSecret(prefixAPIKey, a.cfg.Random) + if err != nil { + return "", nil, errx.Wrap(op, err) + } + now := a.now() + k := &APIKey{ + IDHash: hash, + OwnerID: ownerID, + Name: name, + Abilities: append([]string(nil), abilities...), + CreatedAt: now, + } + if ttl != nil { + exp := now.Add(*ttl) + k.ExpiresAt = &exp + } + if err := a.deps.APIKeys.CreateAPIKey(ctx, k); err != nil { + return "", nil, errx.Wrap(op, err) + } + return plaintext, k, nil +} + +// AuthenticateAPIKey validates an API secret string, touches last_used_at +// (best-effort), and returns a Principal carrying the key's abilities. The +// owning user's roles+permissions are also resolved so the same Principal +// can satisfy RequireRole / RequirePermission middleware. +func (a *Auth) AuthenticateAPIKey(ctx context.Context, plaintext string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateAPIKey" + hash, ok := parseSecret(prefixAPIKey, plaintext) + if !ok { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + k, err := a.deps.APIKeys.GetAPIKey(ctx, hash) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + if k.RevokedAt != nil { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + if k.ExpiresAt != nil && !k.ExpiresAt.After(now) { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + _ = a.deps.APIKeys.TouchAPIKey(ctx, hash, now) + + roles, perms, err := a.resolveRolesAndPermissions(ctx, k.OwnerID) + if err != nil { + return nil, errx.Wrap(op, err) + } + expires := now + if k.ExpiresAt != nil { + expires = *k.ExpiresAt + } + return &Principal{ + UserID: k.OwnerID, + Method: AuthMethodAPIKey, + APIKeyID: hash, + Roles: roles, + Permissions: perms, + Abilities: append([]string(nil), k.Abilities...), + IssuedAt: k.CreatedAt, + ExpiresAt: expires, + }, nil +} + +// RevokeAPIKey marks a key revoked. Idempotent on already-revoked keys. +func (a *Auth) RevokeAPIKey(ctx context.Context, plaintext string) error { + const op = "authkit.Auth.RevokeAPIKey" + hash, ok := parseSecret(prefixAPIKey, plaintext) + if !ok { + return errx.Wrap(op, ErrAPIKeyInvalid) + } + if err := a.deps.APIKeys.RevokeAPIKey(ctx, hash, a.now()); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_authz.go b/service_authz.go new file mode 100644 index 0000000..7a5582c --- /dev/null +++ b/service_authz.go @@ -0,0 +1,85 @@ +package authkit + +import ( + "context" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// UserPermissions returns the union of permission names a user holds via +// their assigned roles. Resolved at call time; v1 does not cache. +func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) { + const op = "authkit.Auth.UserPermissions" + perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) + if err != nil { + return nil, errx.Wrap(op, err) + } + out := make([]string, len(perms)) + for i, p := range perms { + out[i] = p.Name + } + return out, nil +} + +// HasPermission checks whether a user holds the named permission via any +// assigned role. +func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) { + const op = "authkit.Auth.HasPermission" + perms, err := a.UserPermissions(ctx, userID) + if err != nil { + return false, errx.Wrap(op, err) + } + for _, p := range perms { + if p == name { + return true, nil + } + } + return false, nil +} + +// HasRole checks whether a user is assigned the named role. +func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) { + const op = "authkit.Auth.HasRole" + ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name}) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// HasAnyRole checks whether a user holds at least one of the named roles. +func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { + const op = "authkit.Auth.HasAnyRole" + ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// AssignRole is a convenience that looks up a role by name and assigns it. +func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error { + const op = "authkit.Auth.AssignRole" + r, err := a.deps.Roles.GetRoleByName(ctx, roleName) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RemoveRole is the symmetric helper for AssignRole. +func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error { + const op = "authkit.Auth.RemoveRole" + r, err := a.deps.Roles.GetRoleByName(ctx, roleName) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_jwt.go b/service_jwt.go new file mode 100644 index 0000000..d3133b2 --- /dev/null +++ b/service_jwt.go @@ -0,0 +1,147 @@ +package authkit + +import ( + "context" + "errors" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueJWT issues a fresh access JWT and a rotating opaque refresh token. +// The refresh token is bound to a chain via Token.ChainID; rotation +// preserves that chain so reuse-detection can revoke the whole family. +func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) { + const op = "authkit.Auth.IssueJWT" + u, err := a.deps.Users.GetUserByID(ctx, userID) + if err != nil { + return "", "", errx.Wrap(op, err) + } + access, err = a.signAccessToken(u.ID, u.SessionVersion) + if err != nil { + return "", "", errx.Wrap(op, err) + } + refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString()) + if err != nil { + return "", "", errx.Wrap(op, err) + } + return access, refresh, nil +} + +// AuthenticateJWT validates the access JWT, cross-checks the user's +// session_version (instant revocation), and resolves a Principal. +func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateJWT" + claims, err := a.parseAccessToken(access) + if err != nil { + return nil, errx.Wrap(op, err) + } + uid, err := uuid.Parse(claims.Subject) + if err != nil { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + u, err := a.deps.Users.GetUserByID(ctx, uid) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + return nil, errx.Wrap(op, err) + } + if u.SessionVersion != claims.SessionVersion { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + roles, perms, err := a.resolveRolesAndPermissions(ctx, u.ID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return &Principal{ + UserID: u.ID, + Method: AuthMethodJWT, + Roles: roles, + Permissions: perms, + IssuedAt: claims.IssuedAt.Time, + ExpiresAt: claims.ExpiresAt.Time, + }, nil +} + +// RefreshJWT consumes the presented refresh token and mints a new access + +// refresh pair. Reuse of an already-consumed refresh token deletes the +// entire chain (logout-everywhere on that device family) and returns +// ErrTokenReused. +func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) { + const op = "authkit.Auth.RefreshJWT" + hash, ok := parseSecret(prefixRefresh, plaintextRefresh) + if !ok { + return "", "", errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + + consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now) + if err != nil { + // Differentiate plain-invalid (never existed / expired) from + // reuse (existed, already consumed). The presence-check below is + // the reuse signal. + if errors.Is(err, ErrTokenInvalid) { + if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil { + if existing.ChainID != nil && *existing.ChainID != "" { + _, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID) + } + return "", "", errx.Wrap(op, ErrTokenReused) + } + } + return "", "", errx.Wrap(op, err) + } + + var chainID string + if consumed.ChainID != nil { + chainID = *consumed.ChainID + } + if chainID == "" { + // Defensive: every refresh token should be chain-bound. Fall back + // to a fresh chain so we never throw on missing metadata. + chainID = uuid.NewString() + } + + access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID)) + if err != nil { + return "", "", errx.Wrap(op, err) + } + refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID) + if err != nil { + return "", "", errx.Wrap(op, err) + } + return access, refresh, nil +} + +// mintRefreshToken stores a fresh refresh token bound to chainID and returns +// the plaintext. +func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) { + const op = "authkit.Auth.mintRefreshToken" + plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenRefresh, + UserID: userID, + ChainID: &chainID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.RefreshTokenTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// userSessionVersion fetches the current session_version. Errors collapse to +// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly +// — but we still need a value to embed in the freshly-minted access token. +func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int { + if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil { + return u.SessionVersion + } + return 0 +} diff --git a/service_magic.go b/service_magic.go new file mode 100644 index 0000000..d200ca5 --- /dev/null +++ b/service_magic.go @@ -0,0 +1,62 @@ +package authkit + +import ( + "context" + + "git.juancwu.dev/juancwu/errx" +) + +// RequestMagicLink mints a single-use magic-link token for the email and +// returns the plaintext for delivery. ErrUserNotFound is returned for +// unregistered emails. +func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) { + const op = "authkit.Auth.RequestMagicLink" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + return "", errx.Wrap(op, err) + } + plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenMagicLink, + UserID: u.ID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.MagicLinkTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConsumeMagicLink consumes the magic-link token and returns the +// authenticated user. Callers typically follow this with IssueSession or +// IssueJWT to actually log the user in. +func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) { + const op = "authkit.Auth.ConsumeMagicLink" + hash, ok := parseSecret(prefixMagicLink, plaintextToken) + if !ok { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now) + if err != nil { + return nil, errx.Wrap(op, err) + } + u, err := a.deps.Users.GetUserByID(ctx, t.UserID) + if err != nil { + return nil, errx.Wrap(op, err) + } + // A successful magic-link login also implicitly verifies the email + // (the user demonstrably controls the inbox). + if u.EmailVerifiedAt == nil { + if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil { + u.EmailVerifiedAt = &now + } + } + return u, nil +} diff --git a/service_reset.go b/service_reset.go new file mode 100644 index 0000000..a2ebbb6 --- /dev/null +++ b/service_reset.go @@ -0,0 +1,69 @@ +package authkit + +import ( + "context" + "errors" + + "git.juancwu.dev/juancwu/errx" +) + +// RequestPasswordReset mints a single-use password-reset token for the user +// behind email and returns the plaintext for the caller to deliver via email. +// Returns ErrUserNotFound when the email isn't registered (per project +// policy of distinct errors over anti-enumeration). +func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) { + const op = "authkit.Auth.RequestPasswordReset" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + return "", errx.Wrap(op, err) + } + plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenPasswordReset, + UserID: u.ID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.PasswordResetTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConfirmPasswordReset consumes the reset token, sets the new password, +// bumps the user's session_version, and revokes outstanding sessions so the +// reset constitutes a global logout. +func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error { + const op = "authkit.Auth.ConfirmPasswordReset" + hash, ok := parseSecret(prefixPasswordRset, plaintextToken) + if !ok { + return errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now) + if err != nil { + return errx.Wrap(op, err) + } + newHash, err := a.deps.Hasher.Hash(newPassword) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil { + if errors.Is(err, ErrUserNotFound) { + return errx.Wrap(op, ErrUserNotFound) + } + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_session.go b/service_session.go new file mode 100644 index 0000000..a6573d6 --- /dev/null +++ b/service_session.go @@ -0,0 +1,154 @@ +package authkit + +import ( + "context" + "net/http" + "net/netip" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueSession mints an opaque session ID, persists the session record, and +// returns the plaintext (for the cookie) plus the stored Session. +func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) { + const op = "authkit.Auth.IssueSession" + plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random) + if err != nil { + return "", nil, errx.Wrap(op, err) + } + now := a.now() + expires := now.Add(a.cfg.SessionIdleTTL) + if cap := now.Add(a.cfg.SessionAbsoluteTTL); expires.After(cap) { + expires = cap + } + s := &Session{ + IDHash: hash, + UserID: userID, + UserAgent: userAgent, + IP: ip, + CreatedAt: now, + LastSeenAt: now, + ExpiresAt: expires, + } + if err := a.deps.Sessions.CreateSession(ctx, s); err != nil { + return "", nil, errx.Wrap(op, err) + } + return plaintext, s, nil +} + +// AuthenticateSession validates an opaque session string, slides the TTL, +// resolves the user's roles+permissions, and returns a Principal. Expired or +// unknown sessions return ErrSessionInvalid. +func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateSession" + hash, ok := parseSecret(prefixSession, plaintext) + if !ok { + return nil, errx.Wrap(op, ErrSessionInvalid) + } + s, err := a.deps.Sessions.GetSession(ctx, hash) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + if !s.ExpiresAt.After(now) { + _ = a.deps.Sessions.DeleteSession(ctx, hash) + return nil, errx.Wrap(op, ErrSessionInvalid) + } + + // Slide the idle TTL, capped at created_at + AbsoluteTTL so an active + // session still expires at the absolute boundary. + newExpires := now.Add(a.cfg.SessionIdleTTL) + if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) { + newExpires = cap + } + if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil { + return nil, errx.Wrap(op, err) + } + + roles, perms, err := a.resolveRolesAndPermissions(ctx, s.UserID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return &Principal{ + UserID: s.UserID, + Method: AuthMethodSession, + SessionID: hash, + Roles: roles, + Permissions: perms, + IssuedAt: s.CreatedAt, + ExpiresAt: newExpires, + }, nil +} + +// RevokeSession deletes a single session by its plaintext id. Idempotent: +// missing sessions are not an error (logout twice should not 500). +func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error { + const op = "authkit.Auth.RevokeSession" + hash, ok := parseSecret(prefixSession, plaintext) + if !ok { + return nil + } + if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RevokeAllUserSessions kills every active session for the user and bumps +// the user's session_version (invalidating outstanding JWT access tokens). +func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.Auth.RevokeAllUserSessions" + if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the +// plaintext returned by IssueSession; pass the matching ExpiresAt from the +// returned *Session as `expires`. To clear a cookie at logout, pass an empty +// plaintext and a past expiry. +func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie { + c := &http.Cookie{ + Name: a.cfg.SessionCookieName, + Value: plaintext, + Path: a.cfg.SessionCookiePath, + Domain: a.cfg.SessionCookieDomain, + Secure: a.cfg.SessionCookieSecure, + HttpOnly: a.cfg.SessionCookieHTTPOnly, + SameSite: a.cfg.SessionCookieSameSite, + Expires: expires, + } + if plaintext == "" { + c.MaxAge = -1 + } + return c +} + +// resolveRolesAndPermissions fetches the user's role names and the union of +// their permission names. Both are returned as flat string slices for cheap +// containment checks on the Principal. +func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) { + roles, err := a.deps.Roles.GetUserRoles(ctx, userID) + if err != nil { + return nil, nil, err + } + perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) + if err != nil { + return nil, nil, err + } + rNames := make([]string, len(roles)) + for i, r := range roles { + rNames[i] = r.Name + } + pNames := make([]string, len(perms)) + for i, p := range perms { + pNames[i] = p.Name + } + return rNames, pNames, nil +} diff --git a/service_test.go b/service_test.go new file mode 100644 index 0000000..868ec1b --- /dev/null +++ b/service_test.go @@ -0,0 +1,220 @@ +package authkit + +import ( + "context" + "errors" + "net/netip" + "testing" +) + +func TestRegisterAndLogin(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + if u.EmailNormalized != "alice@example.com" { + t.Fatalf("email_normalized = %q", u.EmailNormalized) + } + got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("LoginPassword: %v", err) + } + if got.ID != u.ID { + t.Fatalf("login user mismatch") + } +} + +func TestRegisterDuplicateEmail(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil { + t.Fatalf("Register: %v", err) + } + _, err := a.Register(context.Background(), "X@Y.COM", "abc") + if !errors.Is(err, ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken, got %v", err) + } +} + +func TestLoginWrongPassword(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil { + t.Fatalf("Register: %v", err) + } + if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestSessionIssueAuthenticateRevoke(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "s@s.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + if sess == nil || plain == "" { + t.Fatalf("missing session or plaintext") + } + + p, err := a.AuthenticateSession(context.Background(), plain) + if err != nil { + t.Fatalf("AuthenticateSession: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("principal user id mismatch") + } + if err := a.RevokeSession(context.Background(), plain); err != nil { + t.Fatalf("RevokeSession: %v", err) + } + if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) + } +} + +func TestEmailVerificationFlow(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "ev@e.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + tok, err := a.RequestEmailVerification(context.Background(), u.ID) + if err != nil { + t.Fatalf("RequestEmailVerification: %v", err) + } + confirmed, err := a.ConfirmEmail(context.Background(), tok) + if err != nil { + t.Fatalf("ConfirmEmail: %v", err) + } + if confirmed.EmailVerifiedAt == nil { + t.Fatalf("email_verified_at not set") + } + // Re-using the token must fail. + if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err) + } +} + +func TestPasswordResetFlow(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "r@r.com", "old") + if err != nil { + t.Fatalf("Register: %v", err) + } + // Issue a session that should be invalidated by the reset. + plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + tok, err := a.RequestPasswordReset(context.Background(), "r@r.com") + if err != nil { + t.Fatalf("RequestPasswordReset: %v", err) + } + if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil { + t.Fatalf("ConfirmPasswordReset: %v", err) + } + if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("old password should fail post-reset, got %v", err) + } + if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil { + t.Fatalf("new password should work post-reset, got %v", err) + } + if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("session should be invalidated by reset, got %v", err) + } +} + +func TestMagicLinkFlow(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil { + t.Fatalf("Register: %v", err) + } + tok, err := a.RequestMagicLink(context.Background(), "m@m.com") + if err != nil { + t.Fatalf("RequestMagicLink: %v", err) + } + u, err := a.ConsumeMagicLink(context.Background(), tok) + if err != nil { + t.Fatalf("ConsumeMagicLink: %v", err) + } + if u.EmailVerifiedAt == nil { + t.Fatalf("magic link should imply email verification") + } + if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err) + } +} + +func TestAPIKeyFlowWithAbilities(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "k@k.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plaintext, k, err := a.IssueAPIKey(context.Background(), u.ID, "ci", []string{"billing:read", "users:list"}, nil) + if err != nil { + t.Fatalf("IssueAPIKey: %v", err) + } + if k == nil || plaintext == "" { + t.Fatalf("missing api key") + } + p, err := a.AuthenticateAPIKey(context.Background(), plaintext) + if err != nil { + t.Fatalf("AuthenticateAPIKey: %v", err) + } + if !p.HasAbility("billing:read") || !p.HasAbility("users:list") { + t.Fatalf("abilities missing on principal: %+v", p.Abilities) + } + if p.HasAbility("admin:nuke") { + t.Fatalf("unexpected ability granted") + } + if err := a.RevokeAPIKey(context.Background(), plaintext); err != nil { + t.Fatalf("RevokeAPIKey: %v", err) + } + if _, err := a.AuthenticateAPIKey(context.Background(), plaintext); !errors.Is(err, ErrAPIKeyInvalid) { + t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err) + } +} + +func TestRBACRolesAndPermissions(t *testing.T) { + ctx := context.Background() + a := newTestAuth(t) + u, err := a.Register(ctx, "rb@a.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + + // Create role + permission, hook them up. + role := &Role{Name: "editor"} + if err := a.deps.Roles.CreateRole(ctx, role); err != nil { + t.Fatalf("CreateRole: %v", err) + } + perm := &Permission{Name: "posts:write"} + if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil { + t.Fatalf("CreatePermission: %v", err) + } + if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil { + t.Fatalf("AssignPermissionToRole: %v", err) + } + if err := a.AssignRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + ok, err := a.HasPermission(ctx, u.ID, "posts:write") + if err != nil || !ok { + t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err) + } + ok, err = a.HasRole(ctx, u.ID, "editor") + if err != nil || !ok { + t.Fatalf("HasRole editor should be true, got %v %v", ok, err) + } + if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("RemoveRole: %v", err) + } + ok, _ = a.HasPermission(ctx, u.ID, "posts:write") + if ok { + t.Fatalf("HasPermission should be false after RemoveRole") + } +} diff --git a/service_user.go b/service_user.go new file mode 100644 index 0000000..b888897 --- /dev/null +++ b/service_user.go @@ -0,0 +1,178 @@ +package authkit + +import ( + "context" + "errors" + "strings" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail +// and the email_normalized column. Trim + lowercase is intentional; we do +// not collapse Gmail-style "+" addressing or strip dots — that's a policy +// decision callers can layer on top. +func normalizeEmail(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +// Register creates a new user with an Argon2id-hashed password. Returns +// ErrEmailTaken if the normalized email is already registered. +func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) { + const op = "authkit.Auth.Register" + if email == "" || password == "" { + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + hash, err := a.deps.Hasher.Hash(password) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + u := &User{ + ID: uuid.New(), + Email: email, + EmailNormalized: normalizeEmail(email), + PasswordHash: hash, + CreatedAt: now, + UpdatedAt: now, + } + if err := a.deps.Users.CreateUser(ctx, u); err != nil { + return nil, errx.Wrap(op, err) + } + return u, nil +} + +// LoginPassword verifies the password and returns the authenticated user. +// Failure increments failed_logins; success resets it and stamps last_login_at. +// LoginHook (if configured) is invoked with the success outcome — use this to +// hook in rate limiting or audit logging. +func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) { + const op = "authkit.Auth.LoginPassword" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + _ = a.fireLoginHook(ctx, email, false) + if errors.Is(err, ErrUserNotFound) { + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + return nil, errx.Wrap(op, err) + } + if u.PasswordHash == "" { + _ = a.fireLoginHook(ctx, email, false) + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash) + if err != nil { + return nil, errx.Wrap(op, err) + } + if !ok { + _, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID) + _ = a.fireLoginHook(ctx, email, false) + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + + now := a.now() + u.LastLoginAt = &now + u.FailedLogins = 0 + if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil { + return nil, errx.Wrap(op, err) + } + if err := a.deps.Users.UpdateUser(ctx, u); err != nil { + return nil, errx.Wrap(op, err) + } + + if needsRehash { + if newHash, herr := a.deps.Hasher.Hash(password); herr == nil { + _ = a.deps.Users.SetPassword(ctx, u.ID, newHash) + u.PasswordHash = newHash + } + } + _ = a.fireLoginHook(ctx, email, true) + return u, nil +} + +// ChangePassword verifies the current password, sets the new one, and bumps +// the user's session_version so all outstanding JWT access tokens are +// instantly invalidated. Outstanding opaque sessions are also revoked. +func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error { + const op = "authkit.Auth.ChangePassword" + u, err := a.deps.Users.GetUserByID(ctx, userID) + if err != nil { + return errx.Wrap(op, err) + } + if u.PasswordHash == "" { + return errx.Wrap(op, ErrInvalidCredentials) + } + ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash) + if err != nil { + return errx.Wrap(op, err) + } + if !ok { + return errx.Wrap(op, ErrInvalidCredentials) + } + newHash, err := a.deps.Hasher.Hash(newPassword) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil { + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RequestEmailVerification mints a single-use email-verify token for the +// user. Return the plaintext to the caller so they can put it in an email +// link; the lookup hash is stored in TokenStore. +func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) { + const op = "authkit.Auth.RequestEmailVerification" + plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenEmailVerify, + UserID: userID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.EmailVerifyTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConfirmEmail consumes the verification token and marks the user's email +// verified. Returns ErrTokenInvalid if the token is missing/expired/used. +func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) { + const op = "authkit.Auth.ConfirmEmail" + hash, ok := parseSecret(prefixEmailVerify, plaintextToken) + if !ok { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now) + if err != nil { + return nil, errx.Wrap(op, err) + } + if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil { + return nil, errx.Wrap(op, err) + } + return a.deps.Users.GetUserByID(ctx, t.UserID) +} + +// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied +// hooks; we never want a misbehaving telemetry hook to break login. +func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error { + if a.cfg.LoginHook == nil { + return nil + } + return a.cfg.LoginHook(ctx, email, success) +} diff --git a/sqlstore/apikeys.go b/sqlstore/apikeys.go new file mode 100644 index 0000000..1212c20 --- /dev/null +++ b/sqlstore/apikeys.go @@ -0,0 +1,124 @@ +package sqlstore + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type apiKeyStore struct{ storeBase } + +func (s *apiKeyStore) CreateAPIKey(ctx context.Context, k *authkit.APIKey) error { + const op = "authkit.sqlstore.APIKeyStore.CreateAPIKey" + if k.CreatedAt.IsZero() { + k.CreatedAt = time.Now().UTC() + } + if k.Abilities == nil { + k.Abilities = []string{} + } + abilities, err := json.Marshal(k.Abilities) + if err != nil { + return errx.Wrap(op, err) + } + _, err = s.db.ExecContext(ctx, s.q.CreateAPIKey, + k.IDHash, uuidArg(k.OwnerID), k.Name, abilities, + nullableTime(k.LastUsedAt), k.CreatedAt, + nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt)) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *apiKeyStore) GetAPIKey(ctx context.Context, idHash []byte) (*authkit.APIKey, error) { + const op = "authkit.sqlstore.APIKeyStore.GetAPIKey" + k, err := scanAPIKey(s.db.QueryRowContext(ctx, s.q.GetAPIKey, idHash)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrAPIKeyInvalid)) + } + return k, nil +} + +func (s *apiKeyStore) ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*authkit.APIKey, error) { + const op = "authkit.sqlstore.APIKeyStore.ListAPIKeysByOwner" + rows, err := s.db.QueryContext(ctx, s.q.ListAPIKeysByOwner, uuidArg(ownerID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.APIKey + for rows.Next() { + k, err := scanAPIKey(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, k) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *apiKeyStore) TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.TouchAPIKey" + if _, err := s.db.ExecContext(ctx, s.q.TouchAPIKey, at, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *apiKeyStore) RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKey" + tag, err := s.db.ExecContext(ctx, s.q.RevokeAPIKey, at, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrAPIKeyInvalid) + } + return nil +} + +func (s *apiKeyStore) RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKeysByOwner" + if _, err := s.db.ExecContext(ctx, s.q.RevokeAPIKeysByOwner, at, uuidArg(ownerID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func scanAPIKey(row rowScanner) (*authkit.APIKey, error) { + var ( + k authkit.APIKey + ownerIDStr string + abilitiesRaw []byte + lastUsed sql.NullTime + expires sql.NullTime + revoked sql.NullTime + ) + if err := row.Scan(&k.IDHash, &ownerIDStr, &k.Name, &abilitiesRaw, + &lastUsed, &k.CreatedAt, &expires, &revoked); err != nil { + return nil, err + } + owner, err := scanUUID(ownerIDStr) + if err != nil { + return nil, err + } + k.OwnerID = owner + if len(abilitiesRaw) > 0 { + if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil { + return nil, err + } + } + if k.Abilities == nil { + k.Abilities = []string{} + } + k.LastUsedAt = scanNullTimePtr(lastUsed) + k.ExpiresAt = scanNullTimePtr(expires) + k.RevokedAt = scanNullTimePtr(revoked) + return &k, nil +} diff --git a/sqlstore/dialect.go b/sqlstore/dialect.go new file mode 100644 index 0000000..4b6c3d8 --- /dev/null +++ b/sqlstore/dialect.go @@ -0,0 +1,118 @@ +package sqlstore + +import ( + "context" + "database/sql" + "io/fs" +) + +// Dialect describes how a particular SQL backend renders the queries authkit +// runs at runtime, bootstraps its session for migrations, and reports +// driver-specific errors. v1 ships the Postgres dialect; future MySQL or +// SQLite implementations satisfy the same interface without changes to the +// store code. +type Dialect interface { + // Name returns a stable short identifier ("postgres", "mysql", ...). + Name() string + + // BuildQueries renders every templated SQL string from a validated + // Schema. Called once at New() and at Migrate() so identifiers and + // placeholder styles are baked in by the time stores execute queries. + BuildQueries(s Schema) Queries + + // Bootstrap runs once per Migrate() before the migration lock is taken. + // It is the place for things that must happen outside any transaction + // (CREATE EXTENSION on Postgres, for example). Must be idempotent and + // may be a no-op. + Bootstrap(ctx context.Context, db *sql.DB) error + + // AcquireMigrationLock takes a session-scoped lock on conn so concurrent + // migrators serialise. The returned release function never returns an + // error — implementations may log internally. + AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) + + // Migrations returns the dialect's embedded .sql files, rooted such + // that fs.ReadDir(".") lists them lex-sorted by filename. + Migrations() fs.FS + + // IsUniqueViolation maps a duplicate-key error from this driver to true + // so insert paths can return the matching authkit sentinel without + // depending on driver internals. + IsUniqueViolation(err error) bool + + // Placeholder returns the placeholder for parameter index n (1-based) + // in this dialect — "$1" for Postgres, "?" for MySQL/SQLite. + Placeholder(n int) string + + // PlaceholderList returns a comma-separated placeholder list for `count` + // parameters starting at position `start` (1-based), suitable for + // dynamic IN-clauses. The second return is the same length, prefilled + // with nil so the caller can append actual values. + PlaceholderList(start, count int) string +} + +// Queries is the full set of statement templates authkit issues. Field names +// match the store method that consumes the query. +type Queries struct { + // users + CreateUser string + GetUserByID string + GetUserByEmail string + UpdateUser string + DeleteUser string + SetPassword string + SetEmailVerified string + BumpSessionVersion string + IncrementFailedLogins string + ResetFailedLogins string + + // sessions + CreateSession string + GetSession string + TouchSession string + DeleteSession string + DeleteUserSessions string + DeleteExpiredSessions string + + // tokens + CreateToken string + ConsumeToken string + GetToken string + DeleteByChain string + DeleteExpiredTokens string + + // api keys + CreateAPIKey string + GetAPIKey string + ListAPIKeysByOwner string + TouchAPIKey string + RevokeAPIKey string + RevokeAPIKeysByOwner string + + // roles + CreateRole string + GetRoleByID string + GetRoleByName string + ListRoles string + DeleteRole string + AssignRoleToUser string + RemoveRoleFromUser string + GetUserRoles string + // HasAnyRole is built at call time because the placeholder count varies. + + // permissions + CreatePermission string + GetPermissionByID string + GetPermissionByName string + ListPermissions string + DeletePermission string + AssignPermissionToRole string + RemovePermissionFromRole string + GetRolePermissions string + GetUserPermissions string + + // migrations + CreateMigrationsTable string + SelectAppliedVersions string + InsertAppliedVersion string +} diff --git a/sqlstore/dialect/postgres/errors.go b/sqlstore/dialect/postgres/errors.go new file mode 100644 index 0000000..d4e93ad --- /dev/null +++ b/sqlstore/dialect/postgres/errors.go @@ -0,0 +1,33 @@ +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5/pgconn" +) + +// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib +// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError. +// lib/pq uses *pq.Error which has a Code field of the same value. +const pgUniqueViolation = "23505" + +// isUniqueViolation inspects err for a Postgres unique-violation, regardless +// of which driver registered the connection. We match on either the pgx +// error type or any error implementing a Code() string method (lib/pq's +// pq.Error has SQLState and Code fields; we check via reflection-free +// duck-typing through an interface). +func isUniqueViolation(err error) bool { + if err == nil { + return false + } + var pgxErr *pgconn.PgError + if errors.As(err, &pgxErr) { + return pgxErr.Code == pgUniqueViolation + } + type sqlStater interface{ SQLState() string } + var s sqlStater + if errors.As(err, &s) { + return s.SQLState() == pgUniqueViolation + } + return false +} diff --git a/sqlstore/dialect/postgres/migrations/0001_init.sql b/sqlstore/dialect/postgres/migrations/0001_init.sql new file mode 100644 index 0000000..e8f751d --- /dev/null +++ b/sqlstore/dialect/postgres/migrations/0001_init.sql @@ -0,0 +1,99 @@ +-- 0001_init.sql +-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the +-- library can be embedded in an existing application database. Each +-- migration owns its own transaction and inserts its version row at the +-- bottom; the runner only orchestrates file discovery and concurrency. + +BEGIN; + +CREATE TABLE IF NOT EXISTS authkit_schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_users ( + id UUID PRIMARY KEY, + email TEXT NOT NULL, + email_normalized TEXT NOT NULL, + email_verified_at TIMESTAMPTZ, + password_hash TEXT, + session_version INTEGER NOT NULL DEFAULT 0, + failed_logins INTEGER NOT NULL DEFAULT 0, + last_login_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq + ON authkit_users (email_normalized); + +CREATE TABLE IF NOT EXISTS authkit_sessions ( + id_hash BYTEA PRIMARY KEY, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + user_agent TEXT NOT NULL DEFAULT '', + ip TEXT, + created_at TIMESTAMPTZ NOT NULL, + last_seen_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL +); +CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id); +CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at); + +CREATE TABLE IF NOT EXISTS authkit_tokens ( + hash BYTEA NOT NULL, + kind TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + chain_id TEXT, + consumed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (kind, hash) +); +CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id); +CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at); +CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx + ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL; + +CREATE TABLE IF NOT EXISTS authkit_api_keys ( + id_hash BYTEA PRIMARY KEY, + owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + name TEXT NOT NULL, + abilities JSONB NOT NULL DEFAULT '[]'::jsonb, + last_used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ, + revoked_at TIMESTAMPTZ +); +CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id); + +CREATE TABLE IF NOT EXISTS authkit_roles ( + id UUID PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_permissions ( + id UUID PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_role_permissions ( + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE, + PRIMARY KEY (role_id, permission_id) +); + +CREATE TABLE IF NOT EXISTS authkit_user_roles ( + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + granted_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (user_id, role_id) +); +CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id); + +INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now()) +ON CONFLICT (version) DO NOTHING; + +COMMIT; diff --git a/sqlstore/dialect/postgres/postgres.go b/sqlstore/dialect/postgres/postgres.go new file mode 100644 index 0000000..56ec880 --- /dev/null +++ b/sqlstore/dialect/postgres/postgres.go @@ -0,0 +1,272 @@ +// Package postgres is the Postgres dialect for authkit/sqlstore. Importing +// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"` +// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`. +package postgres + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log" + "strings" + + "git.juancwu.dev/juancwu/authkit/sqlstore" +) + +//go:embed migrations/*.sql +var migrationsFS embed.FS + +// Dialect implements sqlstore.Dialect for Postgres. The zero value is the +// only required form; New() returns a pointer for clarity at call sites. +type Dialect struct{} + +// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate. +func New() *Dialect { return &Dialect{} } + +func (Dialect) Name() string { return "postgres" } + +// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable +// across rollouts and unlikely to clash with caller advisory locks. +const advisoryLockKey int64 = 0x617574686b6974 + +func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error { + // Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a + // hook for future migration prerequisites. + return nil +} + +func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) { + if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil { + return nil, err + } + release := func() { + if _, err := conn.ExecContext(context.Background(), + "SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil { + log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err) + } + } + return release, nil +} + +func (Dialect) Migrations() fs.FS { + sub, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + // migrationsFS is statically populated; this can only fail if the + // embed directive is removed, which would be a build-time error. + panic(err) + } + return sub +} + +func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) } + +func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) } + +// PlaceholderList renders a comma-separated `$start,$start+1,...` list of +// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion. +func (Dialect) PlaceholderList(start, count int) string { + if count <= 0 { + return "" + } + var b strings.Builder + for i := 0; i < count; i++ { + if i > 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, "$%d", start+i) + } + return b.String() +} + +// BuildQueries renders every query authkit issues, with table identifiers +// taken from s and `?` placeholders rewritten to `$N`. Identifiers are +// already validated by Schema.Validate — this is interpolated with +// fmt.Sprintf, so the validation gate is load-bearing. +func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries { + t := s.Tables + q := sqlstore.Queries{ + // users + CreateUser: `INSERT INTO ` + t.Users + ` + (id, email, email_normalized, email_verified_at, password_hash, + session_version, failed_logins, last_login_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + GetUserByID: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, failed_logins, last_login_at, + created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`, + GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, failed_logins, last_login_at, + created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`, + UpdateUser: `UPDATE ` + t.Users + ` SET + email = ?, email_normalized = ?, email_verified_at = ?, + password_hash = ?, session_version = ?, failed_logins = ?, + last_login_at = ?, updated_at = ? + WHERE id = ?`, + DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`, + SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`, + SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`, + BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`, + IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`, + ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`, + + // sessions + CreateSession: `INSERT INTO ` + t.Sessions + ` + (id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at + FROM ` + t.Sessions + ` WHERE id_hash = ?`, + TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`, + DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`, + DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`, + DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`, + + // tokens + CreateToken: `INSERT INTO ` + t.Tokens + ` + (hash, kind, user_id, chain_id, consumed_at, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + ConsumeToken: `UPDATE ` + t.Tokens + ` + SET consumed_at = ? + WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ? + RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`, + GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at + FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`, + DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`, + DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`, + + // api keys + CreateAPIKey: `INSERT INTO ` + t.APIKeys + ` + (id_hash, owner_id, name, abilities, last_used_at, created_at, expires_at, revoked_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + GetAPIKey: `SELECT id_hash, owner_id, name, abilities, last_used_at, + created_at, expires_at, revoked_at + FROM ` + t.APIKeys + ` WHERE id_hash = ?`, + ListAPIKeysByOwner: `SELECT id_hash, owner_id, name, abilities, last_used_at, + created_at, expires_at, revoked_at + FROM ` + t.APIKeys + ` WHERE owner_id = ? ORDER BY created_at DESC`, + TouchAPIKey: `UPDATE ` + t.APIKeys + ` SET last_used_at = ? WHERE id_hash = ?`, + RevokeAPIKey: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`, + RevokeAPIKeysByOwner: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE owner_id = ? AND revoked_at IS NULL`, + + // roles + CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, + GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`, + GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`, + ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`, + DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`, + AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`, + RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`, + GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at + FROM ` + t.Roles + ` r + JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id + WHERE ur.user_id = ? ORDER BY r.name`, + + // permissions + CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, + GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`, + GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`, + ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`, + DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`, + AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?) + ON CONFLICT DO NOTHING`, + RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`, + GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at + FROM ` + t.Permissions + ` p + JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id + WHERE rp.role_id = ? ORDER BY p.name`, + GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at + FROM ` + t.Permissions + ` p + JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id + JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id + WHERE ur.user_id = ? ORDER BY p.name`, + + // migrations + CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL + )`, + SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations, + InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`, + } + + // Rewrite `?` placeholders to `$N`. Each query is independent; numbering + // resets per query. + q.CreateUser = rebind(q.CreateUser) + q.GetUserByID = rebind(q.GetUserByID) + q.GetUserByEmail = rebind(q.GetUserByEmail) + q.UpdateUser = rebind(q.UpdateUser) + q.DeleteUser = rebind(q.DeleteUser) + q.SetPassword = rebind(q.SetPassword) + q.SetEmailVerified = rebind(q.SetEmailVerified) + q.BumpSessionVersion = rebind(q.BumpSessionVersion) + q.IncrementFailedLogins = rebind(q.IncrementFailedLogins) + q.ResetFailedLogins = rebind(q.ResetFailedLogins) + + q.CreateSession = rebind(q.CreateSession) + q.GetSession = rebind(q.GetSession) + q.TouchSession = rebind(q.TouchSession) + q.DeleteSession = rebind(q.DeleteSession) + q.DeleteUserSessions = rebind(q.DeleteUserSessions) + q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions) + + q.CreateToken = rebind(q.CreateToken) + q.ConsumeToken = rebind(q.ConsumeToken) + q.GetToken = rebind(q.GetToken) + q.DeleteByChain = rebind(q.DeleteByChain) + q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens) + + q.CreateAPIKey = rebind(q.CreateAPIKey) + q.GetAPIKey = rebind(q.GetAPIKey) + q.ListAPIKeysByOwner = rebind(q.ListAPIKeysByOwner) + q.TouchAPIKey = rebind(q.TouchAPIKey) + q.RevokeAPIKey = rebind(q.RevokeAPIKey) + q.RevokeAPIKeysByOwner = rebind(q.RevokeAPIKeysByOwner) + + q.CreateRole = rebind(q.CreateRole) + q.GetRoleByID = rebind(q.GetRoleByID) + q.GetRoleByName = rebind(q.GetRoleByName) + q.ListRoles = rebind(q.ListRoles) + q.DeleteRole = rebind(q.DeleteRole) + q.AssignRoleToUser = rebind(q.AssignRoleToUser) + q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser) + q.GetUserRoles = rebind(q.GetUserRoles) + + q.CreatePermission = rebind(q.CreatePermission) + q.GetPermissionByID = rebind(q.GetPermissionByID) + q.GetPermissionByName = rebind(q.GetPermissionByName) + q.ListPermissions = rebind(q.ListPermissions) + q.DeletePermission = rebind(q.DeletePermission) + q.AssignPermissionToRole = rebind(q.AssignPermissionToRole) + q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole) + q.GetRolePermissions = rebind(q.GetRolePermissions) + q.GetUserPermissions = rebind(q.GetUserPermissions) + + q.SelectAppliedVersions = rebind(q.SelectAppliedVersions) + q.InsertAppliedVersion = rebind(q.InsertAppliedVersion) + // CreateMigrationsTable contains no parameters. + + return q +} + +// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order. +// Our query strings contain no string literals that include `?` (verified +// by inspection of every query in BuildQueries); a literal-aware rewriter +// would be more defensive but is not needed for v1. +func rebind(s string) string { + var b strings.Builder + b.Grow(len(s) + 16) + n := 1 + for i := 0; i < len(s); i++ { + c := s[i] + if c == '?' { + fmt.Fprintf(&b, "$%d", n) + n++ + continue + } + b.WriteByte(c) + } + return b.String() +} + +// Compile-time interface compliance check. +var _ sqlstore.Dialect = (*Dialect)(nil) diff --git a/sqlstore/migrate.go b/sqlstore/migrate.go new file mode 100644 index 0000000..48906a0 --- /dev/null +++ b/sqlstore/migrate.go @@ -0,0 +1,116 @@ +package sqlstore + +import ( + "context" + "database/sql" + "io/fs" + "sort" + "strings" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +// Migrate applies every embedded migration the dialect ships that has not +// yet been recorded in the schema-migrations table. It is safe to call +// repeatedly and concurrently across processes — the dialect's session +// lock serialises rollouts. +// +// Each migration .sql file is responsible for owning its own +// BEGIN/COMMIT and inserting a row into the schema-migrations table on +// success. The runner only handles file discovery, version tracking, and +// concurrency. +func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error { + const op = "authkit.sqlstore.Migrate" + if db == nil { + return errx.New(op, "db is required") + } + if dialect == nil { + return errx.New(op, "dialect is required") + } + if err := schema.Validate(); err != nil { + return errx.Wrap(op, err) + } + + if err := dialect.Bootstrap(ctx, db); err != nil { + return errx.Wrap(op, err) + } + + conn, err := db.Conn(ctx) + if err != nil { + return errx.Wrap(op, err) + } + defer conn.Close() + + release, err := dialect.AcquireMigrationLock(ctx, conn) + if err != nil { + return errx.Wrap(op, err) + } + defer release() + + q := dialect.BuildQueries(schema) + if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil { + return errx.Wrap(op, err) + } + + applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions) + if err != nil { + return errx.Wrap(op, err) + } + + migs := dialect.Migrations() + files, err := fs.ReadDir(migs, ".") + if err != nil { + return errx.Wrap(op, err) + } + names := make([]string, 0, len(files)) + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + names = append(names, f.Name()) + } + } + sort.Strings(names) + + for _, name := range names { + version := strings.TrimSuffix(name, ".sql") + if _, ok := applied[version]; ok { + continue + } + body, err := fs.ReadFile(migs, name) + if err != nil { + return errx.Wrapf(op, err, "read %s", name) + } + if _, err := conn.ExecContext(ctx, string(body)); err != nil { + return errx.Wrapf(op, err, "apply %s", version) + } + } + return nil +} + +// applyVersionRow is intentionally not exposed: migration files own their +// own version-row insert. We keep this helper around in case a dialect ever +// needs to backfill versions from outside a migration body — currently +// unused. +var _ = applyVersionRow + +func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error { + _, err := conn.ExecContext(ctx, insertQ, version, at) + return err +} + +func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) { + rows, err := conn.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + out := make(map[string]struct{}) + for rows.Next() { + var v string + if err := rows.Scan(&v); err != nil { + return nil, err + } + out[v] = struct{}{} + } + return out, rows.Err() +} diff --git a/sqlstore/rbac.go b/sqlstore/rbac.go new file mode 100644 index 0000000..767440e --- /dev/null +++ b/sqlstore/rbac.go @@ -0,0 +1,301 @@ +package sqlstore + +import ( + "context" + "fmt" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type roleStore struct{ storeBase } +type permissionStore struct{ storeBase } + +// ----- roleStore ------------------------------------------------------------ + +func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error { + const op = "authkit.sqlstore.RoleStore.CreateRole" + if r.ID == uuid.Nil { + r.ID = uuid.New() + } + if r.CreatedAt.IsZero() { + r.CreatedAt = time.Now().UTC() + } + if _, err := s.db.ExecContext(ctx, s.q.CreateRole, + uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrapf(op, err, "role %q already exists", r.Name) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetRoleByID" + r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) + } + return r, nil +} + +func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetRoleByName" + r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) + } + return r, nil +} + +func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.ListRoles" + rows, err := s.db.QueryContext(ctx, s.q.ListRoles) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.DeleteRole" + tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrRoleNotFound) + } + return nil +} + +func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.AssignRoleToUser" + if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser, + uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser" + if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser, + uuidArg(userID), uuidArg(roleID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetUserRoles" + rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +// HasAnyRole builds the IN-clause at call time because the placeholder count +// depends on len(names). Identifier substitution comes from the validated +// Schema; values are bound, never interpolated. +func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { + const op = "authkit.sqlstore.RoleStore.HasAnyRole" + if len(names) == 0 { + return false, nil + } + // Placeholder $1 is the user_id; $2..$N+1 cover the names slice. + listSQL := s.d.PlaceholderList(2, len(names)) + q := fmt.Sprintf(`SELECT EXISTS ( + SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id + WHERE ur.user_id = %s AND r.name IN (%s) + )`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL) + + args := make([]any, 0, 1+len(names)) + args = append(args, uuidArg(userID)) + for _, n := range names { + args = append(args, n) + } + var ok bool + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// ----- permissionStore ------------------------------------------------------ + +func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error { + const op = "authkit.sqlstore.PermissionStore.CreatePermission" + if p.ID == uuid.Nil { + p.ID = uuid.New() + } + if p.CreatedAt.IsZero() { + p.CreatedAt = time.Now().UTC() + } + if _, err := s.db.ExecContext(ctx, s.q.CreatePermission, + uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrapf(op, err, "permission %q already exists", p.Name) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetPermissionByID" + p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) + } + return p, nil +} + +func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetPermissionByName" + p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) + } + return p, nil +} + +func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.ListPermissions" + rows, err := s.db.QueryContext(ctx, s.q.ListPermissions) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.DeletePermission" + tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrPermissionNotFound) + } + return nil +} + +func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole" + if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole" + if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetRolePermissions" + rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetUserPermissions" + rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func scanRole(row rowScanner) (*authkit.Role, error) { + var ( + r authkit.Role + idStr string + ) + if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + r.ID = id + return &r, nil +} + +func scanPermission(row rowScanner) (*authkit.Permission, error) { + var ( + p authkit.Permission + idStr string + ) + if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + p.ID = id + return &p, nil +} diff --git a/sqlstore/scan.go b/sqlstore/scan.go new file mode 100644 index 0000000..a962535 --- /dev/null +++ b/sqlstore/scan.go @@ -0,0 +1,86 @@ +package sqlstore + +import ( + "database/sql" + "net/netip" + "time" + + "github.com/google/uuid" +) + +// rowScanner is the lowest-common-denominator interface satisfied by both +// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops +// uniformly. +type rowScanner interface { + Scan(dest ...any) error +} + +// nullableTime returns nil when t is the zero time, otherwise &t. Callers +// should usually accept *time.Time on the model side and bind via this. +func nullableTime(t *time.Time) any { + if t == nil { + return nil + } + return *t +} + +// nullableString turns "" into nil so columns store NULL rather than an +// empty string. Used for password hashes and similar optional text columns. +func nullableString(s string) any { + if s == "" { + return nil + } + return s +} + +// nullableAddrString returns the string form of a netip.Addr when valid, or +// nil to bind as SQL NULL. Pairs with scanAddr. +func nullableAddrString(a netip.Addr) any { + if !a.IsValid() { + return nil + } + return a.String() +} + +// scanAddr parses a *string column into a netip.Addr. The zero Addr value is +// returned when the column was NULL or empty. +func scanAddr(s *string) (netip.Addr, error) { + if s == nil || *s == "" { + return netip.Addr{}, nil + } + return netip.ParseAddr(*s) +} + +// uuidArg returns the canonical string form of a UUID for binding. Every +// store uses this rather than passing uuid.UUID directly to keep behaviour +// identical across drivers (some accept driver.Valuer, some don't). +func uuidArg(id uuid.UUID) any { return id.String() } + +// scanUUID reads a string column and parses it back to a uuid.UUID. +func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } + +// chainArg returns either a *string or nil for binding the chain_id column. +func chainArg(c *string) any { + if c == nil { + return nil + } + return *c +} + +// scanNullStringPtr converts sql.NullString to *string for the model. +func scanNullStringPtr(ns sql.NullString) *string { + if !ns.Valid { + return nil + } + v := ns.String + return &v +} + +// scanNullTimePtr converts sql.NullTime to *time.Time for the model. +func scanNullTimePtr(nt sql.NullTime) *time.Time { + if !nt.Valid { + return nil + } + t := nt.Time + return &t +} diff --git a/sqlstore/schema.go b/sqlstore/schema.go new file mode 100644 index 0000000..cb254e4 --- /dev/null +++ b/sqlstore/schema.go @@ -0,0 +1,78 @@ +package sqlstore + +import ( + "regexp" + + "git.juancwu.dev/juancwu/errx" +) + +// Schema lets consumers map authkit storage to their own table names. +// Column overrides are intentionally not present in v1 — adding them later +// is purely additive. +type Schema struct { + Tables Tables +} + +// Tables is the per-table identifier override set. Every field must be a +// valid unquoted SQL identifier (matching identifierRE). Validation runs at +// New() and Migrate() time so SQL injection through Schema is impossible +// past that gate. +type Tables struct { + Users string + Sessions string + Tokens string + APIKeys string + Roles string + Permissions string + UserRoles string + RolePermissions string + SchemaMigrations string +} + +// DefaultSchema returns the stock authkit_* names used by the embedded +// migration files. +func DefaultSchema() Schema { + return Schema{Tables: Tables{ + Users: "authkit_users", + Sessions: "authkit_sessions", + Tokens: "authkit_tokens", + APIKeys: "authkit_api_keys", + Roles: "authkit_roles", + Permissions: "authkit_permissions", + UserRoles: "authkit_user_roles", + RolePermissions: "authkit_role_permissions", + SchemaMigrations: "authkit_schema_migrations", + }} +} + +// identifierRE matches the safe ASCII identifier subset shared by Postgres, +// MySQL and SQLite when not quoted. Anything outside this set is rejected +// rather than escaped — Schema is not the place to support exotic names. +var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// Validate ensures every Schema.Tables field is a non-empty, safe identifier. +func (s Schema) Validate() error { + const op = "authkit.sqlstore.Schema.Validate" + checks := []struct { + field, value string + }{ + {"Users", s.Tables.Users}, + {"Sessions", s.Tables.Sessions}, + {"Tokens", s.Tables.Tokens}, + {"APIKeys", s.Tables.APIKeys}, + {"Roles", s.Tables.Roles}, + {"Permissions", s.Tables.Permissions}, + {"UserRoles", s.Tables.UserRoles}, + {"RolePermissions", s.Tables.RolePermissions}, + {"SchemaMigrations", s.Tables.SchemaMigrations}, + } + for _, c := range checks { + if c.value == "" { + return errx.Newf(op, "Schema.Tables.%s is empty", c.field) + } + if !identifierRE.MatchString(c.value) { + return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value) + } + } + return nil +} diff --git a/sqlstore/sessions.go b/sqlstore/sessions.go new file mode 100644 index 0000000..8241bb5 --- /dev/null +++ b/sqlstore/sessions.go @@ -0,0 +1,98 @@ +package sqlstore + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type sessionStore struct{ storeBase } + +func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error { + const op = "authkit.sqlstore.SessionStore.CreateSession" + now := time.Now().UTC() + if ses.CreatedAt.IsZero() { + ses.CreatedAt = now + } + if ses.LastSeenAt.IsZero() { + ses.LastSeenAt = ses.CreatedAt + } + _, err := s.db.ExecContext(ctx, s.q.CreateSession, + ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP), + ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) { + const op = "authkit.sqlstore.SessionStore.GetSession" + var ( + ses authkit.Session + uidStr string + ipStr sql.NullString + ) + err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan( + &ses.IDHash, &uidStr, &ses.UserAgent, &ipStr, + &ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid)) + } + uid, err := scanUUID(uidStr) + if err != nil { + return nil, errx.Wrap(op, err) + } + ses.UserID = uid + if ipStr.Valid { + addr, err := scanAddr(&ipStr.String) + if err != nil { + return nil, errx.Wrap(op, err) + } + ses.IP = addr + } + return &ses, nil +} + +func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error { + const op = "authkit.sqlstore.SessionStore.TouchSession" + tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrSessionInvalid) + } + return nil +} + +func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error { + const op = "authkit.sqlstore.SessionStore.DeleteSession" + if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.sqlstore.SessionStore.DeleteUserSessions" + if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.sqlstore.SessionStore.DeleteExpired" + tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} diff --git a/sqlstore/sqlstore.go b/sqlstore/sqlstore.go new file mode 100644 index 0000000..a87d9ca --- /dev/null +++ b/sqlstore/sqlstore.go @@ -0,0 +1,59 @@ +// Package sqlstore provides database/sql-backed implementations of every +// authkit store interface. It works with any driver (pgx-stdlib, lib/pq, +// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via +// Schema. Driver-specific behaviour lives in a Dialect; v1 ships +// dialect/postgres. +package sqlstore + +import ( + "database/sql" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" +) + +// Stores bundles every store implementation produced by New, ready to drop +// into authkit.Deps. +type Stores struct { + Users authkit.UserStore + Sessions authkit.SessionStore + Tokens authkit.TokenStore + APIKeys authkit.APIKeyStore + Roles authkit.RoleStore + Permissions authkit.PermissionStore +} + +// New validates the Schema, builds the dialect's templated queries, and +// returns store implementations bound to db. +func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) { + const op = "authkit.sqlstore.New" + if db == nil { + return nil, errx.New(op, "db is required") + } + if dialect == nil { + return nil, errx.New(op, "dialect is required") + } + if err := schema.Validate(); err != nil { + return nil, errx.Wrap(op, err) + } + q := dialect.BuildQueries(schema) + + base := storeBase{db: db, q: q, d: dialect, s: schema} + return &Stores{ + Users: &userStore{storeBase: base}, + Sessions: &sessionStore{storeBase: base}, + Tokens: &tokenStore{storeBase: base}, + APIKeys: &apiKeyStore{storeBase: base}, + Roles: &roleStore{storeBase: base}, + Permissions: &permissionStore{storeBase: base}, + }, nil +} + +// storeBase carries the shared dependencies every store needs. Embedded into +// each concrete store struct. +type storeBase struct { + db *sql.DB + q Queries + d Dialect + s Schema +} diff --git a/sqlstore/sqlstore_test.go b/sqlstore/sqlstore_test.go new file mode 100644 index 0000000..8c444b6 --- /dev/null +++ b/sqlstore/sqlstore_test.go @@ -0,0 +1,258 @@ +package sqlstore_test + +// Integration tests against a real Postgres. Skipped unless +// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by +// running Migrate against a randomly-named set of tables — no cleanup +// fixture, no external Docker dependency. + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/netip" + "os" + "testing" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" + "git.juancwu.dev/juancwu/authkit/sqlstore" + pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +func envURL(t *testing.T) string { + t.Helper() + url := os.Getenv("AUTHKIT_TEST_DATABASE_URL") + if url == "" { + t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test") + } + return url +} + +// freshDB opens a connection, runs Migrate against the default schema, and +// schedules a teardown that drops every authkit_* table. Tests must run +// sequentially against a single database (the package's tests do, by +// default — go test serialises within a package unless t.Parallel is +// called). +func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) { + t.Helper() + url := envURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("ping: %v", err) + } + + schema := sqlstore.DefaultSchema() + // Drop first to start clean — previous failed tests may have left rows. + dropAuthkitTables(t, db, schema) + if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { + t.Fatalf("Migrate: %v", err) + } + t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) + + stores, err := sqlstore.New(db, pgdialect.New(), schema) + if err != nil { + t.Fatalf("sqlstore.New: %v", err) + } + auth := authkit.New(authkit.Deps{ + Users: stores.Users, + Sessions: stores.Sessions, + Tokens: stores.Tokens, + APIKeys: stores.APIKeys, + Roles: stores.Roles, + Permissions: stores.Permissions, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), + }, authkit.Config{ + JWTSecret: []byte("integration-secret-thirty-two!!!!"), + JWTIssuer: "authkit-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: 1 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + }) + return auth, db, schema +} + +func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) { + t.Helper() + tables := []string{ + s.Tables.UserRoles, s.Tables.RolePermissions, + s.Tables.Roles, s.Tables.Permissions, + s.Tables.APIKeys, s.Tables.Tokens, + s.Tables.Sessions, s.Tables.Users, + s.Tables.SchemaMigrations, + } + for _, name := range tables { + _, _ = db.ExecContext(context.Background(), + fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name)) + } +} + +func TestIntegration_MigrateIdempotent(t *testing.T) { + url := envURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + schema := sqlstore.DefaultSchema() + t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) + for i := 0; i < 3; i++ { + if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { + t.Fatalf("Migrate iter %d: %v", i, err) + } + } +} + +func TestIntegration_RegisterAndLogin(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("LoginPassword: %v", err) + } + if got.ID != u.ID { + t.Fatalf("login user id mismatch") + } + + if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken, got %v", err) + } + if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestIntegration_SessionLifecycle(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "s@s.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1")) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + if sess.ExpiresAt.Before(time.Now()) { + t.Fatalf("session already expired at issue") + } + if _, err := auth.AuthenticateSession(ctx, plain); err != nil { + t.Fatalf("AuthenticateSession: %v", err) + } + if err := auth.RevokeSession(ctx, plain); err != nil { + t.Fatalf("RevokeSession: %v", err) + } + if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) { + t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) + } +} + +func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "j@j.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + _, refresh1, err := auth.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + _, refresh2, err := auth.RefreshJWT(ctx, refresh1) + if err != nil { + t.Fatalf("first RefreshJWT: %v", err) + } + if refresh1 == refresh2 { + t.Fatalf("refresh token did not rotate") + } + + if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) { + t.Fatalf("expected ErrTokenReused on replay, got %v", err) + } + if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err) + } +} + +func TestIntegration_APIKeyWithAbilities(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "k@k.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, _, err := auth.IssueAPIKey(ctx, u.ID, "ci", + []string{"billing:read", "users:list"}, nil) + if err != nil { + t.Fatalf("IssueAPIKey: %v", err) + } + p, err := auth.AuthenticateAPIKey(ctx, plain) + if err != nil { + t.Fatalf("AuthenticateAPIKey: %v", err) + } + if !p.HasAbility("billing:read") || !p.HasAbility("users:list") { + t.Fatalf("abilities missing: %+v", p.Abilities) + } + if err := auth.RevokeAPIKey(ctx, plain); err != nil { + t.Fatalf("RevokeAPIKey: %v", err) + } + if _, err := auth.AuthenticateAPIKey(ctx, plain); !errors.Is(err, authkit.ErrAPIKeyInvalid) { + t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err) + } +} + +func TestIntegration_RBAC(t *testing.T) { + auth, db, schema := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "rb@b.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + + stores, _ := sqlstore.New(db, pgdialect.New(), schema) + r := &authkit.Role{Name: "editor"} + if err := stores.Roles.CreateRole(ctx, r); err != nil { + t.Fatalf("CreateRole: %v", err) + } + p := &authkit.Permission{Name: "posts:write"} + if err := stores.Permissions.CreatePermission(ctx, p); err != nil { + t.Fatalf("CreatePermission: %v", err) + } + if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil { + t.Fatalf("AssignPermissionToRole: %v", err) + } + if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + ok, err := auth.HasPermission(ctx, u.ID, "posts:write") + if err != nil || !ok { + t.Fatalf("HasPermission: %v %v", ok, err) + } + ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"}) + if err != nil || !ok { + t.Fatalf("HasAnyRole: %v %v", ok, err) + } + if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("RemoveRole: %v", err) + } + ok, _ = auth.HasPermission(ctx, u.ID, "posts:write") + if ok { + t.Fatalf("HasPermission should be false after RemoveRole") + } +} diff --git a/sqlstore/tokens.go b/sqlstore/tokens.go new file mode 100644 index 0000000..99841fc --- /dev/null +++ b/sqlstore/tokens.go @@ -0,0 +1,92 @@ +package sqlstore + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" +) + +type tokenStore struct{ storeBase } + +func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error { + const op = "authkit.sqlstore.TokenStore.CreateToken" + if t.CreatedAt.IsZero() { + t.CreatedAt = time.Now().UTC() + } + _, err := s.db.ExecContext(ctx, s.q.CreateToken, + t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID), + nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// ConsumeToken does the find-and-mark-consumed in a single statement so two +// concurrent callers cannot both successfully consume the same token. The +// row is returned for inspection (e.g. ChainID for refresh rotation). +func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) { + const op = "authkit.sqlstore.TokenStore.ConsumeToken" + row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) + } + return t, nil +} + +func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) { + const op = "authkit.sqlstore.TokenStore.GetToken" + row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) + } + return t, nil +} + +func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) { + const op = "authkit.sqlstore.TokenStore.DeleteByChain" + tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.sqlstore.TokenStore.DeleteExpired" + tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func scanToken(row rowScanner) (*authkit.Token, error) { + var ( + t authkit.Token + kind string + userIDStr string + chainID sql.NullString + consumedAt sql.NullTime + ) + if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, + &consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil { + return nil, err + } + t.Kind = authkit.TokenKind(kind) + uid, err := scanUUID(userIDStr) + if err != nil { + return nil, err + } + t.UserID = uid + t.ChainID = scanNullStringPtr(chainID) + t.ConsumedAt = scanNullTimePtr(consumedAt) + return &t, nil +} diff --git a/sqlstore/users.go b/sqlstore/users.go new file mode 100644 index 0000000..b03c379 --- /dev/null +++ b/sqlstore/users.go @@ -0,0 +1,186 @@ +package sqlstore + +import ( + "context" + "database/sql" + "errors" + "strings" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type userStore struct{ storeBase } + +func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error { + const op = "authkit.sqlstore.UserStore.CreateUser" + if u.ID == uuid.Nil { + u.ID = uuid.New() + } + if u.EmailNormalized == "" { + u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email)) + } + now := time.Now().UTC() + if u.CreatedAt.IsZero() { + u.CreatedAt = now + } + if u.UpdatedAt.IsZero() { + u.UpdatedAt = now + } + _, err := s.db.ExecContext(ctx, s.q.CreateUser, + uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, + nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt) + if err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrap(op, authkit.ErrEmailTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) { + const op = "authkit.sqlstore.UserStore.GetUserByID" + u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return u, nil +} + +func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) { + const op = "authkit.sqlstore.UserStore.GetUserByEmail" + u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return u, nil +} + +func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error { + const op = "authkit.sqlstore.UserStore.UpdateUser" + u.UpdatedAt = time.Now().UTC() + tag, err := s.db.ExecContext(ctx, s.q.UpdateUser, + u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, + nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.UserStore.DeleteUser" + tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error { + const op = "authkit.sqlstore.UserStore.SetPassword" + tag, err := s.db.ExecContext(ctx, s.q.SetPassword, + nullableString(encodedHash), time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error { + const op = "authkit.sqlstore.UserStore.SetEmailVerified" + tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) { + const op = "authkit.sqlstore.UserStore.BumpSessionVersion" + var v int + if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion, + time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return v, nil +} + +func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) { + const op = "authkit.sqlstore.UserStore.IncrementFailedLogins" + var n int + if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins, + time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return n, nil +} + +func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.sqlstore.UserStore.ResetFailedLogins" + tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func scanUser(row rowScanner) (*authkit.User, error) { + var ( + u authkit.User + idStr string + passwordHash sql.NullString + emailVerified sql.NullTime + lastLogin sql.NullTime + ) + if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified, + &passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin, + &u.CreatedAt, &u.UpdatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + u.ID = id + if passwordHash.Valid { + u.PasswordHash = passwordHash.String + } + u.EmailVerifiedAt = scanNullTimePtr(emailVerified) + u.LastLoginAt = scanNullTimePtr(lastLogin) + return &u, nil +} + +// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so +// callers get reliable errors.Is targets through errx wrapping. +func mapNotFound(err error, sentinel error) error { + if errors.Is(err, sql.ErrNoRows) { + return sentinel + } + return err +} diff --git a/stores.go b/stores.go new file mode 100644 index 0000000..36f2374 --- /dev/null +++ b/stores.go @@ -0,0 +1,83 @@ +package authkit + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +type UserStore interface { + CreateUser(ctx context.Context, u *User) error + GetUserByID(ctx context.Context, id uuid.UUID) (*User, error) + GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error) + UpdateUser(ctx context.Context, u *User) error + DeleteUser(ctx context.Context, id uuid.UUID) error + SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error + SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error + BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) + IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) + ResetFailedLogins(ctx context.Context, userID uuid.UUID) error +} + +type SessionStore interface { + CreateSession(ctx context.Context, s *Session) error + GetSession(ctx context.Context, idHash []byte) (*Session, error) + TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error + DeleteSession(ctx context.Context, idHash []byte) error + DeleteUserSessions(ctx context.Context, userID uuid.UUID) error + DeleteExpired(ctx context.Context, now time.Time) (int64, error) +} + +type TokenStore interface { + CreateToken(ctx context.Context, t *Token) error + // ConsumeToken atomically marks the matching unexpired, unconsumed token + // as consumed and returns it. Returns ErrTokenInvalid if no row matched. + // Implementations MUST do this in one statement (UPDATE ... RETURNING) + // to prevent double-spend under concurrent callers. + ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) + // GetToken returns a token without consuming it. Used for refresh-token + // reuse detection: a token that exists with consumed_at != nil is a + // replay signal. + GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) + DeleteByChain(ctx context.Context, chainID string) (int64, error) + DeleteExpired(ctx context.Context, now time.Time) (int64, error) +} + +type APIKeyStore interface { + CreateAPIKey(ctx context.Context, k *APIKey) error + GetAPIKey(ctx context.Context, idHash []byte) (*APIKey, error) + ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*APIKey, error) + TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error + RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error + RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error +} + +type RoleStore interface { + CreateRole(ctx context.Context, r *Role) error + GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error) + GetRoleByName(ctx context.Context, name string) (*Role, error) + ListRoles(ctx context.Context) ([]*Role, error) + DeleteRole(ctx context.Context, id uuid.UUID) error + AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error + RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error + GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error) + HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) +} + +type PermissionStore interface { + CreatePermission(ctx context.Context, p *Permission) error + GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error) + GetPermissionByName(ctx context.Context, name string) (*Permission, error) + ListPermissions(ctx context.Context) ([]*Permission, error) + DeletePermission(ctx context.Context, id uuid.UUID) error + AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error + RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error + GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error) + GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error) +} + +type Hasher interface { + Hash(password string) (string, error) + Verify(password, encoded string) (ok bool, needsRehash bool, err error) +} diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..78d5e22 --- /dev/null +++ b/tokens.go @@ -0,0 +1,67 @@ +package authkit + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "io" + "strings" + + "git.juancwu.dev/juancwu/errx" +) + +const secretRandomBytes = 32 + +// Secret prefixes. The leading token namespaces the secret so callers can +// route them to the right verifier without trial-and-error. +const ( + prefixSession = "sess" + prefixRefresh = "rfr" + prefixAPIKey = "ak" + prefixEmailVerify = "evr" + prefixPasswordRset = "pwr" + prefixMagicLink = "mlnk" +) + +// mintSecret generates a new opaque secret of the given prefix kind and +// returns the plaintext (to be returned to the user, never stored) and the +// SHA-256 lookup hash (to be stored). +func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) { + const op = "authkit.mintSecret" + if rng == nil { + rng = rand.Reader + } + buf := make([]byte, secretRandomBytes) + if _, err := io.ReadFull(rng, buf); err != nil { + return "", nil, errx.Wrap(op, err) + } + body := base64.RawURLEncoding.EncodeToString(buf) + plaintext = prefix + "_" + body + hash = hashSecret(plaintext) + return plaintext, hash, nil +} + +// hashSecret returns sha256(plaintext) — the lookup key for opaque secrets. +// Plaintexts have full entropy from a CSPRNG so a plain hash is sufficient +// (no per-record salt needed; the random body is the salt). +func hashSecret(plaintext string) []byte { + sum := sha256.Sum256([]byte(plaintext)) + return sum[:] +} + +// parseSecret validates that a plaintext starts with the expected prefix +// and returns the lookup hash. The prefix check is constant-time relative +// to a fixed-length comparison. +func parseSecret(prefix, plaintext string) (hash []byte, ok bool) { + want := prefix + "_" + if !strings.HasPrefix(plaintext, want) { + return nil, false + } + return hashSecret(plaintext), true +} + +// constantTimeEqual is a thin wrapper for readability at call sites. +func constantTimeEqual(a, b []byte) bool { + return subtle.ConstantTimeCompare(a, b) == 1 +} diff --git a/tokens_test.go b/tokens_test.go new file mode 100644 index 0000000..391437a --- /dev/null +++ b/tokens_test.go @@ -0,0 +1,56 @@ +package authkit + +import ( + "bytes" + "crypto/sha256" + "strings" + "testing" +) + +func TestMintSecretRoundtrip(t *testing.T) { + plaintext, hash, err := mintSecret(prefixSession, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if !strings.HasPrefix(plaintext, prefixSession+"_") { + t.Fatalf("missing prefix: %q", plaintext) + } + parsed, ok := parseSecret(prefixSession, plaintext) + if !ok { + t.Fatalf("parseSecret rejected our own mint") + } + if !bytes.Equal(hash, parsed) { + t.Fatalf("hash mismatch") + } + want := sha256.Sum256([]byte(plaintext)) + if !bytes.Equal(hash, want[:]) { + t.Fatalf("hashSecret != sha256(plaintext)") + } +} + +func TestParseSecretWrongPrefix(t *testing.T) { + plaintext, _, err := mintSecret(prefixSession, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if _, ok := parseSecret(prefixAPIKey, plaintext); ok { + t.Fatalf("parseSecret should reject mismatched prefix") + } + if _, ok := parseSecret(prefixSession, "sessXXXX"); ok { + t.Fatalf("parseSecret should require trailing underscore") + } +} + +func TestMintSecretUniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + p, _, err := mintSecret(prefixAPIKey, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if _, dup := seen[p]; dup { + t.Fatalf("duplicate mint: %s", p) + } + seen[p] = struct{}{} + } +}