diff --git a/README.md b/README.md index 926335b..ddee04c 100644 --- a/README.md +++ b/README.md @@ -1,408 +1,283 @@ # authkit -A pragmatic authentication and authorization toolkit for Go web services. +A pragmatic authentication and authorization toolkit for Go web services +on PostgreSQL 16+. -`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and -permissions, plus default `database/sql` Postgres implementations and -framework-neutral HTTP middleware. It supports both opaque server-side -sessions and JWT access tokens with rotating refresh tokens, hashes passwords -with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux) -or any `net/http` stack. +`authkit` is a library, not a service. Drop it into a `net/http` stack and +get registration, password login, opaque server-side sessions, JWT access +tokens with rotating refresh, email verification, password reset, +magic-link login, email OTP, and machine-targeted service tokens with +consumer-defined abilities. Authorization is flat RBAC with both +role-derived and direct user permissions. + +> **Status:** v1.0.0 development. The API is being stabilised; expect +> breaking changes until the v1.0.0 tag. ## Install -``` -go get git.juancwu.dev/juancwu/authkit@v0.1.0 +```sh +go get git.juancwu.dev/juancwu/authkit ``` -`authkit` itself depends only on `database/sql` and the Go standard library -plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring -your own driver: `pgx`, `lib/pq`, or anything else that registers a -`database/sql` driver. +`authkit` depends only on the Go standard library, `golang-jwt`, +`google/uuid`, `golang.org/x/crypto`, and `errx`. Bring your own driver: +`pgx`, `lib/pq`, or anything else that registers a `database/sql` driver. ```go import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" ``` -PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and -`pgcrypto` so no extensions are required. +PostgreSQL 16 or newer is required. ## What's included -**Authentication flows** -- Email + password registration and login (Argon2id PHC-encoded hashes) +**Authentication** +- Email-only registration (`CreateUser`); password is optional and can be + set later via `SetPassword` +- Password login with Argon2id PHC-encoded hashes - Opaque server-side sessions with sliding TTL bounded by an absolute cap -- JWT access tokens (HS256) with rotating refresh tokens and reuse detection -- Email verification, password reset, and magic-link passwordless login +- HS256 JWT access tokens with rotating refresh tokens and reuse + detection +- Email verification, password reset, magic-link login, email OTP **Authorization** -- Roles and permissions with many-to-many wiring (resolved on user-bound - `Principal`s) -- Owner-agnostic service tokens with custom abilities for server-to-server - auth (no FK on owner; cascade-on-delete is the consumer's responsibility) -- A `Principal` for user-bound auth (sessions, JWTs) and a `ServiceKey` for - service-token auth — middleware composes around both subject types +- Roles and permissions (flat RBAC) +- Direct user-permission grants in addition to role-derived ones — + `UserPermissions` returns the UNION +- Service tokens with consumer-defined abilities (machine credentials, no + user owner) + +**Predicate API for middleware authz** +- Leaves: `HasRole(slug)`, `HasPermission(slug)`, `HasAbility(slug)` +- Combinators: `AnyLogin`, `AllLogin`, `AnyServiceKey`, `AllServiceKey` +- Compose freely: + `AnyLogin(HasRole("admin"), AllLogin(HasRole("manager"), HasRole("ads_manager")))` + +**HTTP middleware** +- `RequireLogin` — accept session cookie OR JWT, optionally constrain by + `LoginAuthz` +- `RequireGuest` — block authenticated requests (with a configurable + `OnAuthenticated` callback for redirects) +- `RequireServiceKey` — accept a service token, optionally constrain by + `ServiceKeyAuthz` **Storage** -- Interfaces for every store so callers can plug in their own backends -- Default Postgres implementation built on `*sql.DB` (`sqlstore` package) -- Override table names via `Schema` without forking — useful when authkit - lives alongside an existing application schema -- A `Dialect` abstraction so future MySQL / SQLite implementations slot in - without changes to store code -- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)` - helper that takes a session-scoped advisory lock - -**HTTP** -- User-bound: `middleware.RequireSession`, `RequireJWT`, `RequireAny` -- Service-bound: `middleware.RequireServiceKey` -- Either: `middleware.RequireAnyOrServiceKey` (Session/JWT, falling through to - ServiceKey) -- Authz: `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission` - (operate on `*Principal`); `middleware.RequireAbility` (operates on - `*ServiceKey`) -- `middleware.PrincipalFrom(ctx)` and `middleware.ServiceKeyFrom(ctx)` to - read the authenticated subject in handlers +- PostgreSQL 16+ only +- Migrations and schema verification run on startup (opt-out via + `Config.SkipAutoMigrate` / `Config.SkipSchemaVerify`) +- Override individual table names via `Schema.Tables` +- Schema verifier tolerates extra columns; flags missing tables, missing + columns, type drift, and nullability drift **Errors** -- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`, - `ErrTokenReused`, `ErrSessionInvalid`, `ErrServiceKeyInvalid`, - `ErrPermissionDenied`, ...) compatible with `errors.Is` +- Sentinel errors compatible with `errors.Is` - All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx) - for op tags - -## Out of scope (v1) - -MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching, -pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite -dialects (architecture supports them; only Postgres ships in v1), and -column-name overrides in `Schema` (table-name overrides only). ## Quick start -### 1. Open a database and run migrations +### 1. Open a database ```go import ( "database/sql" - - "git.juancwu.dev/juancwu/authkit/sqlstore" - pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" - - _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" + _ "github.com/jackc/pgx/v5/stdlib" ) db, err := sql.Open("pgx", os.Getenv("DATABASE_URL")) if err != nil { /* ... */ } defer db.Close() - -if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil { - log.Fatal(err) -} ``` -`Migrate` is idempotent and safe to call from multiple processes — it takes -a session-scoped `pg_advisory_lock` to serialise rollouts. - -`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same -calls — the library only cares about `*sql.DB`. - -### 2. Wire the service +### 2. Construct Auth ```go import ( + "context" + "git.juancwu.dev/juancwu/authkit" "git.juancwu.dev/juancwu/authkit/hasher" ) -stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema()) -if err != nil { /* ... */ } - -auth := authkit.New(authkit.Deps{ - Users: stores.Users, - Sessions: stores.Sessions, - Tokens: stores.Tokens, - ServiceKeys: stores.ServiceKeys, - Roles: stores.Roles, - Permissions: stores.Permissions, - Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), +auth, err := authkit.New(ctx, authkit.Deps{ + DB: db, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), }, authkit.Config{ - JWTSecret: []byte(os.Getenv("JWT_SECRET")), - JWTIssuer: "myapp", - SessionCookieSecure: true, - SessionCookieHTTPOnly: true, + JWTSecret: []byte(os.Getenv("JWT_SECRET")), + JWTIssuer: "myapp", }) +if err != nil { log.Fatal(err) } ``` -`Config` zero values fall back to sensible defaults (24h idle / 30d absolute -session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h -password-reset, 15m magic-link). `JWTSecret` and all seven `Deps` fields are -required; `New` panics on a misconfiguration. +`New` runs migrations and verifies the schema. `Config` zero values fall +back to sane defaults: 24h idle / 30d absolute session TTL, 15m access / +30d refresh, 48h email-verify, 1h password-reset, 15m magic-link, 10m +email OTP with 5 attempts. Cookie defaults: `Secure=true`, `HttpOnly=true`, +`SameSite=Lax`. Pass `authkit.BoolPtr(false)` to opt out for local dev. -### 3. Use the service +### 3. Seed roles, permissions, and abilities + +`authkit` does not seed any rows. Use the bundled CLIs: + +```sh +go install git.juancwu.dev/juancwu/authkit/cmd/perms@latest +go install git.juancwu.dev/juancwu/authkit/cmd/roles@latest +go install git.juancwu.dev/juancwu/authkit/cmd/abilities@latest + +export AUTHKIT_DATABASE_URL=postgres://... + +perms create posts:write --label "Write posts" +perms create posts:read --label "Read posts" +roles create editor --label "Editor" +roles grant editor posts:write +roles grant editor posts:read + +abilities create events:write --label "Events ingest" +``` + +Or call the equivalent methods on `*authkit.Auth` from your own seed +script. Slugs match `^[a-z][a-z0-9_:-]*$` (max 64 bytes); invalid slugs +return `ErrSlugInvalid`. + +### 4. User flows ```go -// Registration + password login -u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") -u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2") +// Email-only account, password set later. +u, _ := auth.CreateUser(ctx, "alice@example.com") +_ = auth.SetPassword(ctx, u.ID, "hunter2hunter2") +u, _ = auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2") // case-insensitive -// Opaque session (cookie-friendly) -plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP) +// Opaque session. +plaintext, sess, _ := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP) http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt)) -// JWT + rotating refresh -access, refresh, err := auth.IssueJWT(ctx, u.ID) -access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed +// JWT + rotating refresh. +access, refresh, _ := auth.IssueJWT(ctx, u.ID) +access, refresh, _ = auth.RefreshJWT(ctx, refresh) // old refresh is consumed -// Service token (owner-agnostic; ownerKind labels the namespace). -// Service tokens are the only credential type that carries free-form abilities. -plaintext, sk, err := auth.IssueServiceKey(ctx, - "application", appID, "events-ingest", - []string{"events:write"}, nil) -got, err := auth.AuthenticateServiceKey(ctx, plaintext) -// got.OwnerKind == "application"; got.OwnerID == appID -err = auth.RevokeServiceKey(ctx, plaintext) +// Magic link / OTP / password reset (anti-enumeration: silent on unknown email). +linkToken, _ := auth.RequestMagicLink(ctx, "alice@example.com") +otpCode, _ := auth.RequestEmailOTP(ctx, "alice@example.com") +resetToken, _ := auth.RequestPasswordReset(ctx, "alice@example.com") -// Email verification + password reset + magic link -tok, err := auth.RequestEmailVerification(ctx, u.ID) -_, err = auth.ConfirmEmail(ctx, tok) - -tok, err = auth.RequestPasswordReset(ctx, "alice@example.com") -err = auth.ConfirmPasswordReset(ctx, tok, "new-password") - -tok, err = auth.RequestMagicLink(ctx, "alice@example.com") -u, err = auth.ConsumeMagicLink(ctx, tok) +// Service token with abilities. +plaintext, sk, _ := auth.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{ + Name: "events-ingest", + Abilities: []string{"events:write"}, +}) +got, _ := auth.AuthenticateServiceKey(ctx, plaintext) ``` -The plaintext returned by `IssueSession`, `IssueJWT`, `IssueServiceKey`, and -the token-minting flows is **show-once** — only its SHA-256 hash is stored. -Show it to the user immediately; you cannot recover it later. +The plaintext returned by every issue/mint flow is **show-once** — only +its SHA-256 hash is stored. Show it to the user immediately; you cannot +recover it later. -### 4. Wire middleware - -`authkit/middleware` returns standard `func(http.Handler) http.Handler` -values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any -`net/http` mux that accepts the same shape. +### 5. Wire middleware ```go import ( - authkitmw "git.juancwu.dev/juancwu/authkit/middleware" - "git.juancwu.dev/juancwu/lightmux" + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/middleware" ) -mux := lightmux.New() +// Default RequireLogin reads the session cookie and falls through to a +// Bearer JWT. +loginMW := middleware.RequireLogin(middleware.LoginOptions{Auth: auth}) -cookieAuth := authkitmw.RequireSession(authkitmw.Options{ - Auth: auth, - Extractor: authkit.ChainExtractors( - authkit.CookieExtractor("authkit_session"), - authkit.BearerExtractor(), +// Constrain on roles/permissions: +adminMW := middleware.RequireLogin(middleware.LoginOptions{ + Auth: auth, + Authz: authkit.AnyLogin( + authkit.HasRole("admin"), + authkit.AllLogin(authkit.HasRole("manager"), authkit.HasRole("ads_manager")), ), }) -me := mux.Group("/me", cookieAuth) -me.Get("", func(w http.ResponseWriter, r *http.Request) { - p := authkitmw.MustPrincipal(r) - json.NewEncoder(w).Encode(p) +// Login/register pages: block if already authenticated. +guestMW := middleware.RequireGuest(middleware.GuestOptions{ + Auth: auth, + OnAuthenticated: func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) + }, }) -// RBAC: stack authz on top of any auth method -admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin")) - -// Service-token route with a per-endpoint ability check -api := mux.Group("/api/v1", authkitmw.RequireServiceKey(authkitmw.Options{Auth: auth})) -api.Get("/events", eventsHandler, authkitmw.RequireAbility("events:write")) - -// Mixed route — accept either a session cookie or a service token -mixed := mux.Group("/v1", authkitmw.RequireAnyOrServiceKey(authkitmw.Options{Auth: auth})) -mixed.Get("/profile", func(w http.ResponseWriter, r *http.Request) { - if p, ok := authkitmw.PrincipalFrom(r.Context()); ok { - // user request - _ = p - } else if k, ok := authkitmw.ServiceKeyFrom(r.Context()); ok { - // service request - _ = k - } +// Service tokens with an ability gate. +apiMW := middleware.RequireServiceKey(middleware.ServiceKeyOptions{ + Auth: auth, + Authz: authkit.AllServiceKey(authkit.HasAbility("events:write")), }) ``` -`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or -chain extractors) when reading session cookies. `Options.OnUnauth` and -`Options.OnForbidden` default to a JSON `401` / `403`; override them to match -your error envelope. +`RequireLogin` and `RequireServiceKey` panic at construction if any slug +referenced by the predicate isn't registered in the database — typos fail +at boot, not at request time. -## Custom table names +### 6. Read the user in handlers -Pass a non-default `Schema` to use your own table names. Identifiers must -match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and -`Migrate()` time, so SQL injection through the schema is impossible. +Middleware attaches the `user_id` to the request context. Handlers fetch +the full user lazily: ```go -schema := sqlstore.DefaultSchema() -schema.Tables.Users = "accounts" -schema.Tables.ServiceKeys = "service_credentials" +func handle(w http.ResponseWriter, r *http.Request) { + id, _ := authkit.UserIDFromCtx(r.Context()) // never queries the DB + u, err := authkit.UserFromCtx(r.Context()) // lazy-load + per-request cache + if err != nil { /* handle */ } -stores, _ := sqlstore.New(db, pgdialect.New(), schema) -``` - -The bundled migration files use the default `authkit_*` names. If you -override, you're responsible for matching DDL (most consumers with custom -naming already have their own DDL pipeline). - -Column-name overrides are not exposed in v1 — the column set is fixed for -each table. Adding column overrides later is purely additive. - -## How things work - -### Secret token format - -Sessions, refresh tokens, service tokens, email-verify tokens, -password-reset tokens, and magic-link tokens all share one format: - -``` -plaintext = "_" + base64url(32 random bytes, no padding) -lookup = sha256(plaintext) -``` - -Plaintext is returned to the caller exactly once and never persisted; the -SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or -`Config.Random` for tests). The mint/parse/hash helpers are exported as -`MintOpaqueSecret`, `ParseOpaqueSecret`, and `HashOpaqueSecret` for callers -building bespoke token storage on top of the same shape. - -### User credentials vs. service tokens - -`authkit` exposes two distinct subject types, and middleware composes around -them differently. - -**User credentials** — sessions and JWTs — prove **identity**. They are -produced by `IssueSession` / `IssueJWT` and authenticate via -`AuthenticateSession` / `AuthenticateJWT`, which return a `*Principal` -carrying `UserID`, `Method`, and the user's roles + permissions resolved -through RBAC. Authorization on these requests is **role/permission-based** -via `RequireRole` / `RequirePermission`. User credentials carry no abilities; -"what this user may do" is answered by the user's RBAC, not by anything -embedded on the credential itself. - -**Service tokens** — `IssueServiceKey` — prove **"this caller may do X"**. -They are owner-agnostic: `OwnerKind` labels the namespace ("application", -"tenant", whatever) and `OwnerID` identifies the entity within it. The -database column has **no foreign key** on purpose — `authkit` makes no -assumption about what the owner is, and cascade-on-delete is the consumer's -responsibility. `AuthenticateServiceKey` returns a `*ServiceKey` directly -(no `*Principal`, no role/permission resolution). Authorization on these -requests is **ability-based** via `RequireAbility`; the abilities slice is -free-form and not linked to `authkit_roles` / `authkit_permissions`. - -```go -plaintext, key, err := auth.IssueServiceKey(ctx, - "application", appID, "events-ingest", - []string{"events:write"}, nil) - -k, err := auth.AuthenticateServiceKey(ctx, plaintext) -// k.OwnerKind == "application"; k.OwnerID == appID; k.HasAbility("events:write") - -err = auth.RevokeServiceKey(ctx, plaintext) -``` - -When a consumer-owned entity (an application, a tenant) is deleted, the -consumer must revoke the associated service tokens itself — typically by -iterating `ListServiceKeys(ctx, ownerKind, ownerID)`. - -### JWT revocation - -Access tokens carry `sv` (`session_version`) in their claims. When you call -`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version` -column increments and every outstanding access token fails the next -`AuthenticateJWT`. This is the only way to invalidate a JWT before its -`exp`. - -### Refresh token rotation - -Each `RefreshJWT` consumes the presented refresh token and issues a new one -on the same chain (the `chain_id` column on `authkit_tokens`). If a -*consumed* refresh token is ever presented again — a strong replay signal — -the entire chain is deleted via `TokenStore.DeleteByChain` and the call -returns `ErrTokenReused`. - -### Sliding session TTL - -Each authenticated request via `AuthenticateSession` slides `expires_at` to -`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`. -Long-lived idle sessions still hit the absolute boundary. - -### Schema and migrations - -`sqlstore.Migrate` applies every embedded `.sql` file under -`sqlstore/dialect/postgres/migrations/` whose version (filename without -`.sql`) is not in `authkit_schema_migrations`. The dialect's -`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises -concurrent migrators. Each migration owns its own transaction so future -migrations can use statements like `CREATE INDEX CONCURRENTLY`. - -Every default table is prefixed `authkit_` so the schema can live alongside -your application's own tables in a shared database. - -### Driver and dialect architecture - -The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour -lives behind a small `Dialect` interface: - -```go -type Dialect interface { - Name() string - BuildQueries(s Schema) Queries - Bootstrap(ctx context.Context, db *sql.DB) error - AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) - Migrations() fs.FS - IsUniqueViolation(err error) bool - Placeholder(n int) string - PlaceholderList(start, count int) string + // After an admin-side update that should be visible: + u, err = authkit.RefreshUserInCtx(r.Context()) + _ = u; _ = id } ``` -v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new -implementation; no changes to store code. +The cache lives only for the request lifetime — nothing persists across +requests. For service-token routes, use `authkit.ServiceKeyFromCtx`. + +## Schema verification and drift + +On `New`, `authkit` introspects `information_schema.columns` and verifies +the live database matches the expected layout (table presence, column +names, `data_type`, `is_nullable`). Extra columns are tolerated; missing +tables/columns and type drift fail with `ErrSchemaDrift`. + +When a table cannot be found under the configured name, the verifier +falls back to the default `authkit_*` name. This handles migrations from +custom names back to defaults without manual intervention. ## Configuration reference | Field | Default | Notes | |---|---|---| +| `Schema` | `DefaultSchema()` | Override individual `Tables` fields; missing fields fall back to defaults | +| `SkipAutoMigrate` | `false` | Disables migration run inside `New` | +| `SkipSchemaVerify` | `false` | Disables schema check inside `New` | | `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request | | `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this | | `SessionCookieName` | `authkit_session` | | +| `SessionCookieSecure` | `*true` | Pass `BoolPtr(false)` for local HTTP dev | +| `SessionCookieHTTPOnly` | `*true` | Pass `BoolPtr(false)` if JS must read it (rarely correct) | | `SessionCookieSameSite` | `Lax` | | -| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production | | `JWTSecret` | — (required) | HS256 key | -| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them | -| `AccessTokenTTL` | 15m | | -| `RefreshTokenTTL` | 30d | | +| `AccessTokenTTL` / `RefreshTokenTTL` | 15m / 30d | | | `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | | -| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests | +| `EmailOTPTTL` / `EmailOTPDigits` / `EmailOTPMaxAttempts` | 10m / 6 / 5 | | +| `RevealUnknownEmail` | `false` | Default anti-enumeration: silent success on unknown email | +| `Clock` | `time.Now().UTC` | Override for deterministic tests | | `Random` | `crypto/rand.Reader` | Override for deterministic tests | -| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit | - -## Implementing your own store - -Every store is a small interface with explicit semantics — see `stores.go`. -The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token -consumed and return it in a single statement (`UPDATE ... RETURNING` on -Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed. +| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit. Panics in the hook are recovered. | ## Testing -``` -go test ./... # unit tests, no DB -AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests +```sh +go test ./... # unit tests, no DB required +AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./... -run Integration ``` -Unit tests cover token mint/parse, Argon2id encode/verify (including -`needsRehash` on parameter change), JWT issue/parse (incl. expired, -`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email -verification, password reset cascading session invalidation, magic-link -self-verification, API keys with abilities, and RBAC role-permission -resolution. Integration tests run the full `sqlstore` contract against a -real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set. +The unit suite covers slug validation (incl. fuzz), opaque-secret +roundtrip, email normalization, HTTP extractors, predicate combinators, +and OTP code generation. Integration tests cover every database-bound +flow: registration, login, sessions, JWT refresh + reuse, magic link, +email OTP (incl. attempt cap), password reset, service tokens, RBAC, +direct user permissions, schema verification (drift cases + fallback), +migration idempotency, lazy user-context cache, and middleware behavior. ## License diff --git a/authkit.go b/authkit.go index f1c69a2..b357915 100644 --- a/authkit.go +++ b/authkit.go @@ -3,6 +3,7 @@ package authkit import ( "context" "crypto/rand" + "database/sql" "io" "net/http" "time" @@ -10,31 +11,53 @@ import ( "git.juancwu.dev/juancwu/errx" ) -// Deps bundles every backing store and the password hasher the Auth service -// depends on. All fields are required; New panics on a nil dep so misuse is -// caught at boot rather than under load. +// Hasher is the password hashing interface. The default implementation is +// hasher.Argon2id; consumers can swap in alternative KDFs as long as the +// encoded form lets Verify roundtrip and report needsRehash on parameter +// drift. +type Hasher interface { + Hash(password string) (string, error) + Verify(password, encoded string) (ok bool, needsRehash bool, err error) +} + +// Deps bundles the runtime dependencies the Auth service requires. DB and +// Hasher are required; New panics on either being nil. type Deps struct { - Users UserStore - Sessions SessionStore - Tokens TokenStore - ServiceKeys ServiceKeyStore - Roles RoleStore - Permissions PermissionStore - Hasher Hasher + DB *sql.DB + Hasher Hasher } // Config tunes session/JWT/token TTLs, cookie shape, JWT signing material, -// and optional hooks. Any zero-valued duration is replaced with a sane -// default in New; required fields (notably JWTSecret) cause New to panic. +// schema overrides, and optional hooks. Zero-valued durations are replaced +// with sane defaults in New; required fields (notably JWTSecret) cause New +// to panic. type Config struct { - // Session (opaque) cookies + DB-backed lifetime - SessionIdleTTL time.Duration - SessionAbsoluteTTL time.Duration - SessionCookieName string - SessionCookieDomain string - SessionCookiePath string - SessionCookieSecure bool - SessionCookieHTTPOnly bool + // Schema lets consumers override table names. Zero value uses + // DefaultSchema(). + Schema Schema + + // SkipAutoMigrate disables the migration run inside New. The verifier + // still runs; consumers running their own migrate pipeline should set + // this and call authkit.Migrate manually before New (or skip it + // entirely if they manage DDL out-of-band). + SkipAutoMigrate bool + + // SkipSchemaVerify disables the startup schema check. Recommended only + // for tests that expect schema drift; production callers should let the + // verifier run. + SkipSchemaVerify bool + + // Sessions + SessionIdleTTL time.Duration + SessionAbsoluteTTL time.Duration + SessionCookieName string + SessionCookieDomain string + SessionCookiePath string + // SessionCookieSecure / SessionCookieHTTPOnly use *bool so a nil value + // means "fall back to the secure default (true)" while *bool(false) is + // an explicit opt-out for local dev. BoolPtr is a one-line constructor. + SessionCookieSecure *bool + SessionCookieHTTPOnly *bool SessionCookieSameSite http.SameSite // JWT (HS256) @@ -45,38 +68,82 @@ type Config struct { RefreshTokenTTL time.Duration // Single-use tokens - EmailVerifyTTL time.Duration - PasswordResetTTL time.Duration - MagicLinkTTL time.Duration + EmailVerifyTTL time.Duration + PasswordResetTTL time.Duration + MagicLinkTTL time.Duration + EmailOTPTTL time.Duration + EmailOTPMaxAttempts int + EmailOTPDigits int - // Hooks (optional) + // RevealUnknownEmail flips request flows (RequestPasswordReset, + // RequestMagicLink, RequestEmailOTP) from anti-enumeration silent + // success to returning ErrUserNotFound when the email isn't + // registered. Default false (silent). + RevealUnknownEmail bool + + // Hooks Clock func() time.Time Random io.Reader LoginHook func(ctx context.Context, email string, success bool) error } -// Auth is the high-level service that composes the stores and hasher into the -// flows callers use: registration, login, sessions, JWTs, magic links, API -// keys, and authz checks. It is safe for concurrent use; method receivers +// Auth is the high-level service. Safe for concurrent use; method receivers // never mutate Auth state after construction. type Auth struct { - deps Deps - cfg Config + db *sql.DB + hasher Hasher + cfg Config + q queries + schema Schema } -// New validates Deps and Config, fills in defaults, and returns a ready -// service. It panics on missing deps or missing JWT secret rather than -// returning an error — these are programmer errors, not runtime ones. -func New(deps Deps, cfg Config) *Auth { - if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil || - deps.ServiceKeys == nil || deps.Roles == nil || deps.Permissions == nil || - deps.Hasher == nil { - panic(errx.New("authkit.New", "all Deps fields are required")) +// New validates Deps and Config, fills in defaults, runs migrations +// (unless SkipAutoMigrate), verifies the schema (unless SkipSchemaVerify), +// and returns a ready service. +// +// Panics on missing deps, missing JWT secret, invalid schema, or schema +// drift — these are programmer/operator errors, not runtime failures. +func New(ctx context.Context, deps Deps, cfg Config) (*Auth, error) { + const op = "authkit.New" + if deps.DB == nil { + panic(errx.New(op, "Deps.DB is required")) + } + if deps.Hasher == nil { + panic(errx.New(op, "Deps.Hasher is required")) } if len(cfg.JWTSecret) == 0 { - panic(errx.New("authkit.New", "Config.JWTSecret is required")) + panic(errx.New(op, "Config.JWTSecret is required")) } + cfg.Schema = mergeSchemaDefaults(cfg.Schema) + if err := cfg.Schema.Validate(); err != nil { + panic(errx.Wrap(op, err)) + } + + cfg = applyDefaults(cfg) + + a := &Auth{ + db: deps.DB, + hasher: deps.Hasher, + cfg: cfg, + q: buildQueries(cfg.Schema.Tables), + schema: cfg.Schema, + } + + if !cfg.SkipAutoMigrate { + if err := Migrate(ctx, deps.DB, cfg.Schema); err != nil { + return nil, errx.Wrap(op, err) + } + } + if !cfg.SkipSchemaVerify { + if err := VerifySchema(ctx, deps.DB, cfg.Schema); err != nil { + return nil, errx.Wrap(op, err) + } + } + return a, nil +} + +func applyDefaults(cfg Config) Config { if cfg.SessionIdleTTL == 0 { cfg.SessionIdleTTL = 24 * time.Hour } @@ -92,6 +159,14 @@ func New(deps Deps, cfg Config) *Auth { if cfg.SessionCookieSameSite == 0 { cfg.SessionCookieSameSite = http.SameSiteLaxMode } + // Secure & HTTPOnly default to true. Consumers wanting plain HTTP for + // local dev must pass an explicit *false via BoolPtr. + if cfg.SessionCookieSecure == nil { + cfg.SessionCookieSecure = BoolPtr(true) + } + if cfg.SessionCookieHTTPOnly == nil { + cfg.SessionCookieHTTPOnly = BoolPtr(true) + } if cfg.AccessTokenTTL == 0 { cfg.AccessTokenTTL = 15 * time.Minute } @@ -107,15 +182,34 @@ func New(deps Deps, cfg Config) *Auth { if cfg.MagicLinkTTL == 0 { cfg.MagicLinkTTL = 15 * time.Minute } + if cfg.EmailOTPTTL == 0 { + cfg.EmailOTPTTL = 10 * time.Minute + } + if cfg.EmailOTPMaxAttempts == 0 { + cfg.EmailOTPMaxAttempts = 5 + } + if cfg.EmailOTPDigits == 0 { + cfg.EmailOTPDigits = 6 + } if cfg.Clock == nil { cfg.Clock = func() time.Time { return time.Now().UTC() } } if cfg.Random == nil { cfg.Random = rand.Reader } - - return &Auth{deps: deps, cfg: cfg} + return cfg } +// BoolPtr is a one-line helper for Config fields that take *bool. Use it to +// opt out of the secure cookie defaults: cfg.SessionCookieSecure = BoolPtr(false). +func BoolPtr(b bool) *bool { return &b } + // now returns the configured wall clock, defaulting to time.Now in UTC. func (a *Auth) now() time.Time { return a.cfg.Clock() } + +// DB exposes the underlying *sql.DB. Useful for callers that want to run +// admin queries on the same pool. +func (a *Auth) DB() *sql.DB { return a.db } + +// Schema returns the configured schema. +func (a *Auth) Schema() Schema { return a.schema } diff --git a/authz.go b/authz.go new file mode 100644 index 0000000..bcc4601 --- /dev/null +++ b/authz.go @@ -0,0 +1,176 @@ +package authkit + +import ( + "context" + "fmt" +) + +// LoginAuthz is a predicate over a *Principal. Used by middleware that +// gates handlers on a user's roles or permissions. +type LoginAuthz interface { + // Match reports whether the principal satisfies the predicate. + Match(p *Principal) bool + // Validate verifies that every slug referenced by this predicate exists + // in the database. Called at middleware-construction time so typos fail + // at boot rather than at request time. + Validate(ctx context.Context, a *Auth) error +} + +// ServiceKeyAuthz is the analogous predicate type for service-token +// authorization. +type ServiceKeyAuthz interface { + Match(k *ServiceKey) bool + Validate(ctx context.Context, a *Auth) error +} + +// HasRole returns a leaf predicate satisfied when the principal carries the +// given role slug. +func HasRole(slug string) LoginAuthz { return roleLeaf{slug: slug} } + +// HasPermission returns a leaf predicate satisfied when the principal +// carries the given permission slug (resolved through any combination of +// roles and direct grants). +func HasPermission(slug string) LoginAuthz { return permLeaf{slug: slug} } + +// HasAbility returns a leaf predicate satisfied when the service key +// carries the given ability slug. +func HasAbility(slug string) ServiceKeyAuthz { return abilityLeaf{slug: slug} } + +// AnyLogin returns a predicate satisfied when at least one child predicate +// matches. With no children, AnyLogin matches nothing (returns false). +func AnyLogin(preds ...LoginAuthz) LoginAuthz { return anyLogin{preds: preds} } + +// AllLogin returns a predicate satisfied when every child predicate +// matches. With no children, AllLogin matches everything (returns true). +func AllLogin(preds ...LoginAuthz) LoginAuthz { return allLogin{preds: preds} } + +// AnyServiceKey returns a service-key predicate satisfied when at least one +// child matches. +func AnyServiceKey(preds ...ServiceKeyAuthz) ServiceKeyAuthz { + return anyService{preds: preds} +} + +// AllServiceKey returns a service-key predicate satisfied when every child +// matches. +func AllServiceKey(preds ...ServiceKeyAuthz) ServiceKeyAuthz { + return allService{preds: preds} +} + +// ─── leaves ──────────────────────────────────────────────────────────────── + +type roleLeaf struct{ slug string } + +func (l roleLeaf) Match(p *Principal) bool { return p != nil && p.HasRole(l.slug) } +func (l roleLeaf) Validate(ctx context.Context, a *Auth) error { + if err := validateSlug("authkit.HasRole", l.slug); err != nil { + return err + } + if _, err := a.storeGetRoleBySlug(ctx, l.slug); err != nil { + return fmt.Errorf("authkit.HasRole(%q): %w", l.slug, err) + } + return nil +} + +type permLeaf struct{ slug string } + +func (l permLeaf) Match(p *Principal) bool { return p != nil && p.HasPermission(l.slug) } +func (l permLeaf) Validate(ctx context.Context, a *Auth) error { + if err := validateSlug("authkit.HasPermission", l.slug); err != nil { + return err + } + if _, err := a.storeGetPermissionBySlug(ctx, l.slug); err != nil { + return fmt.Errorf("authkit.HasPermission(%q): %w", l.slug, err) + } + return nil +} + +type abilityLeaf struct{ slug string } + +func (l abilityLeaf) Match(k *ServiceKey) bool { return k != nil && k.HasAbility(l.slug) } +func (l abilityLeaf) Validate(ctx context.Context, a *Auth) error { + if err := validateSlug("authkit.HasAbility", l.slug); err != nil { + return err + } + if _, err := a.storeGetAbilityBySlug(ctx, l.slug); err != nil { + return fmt.Errorf("authkit.HasAbility(%q): %w", l.slug, err) + } + return nil +} + +// ─── combinators ─────────────────────────────────────────────────────────── + +type anyLogin struct{ preds []LoginAuthz } + +func (a anyLogin) Match(p *Principal) bool { + for _, pr := range a.preds { + if pr.Match(p) { + return true + } + } + return false +} +func (a anyLogin) Validate(ctx context.Context, auth *Auth) error { + for _, p := range a.preds { + if err := p.Validate(ctx, auth); err != nil { + return err + } + } + return nil +} + +type allLogin struct{ preds []LoginAuthz } + +func (a allLogin) Match(p *Principal) bool { + for _, pr := range a.preds { + if !pr.Match(p) { + return false + } + } + return true +} +func (a allLogin) Validate(ctx context.Context, auth *Auth) error { + for _, p := range a.preds { + if err := p.Validate(ctx, auth); err != nil { + return err + } + } + return nil +} + +type anyService struct{ preds []ServiceKeyAuthz } + +func (a anyService) Match(k *ServiceKey) bool { + for _, pr := range a.preds { + if pr.Match(k) { + return true + } + } + return false +} +func (a anyService) Validate(ctx context.Context, auth *Auth) error { + for _, p := range a.preds { + if err := p.Validate(ctx, auth); err != nil { + return err + } + } + return nil +} + +type allService struct{ preds []ServiceKeyAuthz } + +func (a allService) Match(k *ServiceKey) bool { + for _, pr := range a.preds { + if !pr.Match(k) { + return false + } + } + return true +} +func (a allService) Validate(ctx context.Context, auth *Auth) error { + for _, p := range a.preds { + if err := p.Validate(ctx, auth); err != nil { + return err + } + } + return nil +} diff --git a/authz_test.go b/authz_test.go new file mode 100644 index 0000000..e798d9f --- /dev/null +++ b/authz_test.go @@ -0,0 +1,90 @@ +package authkit + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestPredicateLeavesAndCombinators(t *testing.T) { + p := &Principal{ + UserID: uuid.New(), + Roles: []string{"admin", "manager"}, + Permissions: []string{"posts:read", "posts:write"}, + } + + if !HasRole("admin").Match(p) { + t.Fatalf("HasRole admin should match") + } + if HasRole("nope").Match(p) { + t.Fatalf("HasRole nope should not match") + } + if !HasPermission("posts:write").Match(p) { + t.Fatalf("HasPermission posts:write should match") + } + + // AnyLogin: short-circuit on first match. + any1 := AnyLogin(HasRole("nope"), HasRole("admin")) + if !any1.Match(p) { + t.Fatalf("AnyLogin with one match should match") + } + any2 := AnyLogin(HasRole("nope"), HasRole("missing")) + if any2.Match(p) { + t.Fatalf("AnyLogin with no matches should not match") + } + // AnyLogin with no children: vacuously false. + if AnyLogin().Match(p) { + t.Fatalf("AnyLogin() should be false (no candidates can satisfy)") + } + + // AllLogin: every child must match. + all1 := AllLogin(HasRole("admin"), HasRole("manager")) + if !all1.Match(p) { + t.Fatalf("AllLogin with all matches should match") + } + all2 := AllLogin(HasRole("admin"), HasRole("missing")) + if all2.Match(p) { + t.Fatalf("AllLogin with one missing should not match") + } + if !AllLogin().Match(p) { + t.Fatalf("AllLogin() should be true (vacuous truth)") + } + + // Nested: Admin OR (Manager AND AdsManager). Without ads_manager, the + // AND-arm fails but the Admin-arm succeeds. + expr := AnyLogin( + HasRole("admin"), + AllLogin(HasRole("manager"), HasRole("ads_manager")), + ) + if !expr.Match(p) { + t.Fatalf("Admin OR (Manager AND AdsManager) should match: admin alone qualifies") + } + + // Same expression against a non-admin manager who lacks ads_manager: + pNonAdmin := &Principal{Roles: []string{"manager"}} + if expr.Match(pNonAdmin) { + t.Fatalf("manager without ads_manager should not match the compound") + } + pBoth := &Principal{Roles: []string{"manager", "ads_manager"}} + if !expr.Match(pBoth) { + t.Fatalf("manager+ads_manager should match the AND-arm") + } +} + +func TestServiceKeyPredicates(t *testing.T) { + k := &ServiceKey{Abilities: []string{"events:write", "events:read"}} + + if !HasAbility("events:write").Match(k) { + t.Fatalf("HasAbility events:write should match") + } + if HasAbility("admin:nuke").Match(k) { + t.Fatalf("HasAbility admin:nuke should not match") + } + + if !AllServiceKey(HasAbility("events:write"), HasAbility("events:read")).Match(k) { + t.Fatalf("AllServiceKey should match when key carries both") + } + if AnyServiceKey(HasAbility("admin:nuke"), HasAbility("missing")).Match(k) { + t.Fatalf("AnyServiceKey should not match with no candidates") + } +} diff --git a/cmd/abilities/main.go b/cmd/abilities/main.go new file mode 100644 index 0000000..920b338 --- /dev/null +++ b/cmd/abilities/main.go @@ -0,0 +1,109 @@ +// Command abilities is the seeding CLI for service-token abilities. +// +// abilities create [--label "..."] +// abilities list +// abilities delete +package main + +import ( + "context" + "errors" + "fmt" + "os" + + "git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp" +) + +func main() { + if len(os.Args) < 2 { + usage() + os.Exit(2) + } + sub := os.Args[1] + args := os.Args[2:] + + switch sub { + case "create": + runCreate(args) + case "list": + runList(args) + case "delete", "rm": + runDelete(args) + case "-h", "--help", "help": + usage() + default: + fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub) + usage() + os.Exit(2) + } +} + +func usage() { + fmt.Fprintln(os.Stderr, `usage: abilities [args] + +Subcommands: + create [--label "..."] create an ability + list list every ability + delete delete an ability + +Common flags: + --dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`) +} + +func runCreate(args []string) { + fs, dsn := clihelp.DSNFlag("abilities create") + label := fs.String("label", "", "optional human label") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("create takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + ab, err := a.CreateAbility(ctx, rest[0], *label) + if err != nil { + clihelp.Fail(err) + } + fmt.Printf("created ability %s (id=%s, label=%q)\n", ab.Slug, ab.ID, ab.Label) +} + +func runList(args []string) { + fs, dsn := clihelp.DSNFlag("abilities list") + _ = fs.Parse(args) + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + abilities, err := a.ListAbilities(ctx) + if err != nil { + clihelp.Fail(err) + } + for _, ab := range abilities { + fmt.Printf("%s\t%s\n", ab.Slug, ab.Label) + } +} + +func runDelete(args []string) { + fs, dsn := clihelp.DSNFlag("abilities delete") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("delete takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + if err := a.DeleteAbility(ctx, rest[0]); err != nil { + clihelp.Fail(err) + } + fmt.Printf("deleted ability %s\n", rest[0]) +} diff --git a/cmd/internal/clihelp/clihelp.go b/cmd/internal/clihelp/clihelp.go new file mode 100644 index 0000000..e6f48ea --- /dev/null +++ b/cmd/internal/clihelp/clihelp.go @@ -0,0 +1,70 @@ +// Package clihelp is a small helper used by the cmd/perms, cmd/roles, and +// cmd/abilities seeding CLIs to dial Postgres, build an *authkit.Auth, and +// share argument-parsing scaffolding. +package clihelp + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "os" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +// DSNFlag returns a flag.FlagSet pre-populated with --dsn. Callers add their +// own flags and call Parse on it. +func DSNFlag(name string) (*flag.FlagSet, *string) { + fs := flag.NewFlagSet(name, flag.ExitOnError) + dsn := fs.String("dsn", "", "PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)") + return fs, dsn +} + +// Dial opens a database connection using either the supplied DSN or the +// AUTHKIT_DATABASE_URL env var, then constructs an *authkit.Auth ready to +// run seed operations. Migrations and schema verification both run as part +// of New. +// +// The CLIs never sign JWTs or hash passwords, but Auth.New requires a JWT +// secret and a hasher — we supply a dummy secret and the default Argon2id +// hasher so the constructor passes. +func Dial(ctx context.Context, dsn string) (*authkit.Auth, *sql.DB, error) { + if dsn == "" { + dsn = os.Getenv("AUTHKIT_DATABASE_URL") + } + if dsn == "" { + return nil, nil, errors.New("no DSN: pass --dsn or set AUTHKIT_DATABASE_URL") + } + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, nil, fmt.Errorf("sql.Open: %w", err) + } + if err := db.PingContext(ctx); err != nil { + _ = db.Close() + return nil, nil, fmt.Errorf("ping: %w", err) + } + a, err := authkit.New(ctx, authkit.Deps{ + DB: db, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), + }, authkit.Config{ + // JWT secret is unused by seed flows but required by New. + JWTSecret: []byte("authkit-cli-not-used-for-anything-real"), + }) + if err != nil { + _ = db.Close() + return nil, nil, err + } + return a, db, nil +} + +// Fail prints err to stderr and exits with status 1. Used by every CLI's +// top-level dispatch. +func Fail(err error) { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) +} diff --git a/cmd/perms/main.go b/cmd/perms/main.go new file mode 100644 index 0000000..943afaa --- /dev/null +++ b/cmd/perms/main.go @@ -0,0 +1,114 @@ +// Command perms is the seeding CLI for authkit permissions. +// +// perms create [--label "..."] +// perms list +// perms delete +// +// Database connection comes from --dsn or $AUTHKIT_DATABASE_URL. +package main + +import ( + "context" + "errors" + "fmt" + "os" + + "git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp" +) + +func main() { + if len(os.Args) < 2 { + usage() + os.Exit(2) + } + sub := os.Args[1] + args := os.Args[2:] + + switch sub { + case "create": + runCreate(args) + case "list": + runList(args) + case "delete", "rm": + runDelete(args) + case "-h", "--help", "help": + usage() + default: + fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub) + usage() + os.Exit(2) + } +} + +func usage() { + fmt.Fprintln(os.Stderr, `usage: perms [args] + +Subcommands: + create [--label "..."] create a permission + list list every permission + delete delete a permission + +Common flags: + --dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`) +} + +func runCreate(args []string) { + fs, dsn := clihelp.DSNFlag("perms create") + label := fs.String("label", "", "optional human label") + _ = fs.Parse(args) + + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("create takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + + p, err := a.CreatePermission(ctx, rest[0], *label) + if err != nil { + clihelp.Fail(err) + } + fmt.Printf("created permission %s (id=%s, label=%q)\n", p.Slug, p.ID, p.Label) +} + +func runList(args []string) { + fs, dsn := clihelp.DSNFlag("perms list") + _ = fs.Parse(args) + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + + perms, err := a.ListPermissions(ctx) + if err != nil { + clihelp.Fail(err) + } + for _, p := range perms { + fmt.Printf("%s\t%s\n", p.Slug, p.Label) + } +} + +func runDelete(args []string) { + fs, dsn := clihelp.DSNFlag("perms delete") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("delete takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + if err := a.DeletePermission(ctx, rest[0]); err != nil { + clihelp.Fail(err) + } + fmt.Printf("deleted permission %s\n", rest[0]) +} diff --git a/cmd/roles/main.go b/cmd/roles/main.go new file mode 100644 index 0000000..592e911 --- /dev/null +++ b/cmd/roles/main.go @@ -0,0 +1,182 @@ +// Command roles is the seeding CLI for authkit roles, plus role↔permission +// linking. +// +// roles create [--label "..."] +// roles list +// roles delete +// roles grant +// roles revoke +// roles permissions +package main + +import ( + "context" + "errors" + "fmt" + "os" + + "git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp" +) + +func main() { + if len(os.Args) < 2 { + usage() + os.Exit(2) + } + sub := os.Args[1] + args := os.Args[2:] + + switch sub { + case "create": + runCreate(args) + case "list": + runList(args) + case "delete", "rm": + runDelete(args) + case "grant": + runGrant(args) + case "revoke": + runRevoke(args) + case "permissions", "perms": + runPermissions(args) + case "-h", "--help", "help": + usage() + default: + fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub) + usage() + os.Exit(2) + } +} + +func usage() { + fmt.Fprintln(os.Stderr, `usage: roles [args] + +Subcommands: + create [--label "..."] create a role + list list every role + delete delete a role + grant grant a permission to a role + revoke revoke a permission from a role + permissions list permissions granted to a role + +Common flags: + --dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`) +} + +func runCreate(args []string) { + fs, dsn := clihelp.DSNFlag("roles create") + label := fs.String("label", "", "optional human label") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("create takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + r, err := a.CreateRole(ctx, rest[0], *label) + if err != nil { + clihelp.Fail(err) + } + fmt.Printf("created role %s (id=%s, label=%q)\n", r.Slug, r.ID, r.Label) +} + +func runList(args []string) { + fs, dsn := clihelp.DSNFlag("roles list") + _ = fs.Parse(args) + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + roles, err := a.ListRoles(ctx) + if err != nil { + clihelp.Fail(err) + } + for _, r := range roles { + fmt.Printf("%s\t%s\n", r.Slug, r.Label) + } +} + +func runDelete(args []string) { + fs, dsn := clihelp.DSNFlag("roles delete") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("delete takes exactly one slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + if err := a.DeleteRole(ctx, rest[0]); err != nil { + clihelp.Fail(err) + } + fmt.Printf("deleted role %s\n", rest[0]) +} + +func runGrant(args []string) { + fs, dsn := clihelp.DSNFlag("roles grant") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 2 { + clihelp.Fail(errors.New("grant takes ")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + if err := a.GrantPermissionToRole(ctx, rest[0], rest[1]); err != nil { + clihelp.Fail(err) + } + fmt.Printf("granted %s to role %s\n", rest[1], rest[0]) +} + +func runRevoke(args []string) { + fs, dsn := clihelp.DSNFlag("roles revoke") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 2 { + clihelp.Fail(errors.New("revoke takes ")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + if err := a.RevokePermissionFromRole(ctx, rest[0], rest[1]); err != nil { + clihelp.Fail(err) + } + fmt.Printf("revoked %s from role %s\n", rest[1], rest[0]) +} + +func runPermissions(args []string) { + fs, dsn := clihelp.DSNFlag("roles permissions") + _ = fs.Parse(args) + rest := fs.Args() + if len(rest) != 1 { + clihelp.Fail(errors.New("permissions takes exactly one role-slug argument")) + } + ctx := context.Background() + a, db, err := clihelp.Dial(ctx, *dsn) + if err != nil { + clihelp.Fail(err) + } + defer db.Close() + perms, err := a.ListRolePermissions(ctx, rest[0]) + if err != nil { + clihelp.Fail(err) + } + for _, p := range perms { + fmt.Printf("%s\t%s\n", p.Slug, p.Label) + } +} diff --git a/doc.go b/doc.go index bbefef9..ae2a9d4 100644 --- a/doc.go +++ b/doc.go @@ -1,14 +1,24 @@ -// Package authkit is an authentication and authorization toolkit for Go web -// services. It defines storage interfaces (UserStore, SessionStore, TokenStore, -// ServiceKeyStore, RoleStore, PermissionStore) and a high-level Auth service -// that composes them to support registration, password login, opaque -// server-side sessions, JWT access plus rotating refresh tokens, email -// verification, password resets, magic-link passwordless login, role-based -// access control, and owner-agnostic service tokens with custom abilities for -// server-to-server auth. +// Package authkit is a pragmatic authentication and authorization toolkit +// for Go web services on PostgreSQL 16+. // -// Default Postgres implementations of every store live in the pgstore -// subpackage. Argon2id password hashing lives in hasher. Framework-neutral -// HTTP middleware (compatible with lightmux and any net/http stack) lives in -// middleware. +// Drop authkit into a net/http stack and get registration, password login, +// opaque server-side sessions, JWT access tokens with rotating refresh, +// email verification, password reset, magic-link login, email OTP, and +// owner-agnostic service tokens with consumer-defined abilities. +// Authorization is flat RBAC with both role-derived and direct user +// permissions. +// +// Roles, permissions, and abilities are seeded by the consumer (typically +// via the cmd/perms, cmd/roles, and cmd/abilities CLIs that ship with this +// repo). The library does not seed any rows automatically — applications +// own their authorization vocabulary. +// +// Migrations and schema verification run at startup. Set +// Config.SkipAutoMigrate to disable. +// +// The library does not send email or otherwise reach out to users. +// Token-minting flows (RequestEmailVerification, RequestPasswordReset, +// RequestMagicLink, RequestEmailOTP, IssueServiceKey, IssueSession, +// IssueJWT) return the plaintext to the caller exactly once — show it to +// the user immediately; only its SHA-256 hash is persisted. package authkit diff --git a/email.go b/email.go new file mode 100644 index 0000000..719219d --- /dev/null +++ b/email.go @@ -0,0 +1,11 @@ +package authkit + +import "strings" + +// normalizeEmail produces the lookup form used by GetUserByEmail and the +// email_normalized column. Trim + lowercase is intentional; we do not +// collapse Gmail-style "+" addressing or strip dots — that's a policy +// decision callers can layer on top. +func normalizeEmail(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} diff --git a/email_test.go b/email_test.go new file mode 100644 index 0000000..24a4cd8 --- /dev/null +++ b/email_test.go @@ -0,0 +1,21 @@ +package authkit + +import "testing" + +func TestNormalizeEmail(t *testing.T) { + cases := []struct { + in, want string + }{ + {"alice@example.com", "alice@example.com"}, + {"Alice@Example.com", "alice@example.com"}, + {"ALICE@EXAMPLE.COM", "alice@example.com"}, + {" alice@example.com ", "alice@example.com"}, + {"\talice@EXAMPLE.com\n", "alice@example.com"}, + } + for _, c := range cases { + got := normalizeEmail(c.in) + if got != c.want { + t.Fatalf("normalizeEmail(%q) = %q, want %q", c.in, got, c.want) + } + } +} diff --git a/errors.go b/errors.go index 753f348..5d63e74 100644 --- a/errors.go +++ b/errors.go @@ -2,11 +2,12 @@ package authkit import "errors" +// Sentinel errors. Internal call sites wrap these via errx so callers using +// errors.Is(err, authkit.ErrFoo) get reliable matching across wrap chains. var ( ErrEmailTaken = errors.New("authkit: email already registered") ErrUserNotFound = errors.New("authkit: user not found") ErrInvalidCredentials = errors.New("authkit: invalid credentials") - ErrEmailNotVerified = errors.New("authkit: email not verified") ErrTokenInvalid = errors.New("authkit: invalid or expired token") ErrTokenReused = errors.New("authkit: token reuse detected") ErrSessionInvalid = errors.New("authkit: invalid or expired session") @@ -14,5 +15,10 @@ var ( ErrPermissionDenied = errors.New("authkit: permission denied") ErrRoleNotFound = errors.New("authkit: role not found") ErrPermissionNotFound = errors.New("authkit: permission not found") - ErrConfigInvalid = errors.New("authkit: invalid configuration") + ErrAbilityNotFound = errors.New("authkit: ability not found") + ErrSlugInvalid = errors.New("authkit: invalid slug") + ErrSlugTaken = errors.New("authkit: slug already in use") + ErrOTPInvalid = errors.New("authkit: invalid or expired OTP") + ErrNoUserContext = errors.New("authkit: no user on request context") + ErrSchemaDrift = errors.New("authkit: database schema does not match expected layout") ) diff --git a/extractor.go b/extractor.go index e1e0d73..0d67f5e 100644 --- a/extractor.go +++ b/extractor.go @@ -5,8 +5,8 @@ import ( "strings" ) -// Extractor pulls a credential string out of an HTTP request. It returns -// (value, true) when a value was found, otherwise ("", false). +// Extractor pulls a credential string out of an HTTP request. Returns +// (value, true) when found, ("", false) otherwise. type Extractor func(r *http.Request) (string, bool) // BearerExtractor reads the value following "Bearer " in the Authorization diff --git a/extractor_test.go b/extractor_test.go new file mode 100644 index 0000000..e51f716 --- /dev/null +++ b/extractor_test.go @@ -0,0 +1,86 @@ +package authkit + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestBearerExtractor(t *testing.T) { + ex := BearerExtractor() + cases := []struct { + name string + header string + want string + ok bool + }{ + {"plain bearer", "Bearer abc", "abc", true}, + {"lowercase bearer", "bearer abc", "abc", true}, + {"mixed case", "BeArEr abc", "abc", true}, + {"no header", "", "", false}, + {"non-bearer scheme", "Basic abc", "", false}, + {"bearer with no token", "Bearer ", "", false}, + {"bearer with whitespace", "Bearer abc ", "abc", true}, + {"too short", "Bearer", "", false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if c.header != "" { + r.Header.Set("Authorization", c.header) + } + got, ok := ex(r) + if ok != c.ok || got != c.want { + t.Fatalf("got (%q, %v), want (%q, %v)", got, ok, c.want, c.ok) + } + }) + } +} + +func TestCookieExtractor(t *testing.T) { + ex := CookieExtractor("session") + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "session", Value: "abc"}) + got, ok := ex(r) + if !ok || got != "abc" { + t.Fatalf("got (%q, %v), want (\"abc\", true)", got, ok) + } + + r2 := httptest.NewRequest(http.MethodGet, "/", nil) + if _, ok := ex(r2); ok { + t.Fatalf("missing cookie should not extract") + } + + r3 := httptest.NewRequest(http.MethodGet, "/", nil) + r3.AddCookie(&http.Cookie{Name: "session", Value: ""}) + if _, ok := ex(r3); ok { + t.Fatalf("empty cookie value should not extract") + } +} + +func TestHeaderExtractor(t *testing.T) { + ex := HeaderExtractor("X-API-Token") + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-API-Token", " abc ") + got, ok := ex(r) + if !ok || got != "abc" { + t.Fatalf("got (%q, %v), want (\"abc\", true)", got, ok) + } +} + +func TestChainExtractors(t *testing.T) { + a := func(r *http.Request) (string, bool) { return "", false } + b := func(r *http.Request) (string, bool) { return "from-b", true } + c := func(r *http.Request) (string, bool) { return "from-c", true } + chain := ChainExtractors(a, b, c) + r := httptest.NewRequest(http.MethodGet, "/", nil) + got, ok := chain(r) + if !ok || got != "from-b" { + t.Fatalf("chain should return first hit; got (%q, %v)", got, ok) + } + + none := ChainExtractors(a, a) + if _, ok := none(r); ok { + t.Fatalf("chain of misses should not extract") + } +} diff --git a/hasher/argon2id.go b/hasher/argon2id.go index 5d5c57b..9abe727 100644 --- a/hasher/argon2id.go +++ b/hasher/argon2id.go @@ -1,7 +1,11 @@ -// Package hasher provides authkit.Hasher implementations. The default -// implementation, Argon2id, encodes hashes in the standard PHC string format -// (https://github.com/P-H-C/phc-string-format) so callers can introspect -// parameters and migrate. +// Package hasher provides password-hashing primitives that satisfy the +// authkit.Hasher interface. The default implementation, Argon2id, encodes +// hashes in the standard PHC string format (https://github.com/P-H-C/phc-string-format) +// so callers can introspect parameters and migrate. +// +// This package intentionally does not import authkit — the Hasher interface +// is structurally satisfied. That keeps the dependency arrow one-way and +// lets test code in the authkit package itself import this package. package hasher import ( @@ -13,7 +17,6 @@ import ( "io" "strings" - "git.juancwu.dev/juancwu/authkit" "git.juancwu.dev/juancwu/errx" "golang.org/x/crypto/argon2" ) @@ -40,24 +43,27 @@ func DefaultArgon2idParams() Argon2idParams { } } -type argon2idHasher struct { +// Argon2idHasher implements password hashing via Argon2id. It satisfies +// authkit.Hasher through structural typing — pass *Argon2idHasher into +// authkit.Deps.Hasher. +type Argon2idHasher struct { params Argon2idParams rng io.Reader } -// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the -// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand. -func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher { +// NewArgon2id builds an *Argon2idHasher. If params is the zero value, +// DefaultArgon2idParams() is used. rng defaults to crypto/rand. +func NewArgon2id(params Argon2idParams, rng io.Reader) *Argon2idHasher { if params == (Argon2idParams{}) { params = DefaultArgon2idParams() } if rng == nil { rng = rand.Reader } - return &argon2idHasher{params: params, rng: rng} + return &Argon2idHasher{params: params, rng: rng} } -func (h *argon2idHasher) Hash(password string) (string, error) { +func (h *Argon2idHasher) Hash(password string) (string, error) { const op = "authkit.hasher.Argon2id.Hash" if password == "" { return "", errx.New(op, "password is empty") @@ -71,7 +77,7 @@ func (h *argon2idHasher) Hash(password string) (string, error) { return encodePHC(h.params, salt, key), nil } -func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) { +func (h *Argon2idHasher) Verify(password, encoded string) (bool, bool, error) { const op = "authkit.hasher.Argon2id.Verify" got, salt, key, err := decodePHC(encoded) if err != nil { diff --git a/jwt.go b/jwt.go index c9fb7b4..dc812df 100644 --- a/jwt.go +++ b/jwt.go @@ -6,9 +6,9 @@ import ( "github.com/google/uuid" ) -// accessClaims is the JWT shape issued by IssueJWT. The session_version -// field carries the User.SessionVersion at issue time so AuthenticateJWT -// can detect global revocations (logout-everywhere, password change). +// accessClaims is the JWT shape issued by IssueJWT. session_version carries +// User.SessionVersion at issue time so AuthenticateJWT can detect global +// revocations (logout-everywhere, password change). type accessClaims struct { jwt.RegisteredClaims SessionVersion int `json:"sv"` @@ -39,6 +39,7 @@ func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, er } // parseAccessToken validates the signature and returns the parsed claims. +// Strictly enforces HS256 — alg=none and asymmetric algorithms are rejected. func (a *Auth) parseAccessToken(token string) (*accessClaims, error) { const op = "authkit.parseAccessToken" opts := []jwt.ParserOption{ diff --git a/jwt_test.go b/jwt_test.go deleted file mode 100644 index 5de34a2..0000000 --- a/jwt_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package authkit - -import ( - "context" - "errors" - "testing" - "time" -) - -func TestJWTIssueAndAuthenticate(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "alice@example.com", "hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - access, refresh, err := a.IssueJWT(context.Background(), u.ID) - if err != nil { - t.Fatalf("IssueJWT: %v", err) - } - if access == "" || refresh == "" { - t.Fatalf("empty tokens") - } - p, err := a.AuthenticateJWT(context.Background(), access) - if err != nil { - t.Fatalf("AuthenticateJWT: %v", err) - } - if p.UserID != u.ID { - t.Fatalf("principal user id mismatch") - } - if p.Method != AuthMethodJWT { - t.Fatalf("principal method = %s, want jwt", p.Method) - } -} - -func TestJWTSessionVersionMismatchRejected(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "bob@example.com", "hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - access, _, err := a.IssueJWT(context.Background(), u.ID) - if err != nil { - t.Fatalf("IssueJWT: %v", err) - } - if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil { - t.Fatalf("RevokeAllUserSessions: %v", err) - } - if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err) - } -} - -func TestJWTRefreshRotationAndReuseDetection(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "carol@example.com", "hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - _, refresh1, err := a.IssueJWT(context.Background(), u.ID) - if err != nil { - t.Fatalf("IssueJWT: %v", err) - } - _, refresh2, err := a.RefreshJWT(context.Background(), refresh1) - if err != nil { - t.Fatalf("first RefreshJWT: %v", err) - } - if refresh1 == refresh2 { - t.Fatalf("refresh token did not rotate") - } - - // Replaying refresh1 must surface ErrTokenReused and revoke the chain. - if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) { - t.Fatalf("expected ErrTokenReused on replay, got %v", err) - } - // After chain revocation, even refresh2 (the legitimate next one) must - // be rejected. - if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err) - } -} - -func TestJWTExpiredTokenRejected(t *testing.T) { - a := newTestAuth(t) - now := time.Now().UTC() - a.cfg.Clock = func() time.Time { return now } - u, err := a.Register(context.Background(), "dan@example.com", "hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - access, _, err := a.IssueJWT(context.Background(), u.ID) - if err != nil { - t.Fatalf("IssueJWT: %v", err) - } - // Advance clock past TTL. - a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) } - if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err) - } -} diff --git a/memstore_test.go b/memstore_test.go deleted file mode 100644 index 0d0e408..0000000 --- a/memstore_test.go +++ /dev/null @@ -1,612 +0,0 @@ -package authkit - -// In-memory store fakes used by service-level tests. Kept in a _test.go file -// so they don't ship in the public API. Each fake is intentionally minimal — -// it satisfies the interface and supports the flows the tests exercise; not -// every method is wired up (a few panic to flag accidental use). - -import ( - "bytes" - "context" - "errors" - "slices" - "sync" - "time" - - "github.com/google/uuid" -) - -type memUserStore struct { - mu sync.Mutex - byID map[uuid.UUID]*User - byEml map[string]uuid.UUID -} - -func newMemUserStore() *memUserStore { - return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}} -} - -func (s *memUserStore) CreateUser(_ context.Context, u *User) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, exists := s.byEml[u.EmailNormalized]; exists { - return ErrEmailTaken - } - if u.ID == uuid.Nil { - u.ID = uuid.New() - } - cp := *u - s.byID[u.ID] = &cp - s.byEml[u.EmailNormalized] = u.ID - return nil -} - -func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return nil, ErrUserNotFound - } - cp := *u - return &cp, nil -} - -func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) { - s.mu.Lock() - defer s.mu.Unlock() - id, ok := s.byEml[ne] - if !ok { - return nil, ErrUserNotFound - } - u := *s.byID[id] - return &u, nil -} - -func (s *memUserStore) UpdateUser(_ context.Context, u *User) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.byID[u.ID]; !ok { - return ErrUserNotFound - } - cp := *u - s.byID[u.ID] = &cp - return nil -} -func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return ErrUserNotFound - } - delete(s.byID, id) - delete(s.byEml, u.EmailNormalized) - return nil -} - -func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return ErrUserNotFound - } - u.PasswordHash = h - return nil -} -func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return ErrUserNotFound - } - u.EmailVerifiedAt = &at - return nil -} -func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return 0, ErrUserNotFound - } - u.SessionVersion++ - return u.SessionVersion, nil -} -func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return 0, ErrUserNotFound - } - u.FailedLogins++ - return u.FailedLogins, nil -} -func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.byID[id] - if !ok { - return ErrUserNotFound - } - u.FailedLogins = 0 - return nil -} - -type memSessionStore struct { - mu sync.Mutex - m map[string]*Session -} - -func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} } -func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *ses - s.m[string(ses.IDHash)] = &cp - return nil -} -func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) { - s.mu.Lock() - defer s.mu.Unlock() - v, ok := s.m[string(h)] - if !ok { - return nil, ErrSessionInvalid - } - cp := *v - return &cp, nil -} -func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - v, ok := s.m[string(h)] - if !ok { - return ErrSessionInvalid - } - v.LastSeenAt = lastSeen - v.ExpiresAt = exp - return nil -} -func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.m, string(h)) - return nil -} -func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - for k, v := range s.m { - if v.UserID == uid { - delete(s.m, k) - } - } - return nil -} -func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { - s.mu.Lock() - defer s.mu.Unlock() - var n int64 - for k, v := range s.m { - if !v.ExpiresAt.After(now) { - delete(s.m, k) - n++ - } - } - return n, nil -} - -type memTokenStore struct { - mu sync.Mutex - // Keyed by kind+hex(hash) so identical hashes across kinds don't collide. - m map[string]*Token -} - -func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} } -func tokKey(kind TokenKind, h []byte) string { - return string(kind) + ":" + string(h) -} -func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *t - s.m[tokKey(t.Kind, t.Hash)] = &cp - return nil -} -func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) { - s.mu.Lock() - defer s.mu.Unlock() - t, ok := s.m[tokKey(kind, h)] - if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) { - return nil, ErrTokenInvalid - } - t.ConsumedAt = &now - cp := *t - return &cp, nil -} -func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) { - s.mu.Lock() - defer s.mu.Unlock() - t, ok := s.m[tokKey(kind, h)] - if !ok { - return nil, ErrTokenInvalid - } - cp := *t - return &cp, nil -} -func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) { - s.mu.Lock() - defer s.mu.Unlock() - var n int64 - for k, t := range s.m { - if t.ChainID != nil && *t.ChainID == chainID { - delete(s.m, k) - n++ - } - } - return n, nil -} -func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { - s.mu.Lock() - defer s.mu.Unlock() - var n int64 - for k, t := range s.m { - if !t.ExpiresAt.After(now) { - delete(s.m, k) - n++ - } - } - return n, nil -} - -type memServiceKeyStore struct { - mu sync.Mutex - m map[string]*ServiceKey -} - -func newMemServiceKeyStore() *memServiceKeyStore { - return &memServiceKeyStore{m: map[string]*ServiceKey{}} -} -func (s *memServiceKeyStore) CreateServiceKey(_ context.Context, k *ServiceKey) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - s.m[string(k.IDHash)] = &cp - return nil -} -func (s *memServiceKeyStore) GetServiceKey(_ context.Context, h []byte) (*ServiceKey, error) { - s.mu.Lock() - defer s.mu.Unlock() - k, ok := s.m[string(h)] - if !ok { - return nil, ErrServiceKeyInvalid - } - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - return &cp, nil -} -func (s *memServiceKeyStore) ListServiceKeysByOwner(_ context.Context, ownerKind string, owner uuid.UUID) ([]*ServiceKey, error) { - s.mu.Lock() - defer s.mu.Unlock() - var out []*ServiceKey - for _, k := range s.m { - if k.OwnerKind == ownerKind && k.OwnerID == owner { - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - out = append(out, &cp) - } - } - return out, nil -} -func (s *memServiceKeyStore) TouchServiceKey(_ context.Context, h []byte, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - if k, ok := s.m[string(h)]; ok { - k.LastUsedAt = &at - } - return nil -} -func (s *memServiceKeyStore) RevokeServiceKey(_ context.Context, h []byte, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - k, ok := s.m[string(h)] - if !ok { - return ErrServiceKeyInvalid - } - if k.RevokedAt != nil { - return ErrServiceKeyInvalid - } - k.RevokedAt = &at - return nil -} - -type memRoleStore struct { - mu sync.Mutex - roles map[uuid.UUID]*Role - rolesByNm map[string]uuid.UUID - userRoles map[uuid.UUID]map[uuid.UUID]struct{} -} - -func newMemRoleStore() *memRoleStore { - return &memRoleStore{ - roles: map[uuid.UUID]*Role{}, - rolesByNm: map[string]uuid.UUID{}, - userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{}, - } -} -func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, dup := s.rolesByNm[r.Name]; dup { - return errors.New("role exists") - } - if r.ID == uuid.Nil { - r.ID = uuid.New() - } - cp := *r - s.roles[r.ID] = &cp - s.rolesByNm[r.Name] = r.ID - return nil -} -func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) { - s.mu.Lock() - defer s.mu.Unlock() - r, ok := s.roles[id] - if !ok { - return nil, ErrRoleNotFound - } - cp := *r - return &cp, nil -} -func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) { - s.mu.Lock() - defer s.mu.Unlock() - id, ok := s.rolesByNm[n] - if !ok { - return nil, ErrRoleNotFound - } - r := *s.roles[id] - return &r, nil -} -func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) { - s.mu.Lock() - defer s.mu.Unlock() - out := make([]*Role, 0, len(s.roles)) - for _, r := range s.roles { - cp := *r - out = append(out, &cp) - } - return out, nil -} -func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - r, ok := s.roles[id] - if !ok { - return ErrRoleNotFound - } - delete(s.roles, id) - delete(s.rolesByNm, r.Name) - for _, m := range s.userRoles { - delete(m, id) - } - return nil -} -func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - m, ok := s.userRoles[uid] - if !ok { - m = map[uuid.UUID]struct{}{} - s.userRoles[uid] = m - } - m[rid] = struct{}{} - return nil -} -func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - if m, ok := s.userRoles[uid]; ok { - delete(m, rid) - } - return nil -} -func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) { - s.mu.Lock() - defer s.mu.Unlock() - out := []*Role{} - for rid := range s.userRoles[uid] { - if r, ok := s.roles[rid]; ok { - cp := *r - out = append(out, &cp) - } - } - return out, nil -} -func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) { - s.mu.Lock() - defer s.mu.Unlock() - for rid := range s.userRoles[uid] { - r := s.roles[rid] - if slices.Contains(names, r.Name) { - return true, nil - } - } - return false, nil -} - -type memPermStore struct { - mu sync.Mutex - perms map[uuid.UUID]*Permission - permsByNm map[string]uuid.UUID - rolePerms map[uuid.UUID]map[uuid.UUID]struct{} - roles *memRoleStore -} - -func newMemPermStore(rs *memRoleStore) *memPermStore { - return &memPermStore{ - perms: map[uuid.UUID]*Permission{}, - permsByNm: map[string]uuid.UUID{}, - rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{}, - roles: rs, - } -} -func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, dup := s.permsByNm[p.Name]; dup { - return errors.New("perm exists") - } - if p.ID == uuid.Nil { - p.ID = uuid.New() - } - cp := *p - s.perms[p.ID] = &cp - s.permsByNm[p.Name] = p.ID - return nil -} -func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) { - s.mu.Lock() - defer s.mu.Unlock() - p, ok := s.perms[id] - if !ok { - return nil, ErrPermissionNotFound - } - cp := *p - return &cp, nil -} -func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) { - s.mu.Lock() - defer s.mu.Unlock() - id, ok := s.permsByNm[n] - if !ok { - return nil, ErrPermissionNotFound - } - p := *s.perms[id] - return &p, nil -} -func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) { - s.mu.Lock() - defer s.mu.Unlock() - out := make([]*Permission, 0, len(s.perms)) - for _, p := range s.perms { - cp := *p - out = append(out, &cp) - } - return out, nil -} -func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - p, ok := s.perms[id] - if !ok { - return ErrPermissionNotFound - } - delete(s.perms, id) - delete(s.permsByNm, p.Name) - for _, m := range s.rolePerms { - delete(m, id) - } - return nil -} -func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - m, ok := s.rolePerms[rid] - if !ok { - m = map[uuid.UUID]struct{}{} - s.rolePerms[rid] = m - } - m[pid] = struct{}{} - return nil -} -func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - if m, ok := s.rolePerms[rid]; ok { - delete(m, pid) - } - return nil -} -func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) { - s.mu.Lock() - defer s.mu.Unlock() - out := []*Permission{} - for pid := range s.rolePerms[rid] { - if p, ok := s.perms[pid]; ok { - cp := *p - out = append(out, &cp) - } - } - return out, nil -} -func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) { - s.roles.mu.Lock() - roleIDs := make([]uuid.UUID, 0) - for rid := range s.roles.userRoles[uid] { - roleIDs = append(roleIDs, rid) - } - s.roles.mu.Unlock() - - s.mu.Lock() - defer s.mu.Unlock() - seen := map[uuid.UUID]struct{}{} - out := []*Permission{} - for _, rid := range roleIDs { - for pid := range s.rolePerms[rid] { - if _, dup := seen[pid]; dup { - continue - } - seen[pid] = struct{}{} - if p, ok := s.perms[pid]; ok { - cp := *p - out = append(out, &cp) - } - } - } - return out, nil -} - -// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in -// the hasher package itself; we just need *something* callable here. -type stubHasher struct{} - -func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil } -func (stubHasher) Verify(p, encoded string) (bool, bool, error) { - want := "stub:" + p - if !bytes.Equal([]byte(want), []byte(encoded)) { - return false, false, nil - } - return true, false, nil -} - -// newTestAuth wires the fakes into Auth with deterministic config. -func newTestAuth(t interface{ Helper() }) *Auth { - if h, ok := t.(interface{ Helper() }); ok { - h.Helper() - } - roles := newMemRoleStore() - return New(Deps{ - Users: newMemUserStore(), - Sessions: newMemSessionStore(), - Tokens: newMemTokenStore(), - ServiceKeys: newMemServiceKeyStore(), - Roles: roles, - Permissions: newMemPermStore(roles), - Hasher: stubHasher{}, - }, Config{ - JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"), - JWTIssuer: "authkit-test", - AccessTokenTTL: 2 * time.Minute, - RefreshTokenTTL: 1 * time.Hour, - SessionIdleTTL: time.Hour, - SessionAbsoluteTTL: 24 * time.Hour, - EmailVerifyTTL: time.Hour, - PasswordResetTTL: time.Hour, - MagicLinkTTL: time.Minute, - }) -} diff --git a/middleware/authz.go b/middleware/authz.go deleted file mode 100644 index 38a69b0..0000000 --- a/middleware/authz.go +++ /dev/null @@ -1,84 +0,0 @@ -package middleware - -import ( - "net/http" - - "git.juancwu.dev/juancwu/authkit" -) - -// authzGuard wraps the common pattern of "look up the Principal, run a -// predicate, succeed or 403". onForbidden defaults to JSON 403. -func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler { - if onForbidden == nil { - onForbidden = defaultJSONError(http.StatusForbidden) - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p, ok := PrincipalFrom(r.Context()) - if !ok { - // No auth middleware ran upstream; treat as forbidden - // rather than crashing — composition is the caller's - // responsibility but a 403 is the safer default. - onForbidden(w, r, authkit.ErrPermissionDenied) - return - } - if !pred(p) { - onForbidden(w, r, authkit.ErrPermissionDenied) - return - } - next.ServeHTTP(w, r) - }) - } -} - -// RequireRole permits requests whose Principal holds the named role. -func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { - return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { - return p.HasRole(name) - }) -} - -// RequireAnyRole permits requests whose Principal holds at least one of the -// named roles. -func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { - return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { - return p.HasAnyRole(names...) - }) -} - -// RequirePermission permits requests whose Principal holds the named -// permission (resolved via roles at auth time). -func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { - return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { - return p.HasPermission(name) - }) -} - -// RequireAbility permits requests whose ServiceKey carries the named ability. -// Abilities live only on service tokens — this middleware reads -// *authkit.ServiceKey from the request context (placed by RequireServiceKey -// or RequireAnyOrServiceKey) and 403s any request authenticated as a user -// (session or JWT), which by definition has no abilities. -func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { - onForb := firstOrNil(onForbidden) - if onForb == nil { - onForb = defaultJSONError(http.StatusForbidden) - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - k, ok := ServiceKeyFrom(r.Context()) - if !ok || !k.HasAbility(name) { - onForb(w, r, authkit.ErrPermissionDenied) - return - } - next.ServeHTTP(w, r) - }) - } -} - -func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) { - if len(s) == 0 { - return nil - } - return s[0] -} diff --git a/middleware/context.go b/middleware/context.go deleted file mode 100644 index d1dd960..0000000 --- a/middleware/context.go +++ /dev/null @@ -1,64 +0,0 @@ -// Package middleware provides framework-neutral HTTP middleware for authkit. -// Every middleware function returns the standard func(http.Handler) -// http.Handler type, so it composes with lightmux's Use/Group/Handle as well -// as any net/http stack that uses the same signature. -package middleware - -import ( - "context" - "net/http" - - "git.juancwu.dev/juancwu/authkit" -) - -// principalKey and serviceKeyKey are unexported context keys. Using distinct -// empty struct types guarantees no collision with caller-defined keys. -type principalKey struct{} -type serviceKeyKey struct{} - -// withPrincipal stashes p on the request context for downstream handlers. -func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context { - return context.WithValue(ctx, principalKey{}, p) -} - -// PrincipalFrom retrieves the authenticated Principal placed by RequireSession -// or RequireJWT. The boolean is false if no user-bound auth middleware ran for -// this request (e.g. the request was authenticated via service key instead). -func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) { - p, ok := ctx.Value(principalKey{}).(*authkit.Principal) - return p, ok -} - -// MustPrincipal panics if no Principal is on the context. Use only on -// handlers known to be behind a Require* middleware that authenticates a -// user (RequireSession or RequireJWT). -func MustPrincipal(r *http.Request) *authkit.Principal { - p, ok := PrincipalFrom(r.Context()) - if !ok { - panic("authkit/middleware: no principal on request context") - } - return p -} - -// withServiceKey stashes k on the request context for downstream handlers. -func withServiceKey(ctx context.Context, k *authkit.ServiceKey) context.Context { - return context.WithValue(ctx, serviceKeyKey{}, k) -} - -// ServiceKeyFrom retrieves the authenticated ServiceKey placed by -// RequireServiceKey. The boolean is false if no service-key middleware ran -// for this request. -func ServiceKeyFrom(ctx context.Context) (*authkit.ServiceKey, bool) { - k, ok := ctx.Value(serviceKeyKey{}).(*authkit.ServiceKey) - return k, ok -} - -// MustServiceKey panics if no ServiceKey is on the context. Use only on -// handlers known to be behind RequireServiceKey. -func MustServiceKey(r *http.Request) *authkit.ServiceKey { - k, ok := ServiceKeyFrom(r.Context()) - if !ok { - panic("authkit/middleware: no service key on request context") - } - return k -} diff --git a/middleware/middleware.go b/middleware/middleware.go index 79ec7bc..88ce310 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,45 +1,236 @@ +// Package middleware provides framework-neutral HTTP middleware for authkit. +// Every middleware function returns the standard func(http.Handler) +// http.Handler shape so it composes with lightmux.Mux.Use/Group/Handle, with +// chi/gorilla, or with any net/http stack accepting that signature. +// +// Three primitives: +// - RequireLogin — accept session OR JWT, optionally constrain by Authz +// - RequireGuest — reject authenticated requests +// - RequireServiceKey — accept a service token, optionally constrain by Authz +// +// All three attach the relevant subject to the request context via +// authkit.WithUserContext or authkit.WithServiceKey, so handlers can read +// it via authkit.UserIDFromCtx / authkit.UserFromCtx / +// authkit.ServiceKeyFromCtx. package middleware import ( + "context" "encoding/json" + "fmt" "net/http" "git.juancwu.dev/juancwu/authkit" ) -// Options configures auth middleware. Auth is required; the rest fall back -// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403 -// on authz failure. -type Options struct { - Auth *authkit.Auth - Extractor authkit.Extractor +// LoginOptions configures RequireLogin. +type LoginOptions struct { + // Auth is required. + Auth *authkit.Auth + + // SessionExtractor reads the session plaintext from the request. + // Defaults to a cookie extractor using Auth.SessionCookieName(). + SessionExtractor authkit.Extractor + + // JWTExtractor reads the JWT access token from the request. Defaults + // to BearerExtractor. + JWTExtractor authkit.Extractor + + // Authz, if non-nil, gates the request on a predicate over the + // resolved *Principal. Validate is called once at construction; an + // invalid predicate (unknown slug) panics. + Authz authkit.LoginAuthz + + // OnUnauth handles "no credential / bad credential" failures (HTTP + // 401). Default: JSON {"error":"Unauthorized"}. + OnUnauth func(w http.ResponseWriter, r *http.Request, err error) + + // OnForbidden handles "credential ok but Authz failed" (HTTP 403). + // Default: JSON {"error":"Forbidden"}. + OnForbidden func(w http.ResponseWriter, r *http.Request, err error) +} + +// RequireLogin returns middleware that authenticates the request via either +// a session cookie or a JWT (in that order) and, if Authz is set, gates the +// resolved Principal against the predicate. +// +// Panics at construction time if Auth is nil or Authz references unknown +// slugs. +func RequireLogin(opts LoginOptions) func(http.Handler) http.Handler { + if opts.Auth == nil { + panic("authkit/middleware: LoginOptions.Auth is required") + } + if opts.Authz != nil { + if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil { + panic(fmt.Sprintf("authkit/middleware: %v", err)) + } + } + sessionEx := opts.SessionExtractor + if sessionEx == nil { + sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName()) + } + jwtEx := opts.JWTExtractor + if jwtEx == nil { + jwtEx = authkit.BearerExtractor() + } + onUnauth := opts.OnUnauth + if onUnauth == nil { + onUnauth = defaultJSONError(http.StatusUnauthorized) + } + onForbidden := opts.OnForbidden + if onForbidden == nil { + onForbidden = defaultJSONError(http.StatusForbidden) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx) + if err != nil { + onUnauth(w, r, err) + return + } + if opts.Authz != nil && !opts.Authz.Match(p) { + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + ctx := authkit.WithUserContext(r.Context(), opts.Auth, p.UserID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// authenticatePrincipal tries the session extractor first, then the JWT +// extractor. Returns the first successful Principal or the last error. +func authenticatePrincipal(r *http.Request, a *authkit.Auth, sessionEx, jwtEx authkit.Extractor) (*authkit.Principal, error) { + if v, ok := sessionEx(r); ok && v != "" { + if p, err := a.AuthenticateSession(r.Context(), v); err == nil { + return p, nil + } + } + if v, ok := jwtEx(r); ok && v != "" { + if p, err := a.AuthenticateJWT(r.Context(), v); err == nil { + return p, nil + } + } + return nil, authkit.ErrSessionInvalid +} + +// GuestOptions configures RequireGuest. +type GuestOptions struct { + // Auth is required. + Auth *authkit.Auth + + SessionExtractor authkit.Extractor + JWTExtractor authkit.Extractor + + // OnAuthenticated handles requests that present a valid credential + // (where a guest was expected). Default: JSON 403. + OnAuthenticated func(w http.ResponseWriter, r *http.Request) +} + +// RequireGuest returns middleware that rejects requests carrying a valid +// session or JWT. Useful for /login or /register pages where authenticated +// users should be redirected away. +// +// Default rejection is HTTP 403 JSON. Pass Options.OnAuthenticated to +// implement a redirect or custom response. +func RequireGuest(opts GuestOptions) func(http.Handler) http.Handler { + if opts.Auth == nil { + panic("authkit/middleware: GuestOptions.Auth is required") + } + sessionEx := opts.SessionExtractor + if sessionEx == nil { + sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName()) + } + jwtEx := opts.JWTExtractor + if jwtEx == nil { + jwtEx = authkit.BearerExtractor() + } + onAuthenticated := opts.OnAuthenticated + if onAuthenticated == nil { + onAuthenticated = func(w http.ResponseWriter, r *http.Request) { + defaultJSONError(http.StatusForbidden)(w, r, authkit.ErrPermissionDenied) + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx); err == nil { + onAuthenticated(w, r) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// ServiceKeyOptions configures RequireServiceKey. +type ServiceKeyOptions struct { + Auth *authkit.Auth + + // Extractor reads the service token plaintext. Defaults to + // BearerExtractor. + Extractor authkit.Extractor + + // Authz, if non-nil, gates the request on a predicate over the + // resolved *ServiceKey. + Authz authkit.ServiceKeyAuthz + OnUnauth func(w http.ResponseWriter, r *http.Request, err error) OnForbidden func(w http.ResponseWriter, r *http.Request, err error) } -func (o Options) extractor() authkit.Extractor { - if o.Extractor != nil { - return o.Extractor +// RequireServiceKey returns middleware that authenticates the request via a +// service token and, if Authz is set, gates on the predicate. +// +// Panics at construction time if Auth is nil or Authz references unknown +// ability slugs. +func RequireServiceKey(opts ServiceKeyOptions) func(http.Handler) http.Handler { + if opts.Auth == nil { + panic("authkit/middleware: ServiceKeyOptions.Auth is required") + } + if opts.Authz != nil { + if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil { + panic(fmt.Sprintf("authkit/middleware: %v", err)) + } + } + ex := opts.Extractor + if ex == nil { + ex = authkit.BearerExtractor() + } + onUnauth := opts.OnUnauth + if onUnauth == nil { + onUnauth = defaultJSONError(http.StatusUnauthorized) + } + onForbidden := opts.OnForbidden + if onForbidden == nil { + onForbidden = defaultJSONError(http.StatusForbidden) } - return authkit.BearerExtractor() -} -func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) { - if o.OnUnauth != nil { - return o.OnUnauth + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, ok := ex(r) + if !ok || raw == "" { + onUnauth(w, r, authkit.ErrServiceKeyInvalid) + return + } + k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw) + if err != nil { + onUnauth(w, r, err) + return + } + if opts.Authz != nil && !opts.Authz.Match(k) { + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + ctx := authkit.WithServiceKey(r.Context(), k) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } - return defaultJSONError(http.StatusUnauthorized) -} - -func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) { - if o.OnForbidden != nil { - return o.OnForbidden - } - return defaultJSONError(http.StatusForbidden) } func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) { - return func(w http.ResponseWriter, _ *http.Request, err error) { + return func(w http.ResponseWriter, _ *http.Request, _ error) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(map[string]string{ @@ -47,169 +238,3 @@ func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, e }) } } - -// RequireSession authenticates the request via an opaque session string. The -// extractor is consulted first; if no extractor is set the default Bearer -// extractor is used. For cookie-based session lookup, set -// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName). -func RequireSession(opts Options) func(http.Handler) http.Handler { - return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { - return opts.Auth.AuthenticateSession(r.Context(), raw) - }) -} - -// RequireJWT authenticates the request via an HS256 JWT. -func RequireJWT(opts Options) func(http.Handler) http.Handler { - return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { - return opts.Auth.AuthenticateJWT(r.Context(), raw) - }) -} - -// RequireServiceKey authenticates the request via an opaque service token -// secret. On success the resolved *authkit.ServiceKey is placed on the -// request context; downstream handlers retrieve it via ServiceKeyFrom. Note -// that this middleware does NOT place a *Principal on the context — service -// tokens have no user — so user-bound authz middleware (RequireRole, -// RequirePermission) will reject service-key requests with 403. -func RequireServiceKey(opts Options) func(http.Handler) http.Handler { - return requireWithServiceKey(opts, func(r *http.Request, raw string) (*authkit.ServiceKey, error) { - return opts.Auth.AuthenticateServiceKey(r.Context(), raw) - }) -} - -// RequireAny tries each user-bound method in order until one succeeds. The -// default set is [Session, JWT]; service tokens are NOT included because -// they yield a different subject type. For routes that accept either a user -// credential or a service token, use RequireAnyOrServiceKey. -func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler { - if len(methods) == 0 { - methods = []authkit.AuthMethod{ - authkit.AuthMethodSession, - authkit.AuthMethodJWT, - } - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, ok := opts.extractor()(r) - if !ok || raw == "" { - opts.onUnauth()(w, r, authkit.ErrSessionInvalid) - return - } - var ( - p *authkit.Principal - lastErr error - ) - for _, m := range methods { - switch m { - case authkit.AuthMethodSession: - p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw) - case authkit.AuthMethodJWT: - p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw) - } - if lastErr == nil && p != nil { - next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) - return - } - } - opts.onUnauth()(w, r, lastErr) - }) - } -} - -// RequireAnyOrServiceKey tries the user-bound methods first (default -// [Session, JWT]); on failure, falls through to a service-key lookup. The -// downstream handler sees either a *Principal or a *ServiceKey on context — -// retrieve via PrincipalFrom or ServiceKeyFrom and dispatch accordingly. -func RequireAnyOrServiceKey(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler { - if opts.Auth == nil { - panic("authkit/middleware: Options.Auth is required") - } - if len(methods) == 0 { - methods = []authkit.AuthMethod{ - authkit.AuthMethodSession, - authkit.AuthMethodJWT, - } - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, ok := opts.extractor()(r) - if !ok || raw == "" { - opts.onUnauth()(w, r, authkit.ErrSessionInvalid) - return - } - var lastErr error - for _, m := range methods { - var p *authkit.Principal - switch m { - case authkit.AuthMethodSession: - p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw) - case authkit.AuthMethodJWT: - p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw) - } - if lastErr == nil && p != nil { - next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) - return - } - } - k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw) - if err == nil && k != nil { - next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k))) - return - } - if lastErr == nil { - lastErr = err - } - opts.onUnauth()(w, r, lastErr) - }) - } -} - -// requireWith is the shared scaffolding for the single-method user-bound -// Require* middlewares. -func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler { - if opts.Auth == nil { - panic("authkit/middleware: Options.Auth is required") - } - extractor := opts.extractor() - onUnauth := opts.onUnauth() - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, ok := extractor(r) - if !ok || raw == "" { - onUnauth(w, r, authkit.ErrSessionInvalid) - return - } - p, err := authn(r, raw) - if err != nil { - onUnauth(w, r, err) - return - } - next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) - }) - } -} - -// requireWithServiceKey is the service-key analogue of requireWith. It places -// a *ServiceKey (not a *Principal) on the request context. -func requireWithServiceKey(opts Options, authn func(r *http.Request, raw string) (*authkit.ServiceKey, error)) func(http.Handler) http.Handler { - if opts.Auth == nil { - panic("authkit/middleware: Options.Auth is required") - } - extractor := opts.extractor() - onUnauth := opts.onUnauth() - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, ok := extractor(r) - if !ok || raw == "" { - onUnauth(w, r, authkit.ErrServiceKeyInvalid) - return - } - k, err := authn(r, raw) - if err != nil { - onUnauth(w, r, err) - return - } - next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k))) - }) - } -} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index f4941ee..8d45490 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,319 +1,92 @@ package middleware_test +// Integration tests for the middleware package. Skipped when +// AUTHKIT_TEST_DATABASE_URL is not set. + import ( "context" - "errors" + "database/sql" + "fmt" "net/http" "net/http/httptest" "net/netip" - "strings" - "sync" + "os" "testing" "time" "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" "git.juancwu.dev/juancwu/authkit/middleware" - "github.com/google/uuid" + + _ "github.com/jackc/pgx/v5/stdlib" ) -// ─── minimal in-memory stores ────────────────────────────────────────────── -// -// The middleware package can't import the parent's _test stores, so we wire -// up a fresh-but-minimal set here. Only the methods actually exercised by -// the middleware tests below have meaningful bodies; unused store methods -// panic to surface unexpected call paths. - -type memUserStore struct { - mu sync.Mutex - m map[uuid.UUID]*authkit.User -} - -func newMemUserStore() *memUserStore { return &memUserStore{m: map[uuid.UUID]*authkit.User{}} } - -func (s *memUserStore) CreateUser(_ context.Context, u *authkit.User) error { - s.mu.Lock() - defer s.mu.Unlock() - for _, existing := range s.m { - if existing.EmailNormalized == u.EmailNormalized { - return authkit.ErrEmailTaken - } - } - cp := *u - s.m[u.ID] = &cp - return nil -} -func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*authkit.User, error) { - s.mu.Lock() - defer s.mu.Unlock() - u, ok := s.m[id] - if !ok { - return nil, authkit.ErrUserNotFound - } - cp := *u - return &cp, nil -} -func (s *memUserStore) GetUserByEmail(_ context.Context, normalized string) (*authkit.User, error) { - s.mu.Lock() - defer s.mu.Unlock() - for _, u := range s.m { - if u.EmailNormalized == normalized { - cp := *u - return &cp, nil - } - } - return nil, authkit.ErrUserNotFound -} -func (s *memUserStore) UpdateUser(_ context.Context, u *authkit.User) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *u - s.m[u.ID] = &cp - return nil -} -func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.m, id) - return nil -} -func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, encoded string) error { - s.mu.Lock() - defer s.mu.Unlock() - if u, ok := s.m[id]; ok { - u.PasswordHash = encoded - } - return nil -} -func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - if u, ok := s.m[id]; ok { - u.EmailVerifiedAt = &at - } - return nil -} -func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - if u, ok := s.m[id]; ok { - u.SessionVersion++ - return u.SessionVersion, nil - } - return 0, authkit.ErrUserNotFound -} -func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - if u, ok := s.m[id]; ok { - u.FailedLogins++ - return u.FailedLogins, nil - } - return 0, authkit.ErrUserNotFound -} -func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error { - s.mu.Lock() - defer s.mu.Unlock() - if u, ok := s.m[id]; ok { - u.FailedLogins = 0 - } - return nil -} - -type memSessionStore struct { - mu sync.Mutex - m map[string]*authkit.Session -} - -func newMemSessionStore() *memSessionStore { - return &memSessionStore{m: map[string]*authkit.Session{}} -} -func (s *memSessionStore) CreateSession(_ context.Context, sess *authkit.Session) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *sess - s.m[string(sess.IDHash)] = &cp - return nil -} -func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*authkit.Session, error) { - s.mu.Lock() - defer s.mu.Unlock() - sess, ok := s.m[string(h)] - if !ok { - return nil, authkit.ErrSessionInvalid - } - cp := *sess - return &cp, nil -} -func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, newExp time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - if sess, ok := s.m[string(h)]; ok { - sess.LastSeenAt = lastSeen - sess.ExpiresAt = newExp - } - return nil -} -func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.m, string(h)) - return nil -} -func (s *memSessionStore) DeleteUserSessions(_ context.Context, _ uuid.UUID) error { return nil } -func (s *memSessionStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) { - return 0, nil -} - -type memTokenStore struct{} - -func (memTokenStore) CreateToken(_ context.Context, _ *authkit.Token) error { return nil } -func (memTokenStore) ConsumeToken(_ context.Context, _ authkit.TokenKind, _ []byte, _ time.Time) (*authkit.Token, error) { - return nil, authkit.ErrTokenInvalid -} -func (memTokenStore) GetToken(_ context.Context, _ authkit.TokenKind, _ []byte) (*authkit.Token, error) { - return nil, authkit.ErrTokenInvalid -} -func (memTokenStore) DeleteByChain(_ context.Context, _ string) (int64, error) { return 0, nil } -func (memTokenStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) { - return 0, nil -} - -type memServiceKeyStore struct { - mu sync.Mutex - m map[string]*authkit.ServiceKey -} - -func newMemServiceKeyStore() *memServiceKeyStore { - return &memServiceKeyStore{m: map[string]*authkit.ServiceKey{}} -} -func (s *memServiceKeyStore) CreateServiceKey(_ context.Context, k *authkit.ServiceKey) error { - s.mu.Lock() - defer s.mu.Unlock() - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - s.m[string(k.IDHash)] = &cp - return nil -} -func (s *memServiceKeyStore) GetServiceKey(_ context.Context, h []byte) (*authkit.ServiceKey, error) { - s.mu.Lock() - defer s.mu.Unlock() - k, ok := s.m[string(h)] - if !ok { - return nil, authkit.ErrServiceKeyInvalid - } - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - return &cp, nil -} -func (s *memServiceKeyStore) ListServiceKeysByOwner(_ context.Context, kind string, owner uuid.UUID) ([]*authkit.ServiceKey, error) { - s.mu.Lock() - defer s.mu.Unlock() - var out []*authkit.ServiceKey - for _, k := range s.m { - if k.OwnerKind == kind && k.OwnerID == owner { - cp := *k - cp.Abilities = append([]string(nil), k.Abilities...) - out = append(out, &cp) - } - } - return out, nil -} -func (s *memServiceKeyStore) TouchServiceKey(_ context.Context, h []byte, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - if k, ok := s.m[string(h)]; ok { - k.LastUsedAt = &at - } - return nil -} -func (s *memServiceKeyStore) RevokeServiceKey(_ context.Context, h []byte, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - k, ok := s.m[string(h)] - if !ok { - return authkit.ErrServiceKeyInvalid - } - if k.RevokedAt != nil { - return authkit.ErrServiceKeyInvalid - } - k.RevokedAt = &at - return nil -} - -type memRoleStore struct{} - -func (memRoleStore) CreateRole(_ context.Context, _ *authkit.Role) error { return nil } -func (memRoleStore) GetRoleByID(_ context.Context, _ uuid.UUID) (*authkit.Role, error) { - return nil, authkit.ErrRoleNotFound -} -func (memRoleStore) GetRoleByName(_ context.Context, _ string) (*authkit.Role, error) { - return nil, authkit.ErrRoleNotFound -} -func (memRoleStore) ListRoles(_ context.Context) ([]*authkit.Role, error) { return nil, nil } -func (memRoleStore) DeleteRole(_ context.Context, _ uuid.UUID) error { return nil } -func (memRoleStore) AssignRoleToUser(_ context.Context, _, _ uuid.UUID) error { return nil } -func (memRoleStore) RemoveRoleFromUser(_ context.Context, _, _ uuid.UUID) error { return nil } -func (memRoleStore) GetUserRoles(_ context.Context, _ uuid.UUID) ([]*authkit.Role, error) { - return nil, nil -} -func (memRoleStore) HasAnyRole(_ context.Context, _ uuid.UUID, _ []string) (bool, error) { - return false, nil -} - -type memPermStore struct{} - -func (memPermStore) CreatePermission(_ context.Context, _ *authkit.Permission) error { return nil } -func (memPermStore) GetPermissionByID(_ context.Context, _ uuid.UUID) (*authkit.Permission, error) { - return nil, authkit.ErrPermissionNotFound -} -func (memPermStore) GetPermissionByName(_ context.Context, _ string) (*authkit.Permission, error) { - return nil, authkit.ErrPermissionNotFound -} -func (memPermStore) ListPermissions(_ context.Context) ([]*authkit.Permission, error) { - return nil, nil -} -func (memPermStore) DeletePermission(_ context.Context, _ uuid.UUID) error { return nil } -func (memPermStore) AssignPermissionToRole(_ context.Context, _, _ uuid.UUID) error { return nil } -func (memPermStore) RemovePermissionFromRole(_ context.Context, _, _ uuid.UUID) error { return nil } -func (memPermStore) GetRolePermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) { - return nil, nil -} -func (memPermStore) GetUserPermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) { - return nil, nil -} - -type stubHasher struct{} - -func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil } -func (stubHasher) Verify(p, encoded string) (bool, bool, error) { - return encoded == "stub:"+p, false, nil -} - -func newTestAuth(t *testing.T) *authkit.Auth { +func freshAuth(t *testing.T) *authkit.Auth { t.Helper() - return authkit.New(authkit.Deps{ - Users: newMemUserStore(), - Sessions: newMemSessionStore(), - Tokens: memTokenStore{}, - ServiceKeys: newMemServiceKeyStore(), - Roles: memRoleStore{}, - Permissions: memPermStore{}, - Hasher: stubHasher{}, + url := os.Getenv("AUTHKIT_TEST_DATABASE_URL") + if url == "" { + t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test") + } + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("ping: %v", err) + } + dropAuthkitTables(t, db) + t.Cleanup(func() { dropAuthkitTables(t, db) }) + + a, err := authkit.New(context.Background(), authkit.Deps{ + DB: db, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), }, authkit.Config{ - JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"), - JWTIssuer: "mw-test", - AccessTokenTTL: 2 * time.Minute, - RefreshTokenTTL: 1 * time.Hour, - SessionIdleTTL: time.Hour, - SessionAbsoluteTTL: 24 * time.Hour, - EmailVerifyTTL: time.Hour, - PasswordResetTTL: time.Hour, - MagicLinkTTL: time.Minute, + JWTSecret: []byte("integration-secret-thirty-two!!!"), + JWTIssuer: "authkit-mw-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + EmailOTPTTL: time.Minute, + EmailOTPMaxAttempts: 3, + // Plain HTTP for tests so secure-cookie defaults don't interfere + // with httptest's HTTP server. + SessionCookieSecure: authkit.BoolPtr(false), }) + if err != nil { + t.Fatalf("authkit.New: %v", err) + } + return a } -// Bearer-style request helper. -func req(token string) *http.Request { +func dropAuthkitTables(t *testing.T, db *sql.DB) { + t.Helper() + tables := []string{ + "authkit_service_key_abilities", + "authkit_user_permissions", + "authkit_user_roles", + "authkit_role_permissions", + "authkit_service_keys", + "authkit_abilities", + "authkit_roles", + "authkit_permissions", + "authkit_tokens", + "authkit_sessions", + "authkit_users", + "authkit_schema_migrations", + } + ctx := context.Background() + for _, name := range tables { + _, _ = db.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name)) + } +} + +// reqWithBearer issues a request carrying Authorization: Bearer . +func reqWithBearer(token string) *http.Request { r := httptest.NewRequest(http.MethodGet, "/", nil) if token != "" { r.Header.Set("Authorization", "Bearer "+token) @@ -323,191 +96,263 @@ func req(token string) *http.Request { func ok200(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) } -// ─── tests ───────────────────────────────────────────────────────────────── +// ─── RequireLogin ────────────────────────────────────────────────────────── -func TestRequireServiceKey_Authenticates(t *testing.T) { - a := newTestAuth(t) - plain, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", []string{"events:write"}, nil) +func TestRequireLogin_AcceptsSessionCookie(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "alice@example.com") if err != nil { - t.Fatalf("IssueServiceKey: %v", err) + t.Fatalf("CreateUser: %v", err) + } + plain, _, err := a.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1")) + if err != nil { + t.Fatalf("IssueSession: %v", err) } - var seen *authkit.ServiceKey - handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - k, ok := middleware.ServiceKeyFrom(r.Context()) - if !ok { - t.Fatalf("no ServiceKey on context") + handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uid, ok := authkit.UserIDFromCtx(r.Context()) + if !ok || uid != u.ID { + t.Fatalf("user_id missing or wrong on context: ok=%v id=%v", ok, uid) } - seen = k w.WriteHeader(http.StatusOK) })) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(a.SessionCookie(plain, time.Now().Add(time.Hour))) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req(plain)) + handler.ServeHTTP(rr, r) if rr.Code != http.StatusOK { - t.Fatalf("status = %d, want 200", rr.Code) - } - if seen == nil || !seen.HasAbility("events:write") { - t.Fatalf("expected ServiceKey with events:write ability; got %+v", seen) + t.Fatalf("expected 200, got %d", rr.Code) } } -func TestRequireServiceKey_RejectsRevoked(t *testing.T) { - a := newTestAuth(t) - plain, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", nil, nil) +func TestRequireLogin_AcceptsJWT(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "j@j.com") if err != nil { - t.Fatalf("IssueServiceKey: %v", err) + t.Fatalf("CreateUser: %v", err) } - if err := a.RevokeServiceKey(context.Background(), plain); err != nil { - t.Fatalf("RevokeServiceKey: %v", err) + access, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) } + handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(access)) + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } +} + +func TestRequireLogin_RejectsUnauthenticated(t *testing.T) { + a := freshAuth(t) + handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rr.Code) + } +} + +func TestRequireLogin_AuthzRoleGate(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateRole(ctx, "admin", ""); err != nil { + t.Fatalf("CreateRole: %v", err) + } + u, err := a.CreateUser(ctx, "noadmin@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + access, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + handler := middleware.RequireLogin(middleware.LoginOptions{ + Auth: a, + Authz: authkit.HasRole("admin"), + })(http.HandlerFunc(ok200)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(access)) + if rr.Code != http.StatusForbidden { + t.Fatalf("non-admin should get 403, got %d", rr.Code) + } + + // Promote the user to admin and retry. + if err := a.AssignRole(ctx, u.ID, "admin"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + access2, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(access2)) + if rr.Code != http.StatusOK { + t.Fatalf("admin should get 200, got %d", rr.Code) + } +} + +func TestRequireLogin_PanicsOnUnknownSlug(t *testing.T) { + a := freshAuth(t) + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic on unknown role slug") + } + }() + middleware.RequireLogin(middleware.LoginOptions{ + Auth: a, + Authz: authkit.HasRole("never-registered"), + }) +} + +// ─── RequireGuest ────────────────────────────────────────────────────────── + +func TestRequireGuest_LetsUnauthenticatedThrough(t *testing.T) { + a := freshAuth(t) called := false - handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true w.WriteHeader(http.StatusOK) })) rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req(plain)) - if rr.Code != http.StatusUnauthorized { - t.Fatalf("status = %d, want 401", rr.Code) + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + if !called { + t.Fatalf("guest middleware should pass through unauthenticated request") } - if called { - t.Fatalf("handler should not have been invoked for revoked key") - } -} - -func TestRequireAbility_AcceptsServiceKeyWithAbility(t *testing.T) { - a := newTestAuth(t) - plain, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", []string{"events:write"}, nil) - if err != nil { - t.Fatalf("IssueServiceKey: %v", err) - } - chain := middleware.RequireServiceKey(middleware.Options{Auth: a})( - middleware.RequireAbility("events:write")(http.HandlerFunc(ok200))) - - rr := httptest.NewRecorder() - chain.ServeHTTP(rr, req(plain)) if rr.Code != http.StatusOK { - t.Fatalf("status = %d, want 200", rr.Code) - } - - // Same chain but ability the key does not carry → 403. - chainBad := middleware.RequireServiceKey(middleware.Options{Auth: a})( - middleware.RequireAbility("admin:nuke")(http.HandlerFunc(ok200))) - - rr = httptest.NewRecorder() - chainBad.ServeHTTP(rr, req(plain)) - if rr.Code != http.StatusForbidden { - t.Fatalf("missing-ability status = %d, want 403", rr.Code) + t.Fatalf("expected 200, got %d", rr.Code) } } -func TestRequireAbility_RejectsUserPrincipal(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2") +func TestRequireGuest_BlocksAuthenticated(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "g@g.com") if err != nil { - t.Fatalf("Register: %v", err) + t.Fatalf("CreateUser: %v", err) } - plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1")) + access, _, err := a.IssueJWT(ctx, u.ID) if err != nil { - t.Fatalf("IssueSession: %v", err) + t.Fatalf("IssueJWT: %v", err) } - chain := middleware.RequireSession(middleware.Options{Auth: a})( - middleware.RequireAbility("events:write")(http.HandlerFunc(ok200))) - - rr := httptest.NewRecorder() - chain.ServeHTTP(rr, req(plain)) - if rr.Code != http.StatusForbidden { - t.Fatalf("status = %d, want 403 (user principal carries no abilities)", rr.Code) - } -} - -func TestRequireRole_RejectsServiceKey(t *testing.T) { - a := newTestAuth(t) - plain, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", nil, nil) - if err != nil { - t.Fatalf("IssueServiceKey: %v", err) - } - chain := middleware.RequireServiceKey(middleware.Options{Auth: a})( - middleware.RequireRole("admin")(http.HandlerFunc(ok200))) - - rr := httptest.NewRecorder() - chain.ServeHTTP(rr, req(plain)) - if rr.Code != http.StatusForbidden { - t.Fatalf("status = %d, want 403 (service key carries no Principal/role)", rr.Code) - } -} - -func TestRequireAnyOrServiceKey(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - sessionPlain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1")) - if err != nil { - t.Fatalf("IssueSession: %v", err) - } - servicePlain, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", nil, nil) - if err != nil { - t.Fatalf("IssueServiceKey: %v", err) - } - - type subject struct { - hasPrincipal bool - hasServiceKey bool - } - var got subject - handler := middleware.RequireAnyOrServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, hp := middleware.PrincipalFrom(r.Context()) - _, hs := middleware.ServiceKeyFrom(r.Context()) - got = subject{hp, hs} - w.WriteHeader(http.StatusOK) + handlerCalled := false + handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true })) - // Session token → Principal in context, no ServiceKey. rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req(sessionPlain)) - if rr.Code != http.StatusOK { - t.Fatalf("session: status = %d, want 200", rr.Code) + handler.ServeHTTP(rr, reqWithBearer(access)) + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) } - if !got.hasPrincipal || got.hasServiceKey { - t.Fatalf("session: ctx subject = %+v, want principal-only", got) - } - - // Service token → ServiceKey in context, no Principal. - rr = httptest.NewRecorder() - got = subject{} - handler.ServeHTTP(rr, req(servicePlain)) - if rr.Code != http.StatusOK { - t.Fatalf("service: status = %d, want 200", rr.Code) - } - if got.hasPrincipal || !got.hasServiceKey { - t.Fatalf("service: ctx subject = %+v, want servicekey-only", got) - } - - // Garbage token → 401, neither subject set. - rr = httptest.NewRecorder() - got = subject{} - handler.ServeHTTP(rr, req(strings.Repeat("x", 50))) - if rr.Code != http.StatusUnauthorized { - t.Fatalf("garbage: status = %d, want 401", rr.Code) + if handlerCalled { + t.Fatalf("handler should not run for authenticated request") } } -// Sanity check: the constructed *authkit.Auth should satisfy errors.Is on the -// canonical sentinels — ensures our minimal stores are wired correctly. -func TestSentinelsReachable(t *testing.T) { - a := newTestAuth(t) - _, err := a.AuthenticateServiceKey(context.Background(), "sk_not-real") - if !errors.Is(err, authkit.ErrServiceKeyInvalid) { - t.Fatalf("expected ErrServiceKeyInvalid, got %v", err) +func TestRequireGuest_CustomOnAuthenticated(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "custom@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + access, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + handler := middleware.RequireGuest(middleware.GuestOptions{ + Auth: a, + OnAuthenticated: func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) + }, + })(http.HandlerFunc(ok200)) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(access)) + if rr.Code != http.StatusSeeOther { + t.Fatalf("expected 303, got %d", rr.Code) + } + if got := rr.Header().Get("Location"); got != "/dashboard" { + t.Fatalf("expected Location=/dashboard, got %q", got) } } + +// ─── RequireServiceKey ───────────────────────────────────────────────────── + +func TestRequireServiceKey_AbilityGate(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil { + t.Fatalf("CreateAbility: %v", err) + } + plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{ + Name: "ci", + Abilities: []string{"events:write"}, + }) + if err != nil { + t.Fatalf("IssueServiceKey: %v", err) + } + + handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{ + Auth: a, + Authz: authkit.HasAbility("events:write"), + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + k, ok := authkit.ServiceKeyFromCtx(r.Context()) + if !ok || !k.HasAbility("events:write") { + t.Fatalf("expected ServiceKey with events:write on context") + } + w.WriteHeader(http.StatusOK) + })) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(plain)) + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } +} + +func TestRequireServiceKey_AbilityGateRejectsMissing(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil { + t.Fatalf("CreateAbility events:write: %v", err) + } + if _, err := a.CreateAbility(ctx, "admin:nuke", ""); err != nil { + t.Fatalf("CreateAbility admin:nuke: %v", err) + } + plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{ + Name: "ci", + Abilities: []string{"events:write"}, + }) + if err != nil { + t.Fatalf("IssueServiceKey: %v", err) + } + handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{ + Auth: a, + Authz: authkit.HasAbility("admin:nuke"), + })(http.HandlerFunc(ok200)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, reqWithBearer(plain)) + if rr.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", rr.Code) + } +} + +func TestRequireServiceKey_PanicsOnUnknownAbility(t *testing.T) { + a := freshAuth(t) + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic on unknown ability slug") + } + }() + middleware.RequireServiceKey(middleware.ServiceKeyOptions{ + Auth: a, + Authz: authkit.HasAbility("never-registered"), + }) +} diff --git a/migrations/0001_init.sql b/migrations/0001_init.sql new file mode 100644 index 0000000..0f9ae3e --- /dev/null +++ b/migrations/0001_init.sql @@ -0,0 +1,136 @@ +-- 0001_init.sql +-- Initial authkit schema for PostgreSQL 16+. All tables prefixed authkit_ so +-- the library can be embedded in an existing application database. Each +-- migration owns its transaction and inserts its version row at the bottom; +-- the runner only orchestrates file discovery and concurrency. + +BEGIN; + +CREATE TABLE IF NOT EXISTS authkit_schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL +); + +-- Users. Password is nullable so accounts can be created without a credential +-- and have one set later (invite flows, magic-link-only accounts, etc.). +CREATE TABLE IF NOT EXISTS authkit_users ( + id UUID PRIMARY KEY, + email TEXT NOT NULL, + email_normalized TEXT NOT NULL, + email_verified_at TIMESTAMPTZ, + password_hash TEXT, + session_version INTEGER NOT NULL DEFAULT 0, + last_login_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq + ON authkit_users (email_normalized); + +-- Opaque server-side sessions. +CREATE TABLE IF NOT EXISTS authkit_sessions ( + id_hash BYTEA PRIMARY KEY, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + user_agent TEXT NOT NULL DEFAULT '', + ip TEXT, + created_at TIMESTAMPTZ NOT NULL, + last_seen_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL +); +CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id); +CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at); + +-- Single-use tokens (refresh, email-verify, password-reset, magic-link, email-otp). +-- attempts_remaining is non-null only for tokens that allow retries (email_otp); +-- ConsumeToken decrements and zeroes-out on exhaustion. +CREATE TABLE IF NOT EXISTS authkit_tokens ( + hash BYTEA NOT NULL, + kind TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + chain_id TEXT, + consumed_at TIMESTAMPTZ, + attempts_remaining INTEGER, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (kind, hash) +); +CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id); +CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at); +CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx + ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL; + +-- Service tokens. No owner column: these are machine credentials, intended to +-- be created by applications for outbound API calls or inbound automation. +-- Consumers tag them with whatever metadata they need via Name. +CREATE TABLE IF NOT EXISTS authkit_service_keys ( + id_hash BYTEA PRIMARY KEY, + name TEXT NOT NULL, + last_used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ, + revoked_at TIMESTAMPTZ +); + +-- Roles, permissions, and abilities are seeded by the consumer (typically via +-- the cmd/roles, cmd/perms, cmd/abilities CLIs). They share the same shape: +-- normalised slug as the unique business key, optional human label. +CREATE TABLE IF NOT EXISTS authkit_roles ( + id UUID PRIMARY KEY, + slug TEXT NOT NULL UNIQUE, + label TEXT, + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_permissions ( + id UUID PRIMARY KEY, + slug TEXT NOT NULL UNIQUE, + label TEXT, + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_abilities ( + id UUID PRIMARY KEY, + slug TEXT NOT NULL UNIQUE, + label TEXT, + created_at TIMESTAMPTZ NOT NULL +); + +-- Role ↔ Permission (defines what permissions a role grants). +CREATE TABLE IF NOT EXISTS authkit_role_permissions ( + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE, + PRIMARY KEY (role_id, permission_id) +); + +-- User ↔ Role (which roles a user holds). +CREATE TABLE IF NOT EXISTS authkit_user_roles ( + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + granted_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (user_id, role_id) +); +CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id); + +-- User ↔ Permission (direct grants, in addition to permissions resolved +-- through roles). GetUserPermissions returns the UNION of both paths. +CREATE TABLE IF NOT EXISTS authkit_user_permissions ( + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE, + granted_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (user_id, permission_id) +); +CREATE INDEX IF NOT EXISTS authkit_user_permissions_perm_id_idx ON authkit_user_permissions(permission_id); + +-- ServiceKey ↔ Ability (which abilities a service key carries). +CREATE TABLE IF NOT EXISTS authkit_service_key_abilities ( + service_key_id_hash BYTEA NOT NULL REFERENCES authkit_service_keys(id_hash) ON DELETE CASCADE, + ability_id UUID NOT NULL REFERENCES authkit_abilities(id) ON DELETE CASCADE, + granted_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (service_key_id_hash, ability_id) +); +CREATE INDEX IF NOT EXISTS authkit_service_key_abilities_ability_idx ON authkit_service_key_abilities(ability_id); + +INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now()) +ON CONFLICT (version) DO NOTHING; + +COMMIT; diff --git a/models.go b/models.go index 2425ea7..108c4ed 100644 --- a/models.go +++ b/models.go @@ -7,6 +7,9 @@ import ( "github.com/google/uuid" ) +// User is the canonical account record. Password hash is empty (and stored +// NULL in the DB) when no credential has been set — accounts created via +// invite or magic-link-only flows live in this state until SetPassword runs. type User struct { ID uuid.UUID Email string @@ -14,12 +17,12 @@ type User struct { EmailVerifiedAt *time.Time PasswordHash string SessionVersion int - FailedLogins int LastLoginAt *time.Time CreatedAt time.Time UpdatedAt time.Time } +// Session is an opaque server-side credential bound to one user. type Session struct { IDHash []byte UserID uuid.UUID @@ -30,35 +33,36 @@ type Session struct { ExpiresAt time.Time } +// TokenKind enumerates the single-use credentials persisted in authkit_tokens. type TokenKind string const ( TokenEmailVerify TokenKind = "email_verify" TokenPasswordReset TokenKind = "password_reset" TokenMagicLink TokenKind = "magic_link" + TokenEmailOTP TokenKind = "email_otp" TokenRefresh TokenKind = "refresh" ) +// Token is one row in authkit_tokens. AttemptsRemaining is non-nil only for +// tokens that allow retry on incorrect input (email OTPs); other kinds are +// strictly one-shot via ConsumeToken. type Token struct { - Hash []byte - Kind TokenKind - UserID uuid.UUID - ChainID *string - ConsumedAt *time.Time - CreatedAt time.Time - ExpiresAt time.Time + Hash []byte + Kind TokenKind + UserID uuid.UUID + ChainID *string + ConsumedAt *time.Time + AttemptsRemaining *int + CreatedAt time.Time + ExpiresAt time.Time } -// ServiceKey is an owner-agnostic credential for server-to-server auth. -// OwnerID is not constrained to authkit_users — OwnerKind labels the owner -// namespace (e.g. "application", "tenant") and consumers manage their own -// cascade-on-delete. It is the only credential type that carries free-form -// abilities; user-bound credentials (sessions, JWTs) prove identity and -// resolve permissions through RBAC instead. +// ServiceKey is a machine credential. It carries no identity — service tokens +// are produced by applications for outbound API access or inbound automation, +// and authorize via Abilities resolved through the join table. type ServiceKey struct { IDHash []byte - OwnerID uuid.UUID - OwnerKind string Name string Abilities []string LastUsedAt *time.Time @@ -67,26 +71,41 @@ type ServiceKey struct { RevokedAt *time.Time } -// HasAbility reports whether the service key carries the named ability. -func (k *ServiceKey) HasAbility(name string) bool { +// HasAbility reports whether the service key carries the named ability slug. +func (k *ServiceKey) HasAbility(slug string) bool { for _, a := range k.Abilities { - if a == name { + if a == slug { return true } } return false } +// Role groups permissions for assignment to users. Slug is the immutable +// business key; Label is an optional human-readable name. type Role struct { - ID uuid.UUID - Name string - Description string - CreatedAt time.Time + ID uuid.UUID + Slug string + Label string + CreatedAt time.Time } +// Permission is a unit of authorization. Granted to users either through a +// role or directly via authkit_user_permissions. type Permission struct { - ID uuid.UUID - Name string - Description string - CreatedAt time.Time + ID uuid.UUID + Slug string + Label string + CreatedAt time.Time +} + +// Ability is a unit of authorization for service tokens. Abilities are a +// separate vocabulary from Permissions because they target machines, not +// users — keep them distinct so middleware predicates remain clear about +// which subject they're authorizing. +type Ability struct { + ID uuid.UUID + Slug string + Label string + CreatedAt time.Time } diff --git a/principal.go b/principal.go index b52a963..6b4b35a 100644 --- a/principal.go +++ b/principal.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" ) +// AuthMethod tags how a Principal was authenticated. type AuthMethod string const ( @@ -13,10 +14,10 @@ const ( AuthMethodJWT AuthMethod = "jwt" ) -// Principal represents an authenticated user. It is produced only by -// user-bound auth methods (session, JWT) and carries identity plus -// RBAC-resolved roles/permissions. Service-token auth produces a -// *ServiceKey instead — those credentials carry abilities, not identity. +// Principal represents an authenticated user. Produced only by user-bound +// auth methods (session, JWT) and carries identity plus RBAC-resolved roles +// and permissions. Service-token auth produces a *ServiceKey instead — those +// credentials carry abilities, not identity. type Principal struct { UserID uuid.UUID Method AuthMethod @@ -27,27 +28,32 @@ type Principal struct { ExpiresAt time.Time } -func (p *Principal) HasRole(name string) bool { +// HasRole reports whether the principal holds the named role slug. +func (p *Principal) HasRole(slug string) bool { for _, r := range p.Roles { - if r == name { + if r == slug { return true } } return false } -func (p *Principal) HasAnyRole(names ...string) bool { - for _, n := range names { - if p.HasRole(n) { +// HasAnyRole reports whether the principal holds at least one of the named +// role slugs. +func (p *Principal) HasAnyRole(slugs ...string) bool { + for _, s := range slugs { + if p.HasRole(s) { return true } } return false } -func (p *Principal) HasPermission(name string) bool { +// HasPermission reports whether the principal holds the named permission +// slug, resolved through any combination of roles and direct grants. +func (p *Principal) HasPermission(slug string) bool { for _, perm := range p.Permissions { - if perm == name { + if perm == slug { return true } } diff --git a/service_authz.go b/service_authz.go index 7a5582c..2033cf1 100644 --- a/service_authz.go +++ b/service_authz.go @@ -7,79 +7,122 @@ import ( "github.com/google/uuid" ) -// UserPermissions returns the union of permission names a user holds via -// their assigned roles. Resolved at call time; v1 does not cache. +// AssignRole assigns roleSlug to userID. Idempotent — a duplicate insert +// is a no-op via ON CONFLICT. +func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleSlug string) error { + const op = "authkit.Auth.AssignRole" + r, err := a.storeGetRoleBySlug(ctx, roleSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeAssignRoleToUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RemoveRole removes roleSlug from userID. Idempotent on missing +// assignments. +func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleSlug string) error { + const op = "authkit.Auth.RemoveRole" + r, err := a.storeGetRoleBySlug(ctx, roleSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeRemoveRoleFromUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// UserRoles returns the role slugs assigned to a user. +func (a *Auth) UserRoles(ctx context.Context, userID uuid.UUID) ([]string, error) { + const op = "authkit.Auth.UserRoles" + roles, err := a.storeGetUserRoles(ctx, userID) + if err != nil { + return nil, errx.Wrap(op, err) + } + out := make([]string, len(roles)) + for i, r := range roles { + out[i] = r.Slug + } + return out, nil +} + +// HasRole reports whether the user holds the named role. +func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, slug string) (bool, error) { + const op = "authkit.Auth.HasRole" + ok, err := a.storeHasAnyRole(ctx, userID, []string{slug}) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// HasAnyRole reports whether the user holds at least one of the named roles. +func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, slugs []string) (bool, error) { + const op = "authkit.Auth.HasAnyRole" + ok, err := a.storeHasAnyRole(ctx, userID, slugs) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// GrantPermissionToUser adds a direct permission grant (not through any +// role). Idempotent. +func (a *Auth) GrantPermissionToUser(ctx context.Context, userID uuid.UUID, permSlug string) error { + const op = "authkit.Auth.GrantPermissionToUser" + p, err := a.storeGetPermissionBySlug(ctx, permSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeGrantPermissionToUser(ctx, userID, p.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RevokePermissionFromUser removes a direct permission grant. +func (a *Auth) RevokePermissionFromUser(ctx context.Context, userID uuid.UUID, permSlug string) error { + const op = "authkit.Auth.RevokePermissionFromUser" + p, err := a.storeGetPermissionBySlug(ctx, permSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeRevokePermissionFromUser(ctx, userID, p.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// UserPermissions returns the union of permission slugs the user holds via +// roles and direct grants. func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) { const op = "authkit.Auth.UserPermissions" - perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) + perms, err := a.storeGetUserPermissions(ctx, userID) if err != nil { return nil, errx.Wrap(op, err) } out := make([]string, len(perms)) for i, p := range perms { - out[i] = p.Name + out[i] = p.Slug } return out, nil } -// HasPermission checks whether a user holds the named permission via any -// assigned role. -func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) { +// HasPermission reports whether the user holds the named permission, via +// any combination of role-derived and direct grants. +func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, permSlug string) (bool, error) { const op = "authkit.Auth.HasPermission" perms, err := a.UserPermissions(ctx, userID) if err != nil { return false, errx.Wrap(op, err) } for _, p := range perms { - if p == name { + if p == permSlug { return true, nil } } return false, nil } - -// HasRole checks whether a user is assigned the named role. -func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) { - const op = "authkit.Auth.HasRole" - ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name}) - if err != nil { - return false, errx.Wrap(op, err) - } - return ok, nil -} - -// HasAnyRole checks whether a user holds at least one of the named roles. -func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { - const op = "authkit.Auth.HasAnyRole" - ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names) - if err != nil { - return false, errx.Wrap(op, err) - } - return ok, nil -} - -// AssignRole is a convenience that looks up a role by name and assigns it. -func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error { - const op = "authkit.Auth.AssignRole" - r, err := a.deps.Roles.GetRoleByName(ctx, roleName) - if err != nil { - return errx.Wrap(op, err) - } - if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -// RemoveRole is the symmetric helper for AssignRole. -func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error { - const op = "authkit.Auth.RemoveRole" - r, err := a.deps.Roles.GetRoleByName(ctx, roleName) - if err != nil { - return errx.Wrap(op, err) - } - if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil { - return errx.Wrap(op, err) - } - return nil -} diff --git a/service_jwt.go b/service_jwt.go index d3133b2..f79ba5e 100644 --- a/service_jwt.go +++ b/service_jwt.go @@ -13,7 +13,7 @@ import ( // preserves that chain so reuse-detection can revoke the whole family. func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) { const op = "authkit.Auth.IssueJWT" - u, err := a.deps.Users.GetUserByID(ctx, userID) + u, err := a.storeGetUserByID(ctx, userID) if err != nil { return "", "", errx.Wrap(op, err) } @@ -40,7 +40,7 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, if err != nil { return nil, errx.Wrap(op, ErrTokenInvalid) } - u, err := a.deps.Users.GetUserByID(ctx, uid) + u, err := a.storeGetUserByID(ctx, uid) if err != nil { if errors.Is(err, ErrUserNotFound) { return nil, errx.Wrap(op, ErrTokenInvalid) @@ -64,27 +64,27 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, }, nil } -// RefreshJWT consumes the presented refresh token and mints a new access + -// refresh pair. Reuse of an already-consumed refresh token deletes the -// entire chain (logout-everywhere on that device family) and returns +// RefreshJWT consumes the presented refresh token and mints a new +// access+refresh pair. Reuse of an already-consumed refresh token deletes +// the entire chain (logout-everywhere on that device family) and returns // ErrTokenReused. func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) { const op = "authkit.Auth.RefreshJWT" - hash, ok := parseSecret(prefixRefresh, plaintextRefresh) + hash, ok := ParseOpaqueSecret(prefixRefresh, plaintextRefresh) if !ok { return "", "", errx.Wrap(op, ErrTokenInvalid) } now := a.now() - consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now) + consumed, err := a.storeConsumeToken(ctx, TokenRefresh, hash, now) if err != nil { - // Differentiate plain-invalid (never existed / expired) from - // reuse (existed, already consumed). The presence-check below is - // the reuse signal. + // Differentiate plain-invalid (never existed / expired) from reuse + // (existed, already consumed). Existence-with-consumed is the + // reuse signal. if errors.Is(err, ErrTokenInvalid) { - if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil { + if existing, gerr := a.storeGetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil { if existing.ChainID != nil && *existing.ChainID != "" { - _, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID) + _, _ = a.storeDeleteByChain(ctx, *existing.ChainID) } return "", "", errx.Wrap(op, ErrTokenReused) } @@ -98,7 +98,7 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, } if chainID == "" { // Defensive: every refresh token should be chain-bound. Fall back - // to a fresh chain so we never throw on missing metadata. + // to a fresh chain rather than throwing on missing metadata. chainID = uuid.NewString() } @@ -113,11 +113,9 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, return access, refresh, nil } -// mintRefreshToken stores a fresh refresh token bound to chainID and returns -// the plaintext. func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) { const op = "authkit.Auth.mintRefreshToken" - plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random) + plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixRefresh) if err != nil { return "", errx.Wrap(op, err) } @@ -130,17 +128,16 @@ func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID s CreatedAt: now, ExpiresAt: now.Add(a.cfg.RefreshTokenTTL), } - if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + if err := a.storeCreateToken(ctx, t); err != nil { return "", errx.Wrap(op, err) } return plaintext, nil } -// userSessionVersion fetches the current session_version. Errors collapse to -// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly -// — but we still need a value to embed in the freshly-minted access token. +// userSessionVersion fetches the current session_version. Errors collapse +// to 0 — a stale token will fail AuthenticateJWT cleanly anyway. func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int { - if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil { + if u, err := a.storeGetUserByID(ctx, userID); err == nil { return u.SessionVersion } return 0 diff --git a/service_jwt_test.go b/service_jwt_test.go new file mode 100644 index 0000000..3f94d4d --- /dev/null +++ b/service_jwt_test.go @@ -0,0 +1,67 @@ +package authkit + +import ( + "context" + "errors" + "testing" +) + +func TestIntegration_JWTIssueAuthenticate(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "j@j.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + access, refresh, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if access == "" || refresh == "" { + t.Fatalf("missing access/refresh") + } + p, err := a.AuthenticateJWT(ctx, access) + if err != nil { + t.Fatalf("AuthenticateJWT: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("user id mismatch") + } + if p.Method != AuthMethodJWT { + t.Fatalf("method = %s, want jwt", p.Method) + } +} + +func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "ref@r.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + _, refresh1, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + _, refresh2, err := a.RefreshJWT(ctx, refresh1) + if err != nil { + t.Fatalf("RefreshJWT: %v", err) + } + if refresh1 == refresh2 { + t.Fatalf("refresh did not rotate") + } + if _, _, err := a.RefreshJWT(ctx, refresh1); !errors.Is(err, ErrTokenReused) { + t.Fatalf("expected ErrTokenReused on replay, got %v", err) + } + if _, _, err := a.RefreshJWT(ctx, refresh2); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err) + } +} + +func TestIntegration_JWTInvalidPrefix(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, _, err := a.RefreshJWT(ctx, "not-a-refresh-token"); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid for malformed input, got %v", err) + } +} diff --git a/service_magic.go b/service_magic.go index d200ca5..2c9303e 100644 --- a/service_magic.go +++ b/service_magic.go @@ -2,20 +2,27 @@ package authkit import ( "context" + "errors" "git.juancwu.dev/juancwu/errx" ) // RequestMagicLink mints a single-use magic-link token for the email and -// returns the plaintext for delivery. ErrUserNotFound is returned for -// unregistered emails. +// returns the plaintext for delivery. +// +// Default behavior is anti-enumeration: if the email is not registered, +// returns ("", nil) — the caller cannot distinguish "exists" from "doesn't +// exist". Set Config.RevealUnknownEmail = true to surface ErrUserNotFound. func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) { const op = "authkit.Auth.RequestMagicLink" - u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) if err != nil { + if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail { + return "", nil + } return "", errx.Wrap(op, err) } - plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random) + plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixMagicLink) if err != nil { return "", errx.Wrap(op, err) } @@ -27,34 +34,33 @@ func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, erro CreatedAt: now, ExpiresAt: now.Add(a.cfg.MagicLinkTTL), } - if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + if err := a.storeCreateToken(ctx, t); err != nil { return "", errx.Wrap(op, err) } return plaintext, nil } // ConsumeMagicLink consumes the magic-link token and returns the -// authenticated user. Callers typically follow this with IssueSession or -// IssueJWT to actually log the user in. +// authenticated user. Callers typically follow with IssueSession or +// IssueJWT to actually log the user in. A successful consume implicitly +// verifies the email (the user demonstrably controls the inbox). func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) { const op = "authkit.Auth.ConsumeMagicLink" - hash, ok := parseSecret(prefixMagicLink, plaintextToken) + hash, ok := ParseOpaqueSecret(prefixMagicLink, plaintextToken) if !ok { return nil, errx.Wrap(op, ErrTokenInvalid) } now := a.now() - t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now) + t, err := a.storeConsumeToken(ctx, TokenMagicLink, hash, now) if err != nil { return nil, errx.Wrap(op, err) } - u, err := a.deps.Users.GetUserByID(ctx, t.UserID) + u, err := a.storeGetUserByID(ctx, t.UserID) if err != nil { return nil, errx.Wrap(op, err) } - // A successful magic-link login also implicitly verifies the email - // (the user demonstrably controls the inbox). if u.EmailVerifiedAt == nil { - if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil { + if err := a.storeSetEmailVerified(ctx, u.ID, now); err == nil { u.EmailVerifiedAt = &now } } diff --git a/service_magic_test.go b/service_magic_test.go new file mode 100644 index 0000000..778e793 --- /dev/null +++ b/service_magic_test.go @@ -0,0 +1,41 @@ +package authkit + +import ( + "context" + "errors" + "testing" +) + +func TestIntegration_MagicLinkFlow(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateUser(ctx, "m@m.com"); err != nil { + t.Fatalf("CreateUser: %v", err) + } + tok, err := a.RequestMagicLink(ctx, "m@m.com") + if err != nil { + t.Fatalf("RequestMagicLink: %v", err) + } + u, err := a.ConsumeMagicLink(ctx, tok) + if err != nil { + t.Fatalf("ConsumeMagicLink: %v", err) + } + if u.EmailVerifiedAt == nil { + t.Fatalf("magic link should imply email verification") + } + if _, err := a.ConsumeMagicLink(ctx, tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on reuse, got %v", err) + } +} + +func TestIntegration_MagicLinkUnknownEmailIsSilent(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + tok, err := a.RequestMagicLink(ctx, "nobody@example.com") + if err != nil { + t.Fatalf("expected silent success, got %v", err) + } + if tok != "" { + t.Fatalf("expected empty token for unknown email, got %q", tok) + } +} diff --git a/service_otp.go b/service_otp.go new file mode 100644 index 0000000..3e5a8f7 --- /dev/null +++ b/service_otp.go @@ -0,0 +1,136 @@ +package authkit + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + + "git.juancwu.dev/juancwu/errx" +) + +// RequestEmailOTP mints a numeric one-time code for the email and returns +// the plaintext for delivery. Anti-enumeration: unknown email returns +// ("", nil) unless Config.RevealUnknownEmail is set. +// +// Code length is Config.EmailOTPDigits (default 6). Brute-force resistance +// comes from Config.EmailOTPMaxAttempts (default 5): after N wrong tries +// the code is invalidated, forcing the caller to request a new one. +func (a *Auth) RequestEmailOTP(ctx context.Context, email string) (string, error) { + const op = "authkit.Auth.RequestEmailOTP" + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail { + return "", nil + } + return "", errx.Wrap(op, err) + } + code, err := generateOTPCode(a.cfg.Random, a.cfg.EmailOTPDigits) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + attempts := a.cfg.EmailOTPMaxAttempts + t := &Token{ + Hash: hashOTPCode(code), + Kind: TokenEmailOTP, + UserID: u.ID, + AttemptsRemaining: &attempts, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.EmailOTPTTL), + } + if err := a.storeCreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return code, nil +} + +// ConsumeEmailOTP verifies a code against the most recent active OTP for +// the user behind email. Successful match consumes the row. A wrong code +// decrements attempts_remaining and returns ErrOTPInvalid; reaching zero +// attempts invalidates the OTP. A successful consume implicitly verifies +// the email. +func (a *Auth) ConsumeEmailOTP(ctx context.Context, email, code string) (*User, error) { + const op = "authkit.Auth.ConsumeEmailOTP" + if code == "" { + return nil, errx.Wrap(op, ErrOTPInvalid) + } + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // Don't leak account existence — same error shape as wrong code. + return nil, errx.Wrap(op, ErrOTPInvalid) + } + return nil, errx.Wrap(op, err) + } + + now := a.now() + active, err := a.storeGetActiveOTPForUser(ctx, TokenEmailOTP, uuidArg(u.ID), now) + if err != nil { + return nil, errx.Wrap(op, err) + } + + if !bytes.Equal(active.Hash, hashOTPCode(code)) { + // Wrong code: decrement attempts on the active OTP. If this + // drives attempts_remaining to 0, the row is consumed atomically + // inside the same UPDATE. + _, derr := a.storeDecrementOTPAttempt(ctx, TokenEmailOTP, active.Hash, now) + if derr != nil && !errors.Is(derr, ErrTokenInvalid) { + return nil, errx.Wrap(op, derr) + } + return nil, errx.Wrap(op, ErrOTPInvalid) + } + + if _, err := a.storeConsumeOTPByHash(ctx, TokenEmailOTP, active.Hash, now); err != nil { + return nil, errx.Wrap(op, err) + } + + if u.EmailVerifiedAt == nil { + if err := a.storeSetEmailVerified(ctx, u.ID, now); err == nil { + u.EmailVerifiedAt = &now + } + } + return u, nil +} + +// generateOTPCode produces a numeric code of length digits using a CSPRNG +// (defaults to crypto/rand if rng is nil). Uniformly distributed; rejects +// rolls that would bias the mod operator. +func generateOTPCode(rng io.Reader, digits int) (string, error) { + if digits <= 0 || digits > 12 { + return "", fmt.Errorf("invalid OTP digits: %d", digits) + } + if rng == nil { + rng = rand.Reader + } + max := uint64(1) + for i := 0; i < digits; i++ { + max *= 10 + } + // Reject rolls that fall in the "leftover" partial keyspace at the top + // of uint64 to keep the distribution uniform. + limit := (^uint64(0)) - ((^uint64(0)) % max) + var buf [8]byte + for { + if _, err := io.ReadFull(rng, buf[:]); err != nil { + return "", err + } + v := binary.BigEndian.Uint64(buf[:]) + if v >= limit { + continue + } + return fmt.Sprintf("%0*d", digits, v%max), nil + } +} + +// hashOTPCode returns sha256(code). Codes are short and low-entropy, so the +// hash is purely a database lookup key — the brute-force defense is +// attempts_remaining, not hash strength. +func hashOTPCode(code string) []byte { + sum := sha256.Sum256([]byte(code)) + return sum[:] +} diff --git a/service_otp_test.go b/service_otp_test.go new file mode 100644 index 0000000..53128ed --- /dev/null +++ b/service_otp_test.go @@ -0,0 +1,85 @@ +package authkit + +import ( + "context" + "errors" + "testing" +) + +func TestIntegration_OTPHappyPath(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateUser(ctx, "otp@example.com"); err != nil { + t.Fatalf("CreateUser: %v", err) + } + code, err := a.RequestEmailOTP(ctx, "otp@example.com") + if err != nil { + t.Fatalf("RequestEmailOTP: %v", err) + } + if len(code) != 6 { + t.Fatalf("expected 6-digit OTP, got %q", code) + } + u, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code) + if err != nil { + t.Fatalf("ConsumeEmailOTP: %v", err) + } + if u.EmailVerifiedAt == nil { + t.Fatalf("OTP consume should imply email verification") + } + // Re-using the same code must fail. + if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code); !errors.Is(err, ErrOTPInvalid) { + t.Fatalf("expected ErrOTPInvalid on reuse, got %v", err) + } +} + +func TestIntegration_OTPWrongCodeDecrementsAndExhausts(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateUser(ctx, "otp@example.com"); err != nil { + t.Fatalf("CreateUser: %v", err) + } + code, err := a.RequestEmailOTP(ctx, "otp@example.com") + if err != nil { + t.Fatalf("RequestEmailOTP: %v", err) + } + // freshAuth configures EmailOTPMaxAttempts=3. + for i := 0; i < 3; i++ { + if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", "000000"); !errors.Is(err, ErrOTPInvalid) { + t.Fatalf("attempt %d: expected ErrOTPInvalid, got %v", i, err) + } + } + // After exhausting attempts, even the correct code must fail (the OTP + // row was consumed when attempts hit zero). + if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code); !errors.Is(err, ErrOTPInvalid) { + t.Fatalf("expected ErrOTPInvalid after exhausting attempts, got %v", err) + } +} + +func TestIntegration_OTPUnknownEmailIsSilent(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + code, err := a.RequestEmailOTP(ctx, "nobody@example.com") + if err != nil { + t.Fatalf("expected silent success, got %v", err) + } + if code != "" { + t.Fatalf("expected empty code, got %q", code) + } +} + +func TestIntegration_OTPGeneratorProducesDigitsOnly(t *testing.T) { + for i := 0; i < 50; i++ { + code, err := generateOTPCode(nil, 6) + if err != nil { + t.Fatalf("generateOTPCode: %v", err) + } + if len(code) != 6 { + t.Fatalf("expected 6 chars, got %d (%q)", len(code), code) + } + for _, c := range code { + if c < '0' || c > '9' { + t.Fatalf("non-digit %q in code %q", c, code) + } + } + } +} diff --git a/service_reset.go b/service_reset.go index a2ebbb6..7e9d98e 100644 --- a/service_reset.go +++ b/service_reset.go @@ -8,16 +8,20 @@ import ( ) // RequestPasswordReset mints a single-use password-reset token for the user -// behind email and returns the plaintext for the caller to deliver via email. -// Returns ErrUserNotFound when the email isn't registered (per project -// policy of distinct errors over anti-enumeration). +// behind email and returns the plaintext for delivery. +// +// Default behavior is anti-enumeration: unknown email returns ("", nil). +// Set Config.RevealUnknownEmail = true to surface ErrUserNotFound. func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) { const op = "authkit.Auth.RequestPasswordReset" - u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) if err != nil { + if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail { + return "", nil + } return "", errx.Wrap(op, err) } - plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random) + plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixPasswordRset) if err != nil { return "", errx.Wrap(op, err) } @@ -29,40 +33,37 @@ func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, CreatedAt: now, ExpiresAt: now.Add(a.cfg.PasswordResetTTL), } - if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + if err := a.storeCreateToken(ctx, t); err != nil { return "", errx.Wrap(op, err) } return plaintext, nil } // ConfirmPasswordReset consumes the reset token, sets the new password, -// bumps the user's session_version, and revokes outstanding sessions so the -// reset constitutes a global logout. +// bumps session_version, and revokes outstanding sessions so the reset is +// a global logout. func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error { const op = "authkit.Auth.ConfirmPasswordReset" - hash, ok := parseSecret(prefixPasswordRset, plaintextToken) + hash, ok := ParseOpaqueSecret(prefixPasswordRset, plaintextToken) if !ok { return errx.Wrap(op, ErrTokenInvalid) } now := a.now() - t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now) + t, err := a.storeConsumeToken(ctx, TokenPasswordReset, hash, now) if err != nil { return errx.Wrap(op, err) } - newHash, err := a.deps.Hasher.Hash(newPassword) + newHash, err := a.hasher.Hash(newPassword) if err != nil { return errx.Wrap(op, err) } - if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil { - if errors.Is(err, ErrUserNotFound) { - return errx.Wrap(op, ErrUserNotFound) - } + if err := a.storeSetPassword(ctx, t.UserID, newHash); err != nil { return errx.Wrap(op, err) } - if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil { + if _, err := a.storeBumpSessionVersion(ctx, t.UserID); err != nil { return errx.Wrap(op, err) } - if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil { + if err := a.storeDeleteUserSessions(ctx, t.UserID); err != nil { return errx.Wrap(op, err) } return nil diff --git a/service_seed.go b/service_seed.go new file mode 100644 index 0000000..4c1b5e5 --- /dev/null +++ b/service_seed.go @@ -0,0 +1,201 @@ +package authkit + +import ( + "context" + + "git.juancwu.dev/juancwu/errx" +) + +// CreateRole inserts a new role. Slug must be a valid normalized slug; +// returns ErrSlugInvalid otherwise. Returns ErrSlugTaken if the slug is +// already in use. +func (a *Auth) CreateRole(ctx context.Context, slug, label string) (*Role, error) { + const op = "authkit.Auth.CreateRole" + if err := validateSlug(op, slug); err != nil { + return nil, err + } + r := &Role{Slug: slug, Label: label} + if err := a.storeCreateRole(ctx, r); err != nil { + return nil, errx.Wrap(op, err) + } + return r, nil +} + +// GetRoleBySlug fetches a role by its slug. +func (a *Auth) GetRoleBySlug(ctx context.Context, slug string) (*Role, error) { + const op = "authkit.Auth.GetRoleBySlug" + r, err := a.storeGetRoleBySlug(ctx, slug) + if err != nil { + return nil, errx.Wrap(op, err) + } + return r, nil +} + +// ListRoles returns every role ordered by slug. +func (a *Auth) ListRoles(ctx context.Context) ([]*Role, error) { + const op = "authkit.Auth.ListRoles" + out, err := a.storeListRoles(ctx) + if err != nil { + return nil, errx.Wrap(op, err) + } + return out, nil +} + +// DeleteRole removes a role by its slug. Cascades to user_roles and +// role_permissions. Returns ErrRoleNotFound if absent. +func (a *Auth) DeleteRole(ctx context.Context, slug string) error { + const op = "authkit.Auth.DeleteRole" + r, err := a.storeGetRoleBySlug(ctx, slug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeDeleteRole(ctx, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// CreatePermission inserts a new permission. +func (a *Auth) CreatePermission(ctx context.Context, slug, label string) (*Permission, error) { + const op = "authkit.Auth.CreatePermission" + if err := validateSlug(op, slug); err != nil { + return nil, err + } + p := &Permission{Slug: slug, Label: label} + if err := a.storeCreatePermission(ctx, p); err != nil { + return nil, errx.Wrap(op, err) + } + return p, nil +} + +// GetPermissionBySlug fetches a permission by its slug. +func (a *Auth) GetPermissionBySlug(ctx context.Context, slug string) (*Permission, error) { + const op = "authkit.Auth.GetPermissionBySlug" + p, err := a.storeGetPermissionBySlug(ctx, slug) + if err != nil { + return nil, errx.Wrap(op, err) + } + return p, nil +} + +// ListPermissions returns every permission ordered by slug. +func (a *Auth) ListPermissions(ctx context.Context) ([]*Permission, error) { + const op = "authkit.Auth.ListPermissions" + out, err := a.storeListPermissions(ctx) + if err != nil { + return nil, errx.Wrap(op, err) + } + return out, nil +} + +// DeletePermission removes a permission by its slug. Cascades to +// role_permissions and user_permissions. +func (a *Auth) DeletePermission(ctx context.Context, slug string) error { + const op = "authkit.Auth.DeletePermission" + p, err := a.storeGetPermissionBySlug(ctx, slug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeDeletePermission(ctx, p.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// CreateAbility inserts a new ability for service tokens. +func (a *Auth) CreateAbility(ctx context.Context, slug, label string) (*Ability, error) { + const op = "authkit.Auth.CreateAbility" + if err := validateSlug(op, slug); err != nil { + return nil, err + } + ab := &Ability{Slug: slug, Label: label} + if err := a.storeCreateAbility(ctx, ab); err != nil { + return nil, errx.Wrap(op, err) + } + return ab, nil +} + +// GetAbilityBySlug fetches an ability by its slug. +func (a *Auth) GetAbilityBySlug(ctx context.Context, slug string) (*Ability, error) { + const op = "authkit.Auth.GetAbilityBySlug" + ab, err := a.storeGetAbilityBySlug(ctx, slug) + if err != nil { + return nil, errx.Wrap(op, err) + } + return ab, nil +} + +// ListAbilities returns every ability ordered by slug. +func (a *Auth) ListAbilities(ctx context.Context) ([]*Ability, error) { + const op = "authkit.Auth.ListAbilities" + out, err := a.storeListAbilities(ctx) + if err != nil { + return nil, errx.Wrap(op, err) + } + return out, nil +} + +// DeleteAbility removes an ability by its slug. Cascades to +// service_key_abilities. +func (a *Auth) DeleteAbility(ctx context.Context, slug string) error { + const op = "authkit.Auth.DeleteAbility" + ab, err := a.storeGetAbilityBySlug(ctx, slug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeDeleteAbility(ctx, ab.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// GrantPermissionToRole adds permSlug to roleSlug's permission set. +// Idempotent. +func (a *Auth) GrantPermissionToRole(ctx context.Context, roleSlug, permSlug string) error { + const op = "authkit.Auth.GrantPermissionToRole" + r, err := a.storeGetRoleBySlug(ctx, roleSlug) + if err != nil { + return errx.Wrap(op, err) + } + p, err := a.storeGetPermissionBySlug(ctx, permSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeAssignPermissionToRole(ctx, r.ID, p.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RevokePermissionFromRole removes permSlug from roleSlug's permission set. +// Idempotent. +func (a *Auth) RevokePermissionFromRole(ctx context.Context, roleSlug, permSlug string) error { + const op = "authkit.Auth.RevokePermissionFromRole" + r, err := a.storeGetRoleBySlug(ctx, roleSlug) + if err != nil { + return errx.Wrap(op, err) + } + p, err := a.storeGetPermissionBySlug(ctx, permSlug) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeRemovePermissionFromRole(ctx, r.ID, p.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// ListRolePermissions returns permissions granted to a role through the +// role-permission link only (not direct user grants). +func (a *Auth) ListRolePermissions(ctx context.Context, roleSlug string) ([]*Permission, error) { + const op = "authkit.Auth.ListRolePermissions" + r, err := a.storeGetRoleBySlug(ctx, roleSlug) + if err != nil { + return nil, errx.Wrap(op, err) + } + out, err := a.storeGetRolePermissions(ctx, r.ID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return out, nil +} diff --git a/service_seed_test.go b/service_seed_test.go new file mode 100644 index 0000000..c6ff152 --- /dev/null +++ b/service_seed_test.go @@ -0,0 +1,136 @@ +package authkit + +import ( + "context" + "errors" + "testing" +) + +func TestIntegration_SeedRolesAndPermissions(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + r, err := a.CreateRole(ctx, "editor", "Editor") + if err != nil { + t.Fatalf("CreateRole: %v", err) + } + if r.Slug != "editor" || r.Label != "Editor" { + t.Fatalf("role round-trip mismatch: %+v", r) + } + if _, err := a.CreateRole(ctx, "editor", ""); !errors.Is(err, ErrSlugTaken) { + t.Fatalf("expected ErrSlugTaken on duplicate, got %v", err) + } + if _, err := a.CreateRole(ctx, "Editor", ""); !errors.Is(err, ErrSlugInvalid) { + t.Fatalf("expected ErrSlugInvalid on uppercase, got %v", err) + } + + if _, err := a.CreatePermission(ctx, "posts:write", "Write posts"); err != nil { + t.Fatalf("CreatePermission: %v", err) + } + if err := a.GrantPermissionToRole(ctx, "editor", "posts:write"); err != nil { + t.Fatalf("GrantPermissionToRole: %v", err) + } + + roles, err := a.ListRoles(ctx) + if err != nil { + t.Fatalf("ListRoles: %v", err) + } + if len(roles) != 1 || roles[0].Slug != "editor" { + t.Fatalf("ListRoles unexpected: %+v", roles) + } + + perms, err := a.ListRolePermissions(ctx, "editor") + if err != nil { + t.Fatalf("ListRolePermissions: %v", err) + } + if len(perms) != 1 || perms[0].Slug != "posts:write" { + t.Fatalf("ListRolePermissions unexpected: %+v", perms) + } +} + +func TestIntegration_DirectUserPermissionsUnion(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "u@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + + // Permission via role. + if _, err := a.CreateRole(ctx, "editor", ""); err != nil { + t.Fatalf("CreateRole: %v", err) + } + if _, err := a.CreatePermission(ctx, "posts:read", ""); err != nil { + t.Fatalf("CreatePermission posts:read: %v", err) + } + if err := a.GrantPermissionToRole(ctx, "editor", "posts:read"); err != nil { + t.Fatalf("GrantPermissionToRole: %v", err) + } + if err := a.AssignRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + + // Direct permission grant. + if _, err := a.CreatePermission(ctx, "billing:view", ""); err != nil { + t.Fatalf("CreatePermission billing:view: %v", err) + } + if err := a.GrantPermissionToUser(ctx, u.ID, "billing:view"); err != nil { + t.Fatalf("GrantPermissionToUser: %v", err) + } + + got, err := a.UserPermissions(ctx, u.ID) + if err != nil { + t.Fatalf("UserPermissions: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 perms (UNION), got %v", got) + } + want := map[string]bool{"posts:read": true, "billing:view": true} + for _, p := range got { + if !want[p] { + t.Fatalf("unexpected permission %q", p) + } + delete(want, p) + } + if len(want) != 0 { + t.Fatalf("missing permissions: %v", want) + } + + // Revoke direct grant; only role-derived remains. + if err := a.RevokePermissionFromUser(ctx, u.ID, "billing:view"); err != nil { + t.Fatalf("RevokePermissionFromUser: %v", err) + } + got2, err := a.UserPermissions(ctx, u.ID) + if err != nil { + t.Fatalf("UserPermissions post-revoke: %v", err) + } + if len(got2) != 1 || got2[0] != "posts:read" { + t.Fatalf("expected only posts:read after revoke, got %v", got2) + } +} + +func TestIntegration_RolePermissionMembershipQueries(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "h@h.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if _, err := a.CreateRole(ctx, "manager", ""); err != nil { + t.Fatalf("CreateRole: %v", err) + } + if err := a.AssignRole(ctx, u.ID, "manager"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + ok, err := a.HasRole(ctx, u.ID, "manager") + if err != nil || !ok { + t.Fatalf("HasRole: ok=%v err=%v", ok, err) + } + ok, err = a.HasAnyRole(ctx, u.ID, []string{"admin", "manager"}) + if err != nil || !ok { + t.Fatalf("HasAnyRole: ok=%v err=%v", ok, err) + } + ok, err = a.HasAnyRole(ctx, u.ID, []string{"admin", "ops"}) + if err != nil || ok { + t.Fatalf("HasAnyRole should be false: ok=%v err=%v", ok, err) + } +} diff --git a/service_service_key.go b/service_service_key.go index 1ccc9b3..fe3d895 100644 --- a/service_service_key.go +++ b/service_service_key.go @@ -2,19 +2,50 @@ package authkit import ( "context" + "errors" "time" "git.juancwu.dev/juancwu/errx" "github.com/google/uuid" ) -// IssueServiceKey mints a fresh owner-agnostic service token. ownerKind is a -// consumer-defined namespace label (e.g. "application", "tenant") and ownerID -// is the owning entity's id; authkit makes no assumption about either. The -// plaintext is returned (show-once) and the SHA-256 lookup hash is stored. -// Pass ttl=nil for a non-expiring key. -func (a *Auth) IssueServiceKey(ctx context.Context, ownerKind string, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *ServiceKey, error) { +// IssueServiceKeyParams is the input shape for IssueServiceKey. Abilities +// are slugs that must already exist in authkit_abilities — issue fails with +// ErrAbilityNotFound if any slug is unknown. TTL is optional; nil means +// non-expiring. +type IssueServiceKeyParams struct { + Name string + Abilities []string + TTL *time.Duration +} + +// IssueServiceKey mints a fresh service token. Plaintext is returned exactly +// once (show-once); only the SHA-256 hash is persisted. Each ability slug is +// resolved to its row before insertion, so the service key carries a +// well-defined set of abilities even after later slug renames or deletes. +func (a *Auth) IssueServiceKey(ctx context.Context, params IssueServiceKeyParams) (string, *ServiceKey, error) { const op = "authkit.Auth.IssueServiceKey" + if params.Name == "" { + return "", nil, errx.New(op, "Name is required") + } + abilityIDs := make([]uuid.UUID, 0, len(params.Abilities)) + resolved := make([]string, 0, len(params.Abilities)) + seen := map[string]struct{}{} + for _, slug := range params.Abilities { + if _, dup := seen[slug]; dup { + continue + } + seen[slug] = struct{}{} + ab, err := a.storeGetAbilityBySlug(ctx, slug) + if err != nil { + if errors.Is(err, ErrAbilityNotFound) { + return "", nil, errx.Wrapf(op, ErrAbilityNotFound, "ability %q is not registered", slug) + } + return "", nil, errx.Wrap(op, err) + } + abilityIDs = append(abilityIDs, ab.ID) + resolved = append(resolved, ab.Slug) + } plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixServiceKey) if err != nil { return "", nil, errx.Wrap(op, err) @@ -22,34 +53,29 @@ func (a *Auth) IssueServiceKey(ctx context.Context, ownerKind string, ownerID uu now := a.now() k := &ServiceKey{ IDHash: hash, - OwnerID: ownerID, - OwnerKind: ownerKind, - Name: name, - Abilities: append([]string(nil), abilities...), + Name: params.Name, + Abilities: resolved, CreatedAt: now, } - if ttl != nil { - exp := now.Add(*ttl) + if params.TTL != nil { + exp := now.Add(*params.TTL) k.ExpiresAt = &exp } - if err := a.deps.ServiceKeys.CreateServiceKey(ctx, k); err != nil { + if err := a.storeCreateServiceKey(ctx, k, abilityIDs); err != nil { return "", nil, errx.Wrap(op, err) } return plaintext, k, nil } -// AuthenticateServiceKey validates a service token, touches last_used_at -// (best-effort), and returns the stored *ServiceKey. Unlike API keys, no -// Principal is returned — service tokens have no owning user, so the -// Principal abstraction does not fit. Consumers needing a Principal can -// build one from the returned key. +// AuthenticateServiceKey validates a service token and returns the stored +// *ServiceKey with its abilities resolved. func (a *Auth) AuthenticateServiceKey(ctx context.Context, plaintext string) (*ServiceKey, error) { const op = "authkit.Auth.AuthenticateServiceKey" hash, ok := ParseOpaqueSecret(prefixServiceKey, plaintext) if !ok { return nil, errx.Wrap(op, ErrServiceKeyInvalid) } - k, err := a.deps.ServiceKeys.GetServiceKey(ctx, hash) + k, err := a.storeGetServiceKey(ctx, hash) if err != nil { return nil, errx.Wrap(op, err) } @@ -60,10 +86,8 @@ func (a *Auth) AuthenticateServiceKey(ctx context.Context, plaintext string) (*S if k.ExpiresAt != nil && !k.ExpiresAt.After(now) { return nil, errx.Wrap(op, ErrServiceKeyInvalid) } - _ = a.deps.ServiceKeys.TouchServiceKey(ctx, hash, now) - out := *k - out.Abilities = append([]string(nil), k.Abilities...) - return &out, nil + _ = a.storeTouchServiceKey(ctx, hash, now) + return k, nil } // RevokeServiceKey marks a service token revoked. Idempotent on @@ -74,17 +98,17 @@ func (a *Auth) RevokeServiceKey(ctx context.Context, plaintext string) error { if !ok { return errx.Wrap(op, ErrServiceKeyInvalid) } - if err := a.deps.ServiceKeys.RevokeServiceKey(ctx, hash, a.now()); err != nil { + if err := a.storeRevokeServiceKey(ctx, hash, a.now()); err != nil { return errx.Wrap(op, err) } return nil } -// ListServiceKeys returns every service token issued for the given -// (ownerKind, ownerID) pair, including revoked and expired keys. -func (a *Auth) ListServiceKeys(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*ServiceKey, error) { +// ListServiceKeys returns every service token, including revoked and +// expired ones, ordered by creation time descending. +func (a *Auth) ListServiceKeys(ctx context.Context) ([]*ServiceKey, error) { const op = "authkit.Auth.ListServiceKeys" - out, err := a.deps.ServiceKeys.ListServiceKeysByOwner(ctx, ownerKind, ownerID) + out, err := a.storeListServiceKeys(ctx) if err != nil { return nil, errx.Wrap(op, err) } diff --git a/service_service_key_test.go b/service_service_key_test.go index c95d7a2..7480814 100644 --- a/service_service_key_test.go +++ b/service_service_key_test.go @@ -2,183 +2,123 @@ package authkit import ( "context" - "encoding/base64" "errors" "strings" "testing" "time" - - "github.com/google/uuid" ) -func TestServiceKeyRoundtrip(t *testing.T) { - a := newTestAuth(t) - appID := uuid.New() - plaintext, k, err := a.IssueServiceKey(context.Background(), - "application", appID, "events-ingest", - []string{"events:write", "events:read"}, nil) +func TestIntegration_ServiceKeyRoundtrip(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "events:write", "Write events"); err != nil { + t.Fatalf("CreateAbility events:write: %v", err) + } + if _, err := a.CreateAbility(ctx, "events:read", "Read events"); err != nil { + t.Fatalf("CreateAbility events:read: %v", err) + } + + plain, k, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{ + Name: "events-ingest", + Abilities: []string{"events:write", "events:read"}, + }) if err != nil { t.Fatalf("IssueServiceKey: %v", err) } - if plaintext == "" || k == nil { - t.Fatalf("missing plaintext or key") + if !strings.HasPrefix(plain, "sk_") { + t.Fatalf("plaintext should start with sk_: %q", plain) } - got, err := a.AuthenticateServiceKey(context.Background(), plaintext) + if k.Name != "events-ingest" { + t.Fatalf("name mismatch: %q", k.Name) + } + if len(k.Abilities) != 2 { + t.Fatalf("expected 2 abilities, got %v", k.Abilities) + } + + got, err := a.AuthenticateServiceKey(ctx, plain) if err != nil { t.Fatalf("AuthenticateServiceKey: %v", err) } - if got.OwnerKind != "application" || got.OwnerID != appID { - t.Fatalf("owner mismatch: kind=%q id=%v", got.OwnerKind, got.OwnerID) - } - if got.Name != "events-ingest" { - t.Fatalf("name mismatch: %q", got.Name) - } - if len(got.Abilities) != 2 || got.Abilities[0] != "events:write" || got.Abilities[1] != "events:read" { - t.Fatalf("abilities mismatch: %+v", got.Abilities) - } - got.Abilities[0] = "tampered" - again, err := a.AuthenticateServiceKey(context.Background(), plaintext) - if err != nil { - t.Fatalf("AuthenticateServiceKey (re-auth): %v", err) - } - if again.Abilities[0] != "events:write" { - t.Fatalf("returned slice was not deep-copied; saw mutation: %+v", again.Abilities) + if !got.HasAbility("events:write") || !got.HasAbility("events:read") { + t.Fatalf("missing expected abilities: %+v", got.Abilities) } } -func TestServiceKeyPlaintextShape(t *testing.T) { - a := newTestAuth(t) - plaintext, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "name", nil, nil) +func TestIntegration_ServiceKeyRejectsUnknownAbility(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + _, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{ + Name: "x", + Abilities: []string{"never-registered"}, + }) + if !errors.Is(err, ErrAbilityNotFound) { + t.Fatalf("expected ErrAbilityNotFound, got %v", err) + } +} + +func TestIntegration_ServiceKeyRevoke(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "ops", ""); err != nil { + t.Fatalf("CreateAbility: %v", err) + } + plain, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{ + Name: "ci", + Abilities: []string{"ops"}, + }) if err != nil { t.Fatalf("IssueServiceKey: %v", err) } - if !strings.HasPrefix(plaintext, "sk_") { - t.Fatalf("plaintext missing sk_ prefix: %q", plaintext) - } - body := strings.TrimPrefix(plaintext, "sk_") - raw, err := base64.RawURLEncoding.DecodeString(body) - if err != nil { - t.Fatalf("base64 decode: %v", err) - } - if len(raw) != 32 { - t.Fatalf("body decoded to %d bytes, want 32", len(raw)) - } -} - -func TestServiceKeyWrongPrefix(t *testing.T) { - a := newTestAuth(t) - _, err := a.AuthenticateServiceKey(context.Background(), "ak_not-a-service-key") - if !errors.Is(err, ErrServiceKeyInvalid) { - t.Fatalf("expected ErrServiceKeyInvalid for wrong prefix, got %v", err) - } -} - -func TestServiceKeyAfterRevoke(t *testing.T) { - a := newTestAuth(t) - plaintext, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ci", nil, nil) - if err != nil { - t.Fatalf("IssueServiceKey: %v", err) - } - if err := a.RevokeServiceKey(context.Background(), plaintext); err != nil { + if err := a.RevokeServiceKey(ctx, plain); err != nil { t.Fatalf("RevokeServiceKey: %v", err) } - if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); !errors.Is(err, ErrServiceKeyInvalid) { + if _, err := a.AuthenticateServiceKey(ctx, plain); !errors.Is(err, ErrServiceKeyInvalid) { t.Fatalf("expected ErrServiceKeyInvalid post-revoke, got %v", err) } } -func TestServiceKeyAfterExpiry(t *testing.T) { - a := newTestAuth(t) +func TestIntegration_ServiceKeyExpiry(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "ops", ""); err != nil { + t.Fatalf("CreateAbility: %v", err) + } now := time.Now().UTC() a.cfg.Clock = func() time.Time { return now } ttl := time.Minute - plaintext, _, err := a.IssueServiceKey(context.Background(), - "application", uuid.New(), "ephemeral", nil, &ttl) + plain, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{ + Name: "ephemeral", + Abilities: []string{"ops"}, + TTL: &ttl, + }) if err != nil { t.Fatalf("IssueServiceKey: %v", err) } a.cfg.Clock = func() time.Time { return now.Add(2 * time.Minute) } - if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); !errors.Is(err, ErrServiceKeyInvalid) { + if _, err := a.AuthenticateServiceKey(ctx, plain); !errors.Is(err, ErrServiceKeyInvalid) { t.Fatalf("expected ErrServiceKeyInvalid post-expiry, got %v", err) } } -func TestServiceKeyListByOwner(t *testing.T) { - a := newTestAuth(t) - appA := uuid.New() - appB := uuid.New() - for i := 0; i < 2; i++ { - if _, _, err := a.IssueServiceKey(context.Background(), "application", appA, "k", nil, nil); err != nil { - t.Fatalf("Issue appA #%d: %v", i, err) +func TestIntegration_ServiceKeyList(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateAbility(ctx, "ops", ""); err != nil { + t.Fatalf("CreateAbility: %v", err) + } + for i := 0; i < 3; i++ { + if _, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{ + Name: "k", + Abilities: []string{"ops"}, + }); err != nil { + t.Fatalf("IssueServiceKey: %v", err) } } - if _, _, err := a.IssueServiceKey(context.Background(), "application", appB, "k", nil, nil); err != nil { - t.Fatalf("Issue appB: %v", err) - } - gotA, err := a.ListServiceKeys(context.Background(), "application", appA) + out, err := a.ListServiceKeys(ctx) if err != nil { - t.Fatalf("ListServiceKeys appA: %v", err) + t.Fatalf("ListServiceKeys: %v", err) } - if len(gotA) != 2 { - t.Fatalf("ListServiceKeys appA = %d keys, want 2", len(gotA)) - } - gotB, err := a.ListServiceKeys(context.Background(), "application", appB) - if err != nil { - t.Fatalf("ListServiceKeys appB: %v", err) - } - if len(gotB) != 1 { - t.Fatalf("ListServiceKeys appB = %d keys, want 1", len(gotB)) - } - gotTenantA, err := a.ListServiceKeys(context.Background(), "tenant", appA) - if err != nil { - t.Fatalf("ListServiceKeys tenant/appA: %v", err) - } - if len(gotTenantA) != 0 { - t.Fatalf("ListServiceKeys tenant/appA = %d, want 0 (different owner_kind)", len(gotTenantA)) - } -} - -func TestServiceKeyHasAbility(t *testing.T) { - k := &ServiceKey{Abilities: []string{"events:write", "events:read"}} - if !k.HasAbility("events:write") { - t.Fatalf("expected HasAbility(events:write) = true") - } - if !k.HasAbility("events:read") { - t.Fatalf("expected HasAbility(events:read) = true") - } - if k.HasAbility("admin:nuke") { - t.Fatalf("expected HasAbility(admin:nuke) = false") - } - empty := &ServiceKey{} - if empty.HasAbility("anything") { - t.Fatalf("HasAbility on empty Abilities must be false") - } -} - -func TestServiceKeyTouchUpdatesLastUsedAt(t *testing.T) { - a := newTestAuth(t) - appID := uuid.New() - plaintext, _, err := a.IssueServiceKey(context.Background(), "application", appID, "k", nil, nil) - if err != nil { - t.Fatalf("IssueServiceKey: %v", err) - } - keys, err := a.ListServiceKeys(context.Background(), "application", appID) - if err != nil || len(keys) != 1 { - t.Fatalf("pre-touch list: err=%v len=%d", err, len(keys)) - } - if keys[0].LastUsedAt != nil { - t.Fatalf("expected LastUsedAt=nil before authenticate, got %v", *keys[0].LastUsedAt) - } - if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); err != nil { - t.Fatalf("AuthenticateServiceKey: %v", err) - } - keys, err = a.ListServiceKeys(context.Background(), "application", appID) - if err != nil || len(keys) != 1 { - t.Fatalf("post-touch list: err=%v len=%d", err, len(keys)) - } - if keys[0].LastUsedAt == nil { - t.Fatalf("expected LastUsedAt to be set after authenticate") + if len(out) != 3 { + t.Fatalf("expected 3 keys, got %d", len(out)) } } diff --git a/service_session.go b/service_session.go index a6573d6..ecd0422 100644 --- a/service_session.go +++ b/service_session.go @@ -14,7 +14,7 @@ import ( // returns the plaintext (for the cookie) plus the stored Session. func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) { const op = "authkit.Auth.IssueSession" - plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random) + plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixSession) if err != nil { return "", nil, errx.Wrap(op, err) } @@ -32,38 +32,35 @@ func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent str LastSeenAt: now, ExpiresAt: expires, } - if err := a.deps.Sessions.CreateSession(ctx, s); err != nil { + if err := a.storeCreateSession(ctx, s); err != nil { return "", nil, errx.Wrap(op, err) } return plaintext, s, nil } // AuthenticateSession validates an opaque session string, slides the TTL, -// resolves the user's roles+permissions, and returns a Principal. Expired or -// unknown sessions return ErrSessionInvalid. +// resolves the user's roles+permissions, and returns a Principal. func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) { const op = "authkit.Auth.AuthenticateSession" - hash, ok := parseSecret(prefixSession, plaintext) + hash, ok := ParseOpaqueSecret(prefixSession, plaintext) if !ok { return nil, errx.Wrap(op, ErrSessionInvalid) } - s, err := a.deps.Sessions.GetSession(ctx, hash) + s, err := a.storeGetSession(ctx, hash) if err != nil { return nil, errx.Wrap(op, err) } now := a.now() if !s.ExpiresAt.After(now) { - _ = a.deps.Sessions.DeleteSession(ctx, hash) + _ = a.storeDeleteSession(ctx, hash) return nil, errx.Wrap(op, ErrSessionInvalid) } - // Slide the idle TTL, capped at created_at + AbsoluteTTL so an active - // session still expires at the absolute boundary. newExpires := now.Add(a.cfg.SessionIdleTTL) if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) { newExpires = cap } - if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil { + if err := a.storeTouchSession(ctx, hash, now, newExpires); err != nil { return nil, errx.Wrap(op, err) } @@ -82,73 +79,85 @@ func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Prin }, nil } -// RevokeSession deletes a single session by its plaintext id. Idempotent: -// missing sessions are not an error (logout twice should not 500). +// RevokeSession deletes a single session by its plaintext id. Idempotent — +// missing sessions are not an error. func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error { const op = "authkit.Auth.RevokeSession" - hash, ok := parseSecret(prefixSession, plaintext) + hash, ok := ParseOpaqueSecret(prefixSession, plaintext) if !ok { return nil } - if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil { + if err := a.storeDeleteSession(ctx, hash); err != nil { return errx.Wrap(op, err) } return nil } // RevokeAllUserSessions kills every active session for the user and bumps -// the user's session_version (invalidating outstanding JWT access tokens). +// session_version (invalidating outstanding JWT access tokens). func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error { const op = "authkit.Auth.RevokeAllUserSessions" - if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + if err := a.storeDeleteUserSessions(ctx, userID); err != nil { return errx.Wrap(op, err) } - if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + if _, err := a.storeBumpSessionVersion(ctx, userID); err != nil { return errx.Wrap(op, err) } return nil } // SessionCookie builds an *http.Cookie pre-configured from Config. Pass the -// plaintext returned by IssueSession; pass the matching ExpiresAt from the -// returned *Session as `expires`. To clear a cookie at logout, pass an empty -// plaintext and a past expiry. +// plaintext returned by IssueSession and the matching ExpiresAt from the +// returned *Session. func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie { - c := &http.Cookie{ + return &http.Cookie{ Name: a.cfg.SessionCookieName, Value: plaintext, Path: a.cfg.SessionCookiePath, Domain: a.cfg.SessionCookieDomain, - Secure: a.cfg.SessionCookieSecure, - HttpOnly: a.cfg.SessionCookieHTTPOnly, + Secure: *a.cfg.SessionCookieSecure, + HttpOnly: *a.cfg.SessionCookieHTTPOnly, SameSite: a.cfg.SessionCookieSameSite, Expires: expires, } - if plaintext == "" { - c.MaxAge = -1 - } - return c } -// resolveRolesAndPermissions fetches the user's role names and the union of -// their permission names. Both are returned as flat string slices for cheap -// containment checks on the Principal. -func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) { - roles, err := a.deps.Roles.GetUserRoles(ctx, userID) - if err != nil { - return nil, nil, err +// ClearSessionCookie returns a cookie that, when set on the response, tells +// the browser to delete the session cookie. Use on logout. +func (a *Auth) ClearSessionCookie() *http.Cookie { + return &http.Cookie{ + Name: a.cfg.SessionCookieName, + Value: "", + Path: a.cfg.SessionCookiePath, + Domain: a.cfg.SessionCookieDomain, + Secure: *a.cfg.SessionCookieSecure, + HttpOnly: *a.cfg.SessionCookieHTTPOnly, + SameSite: a.cfg.SessionCookieSameSite, + MaxAge: -1, } - perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) - if err != nil { - return nil, nil, err - } - rNames := make([]string, len(roles)) - for i, r := range roles { - rNames[i] = r.Name - } - pNames := make([]string, len(perms)) - for i, p := range perms { - pNames[i] = p.Name - } - return rNames, pNames, nil +} + +// SessionCookieName returns the configured cookie name. Useful for callers +// wiring extractors without reaching into Config. +func (a *Auth) SessionCookieName() string { return a.cfg.SessionCookieName } + +// resolveRolesAndPermissions fetches the user's role and permission slugs. +func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) { + roles, err := a.storeGetUserRoles(ctx, userID) + if err != nil { + return nil, nil, err + } + perms, err := a.storeGetUserPermissions(ctx, userID) + if err != nil { + return nil, nil, err + } + rSlugs := make([]string, len(roles)) + for i, r := range roles { + rSlugs[i] = r.Slug + } + pSlugs := make([]string, len(perms)) + for i, p := range perms { + pSlugs[i] = p.Slug + } + return rSlugs, pSlugs, nil } diff --git a/service_session_test.go b/service_session_test.go new file mode 100644 index 0000000..fc72bfa --- /dev/null +++ b/service_session_test.go @@ -0,0 +1,81 @@ +package authkit + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestIntegration_SessionLifecycle(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "s@s.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + plain, sess, err := a.IssueSession(ctx, u.ID, "ua", noIP()) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + if sess.ExpiresAt.Before(time.Now()) { + t.Fatalf("session already expired at issue") + } + p, err := a.AuthenticateSession(ctx, plain) + if err != nil { + t.Fatalf("AuthenticateSession: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("principal user id mismatch") + } + if p.Method != AuthMethodSession { + t.Fatalf("method = %s, want session", p.Method) + } + if err := a.RevokeSession(ctx, plain); err != nil { + t.Fatalf("RevokeSession: %v", err) + } + if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) + } +} + +func TestIntegration_SessionCookieDefaultsSecure(t *testing.T) { + a := freshAuth(t) + c := a.SessionCookie("plaintext", time.Now().Add(time.Hour)) + if !c.Secure { + t.Fatalf("Secure should default to true") + } + if !c.HttpOnly { + t.Fatalf("HttpOnly should default to true") + } + clear := a.ClearSessionCookie() + if clear.MaxAge != -1 || clear.Value != "" { + t.Fatalf("ClearSessionCookie should be MaxAge=-1 and Value=\"\"") + } +} + +func TestIntegration_RevokeAllSessions(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "ra@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + plain, _, err := a.IssueSession(ctx, u.ID, "ua", noIP()) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + access, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if err := a.RevokeAllUserSessions(ctx, u.ID); err != nil { + t.Fatalf("RevokeAllUserSessions: %v", err) + } + if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("session should be revoked, got %v", err) + } + if _, err := a.AuthenticateJWT(ctx, access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("JWT should be invalidated by session_version bump, got %v", err) + } +} diff --git a/service_test.go b/service_test.go deleted file mode 100644 index 3e13cda..0000000 --- a/service_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package authkit - -import ( - "context" - "errors" - "net/netip" - "testing" -) - -func TestRegisterAndLogin(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - if u.EmailNormalized != "alice@example.com" { - t.Fatalf("email_normalized = %q", u.EmailNormalized) - } - got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2") - if err != nil { - t.Fatalf("LoginPassword: %v", err) - } - if got.ID != u.ID { - t.Fatalf("login user mismatch") - } -} - -func TestRegisterDuplicateEmail(t *testing.T) { - a := newTestAuth(t) - if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil { - t.Fatalf("Register: %v", err) - } - _, err := a.Register(context.Background(), "X@Y.COM", "abc") - if !errors.Is(err, ErrEmailTaken) { - t.Fatalf("expected ErrEmailTaken, got %v", err) - } -} - -func TestLoginWrongPassword(t *testing.T) { - a := newTestAuth(t) - if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil { - t.Fatalf("Register: %v", err) - } - if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) { - t.Fatalf("expected ErrInvalidCredentials, got %v", err) - } -} - -func TestSessionIssueAuthenticateRevoke(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "s@s.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) - if err != nil { - t.Fatalf("IssueSession: %v", err) - } - if sess == nil || plain == "" { - t.Fatalf("missing session or plaintext") - } - - p, err := a.AuthenticateSession(context.Background(), plain) - if err != nil { - t.Fatalf("AuthenticateSession: %v", err) - } - if p.UserID != u.ID { - t.Fatalf("principal user id mismatch") - } - if err := a.RevokeSession(context.Background(), plain); err != nil { - t.Fatalf("RevokeSession: %v", err) - } - if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { - t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) - } -} - -func TestEmailVerificationFlow(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "ev@e.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - tok, err := a.RequestEmailVerification(context.Background(), u.ID) - if err != nil { - t.Fatalf("RequestEmailVerification: %v", err) - } - confirmed, err := a.ConfirmEmail(context.Background(), tok) - if err != nil { - t.Fatalf("ConfirmEmail: %v", err) - } - if confirmed.EmailVerifiedAt == nil { - t.Fatalf("email_verified_at not set") - } - // Re-using the token must fail. - if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err) - } -} - -func TestPasswordResetFlow(t *testing.T) { - a := newTestAuth(t) - u, err := a.Register(context.Background(), "r@r.com", "old") - if err != nil { - t.Fatalf("Register: %v", err) - } - // Issue a session that should be invalidated by the reset. - plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) - if err != nil { - t.Fatalf("IssueSession: %v", err) - } - tok, err := a.RequestPasswordReset(context.Background(), "r@r.com") - if err != nil { - t.Fatalf("RequestPasswordReset: %v", err) - } - if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil { - t.Fatalf("ConfirmPasswordReset: %v", err) - } - if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) { - t.Fatalf("old password should fail post-reset, got %v", err) - } - if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil { - t.Fatalf("new password should work post-reset, got %v", err) - } - if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { - t.Fatalf("session should be invalidated by reset, got %v", err) - } -} - -func TestMagicLinkFlow(t *testing.T) { - a := newTestAuth(t) - if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil { - t.Fatalf("Register: %v", err) - } - tok, err := a.RequestMagicLink(context.Background(), "m@m.com") - if err != nil { - t.Fatalf("RequestMagicLink: %v", err) - } - u, err := a.ConsumeMagicLink(context.Background(), tok) - if err != nil { - t.Fatalf("ConsumeMagicLink: %v", err) - } - if u.EmailVerifiedAt == nil { - t.Fatalf("magic link should imply email verification") - } - if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err) - } -} - -func TestRBACRolesAndPermissions(t *testing.T) { - ctx := context.Background() - a := newTestAuth(t) - u, err := a.Register(ctx, "rb@a.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - - // Create role + permission, hook them up. - role := &Role{Name: "editor"} - if err := a.deps.Roles.CreateRole(ctx, role); err != nil { - t.Fatalf("CreateRole: %v", err) - } - perm := &Permission{Name: "posts:write"} - if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil { - t.Fatalf("CreatePermission: %v", err) - } - if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil { - t.Fatalf("AssignPermissionToRole: %v", err) - } - if err := a.AssignRole(ctx, u.ID, "editor"); err != nil { - t.Fatalf("AssignRole: %v", err) - } - ok, err := a.HasPermission(ctx, u.ID, "posts:write") - if err != nil || !ok { - t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err) - } - ok, err = a.HasRole(ctx, u.ID, "editor") - if err != nil || !ok { - t.Fatalf("HasRole editor should be true, got %v %v", ok, err) - } - if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil { - t.Fatalf("RemoveRole: %v", err) - } - ok, _ = a.HasPermission(ctx, u.ID, "posts:write") - if ok { - t.Fatalf("HasPermission should be false after RemoveRole") - } -} diff --git a/service_user.go b/service_user.go index b888897..0b3cf45 100644 --- a/service_user.go +++ b/service_user.go @@ -3,53 +3,61 @@ package authkit import ( "context" "errors" - "strings" "git.juancwu.dev/juancwu/errx" "github.com/google/uuid" ) -// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail -// and the email_normalized column. Trim + lowercase is intentional; we do -// not collapse Gmail-style "+" addressing or strip dots — that's a policy -// decision callers can layer on top. -func normalizeEmail(s string) string { - return strings.ToLower(strings.TrimSpace(s)) -} - -// Register creates a new user with an Argon2id-hashed password. Returns -// ErrEmailTaken if the normalized email is already registered. -func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) { - const op = "authkit.Auth.Register" - if email == "" || password == "" { +// CreateUser registers a new account with the given email. Password is +// optional — accounts can be created without a credential and have one set +// later via SetPassword. Returns ErrEmailTaken if the normalized email is +// already registered. +func (a *Auth) CreateUser(ctx context.Context, email string) (*User, error) { + const op = "authkit.Auth.CreateUser" + if email == "" { return nil, errx.Wrap(op, ErrInvalidCredentials) } - hash, err := a.deps.Hasher.Hash(password) - if err != nil { - return nil, errx.Wrap(op, err) - } now := a.now() u := &User{ ID: uuid.New(), Email: email, EmailNormalized: normalizeEmail(email), - PasswordHash: hash, CreatedAt: now, UpdatedAt: now, } - if err := a.deps.Users.CreateUser(ctx, u); err != nil { + if err := a.storeCreateUser(ctx, u); err != nil { return nil, errx.Wrap(op, err) } return u, nil } +// SetPassword stores a password hash for the user. Use for the +// initial-credential flow and for administrative password changes. +// Bumping session_version is the caller's responsibility; SetPassword does +// not invalidate existing sessions on its own. ChangePassword is the +// safer wrapper for end-user-driven changes. +func (a *Auth) SetPassword(ctx context.Context, userID uuid.UUID, password string) error { + const op = "authkit.Auth.SetPassword" + if password == "" { + return errx.Wrap(op, ErrInvalidCredentials) + } + hash, err := a.hasher.Hash(password) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.storeSetPassword(ctx, userID, hash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + // LoginPassword verifies the password and returns the authenticated user. -// Failure increments failed_logins; success resets it and stamps last_login_at. -// LoginHook (if configured) is invoked with the success outcome — use this to -// hook in rate limiting or audit logging. +// Failure does not increment any counter — consumers wanting lockout should +// implement it via LoginHook (see README). Success resets nothing and stamps +// last_login_at. LoginHook is invoked with the success outcome. func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) { const op = "authkit.Auth.LoginPassword" - u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) if err != nil { _ = a.fireLoginHook(ctx, email, false) if errors.Is(err, ErrUserNotFound) { @@ -58,32 +66,29 @@ func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User return nil, errx.Wrap(op, err) } if u.PasswordHash == "" { + // Password-less account (invite-only / magic-link-only). Treat the + // same as wrong password to avoid leaking account state. _ = a.fireLoginHook(ctx, email, false) return nil, errx.Wrap(op, ErrInvalidCredentials) } - ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash) + ok, needsRehash, err := a.hasher.Verify(password, u.PasswordHash) if err != nil { return nil, errx.Wrap(op, err) } if !ok { - _, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID) _ = a.fireLoginHook(ctx, email, false) return nil, errx.Wrap(op, ErrInvalidCredentials) } now := a.now() u.LastLoginAt = &now - u.FailedLogins = 0 - if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil { - return nil, errx.Wrap(op, err) - } - if err := a.deps.Users.UpdateUser(ctx, u); err != nil { + if err := a.storeUpdateUser(ctx, u); err != nil { return nil, errx.Wrap(op, err) } if needsRehash { - if newHash, herr := a.deps.Hasher.Hash(password); herr == nil { - _ = a.deps.Users.SetPassword(ctx, u.ID, newHash) + if newHash, herr := a.hasher.Hash(password); herr == nil { + _ = a.storeSetPassword(ctx, u.ID, newHash) u.PasswordHash = newHash } } @@ -92,46 +97,76 @@ func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User } // ChangePassword verifies the current password, sets the new one, and bumps -// the user's session_version so all outstanding JWT access tokens are -// instantly invalidated. Outstanding opaque sessions are also revoked. +// the user's session_version (invalidating outstanding JWTs). Outstanding +// opaque sessions are also revoked. func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error { const op = "authkit.Auth.ChangePassword" - u, err := a.deps.Users.GetUserByID(ctx, userID) + u, err := a.storeGetUserByID(ctx, userID) if err != nil { return errx.Wrap(op, err) } if u.PasswordHash == "" { return errx.Wrap(op, ErrInvalidCredentials) } - ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash) + ok, _, err := a.hasher.Verify(oldPassword, u.PasswordHash) if err != nil { return errx.Wrap(op, err) } if !ok { return errx.Wrap(op, ErrInvalidCredentials) } - newHash, err := a.deps.Hasher.Hash(newPassword) + newHash, err := a.hasher.Hash(newPassword) if err != nil { return errx.Wrap(op, err) } - if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil { + if err := a.storeSetPassword(ctx, userID, newHash); err != nil { return errx.Wrap(op, err) } - if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + if _, err := a.storeBumpSessionVersion(ctx, userID); err != nil { return errx.Wrap(op, err) } - if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + if err := a.storeDeleteUserSessions(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// GetUser fetches the user by ID. Returns ErrUserNotFound if absent. +func (a *Auth) GetUser(ctx context.Context, userID uuid.UUID) (*User, error) { + const op = "authkit.Auth.GetUser" + u, err := a.storeGetUserByID(ctx, userID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return u, nil +} + +// GetUserByEmail fetches the user by email (input is normalized internally). +func (a *Auth) GetUserByEmail(ctx context.Context, email string) (*User, error) { + const op = "authkit.Auth.GetUserByEmail" + u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + return nil, errx.Wrap(op, err) + } + return u, nil +} + +// DeleteUser removes the user. Cascades to sessions, tokens, role +// assignments, and direct permission grants via FK ON DELETE CASCADE. +func (a *Auth) DeleteUser(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.Auth.DeleteUser" + if err := a.storeDeleteUser(ctx, userID); err != nil { return errx.Wrap(op, err) } return nil } // RequestEmailVerification mints a single-use email-verify token for the -// user. Return the plaintext to the caller so they can put it in an email -// link; the lookup hash is stored in TokenStore. +// user. Return the plaintext to the caller for delivery; the lookup hash +// is what's stored. func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) { const op = "authkit.Auth.RequestEmailVerification" - plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random) + plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixEmailVerify) if err != nil { return "", errx.Wrap(op, err) } @@ -143,36 +178,42 @@ func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) ( CreatedAt: now, ExpiresAt: now.Add(a.cfg.EmailVerifyTTL), } - if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + if err := a.storeCreateToken(ctx, t); err != nil { return "", errx.Wrap(op, err) } return plaintext, nil } -// ConfirmEmail consumes the verification token and marks the user's email -// verified. Returns ErrTokenInvalid if the token is missing/expired/used. +// ConfirmEmail consumes a verification token and marks the user's email +// verified. Returns ErrTokenInvalid for missing/expired/already-used tokens. func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) { const op = "authkit.Auth.ConfirmEmail" - hash, ok := parseSecret(prefixEmailVerify, plaintextToken) + hash, ok := ParseOpaqueSecret(prefixEmailVerify, plaintextToken) if !ok { return nil, errx.Wrap(op, ErrTokenInvalid) } now := a.now() - t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now) + t, err := a.storeConsumeToken(ctx, TokenEmailVerify, hash, now) if err != nil { return nil, errx.Wrap(op, err) } - if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil { + if err := a.storeSetEmailVerified(ctx, t.UserID, now); err != nil { return nil, errx.Wrap(op, err) } - return a.deps.Users.GetUserByID(ctx, t.UserID) + return a.storeGetUserByID(ctx, t.UserID) } -// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied -// hooks; we never want a misbehaving telemetry hook to break login. -func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error { +// fireLoginHook runs Config.LoginHook if configured. Returned errors are +// surfaced for the caller to log; they never break login. The hook is +// wrapped in recover() so a misbehaving hook can't take down the auth path. +func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) (err error) { if a.cfg.LoginHook == nil { return nil } + defer func() { + if r := recover(); r != nil { + err = errx.Newf("authkit.fireLoginHook", "login hook panicked: %v", r) + } + }() return a.cfg.LoginHook(ctx, email, success) } diff --git a/service_user_test.go b/service_user_test.go new file mode 100644 index 0000000..52ccb3a --- /dev/null +++ b/service_user_test.go @@ -0,0 +1,149 @@ +package authkit + +import ( + "context" + "errors" + "testing" +) + +func TestIntegration_CreateUserNoPasswordThenLoginFails(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "alice@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if u.PasswordHash != "" { + t.Fatalf("password should be empty for fresh user") + } + // Login against the password-less user must fail with + // ErrInvalidCredentials, not leak account existence. + if _, err := a.LoginPassword(ctx, "alice@example.com", "anything"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestIntegration_CreateUserSetPasswordThenLogin(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "bob@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if err := a.SetPassword(ctx, u.ID, "hunter2hunter2"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + got, err := a.LoginPassword(ctx, "Bob@Example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("LoginPassword (case-insensitive email): %v", err) + } + if got.ID != u.ID { + t.Fatalf("user id mismatch") + } + if _, err := a.LoginPassword(ctx, "bob@example.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestIntegration_CreateUserDuplicateEmail(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.CreateUser(ctx, "dup@example.com"); err != nil { + t.Fatalf("CreateUser: %v", err) + } + if _, err := a.CreateUser(ctx, "DUP@example.com"); !errors.Is(err, ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken on case-folded duplicate, got %v", err) + } +} + +func TestIntegration_EmailVerificationFlow(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "ev@e.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + tok, err := a.RequestEmailVerification(ctx, u.ID) + if err != nil { + t.Fatalf("RequestEmailVerification: %v", err) + } + confirmed, err := a.ConfirmEmail(ctx, tok) + if err != nil { + t.Fatalf("ConfirmEmail: %v", err) + } + if confirmed.EmailVerifiedAt == nil { + t.Fatalf("email_verified_at not set") + } + if _, err := a.ConfirmEmail(ctx, tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err) + } +} + +func TestIntegration_PasswordResetCascadesSessionInvalidation(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "r@r.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if err := a.SetPassword(ctx, u.ID, "old-password"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + plain, _, err := a.IssueSession(ctx, u.ID, "ua", noIP()) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + tok, err := a.RequestPasswordReset(ctx, "r@r.com") + if err != nil { + t.Fatalf("RequestPasswordReset: %v", err) + } + if tok == "" { + t.Fatalf("expected token for known email") + } + if err := a.ConfirmPasswordReset(ctx, tok, "new-password"); err != nil { + t.Fatalf("ConfirmPasswordReset: %v", err) + } + if _, err := a.LoginPassword(ctx, "r@r.com", "old-password"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("old password should fail, got %v", err) + } + if _, err := a.LoginPassword(ctx, "r@r.com", "new-password"); err != nil { + t.Fatalf("new password should work: %v", err) + } + if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("session should be invalidated by reset: got %v", err) + } +} + +func TestIntegration_PasswordResetUnknownEmailIsSilent(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + tok, err := a.RequestPasswordReset(ctx, "nobody@example.com") + if err != nil { + t.Fatalf("expected silent success, got err %v", err) + } + if tok != "" { + t.Fatalf("expected empty token for unknown email, got %q", tok) + } +} + +func TestIntegration_ChangePasswordRevokesEverything(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "cp@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if err := a.SetPassword(ctx, u.ID, "old-password"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + access, _, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if err := a.ChangePassword(ctx, u.ID, "old-password", "new-password"); err != nil { + t.Fatalf("ChangePassword: %v", err) + } + if _, err := a.AuthenticateJWT(ctx, access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("JWT should be invalidated by password change, got %v", err) + } +} diff --git a/slug.go b/slug.go new file mode 100644 index 0000000..4ee534b --- /dev/null +++ b/slug.go @@ -0,0 +1,32 @@ +package authkit + +import ( + "regexp" + + "git.juancwu.dev/juancwu/errx" +) + +// MaxSlugLength is the upper bound on slug length, in bytes. Slugs are ASCII +// so this also bounds character count. +const MaxSlugLength = 64 + +// slugRE matches the accepted slug shape: a lowercase ASCII letter followed +// by any number of lowercase letters, digits, or one of `_`, `:`, `-`. +// Common valid forms: "admin", "ads-manager", "ads_manager", "posts:write". +var slugRE = regexp.MustCompile(`^[a-z][a-z0-9_:-]*$`) + +// validateSlug returns nil when s is a syntactically valid slug. Strict +// validation, no transformation: the caller must pre-normalize before +// passing in. Wrapped with op for call-site context. +func validateSlug(op, s string) error { + if s == "" { + return errx.Wrap(op, ErrSlugInvalid) + } + if len(s) > MaxSlugLength { + return errx.Wrapf(op, ErrSlugInvalid, "slug exceeds %d bytes", MaxSlugLength) + } + if !slugRE.MatchString(s) { + return errx.Wrapf(op, ErrSlugInvalid, "slug %q does not match %s", s, slugRE.String()) + } + return nil +} diff --git a/slug_test.go b/slug_test.go new file mode 100644 index 0000000..11d517e --- /dev/null +++ b/slug_test.go @@ -0,0 +1,87 @@ +package authkit + +import ( + "errors" + "strings" + "testing" +) + +func TestValidateSlug(t *testing.T) { + cases := []struct { + name string + slug string + wantErr bool + }{ + {"plain lowercase", "admin", false}, + {"snake", "ads_manager", false}, + {"kebab", "ads-manager", false}, + {"colon namespaced", "posts:write", false}, + {"with digits", "v2:posts", false}, + {"single char start", "a", false}, + + {"empty", "", true}, + {"uppercase", "Admin", true}, + {"starts with digit", "1admin", true}, + {"starts with underscore", "_admin", true}, + {"starts with hyphen", "-admin", true}, + {"contains space", "ads manager", true}, + {"contains slash", "posts/write", true}, + {"contains dot", "posts.write", true}, + {"only digits", "12345", true}, + {"too long", strings.Repeat("a", MaxSlugLength+1), true}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := validateSlug("test", c.slug) + if c.wantErr { + if err == nil { + t.Fatalf("expected ErrSlugInvalid for %q, got nil", c.slug) + } + if !errors.Is(err, ErrSlugInvalid) { + t.Fatalf("expected error wrapping ErrSlugInvalid, got %v", err) + } + return + } + if err != nil { + t.Fatalf("expected %q valid, got %v", c.slug, err) + } + }) + } +} + +func TestValidateSlugAtMaxLength(t *testing.T) { + s := strings.Repeat("a", MaxSlugLength) + if err := validateSlug("test", s); err != nil { + t.Fatalf("MaxSlugLength=%d should pass, got %v", MaxSlugLength, err) + } +} + +func FuzzValidateSlug(f *testing.F) { + f.Add("admin") + f.Add("posts:write") + f.Add("ads-manager") + f.Add("Admin") + f.Add("") + f.Add("a/b") + f.Add(strings.Repeat("a", 65)) + f.Fuzz(func(t *testing.T, s string) { + err := validateSlug("fuzz", s) + if err == nil { + // Sanity-check the validation invariants. Anything that passes + // must be ASCII, length <= MaxSlugLength, lowercase-start. + if len(s) == 0 || len(s) > MaxSlugLength { + t.Fatalf("slug %q passed but length %d violates bounds", s, len(s)) + } + c := s[0] + if !(c >= 'a' && c <= 'z') { + t.Fatalf("slug %q passed but starts with %q", s, c) + } + for _, r := range s { + if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && + r != '_' && r != ':' && r != '-' { + t.Fatalf("slug %q passed but contains %q", s, r) + } + } + } + }) +} diff --git a/sqlstore/dialect.go b/sqlstore/dialect.go deleted file mode 100644 index 79b1eae..0000000 --- a/sqlstore/dialect.go +++ /dev/null @@ -1,117 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "io/fs" -) - -// Dialect describes how a particular SQL backend renders the queries authkit -// runs at runtime, bootstraps its session for migrations, and reports -// driver-specific errors. v1 ships the Postgres dialect; future MySQL or -// SQLite implementations satisfy the same interface without changes to the -// store code. -type Dialect interface { - // Name returns a stable short identifier ("postgres", "mysql", ...). - Name() string - - // BuildQueries renders every templated SQL string from a validated - // Schema. Called once at New() and at Migrate() so identifiers and - // placeholder styles are baked in by the time stores execute queries. - BuildQueries(s Schema) Queries - - // Bootstrap runs once per Migrate() before the migration lock is taken. - // It is the place for things that must happen outside any transaction - // (CREATE EXTENSION on Postgres, for example). Must be idempotent and - // may be a no-op. - Bootstrap(ctx context.Context, db *sql.DB) error - - // AcquireMigrationLock takes a session-scoped lock on conn so concurrent - // migrators serialise. The returned release function never returns an - // error — implementations may log internally. - AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) - - // Migrations returns the dialect's embedded .sql files, rooted such - // that fs.ReadDir(".") lists them lex-sorted by filename. - Migrations() fs.FS - - // IsUniqueViolation maps a duplicate-key error from this driver to true - // so insert paths can return the matching authkit sentinel without - // depending on driver internals. - IsUniqueViolation(err error) bool - - // Placeholder returns the placeholder for parameter index n (1-based) - // in this dialect — "$1" for Postgres, "?" for MySQL/SQLite. - Placeholder(n int) string - - // PlaceholderList returns a comma-separated placeholder list for `count` - // parameters starting at position `start` (1-based), suitable for - // dynamic IN-clauses. The second return is the same length, prefilled - // with nil so the caller can append actual values. - PlaceholderList(start, count int) string -} - -// Queries is the full set of statement templates authkit issues. Field names -// match the store method that consumes the query. -type Queries struct { - // users - CreateUser string - GetUserByID string - GetUserByEmail string - UpdateUser string - DeleteUser string - SetPassword string - SetEmailVerified string - BumpSessionVersion string - IncrementFailedLogins string - ResetFailedLogins string - - // sessions - CreateSession string - GetSession string - TouchSession string - DeleteSession string - DeleteUserSessions string - DeleteExpiredSessions string - - // tokens - CreateToken string - ConsumeToken string - GetToken string - DeleteByChain string - DeleteExpiredTokens string - - // service keys - CreateServiceKey string - GetServiceKey string - ListServiceKeysByOwner string - TouchServiceKey string - RevokeServiceKey string - - // roles - CreateRole string - GetRoleByID string - GetRoleByName string - ListRoles string - DeleteRole string - AssignRoleToUser string - RemoveRoleFromUser string - GetUserRoles string - // HasAnyRole is built at call time because the placeholder count varies. - - // permissions - CreatePermission string - GetPermissionByID string - GetPermissionByName string - ListPermissions string - DeletePermission string - AssignPermissionToRole string - RemovePermissionFromRole string - GetRolePermissions string - GetUserPermissions string - - // migrations - CreateMigrationsTable string - SelectAppliedVersions string - InsertAppliedVersion string -} diff --git a/sqlstore/dialect/postgres/errors.go b/sqlstore/dialect/postgres/errors.go deleted file mode 100644 index d4e93ad..0000000 --- a/sqlstore/dialect/postgres/errors.go +++ /dev/null @@ -1,33 +0,0 @@ -package postgres - -import ( - "errors" - - "github.com/jackc/pgx/v5/pgconn" -) - -// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib -// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError. -// lib/pq uses *pq.Error which has a Code field of the same value. -const pgUniqueViolation = "23505" - -// isUniqueViolation inspects err for a Postgres unique-violation, regardless -// of which driver registered the connection. We match on either the pgx -// error type or any error implementing a Code() string method (lib/pq's -// pq.Error has SQLState and Code fields; we check via reflection-free -// duck-typing through an interface). -func isUniqueViolation(err error) bool { - if err == nil { - return false - } - var pgxErr *pgconn.PgError - if errors.As(err, &pgxErr) { - return pgxErr.Code == pgUniqueViolation - } - type sqlStater interface{ SQLState() string } - var s sqlStater - if errors.As(err, &s) { - return s.SQLState() == pgUniqueViolation - } - return false -} diff --git a/sqlstore/dialect/postgres/migrations/0001_init.sql b/sqlstore/dialect/postgres/migrations/0001_init.sql deleted file mode 100644 index e8f751d..0000000 --- a/sqlstore/dialect/postgres/migrations/0001_init.sql +++ /dev/null @@ -1,99 +0,0 @@ --- 0001_init.sql --- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the --- library can be embedded in an existing application database. Each --- migration owns its own transaction and inserts its version row at the --- bottom; the runner only orchestrates file discovery and concurrency. - -BEGIN; - -CREATE TABLE IF NOT EXISTS authkit_schema_migrations ( - version TEXT PRIMARY KEY, - applied_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS authkit_users ( - id UUID PRIMARY KEY, - email TEXT NOT NULL, - email_normalized TEXT NOT NULL, - email_verified_at TIMESTAMPTZ, - password_hash TEXT, - session_version INTEGER NOT NULL DEFAULT 0, - failed_logins INTEGER NOT NULL DEFAULT 0, - last_login_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL, - updated_at TIMESTAMPTZ NOT NULL -); -CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq - ON authkit_users (email_normalized); - -CREATE TABLE IF NOT EXISTS authkit_sessions ( - id_hash BYTEA PRIMARY KEY, - user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, - user_agent TEXT NOT NULL DEFAULT '', - ip TEXT, - created_at TIMESTAMPTZ NOT NULL, - last_seen_at TIMESTAMPTZ NOT NULL, - expires_at TIMESTAMPTZ NOT NULL -); -CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id); -CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at); - -CREATE TABLE IF NOT EXISTS authkit_tokens ( - hash BYTEA NOT NULL, - kind TEXT NOT NULL, - user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, - chain_id TEXT, - consumed_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL, - expires_at TIMESTAMPTZ NOT NULL, - PRIMARY KEY (kind, hash) -); -CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id); -CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at); -CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx - ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL; - -CREATE TABLE IF NOT EXISTS authkit_api_keys ( - id_hash BYTEA PRIMARY KEY, - owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, - name TEXT NOT NULL, - abilities JSONB NOT NULL DEFAULT '[]'::jsonb, - last_used_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL, - expires_at TIMESTAMPTZ, - revoked_at TIMESTAMPTZ -); -CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id); - -CREATE TABLE IF NOT EXISTS authkit_roles ( - id UUID PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - description TEXT NOT NULL DEFAULT '', - created_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS authkit_permissions ( - id UUID PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - description TEXT NOT NULL DEFAULT '', - created_at TIMESTAMPTZ NOT NULL -); - -CREATE TABLE IF NOT EXISTS authkit_role_permissions ( - role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, - permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE, - PRIMARY KEY (role_id, permission_id) -); - -CREATE TABLE IF NOT EXISTS authkit_user_roles ( - user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, - role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, - granted_at TIMESTAMPTZ NOT NULL, - PRIMARY KEY (user_id, role_id) -); -CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id); - -INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now()) -ON CONFLICT (version) DO NOTHING; - -COMMIT; diff --git a/sqlstore/dialect/postgres/migrations/0002_service_keys.sql b/sqlstore/dialect/postgres/migrations/0002_service_keys.sql deleted file mode 100644 index 5c11b19..0000000 --- a/sqlstore/dialect/postgres/migrations/0002_service_keys.sql +++ /dev/null @@ -1,27 +0,0 @@ --- 0002_service_keys.sql --- Adds owner-agnostic service tokens. Unlike authkit_api_keys, owner_id is --- intentionally NOT FK-constrained: consumers manage their own cascades, and --- authkit has no opinion on what "owner" means here (application id, tenant --- id, etc.). - -BEGIN; - -CREATE TABLE IF NOT EXISTS authkit_service_keys ( - id_hash BYTEA PRIMARY KEY, - owner_id UUID NOT NULL, - owner_kind TEXT NOT NULL, - name TEXT NOT NULL, - abilities JSONB NOT NULL DEFAULT '[]'::jsonb, - last_used_at TIMESTAMPTZ, - created_at TIMESTAMPTZ NOT NULL, - expires_at TIMESTAMPTZ, - revoked_at TIMESTAMPTZ -); - -CREATE INDEX IF NOT EXISTS authkit_service_keys_owner_idx - ON authkit_service_keys(owner_kind, owner_id); - -INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0002_service_keys', now()) -ON CONFLICT (version) DO NOTHING; - -COMMIT; diff --git a/sqlstore/dialect/postgres/migrations/0003_drop_api_keys.sql b/sqlstore/dialect/postgres/migrations/0003_drop_api_keys.sql deleted file mode 100644 index d67fcb4..0000000 --- a/sqlstore/dialect/postgres/migrations/0003_drop_api_keys.sql +++ /dev/null @@ -1,13 +0,0 @@ --- 0003_drop_api_keys.sql --- Drops the user-owned API key table. After this migration only service --- tokens carry abilities; user-owned credentials (sessions, JWTs, --- magic-links) prove identity, with permissions resolved via RBAC. - -BEGIN; - -DROP TABLE IF EXISTS authkit_api_keys CASCADE; - -INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0003_drop_api_keys', now()) -ON CONFLICT (version) DO NOTHING; - -COMMIT; diff --git a/sqlstore/dialect/postgres/postgres.go b/sqlstore/dialect/postgres/postgres.go deleted file mode 100644 index 8b5f615..0000000 --- a/sqlstore/dialect/postgres/postgres.go +++ /dev/null @@ -1,270 +0,0 @@ -// Package postgres is the Postgres dialect for authkit/sqlstore. Importing -// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"` -// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`. -package postgres - -import ( - "context" - "database/sql" - "embed" - "fmt" - "io/fs" - "log" - "strings" - - "git.juancwu.dev/juancwu/authkit/sqlstore" -) - -//go:embed migrations/*.sql -var migrationsFS embed.FS - -// Dialect implements sqlstore.Dialect for Postgres. The zero value is the -// only required form; New() returns a pointer for clarity at call sites. -type Dialect struct{} - -// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate. -func New() *Dialect { return &Dialect{} } - -func (Dialect) Name() string { return "postgres" } - -// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable -// across rollouts and unlikely to clash with caller advisory locks. -const advisoryLockKey int64 = 0x617574686b6974 - -func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error { - // Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a - // hook for future migration prerequisites. - return nil -} - -func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) { - if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil { - return nil, err - } - release := func() { - if _, err := conn.ExecContext(context.Background(), - "SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil { - log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err) - } - } - return release, nil -} - -func (Dialect) Migrations() fs.FS { - sub, err := fs.Sub(migrationsFS, "migrations") - if err != nil { - // migrationsFS is statically populated; this can only fail if the - // embed directive is removed, which would be a build-time error. - panic(err) - } - return sub -} - -func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) } - -func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) } - -// PlaceholderList renders a comma-separated `$start,$start+1,...` list of -// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion. -func (Dialect) PlaceholderList(start, count int) string { - if count <= 0 { - return "" - } - var b strings.Builder - for i := 0; i < count; i++ { - if i > 0 { - b.WriteByte(',') - } - fmt.Fprintf(&b, "$%d", start+i) - } - return b.String() -} - -// BuildQueries renders every query authkit issues, with table identifiers -// taken from s and `?` placeholders rewritten to `$N`. Identifiers are -// already validated by Schema.Validate — this is interpolated with -// fmt.Sprintf, so the validation gate is load-bearing. -func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries { - t := s.Tables - q := sqlstore.Queries{ - // users - CreateUser: `INSERT INTO ` + t.Users + ` - (id, email, email_normalized, email_verified_at, password_hash, - session_version, failed_logins, last_login_at, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - GetUserByID: `SELECT id, email, email_normalized, email_verified_at, - password_hash, session_version, failed_logins, last_login_at, - created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`, - GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at, - password_hash, session_version, failed_logins, last_login_at, - created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`, - UpdateUser: `UPDATE ` + t.Users + ` SET - email = ?, email_normalized = ?, email_verified_at = ?, - password_hash = ?, session_version = ?, failed_logins = ?, - last_login_at = ?, updated_at = ? - WHERE id = ?`, - DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`, - SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`, - SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`, - BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`, - IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`, - ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`, - - // sessions - CreateSession: `INSERT INTO ` + t.Sessions + ` - (id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at - FROM ` + t.Sessions + ` WHERE id_hash = ?`, - TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`, - DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`, - DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`, - DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`, - - // tokens - CreateToken: `INSERT INTO ` + t.Tokens + ` - (hash, kind, user_id, chain_id, consumed_at, created_at, expires_at) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - ConsumeToken: `UPDATE ` + t.Tokens + ` - SET consumed_at = ? - WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ? - RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`, - GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at - FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`, - DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`, - DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`, - - // service keys - CreateServiceKey: `INSERT INTO ` + t.ServiceKeys + ` - (id_hash, owner_id, owner_kind, name, abilities, last_used_at, created_at, expires_at, revoked_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, - GetServiceKey: `SELECT id_hash, owner_id, owner_kind, name, abilities, last_used_at, - created_at, expires_at, revoked_at - FROM ` + t.ServiceKeys + ` WHERE id_hash = ?`, - ListServiceKeysByOwner: `SELECT id_hash, owner_id, owner_kind, name, abilities, last_used_at, - created_at, expires_at, revoked_at - FROM ` + t.ServiceKeys + ` WHERE owner_kind = ? AND owner_id = ? ORDER BY created_at DESC`, - TouchServiceKey: `UPDATE ` + t.ServiceKeys + ` SET last_used_at = ? WHERE id_hash = ?`, - RevokeServiceKey: `UPDATE ` + t.ServiceKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`, - - // roles - CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, - GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`, - GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`, - ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`, - DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`, - AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`, - RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`, - GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at - FROM ` + t.Roles + ` r - JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id - WHERE ur.user_id = ? ORDER BY r.name`, - - // permissions - CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, - GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`, - GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`, - ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`, - DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`, - AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?) - ON CONFLICT DO NOTHING`, - RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`, - GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at - FROM ` + t.Permissions + ` p - JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id - WHERE rp.role_id = ? ORDER BY p.name`, - GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at - FROM ` + t.Permissions + ` p - JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id - JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id - WHERE ur.user_id = ? ORDER BY p.name`, - - // migrations - CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` ( - version TEXT PRIMARY KEY, - applied_at TIMESTAMPTZ NOT NULL - )`, - SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations, - InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`, - } - - // Rewrite `?` placeholders to `$N`. Each query is independent; numbering - // resets per query. - q.CreateUser = rebind(q.CreateUser) - q.GetUserByID = rebind(q.GetUserByID) - q.GetUserByEmail = rebind(q.GetUserByEmail) - q.UpdateUser = rebind(q.UpdateUser) - q.DeleteUser = rebind(q.DeleteUser) - q.SetPassword = rebind(q.SetPassword) - q.SetEmailVerified = rebind(q.SetEmailVerified) - q.BumpSessionVersion = rebind(q.BumpSessionVersion) - q.IncrementFailedLogins = rebind(q.IncrementFailedLogins) - q.ResetFailedLogins = rebind(q.ResetFailedLogins) - - q.CreateSession = rebind(q.CreateSession) - q.GetSession = rebind(q.GetSession) - q.TouchSession = rebind(q.TouchSession) - q.DeleteSession = rebind(q.DeleteSession) - q.DeleteUserSessions = rebind(q.DeleteUserSessions) - q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions) - - q.CreateToken = rebind(q.CreateToken) - q.ConsumeToken = rebind(q.ConsumeToken) - q.GetToken = rebind(q.GetToken) - q.DeleteByChain = rebind(q.DeleteByChain) - q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens) - - q.CreateServiceKey = rebind(q.CreateServiceKey) - q.GetServiceKey = rebind(q.GetServiceKey) - q.ListServiceKeysByOwner = rebind(q.ListServiceKeysByOwner) - q.TouchServiceKey = rebind(q.TouchServiceKey) - q.RevokeServiceKey = rebind(q.RevokeServiceKey) - - q.CreateRole = rebind(q.CreateRole) - q.GetRoleByID = rebind(q.GetRoleByID) - q.GetRoleByName = rebind(q.GetRoleByName) - q.ListRoles = rebind(q.ListRoles) - q.DeleteRole = rebind(q.DeleteRole) - q.AssignRoleToUser = rebind(q.AssignRoleToUser) - q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser) - q.GetUserRoles = rebind(q.GetUserRoles) - - q.CreatePermission = rebind(q.CreatePermission) - q.GetPermissionByID = rebind(q.GetPermissionByID) - q.GetPermissionByName = rebind(q.GetPermissionByName) - q.ListPermissions = rebind(q.ListPermissions) - q.DeletePermission = rebind(q.DeletePermission) - q.AssignPermissionToRole = rebind(q.AssignPermissionToRole) - q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole) - q.GetRolePermissions = rebind(q.GetRolePermissions) - q.GetUserPermissions = rebind(q.GetUserPermissions) - - q.SelectAppliedVersions = rebind(q.SelectAppliedVersions) - q.InsertAppliedVersion = rebind(q.InsertAppliedVersion) - // CreateMigrationsTable contains no parameters. - - return q -} - -// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order. -// Our query strings contain no string literals that include `?` (verified -// by inspection of every query in BuildQueries); a literal-aware rewriter -// would be more defensive but is not needed for v1. -func rebind(s string) string { - var b strings.Builder - b.Grow(len(s) + 16) - n := 1 - for i := 0; i < len(s); i++ { - c := s[i] - if c == '?' { - fmt.Fprintf(&b, "$%d", n) - n++ - continue - } - b.WriteByte(c) - } - return b.String() -} - -// Compile-time interface compliance check. -var _ sqlstore.Dialect = (*Dialect)(nil) diff --git a/sqlstore/migrate.go b/sqlstore/migrate.go deleted file mode 100644 index 48906a0..0000000 --- a/sqlstore/migrate.go +++ /dev/null @@ -1,116 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "io/fs" - "sort" - "strings" - "time" - - "git.juancwu.dev/juancwu/errx" -) - -// Migrate applies every embedded migration the dialect ships that has not -// yet been recorded in the schema-migrations table. It is safe to call -// repeatedly and concurrently across processes — the dialect's session -// lock serialises rollouts. -// -// Each migration .sql file is responsible for owning its own -// BEGIN/COMMIT and inserting a row into the schema-migrations table on -// success. The runner only handles file discovery, version tracking, and -// concurrency. -func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error { - const op = "authkit.sqlstore.Migrate" - if db == nil { - return errx.New(op, "db is required") - } - if dialect == nil { - return errx.New(op, "dialect is required") - } - if err := schema.Validate(); err != nil { - return errx.Wrap(op, err) - } - - if err := dialect.Bootstrap(ctx, db); err != nil { - return errx.Wrap(op, err) - } - - conn, err := db.Conn(ctx) - if err != nil { - return errx.Wrap(op, err) - } - defer conn.Close() - - release, err := dialect.AcquireMigrationLock(ctx, conn) - if err != nil { - return errx.Wrap(op, err) - } - defer release() - - q := dialect.BuildQueries(schema) - if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil { - return errx.Wrap(op, err) - } - - applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions) - if err != nil { - return errx.Wrap(op, err) - } - - migs := dialect.Migrations() - files, err := fs.ReadDir(migs, ".") - if err != nil { - return errx.Wrap(op, err) - } - names := make([]string, 0, len(files)) - for _, f := range files { - if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { - names = append(names, f.Name()) - } - } - sort.Strings(names) - - for _, name := range names { - version := strings.TrimSuffix(name, ".sql") - if _, ok := applied[version]; ok { - continue - } - body, err := fs.ReadFile(migs, name) - if err != nil { - return errx.Wrapf(op, err, "read %s", name) - } - if _, err := conn.ExecContext(ctx, string(body)); err != nil { - return errx.Wrapf(op, err, "apply %s", version) - } - } - return nil -} - -// applyVersionRow is intentionally not exposed: migration files own their -// own version-row insert. We keep this helper around in case a dialect ever -// needs to backfill versions from outside a migration body — currently -// unused. -var _ = applyVersionRow - -func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error { - _, err := conn.ExecContext(ctx, insertQ, version, at) - return err -} - -func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) { - rows, err := conn.QueryContext(ctx, q) - if err != nil { - return nil, err - } - defer rows.Close() - out := make(map[string]struct{}) - for rows.Next() { - var v string - if err := rows.Scan(&v); err != nil { - return nil, err - } - out[v] = struct{}{} - } - return out, rows.Err() -} diff --git a/sqlstore/rbac.go b/sqlstore/rbac.go deleted file mode 100644 index 767440e..0000000 --- a/sqlstore/rbac.go +++ /dev/null @@ -1,301 +0,0 @@ -package sqlstore - -import ( - "context" - "fmt" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" - "github.com/google/uuid" -) - -type roleStore struct{ storeBase } -type permissionStore struct{ storeBase } - -// ----- roleStore ------------------------------------------------------------ - -func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error { - const op = "authkit.sqlstore.RoleStore.CreateRole" - if r.ID == uuid.Nil { - r.ID = uuid.New() - } - if r.CreatedAt.IsZero() { - r.CreatedAt = time.Now().UTC() - } - if _, err := s.db.ExecContext(ctx, s.q.CreateRole, - uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil { - if s.d.IsUniqueViolation(err) { - return errx.Wrapf(op, err, "role %q already exists", r.Name) - } - return errx.Wrap(op, err) - } - return nil -} - -func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) { - const op = "authkit.sqlstore.RoleStore.GetRoleByID" - r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id))) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) - } - return r, nil -} - -func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) { - const op = "authkit.sqlstore.RoleStore.GetRoleByName" - r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name)) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) - } - return r, nil -} - -func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) { - const op = "authkit.sqlstore.RoleStore.ListRoles" - rows, err := s.db.QueryContext(ctx, s.q.ListRoles) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.Role - for rows.Next() { - r, err := scanRole(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, r) - } - return out, errx.Wrap(op, rows.Err()) -} - -func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error { - const op = "authkit.sqlstore.RoleStore.DeleteRole" - tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrRoleNotFound) - } - return nil -} - -func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error { - const op = "authkit.sqlstore.RoleStore.AssignRoleToUser" - if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser, - uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error { - const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser" - if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser, - uuidArg(userID), uuidArg(roleID)); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) { - const op = "authkit.sqlstore.RoleStore.GetUserRoles" - rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID)) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.Role - for rows.Next() { - r, err := scanRole(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, r) - } - return out, errx.Wrap(op, rows.Err()) -} - -// HasAnyRole builds the IN-clause at call time because the placeholder count -// depends on len(names). Identifier substitution comes from the validated -// Schema; values are bound, never interpolated. -func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { - const op = "authkit.sqlstore.RoleStore.HasAnyRole" - if len(names) == 0 { - return false, nil - } - // Placeholder $1 is the user_id; $2..$N+1 cover the names slice. - listSQL := s.d.PlaceholderList(2, len(names)) - q := fmt.Sprintf(`SELECT EXISTS ( - SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id - WHERE ur.user_id = %s AND r.name IN (%s) - )`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL) - - args := make([]any, 0, 1+len(names)) - args = append(args, uuidArg(userID)) - for _, n := range names { - args = append(args, n) - } - var ok bool - if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil { - return false, errx.Wrap(op, err) - } - return ok, nil -} - -// ----- permissionStore ------------------------------------------------------ - -func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error { - const op = "authkit.sqlstore.PermissionStore.CreatePermission" - if p.ID == uuid.Nil { - p.ID = uuid.New() - } - if p.CreatedAt.IsZero() { - p.CreatedAt = time.Now().UTC() - } - if _, err := s.db.ExecContext(ctx, s.q.CreatePermission, - uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil { - if s.d.IsUniqueViolation(err) { - return errx.Wrapf(op, err, "permission %q already exists", p.Name) - } - return errx.Wrap(op, err) - } - return nil -} - -func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) { - const op = "authkit.sqlstore.PermissionStore.GetPermissionByID" - p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id))) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) - } - return p, nil -} - -func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) { - const op = "authkit.sqlstore.PermissionStore.GetPermissionByName" - p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name)) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) - } - return p, nil -} - -func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) { - const op = "authkit.sqlstore.PermissionStore.ListPermissions" - rows, err := s.db.QueryContext(ctx, s.q.ListPermissions) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.Permission - for rows.Next() { - p, err := scanPermission(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, p) - } - return out, errx.Wrap(op, rows.Err()) -} - -func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error { - const op = "authkit.sqlstore.PermissionStore.DeletePermission" - tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrPermissionNotFound) - } - return nil -} - -func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error { - const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole" - if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole, - uuidArg(roleID), uuidArg(permID)); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error { - const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole" - if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole, - uuidArg(roleID), uuidArg(permID)); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) { - const op = "authkit.sqlstore.PermissionStore.GetRolePermissions" - rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID)) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.Permission - for rows.Next() { - p, err := scanPermission(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, p) - } - return out, errx.Wrap(op, rows.Err()) -} - -func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) { - const op = "authkit.sqlstore.PermissionStore.GetUserPermissions" - rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID)) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.Permission - for rows.Next() { - p, err := scanPermission(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, p) - } - return out, errx.Wrap(op, rows.Err()) -} - -func scanRole(row rowScanner) (*authkit.Role, error) { - var ( - r authkit.Role - idStr string - ) - if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil { - return nil, err - } - id, err := scanUUID(idStr) - if err != nil { - return nil, err - } - r.ID = id - return &r, nil -} - -func scanPermission(row rowScanner) (*authkit.Permission, error) { - var ( - p authkit.Permission - idStr string - ) - if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil { - return nil, err - } - id, err := scanUUID(idStr) - if err != nil { - return nil, err - } - p.ID = id - return &p, nil -} diff --git a/sqlstore/scan.go b/sqlstore/scan.go deleted file mode 100644 index a962535..0000000 --- a/sqlstore/scan.go +++ /dev/null @@ -1,86 +0,0 @@ -package sqlstore - -import ( - "database/sql" - "net/netip" - "time" - - "github.com/google/uuid" -) - -// rowScanner is the lowest-common-denominator interface satisfied by both -// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops -// uniformly. -type rowScanner interface { - Scan(dest ...any) error -} - -// nullableTime returns nil when t is the zero time, otherwise &t. Callers -// should usually accept *time.Time on the model side and bind via this. -func nullableTime(t *time.Time) any { - if t == nil { - return nil - } - return *t -} - -// nullableString turns "" into nil so columns store NULL rather than an -// empty string. Used for password hashes and similar optional text columns. -func nullableString(s string) any { - if s == "" { - return nil - } - return s -} - -// nullableAddrString returns the string form of a netip.Addr when valid, or -// nil to bind as SQL NULL. Pairs with scanAddr. -func nullableAddrString(a netip.Addr) any { - if !a.IsValid() { - return nil - } - return a.String() -} - -// scanAddr parses a *string column into a netip.Addr. The zero Addr value is -// returned when the column was NULL or empty. -func scanAddr(s *string) (netip.Addr, error) { - if s == nil || *s == "" { - return netip.Addr{}, nil - } - return netip.ParseAddr(*s) -} - -// uuidArg returns the canonical string form of a UUID for binding. Every -// store uses this rather than passing uuid.UUID directly to keep behaviour -// identical across drivers (some accept driver.Valuer, some don't). -func uuidArg(id uuid.UUID) any { return id.String() } - -// scanUUID reads a string column and parses it back to a uuid.UUID. -func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } - -// chainArg returns either a *string or nil for binding the chain_id column. -func chainArg(c *string) any { - if c == nil { - return nil - } - return *c -} - -// scanNullStringPtr converts sql.NullString to *string for the model. -func scanNullStringPtr(ns sql.NullString) *string { - if !ns.Valid { - return nil - } - v := ns.String - return &v -} - -// scanNullTimePtr converts sql.NullTime to *time.Time for the model. -func scanNullTimePtr(nt sql.NullTime) *time.Time { - if !nt.Valid { - return nil - } - t := nt.Time - return &t -} diff --git a/sqlstore/schema.go b/sqlstore/schema.go deleted file mode 100644 index 82f4182..0000000 --- a/sqlstore/schema.go +++ /dev/null @@ -1,78 +0,0 @@ -package sqlstore - -import ( - "regexp" - - "git.juancwu.dev/juancwu/errx" -) - -// Schema lets consumers map authkit storage to their own table names. -// Column overrides are intentionally not present in v1 — adding them later -// is purely additive. -type Schema struct { - Tables Tables -} - -// Tables is the per-table identifier override set. Every field must be a -// valid unquoted SQL identifier (matching identifierRE). Validation runs at -// New() and Migrate() time so SQL injection through Schema is impossible -// past that gate. -type Tables struct { - Users string - Sessions string - Tokens string - ServiceKeys string - Roles string - Permissions string - UserRoles string - RolePermissions string - SchemaMigrations string -} - -// DefaultSchema returns the stock authkit_* names used by the embedded -// migration files. -func DefaultSchema() Schema { - return Schema{Tables: Tables{ - Users: "authkit_users", - Sessions: "authkit_sessions", - Tokens: "authkit_tokens", - ServiceKeys: "authkit_service_keys", - Roles: "authkit_roles", - Permissions: "authkit_permissions", - UserRoles: "authkit_user_roles", - RolePermissions: "authkit_role_permissions", - SchemaMigrations: "authkit_schema_migrations", - }} -} - -// identifierRE matches the safe ASCII identifier subset shared by Postgres, -// MySQL and SQLite when not quoted. Anything outside this set is rejected -// rather than escaped — Schema is not the place to support exotic names. -var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) - -// Validate ensures every Schema.Tables field is a non-empty, safe identifier. -func (s Schema) Validate() error { - const op = "authkit.sqlstore.Schema.Validate" - checks := []struct { - field, value string - }{ - {"Users", s.Tables.Users}, - {"Sessions", s.Tables.Sessions}, - {"Tokens", s.Tables.Tokens}, - {"ServiceKeys", s.Tables.ServiceKeys}, - {"Roles", s.Tables.Roles}, - {"Permissions", s.Tables.Permissions}, - {"UserRoles", s.Tables.UserRoles}, - {"RolePermissions", s.Tables.RolePermissions}, - {"SchemaMigrations", s.Tables.SchemaMigrations}, - } - for _, c := range checks { - if c.value == "" { - return errx.Newf(op, "Schema.Tables.%s is empty", c.field) - } - if !identifierRE.MatchString(c.value) { - return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value) - } - } - return nil -} diff --git a/sqlstore/service_keys.go b/sqlstore/service_keys.go deleted file mode 100644 index 0c52fde..0000000 --- a/sqlstore/service_keys.go +++ /dev/null @@ -1,116 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" - "github.com/google/uuid" -) - -type serviceKeyStore struct{ storeBase } - -func (s *serviceKeyStore) CreateServiceKey(ctx context.Context, k *authkit.ServiceKey) error { - const op = "authkit.sqlstore.ServiceKeyStore.CreateServiceKey" - if k.CreatedAt.IsZero() { - k.CreatedAt = time.Now().UTC() - } - if k.Abilities == nil { - k.Abilities = []string{} - } - abilities, err := json.Marshal(k.Abilities) - if err != nil { - return errx.Wrap(op, err) - } - _, err = s.db.ExecContext(ctx, s.q.CreateServiceKey, - k.IDHash, uuidArg(k.OwnerID), k.OwnerKind, k.Name, abilities, - nullableTime(k.LastUsedAt), k.CreatedAt, - nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt)) - if err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *serviceKeyStore) GetServiceKey(ctx context.Context, idHash []byte) (*authkit.ServiceKey, error) { - const op = "authkit.sqlstore.ServiceKeyStore.GetServiceKey" - k, err := scanServiceKey(s.db.QueryRowContext(ctx, s.q.GetServiceKey, idHash)) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrServiceKeyInvalid)) - } - return k, nil -} - -func (s *serviceKeyStore) ListServiceKeysByOwner(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*authkit.ServiceKey, error) { - const op = "authkit.sqlstore.ServiceKeyStore.ListServiceKeysByOwner" - rows, err := s.db.QueryContext(ctx, s.q.ListServiceKeysByOwner, ownerKind, uuidArg(ownerID)) - if err != nil { - return nil, errx.Wrap(op, err) - } - defer rows.Close() - var out []*authkit.ServiceKey - for rows.Next() { - k, err := scanServiceKey(rows) - if err != nil { - return nil, errx.Wrap(op, err) - } - out = append(out, k) - } - return out, errx.Wrap(op, rows.Err()) -} - -func (s *serviceKeyStore) TouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error { - const op = "authkit.sqlstore.ServiceKeyStore.TouchServiceKey" - if _, err := s.db.ExecContext(ctx, s.q.TouchServiceKey, at, idHash); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *serviceKeyStore) RevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error { - const op = "authkit.sqlstore.ServiceKeyStore.RevokeServiceKey" - tag, err := s.db.ExecContext(ctx, s.q.RevokeServiceKey, at, idHash) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrServiceKeyInvalid) - } - return nil -} - -func scanServiceKey(row rowScanner) (*authkit.ServiceKey, error) { - var ( - k authkit.ServiceKey - ownerIDStr string - abilitiesRaw []byte - lastUsed sql.NullTime - expires sql.NullTime - revoked sql.NullTime - ) - if err := row.Scan(&k.IDHash, &ownerIDStr, &k.OwnerKind, &k.Name, &abilitiesRaw, - &lastUsed, &k.CreatedAt, &expires, &revoked); err != nil { - return nil, err - } - owner, err := scanUUID(ownerIDStr) - if err != nil { - return nil, err - } - k.OwnerID = owner - if len(abilitiesRaw) > 0 { - if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil { - return nil, err - } - } - if k.Abilities == nil { - k.Abilities = []string{} - } - k.LastUsedAt = scanNullTimePtr(lastUsed) - k.ExpiresAt = scanNullTimePtr(expires) - k.RevokedAt = scanNullTimePtr(revoked) - return &k, nil -} diff --git a/sqlstore/sessions.go b/sqlstore/sessions.go deleted file mode 100644 index 8241bb5..0000000 --- a/sqlstore/sessions.go +++ /dev/null @@ -1,98 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" - "github.com/google/uuid" -) - -type sessionStore struct{ storeBase } - -func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error { - const op = "authkit.sqlstore.SessionStore.CreateSession" - now := time.Now().UTC() - if ses.CreatedAt.IsZero() { - ses.CreatedAt = now - } - if ses.LastSeenAt.IsZero() { - ses.LastSeenAt = ses.CreatedAt - } - _, err := s.db.ExecContext(ctx, s.q.CreateSession, - ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP), - ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt) - if err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) { - const op = "authkit.sqlstore.SessionStore.GetSession" - var ( - ses authkit.Session - uidStr string - ipStr sql.NullString - ) - err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan( - &ses.IDHash, &uidStr, &ses.UserAgent, &ipStr, - &ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid)) - } - uid, err := scanUUID(uidStr) - if err != nil { - return nil, errx.Wrap(op, err) - } - ses.UserID = uid - if ipStr.Valid { - addr, err := scanAddr(&ipStr.String) - if err != nil { - return nil, errx.Wrap(op, err) - } - ses.IP = addr - } - return &ses, nil -} - -func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error { - const op = "authkit.sqlstore.SessionStore.TouchSession" - tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrSessionInvalid) - } - return nil -} - -func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error { - const op = "authkit.sqlstore.SessionStore.DeleteSession" - if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error { - const op = "authkit.sqlstore.SessionStore.DeleteUserSessions" - if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil { - return errx.Wrap(op, err) - } - return nil -} - -func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { - const op = "authkit.sqlstore.SessionStore.DeleteExpired" - tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now) - if err != nil { - return 0, errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - return n, nil -} diff --git a/sqlstore/sqlstore.go b/sqlstore/sqlstore.go deleted file mode 100644 index d20ab08..0000000 --- a/sqlstore/sqlstore.go +++ /dev/null @@ -1,59 +0,0 @@ -// Package sqlstore provides database/sql-backed implementations of every -// authkit store interface. It works with any driver (pgx-stdlib, lib/pq, -// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via -// Schema. Driver-specific behaviour lives in a Dialect; v1 ships -// dialect/postgres. -package sqlstore - -import ( - "database/sql" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" -) - -// Stores bundles every store implementation produced by New, ready to drop -// into authkit.Deps. -type Stores struct { - Users authkit.UserStore - Sessions authkit.SessionStore - Tokens authkit.TokenStore - ServiceKeys authkit.ServiceKeyStore - Roles authkit.RoleStore - Permissions authkit.PermissionStore -} - -// New validates the Schema, builds the dialect's templated queries, and -// returns store implementations bound to db. -func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) { - const op = "authkit.sqlstore.New" - if db == nil { - return nil, errx.New(op, "db is required") - } - if dialect == nil { - return nil, errx.New(op, "dialect is required") - } - if err := schema.Validate(); err != nil { - return nil, errx.Wrap(op, err) - } - q := dialect.BuildQueries(schema) - - base := storeBase{db: db, q: q, d: dialect, s: schema} - return &Stores{ - Users: &userStore{storeBase: base}, - Sessions: &sessionStore{storeBase: base}, - Tokens: &tokenStore{storeBase: base}, - ServiceKeys: &serviceKeyStore{storeBase: base}, - Roles: &roleStore{storeBase: base}, - Permissions: &permissionStore{storeBase: base}, - }, nil -} - -// storeBase carries the shared dependencies every store needs. Embedded into -// each concrete store struct. -type storeBase struct { - db *sql.DB - q Queries - d Dialect - s Schema -} diff --git a/sqlstore/sqlstore_test.go b/sqlstore/sqlstore_test.go deleted file mode 100644 index 1bc721c..0000000 --- a/sqlstore/sqlstore_test.go +++ /dev/null @@ -1,284 +0,0 @@ -package sqlstore_test - -// Integration tests against a real Postgres. Skipped unless -// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by -// running Migrate against a randomly-named set of tables — no cleanup -// fixture, no external Docker dependency. - -import ( - "context" - "database/sql" - "errors" - "fmt" - "net/netip" - "os" - "testing" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/authkit/hasher" - "git.juancwu.dev/juancwu/authkit/sqlstore" - pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" - - "github.com/google/uuid" - _ "github.com/jackc/pgx/v5/stdlib" -) - -func envURL(t *testing.T) string { - t.Helper() - url := os.Getenv("AUTHKIT_TEST_DATABASE_URL") - if url == "" { - t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test") - } - return url -} - -// freshDB opens a connection, runs Migrate against the default schema, and -// schedules a teardown that drops every authkit_* table. Tests must run -// sequentially against a single database (the package's tests do, by -// default — go test serialises within a package unless t.Parallel is -// called). -func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) { - t.Helper() - url := envURL(t) - db, err := sql.Open("pgx", url) - if err != nil { - t.Fatalf("sql.Open: %v", err) - } - t.Cleanup(func() { _ = db.Close() }) - if err := db.PingContext(context.Background()); err != nil { - t.Fatalf("ping: %v", err) - } - - schema := sqlstore.DefaultSchema() - // Drop first to start clean — previous failed tests may have left rows. - dropAuthkitTables(t, db, schema) - if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { - t.Fatalf("Migrate: %v", err) - } - t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) - - stores, err := sqlstore.New(db, pgdialect.New(), schema) - if err != nil { - t.Fatalf("sqlstore.New: %v", err) - } - auth := authkit.New(authkit.Deps{ - Users: stores.Users, - Sessions: stores.Sessions, - Tokens: stores.Tokens, - ServiceKeys: stores.ServiceKeys, - Roles: stores.Roles, - Permissions: stores.Permissions, - Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), - }, authkit.Config{ - JWTSecret: []byte("integration-secret-thirty-two!!!!"), - JWTIssuer: "authkit-int", - AccessTokenTTL: 2 * time.Minute, - RefreshTokenTTL: 1 * time.Hour, - SessionIdleTTL: time.Hour, - SessionAbsoluteTTL: 24 * time.Hour, - EmailVerifyTTL: time.Hour, - PasswordResetTTL: time.Hour, - MagicLinkTTL: time.Minute, - }) - return auth, db, schema -} - -func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) { - t.Helper() - tables := []string{ - s.Tables.UserRoles, s.Tables.RolePermissions, - s.Tables.Roles, s.Tables.Permissions, - s.Tables.ServiceKeys, s.Tables.Tokens, - s.Tables.Sessions, s.Tables.Users, - s.Tables.SchemaMigrations, - } - for _, name := range tables { - _, _ = db.ExecContext(context.Background(), - fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name)) - } -} - -func TestIntegration_MigrateIdempotent(t *testing.T) { - url := envURL(t) - db, err := sql.Open("pgx", url) - if err != nil { - t.Fatalf("sql.Open: %v", err) - } - t.Cleanup(func() { _ = db.Close() }) - - schema := sqlstore.DefaultSchema() - t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) - for i := 0; i < 3; i++ { - if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { - t.Fatalf("Migrate iter %d: %v", i, err) - } - } -} - -func TestIntegration_RegisterAndLogin(t *testing.T) { - auth, _, _ := freshDB(t) - ctx := context.Background() - u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") - if err != nil { - t.Fatalf("Register: %v", err) - } - got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2") - if err != nil { - t.Fatalf("LoginPassword: %v", err) - } - if got.ID != u.ID { - t.Fatalf("login user id mismatch") - } - - if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) { - t.Fatalf("expected ErrEmailTaken, got %v", err) - } - if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) { - t.Fatalf("expected ErrInvalidCredentials, got %v", err) - } -} - -func TestIntegration_SessionLifecycle(t *testing.T) { - auth, _, _ := freshDB(t) - ctx := context.Background() - u, err := auth.Register(ctx, "s@s.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1")) - if err != nil { - t.Fatalf("IssueSession: %v", err) - } - if sess.ExpiresAt.Before(time.Now()) { - t.Fatalf("session already expired at issue") - } - if _, err := auth.AuthenticateSession(ctx, plain); err != nil { - t.Fatalf("AuthenticateSession: %v", err) - } - if err := auth.RevokeSession(ctx, plain); err != nil { - t.Fatalf("RevokeSession: %v", err) - } - if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) { - t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) - } -} - -func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) { - auth, _, _ := freshDB(t) - ctx := context.Background() - u, err := auth.Register(ctx, "j@j.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - _, refresh1, err := auth.IssueJWT(ctx, u.ID) - if err != nil { - t.Fatalf("IssueJWT: %v", err) - } - _, refresh2, err := auth.RefreshJWT(ctx, refresh1) - if err != nil { - t.Fatalf("first RefreshJWT: %v", err) - } - if refresh1 == refresh2 { - t.Fatalf("refresh token did not rotate") - } - - if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) { - t.Fatalf("expected ErrTokenReused on replay, got %v", err) - } - if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) { - t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err) - } -} - -func TestIntegration_ServiceKeyFlow(t *testing.T) { - auth, _, _ := freshDB(t) - ctx := context.Background() - appA := uuid.New() - appB := uuid.New() - - plainA1, _, err := auth.IssueServiceKey(ctx, "application", appA, "events-ingest", - []string{"events:write"}, nil) - if err != nil { - t.Fatalf("IssueServiceKey appA #1: %v", err) - } - if _, _, err := auth.IssueServiceKey(ctx, "application", appA, "events-ingest-2", nil, nil); err != nil { - t.Fatalf("IssueServiceKey appA #2: %v", err) - } - if _, _, err := auth.IssueServiceKey(ctx, "application", appB, "billing", nil, nil); err != nil { - t.Fatalf("IssueServiceKey appB: %v", err) - } - - got, err := auth.AuthenticateServiceKey(ctx, plainA1) - if err != nil { - t.Fatalf("AuthenticateServiceKey: %v", err) - } - if got.OwnerKind != "application" || got.OwnerID != appA { - t.Fatalf("owner mismatch: kind=%q id=%v", got.OwnerKind, got.OwnerID) - } - if len(got.Abilities) != 1 || got.Abilities[0] != "events:write" { - t.Fatalf("abilities mismatch: %+v", got.Abilities) - } - - listA, err := auth.ListServiceKeys(ctx, "application", appA) - if err != nil { - t.Fatalf("ListServiceKeys appA: %v", err) - } - if len(listA) != 2 { - t.Fatalf("ListServiceKeys appA = %d, want 2", len(listA)) - } - listB, err := auth.ListServiceKeys(ctx, "application", appB) - if err != nil { - t.Fatalf("ListServiceKeys appB: %v", err) - } - if len(listB) != 1 { - t.Fatalf("ListServiceKeys appB = %d, want 1", len(listB)) - } - - if err := auth.RevokeServiceKey(ctx, plainA1); err != nil { - t.Fatalf("RevokeServiceKey: %v", err) - } - if _, err := auth.AuthenticateServiceKey(ctx, plainA1); !errors.Is(err, authkit.ErrServiceKeyInvalid) { - t.Fatalf("expected ErrServiceKeyInvalid post-revoke, got %v", err) - } -} - -func TestIntegration_RBAC(t *testing.T) { - auth, db, schema := freshDB(t) - ctx := context.Background() - u, err := auth.Register(ctx, "rb@b.com", "pw") - if err != nil { - t.Fatalf("Register: %v", err) - } - - stores, _ := sqlstore.New(db, pgdialect.New(), schema) - r := &authkit.Role{Name: "editor"} - if err := stores.Roles.CreateRole(ctx, r); err != nil { - t.Fatalf("CreateRole: %v", err) - } - p := &authkit.Permission{Name: "posts:write"} - if err := stores.Permissions.CreatePermission(ctx, p); err != nil { - t.Fatalf("CreatePermission: %v", err) - } - if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil { - t.Fatalf("AssignPermissionToRole: %v", err) - } - if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil { - t.Fatalf("AssignRole: %v", err) - } - ok, err := auth.HasPermission(ctx, u.ID, "posts:write") - if err != nil || !ok { - t.Fatalf("HasPermission: %v %v", ok, err) - } - ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"}) - if err != nil || !ok { - t.Fatalf("HasAnyRole: %v %v", ok, err) - } - if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil { - t.Fatalf("RemoveRole: %v", err) - } - ok, _ = auth.HasPermission(ctx, u.ID, "posts:write") - if ok { - t.Fatalf("HasPermission should be false after RemoveRole") - } -} diff --git a/sqlstore/tokens.go b/sqlstore/tokens.go deleted file mode 100644 index 99841fc..0000000 --- a/sqlstore/tokens.go +++ /dev/null @@ -1,92 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" -) - -type tokenStore struct{ storeBase } - -func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error { - const op = "authkit.sqlstore.TokenStore.CreateToken" - if t.CreatedAt.IsZero() { - t.CreatedAt = time.Now().UTC() - } - _, err := s.db.ExecContext(ctx, s.q.CreateToken, - t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID), - nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt) - if err != nil { - return errx.Wrap(op, err) - } - return nil -} - -// ConsumeToken does the find-and-mark-consumed in a single statement so two -// concurrent callers cannot both successfully consume the same token. The -// row is returned for inspection (e.g. ChainID for refresh rotation). -func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) { - const op = "authkit.sqlstore.TokenStore.ConsumeToken" - row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now) - t, err := scanToken(row) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) - } - return t, nil -} - -func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) { - const op = "authkit.sqlstore.TokenStore.GetToken" - row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash) - t, err := scanToken(row) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) - } - return t, nil -} - -func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) { - const op = "authkit.sqlstore.TokenStore.DeleteByChain" - tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID) - if err != nil { - return 0, errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - return n, nil -} - -func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { - const op = "authkit.sqlstore.TokenStore.DeleteExpired" - tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now) - if err != nil { - return 0, errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - return n, nil -} - -func scanToken(row rowScanner) (*authkit.Token, error) { - var ( - t authkit.Token - kind string - userIDStr string - chainID sql.NullString - consumedAt sql.NullTime - ) - if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, - &consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil { - return nil, err - } - t.Kind = authkit.TokenKind(kind) - uid, err := scanUUID(userIDStr) - if err != nil { - return nil, err - } - t.UserID = uid - t.ChainID = scanNullStringPtr(chainID) - t.ConsumedAt = scanNullTimePtr(consumedAt) - return &t, nil -} diff --git a/sqlstore/users.go b/sqlstore/users.go deleted file mode 100644 index b03c379..0000000 --- a/sqlstore/users.go +++ /dev/null @@ -1,186 +0,0 @@ -package sqlstore - -import ( - "context" - "database/sql" - "errors" - "strings" - "time" - - "git.juancwu.dev/juancwu/authkit" - "git.juancwu.dev/juancwu/errx" - "github.com/google/uuid" -) - -type userStore struct{ storeBase } - -func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error { - const op = "authkit.sqlstore.UserStore.CreateUser" - if u.ID == uuid.Nil { - u.ID = uuid.New() - } - if u.EmailNormalized == "" { - u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email)) - } - now := time.Now().UTC() - if u.CreatedAt.IsZero() { - u.CreatedAt = now - } - if u.UpdatedAt.IsZero() { - u.UpdatedAt = now - } - _, err := s.db.ExecContext(ctx, s.q.CreateUser, - uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), - nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, - nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt) - if err != nil { - if s.d.IsUniqueViolation(err) { - return errx.Wrap(op, authkit.ErrEmailTaken) - } - return errx.Wrap(op, err) - } - return nil -} - -func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) { - const op = "authkit.sqlstore.UserStore.GetUserByID" - u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id))) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) - } - return u, nil -} - -func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) { - const op = "authkit.sqlstore.UserStore.GetUserByEmail" - u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail)) - if err != nil { - return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) - } - return u, nil -} - -func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error { - const op = "authkit.sqlstore.UserStore.UpdateUser" - u.UpdatedAt = time.Now().UTC() - tag, err := s.db.ExecContext(ctx, s.q.UpdateUser, - u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), - nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, - nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrUserNotFound) - } - return nil -} - -func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error { - const op = "authkit.sqlstore.UserStore.DeleteUser" - tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrUserNotFound) - } - return nil -} - -func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error { - const op = "authkit.sqlstore.UserStore.SetPassword" - tag, err := s.db.ExecContext(ctx, s.q.SetPassword, - nullableString(encodedHash), time.Now().UTC(), uuidArg(userID)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrUserNotFound) - } - return nil -} - -func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error { - const op = "authkit.sqlstore.UserStore.SetEmailVerified" - tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrUserNotFound) - } - return nil -} - -func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) { - const op = "authkit.sqlstore.UserStore.BumpSessionVersion" - var v int - if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion, - time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil { - return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) - } - return v, nil -} - -func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) { - const op = "authkit.sqlstore.UserStore.IncrementFailedLogins" - var n int - if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins, - time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil { - return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) - } - return n, nil -} - -func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error { - const op = "authkit.sqlstore.UserStore.ResetFailedLogins" - tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID)) - if err != nil { - return errx.Wrap(op, err) - } - n, _ := tag.RowsAffected() - if n == 0 { - return errx.Wrap(op, authkit.ErrUserNotFound) - } - return nil -} - -func scanUser(row rowScanner) (*authkit.User, error) { - var ( - u authkit.User - idStr string - passwordHash sql.NullString - emailVerified sql.NullTime - lastLogin sql.NullTime - ) - if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified, - &passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin, - &u.CreatedAt, &u.UpdatedAt); err != nil { - return nil, err - } - id, err := scanUUID(idStr) - if err != nil { - return nil, err - } - u.ID = id - if passwordHash.Valid { - u.PasswordHash = passwordHash.String - } - u.EmailVerifiedAt = scanNullTimePtr(emailVerified) - u.LastLoginAt = scanNullTimePtr(lastLogin) - return &u, nil -} - -// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so -// callers get reliable errors.Is targets through errx wrapping. -func mapNotFound(err error, sentinel error) error { - if errors.Is(err, sql.ErrNoRows) { - return sentinel - } - return err -} diff --git a/store_abilities.go b/store_abilities.go new file mode 100644 index 0000000..91eaa1a --- /dev/null +++ b/store_abilities.go @@ -0,0 +1,94 @@ +package authkit + +import ( + "context" + "database/sql" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreateAbility(ctx context.Context, ab *Ability) error { + const op = "authkit.storeCreateAbility" + if ab.ID == uuid.Nil { + ab.ID = uuid.New() + } + if ab.CreatedAt.IsZero() { + ab.CreatedAt = a.now() + } + if _, err := a.db.ExecContext(ctx, a.q.createAbility, + uuidArg(ab.ID), ab.Slug, nullableLabel(ab.Label), ab.CreatedAt); err != nil { + if isUniqueViolation(err) { + return errx.Wrap(op, ErrSlugTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetAbilityByID(ctx context.Context, id uuid.UUID) (*Ability, error) { + const op = "authkit.storeGetAbilityByID" + ab, err := scanAbility(a.db.QueryRowContext(ctx, a.q.getAbilityByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrAbilityNotFound)) + } + return ab, nil +} + +func (a *Auth) storeGetAbilityBySlug(ctx context.Context, slug string) (*Ability, error) { + const op = "authkit.storeGetAbilityBySlug" + ab, err := scanAbility(a.db.QueryRowContext(ctx, a.q.getAbilityBySlug, slug)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrAbilityNotFound)) + } + return ab, nil +} + +func (a *Auth) storeListAbilities(ctx context.Context) ([]*Ability, error) { + const op = "authkit.storeListAbilities" + rows, err := a.db.QueryContext(ctx, a.q.listAbilities) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Ability + for rows.Next() { + ab, err := scanAbility(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, ab) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (a *Auth) storeDeleteAbility(ctx context.Context, id uuid.UUID) error { + const op = "authkit.storeDeleteAbility" + tag, err := a.db.ExecContext(ctx, a.q.deleteAbility, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrAbilityNotFound) + } + return nil +} + +func scanAbility(row rowScanner) (*Ability, error) { + var ( + ab Ability + idStr string + label sql.NullString + ) + if err := row.Scan(&idStr, &ab.Slug, &label, &ab.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + ab.ID = id + ab.Label = scanNullString(label) + return &ab, nil +} diff --git a/store_errors.go b/store_errors.go new file mode 100644 index 0000000..83c16f3 --- /dev/null +++ b/store_errors.go @@ -0,0 +1,40 @@ +package authkit + +import ( + "errors" + + "github.com/jackc/pgx/v5/pgconn" +) + +const ( + pgUniqueViolation = "23505" + pgUndefinedTable = "42P01" +) + +// isUniqueViolation reports whether err is a Postgres unique-violation, +// regardless of which driver is registered. +func isUniqueViolation(err error) bool { + return matchesSQLState(err, pgUniqueViolation) +} + +// isMissingTable reports whether err is a Postgres undefined_table error. +// Used to distinguish "schema not yet bootstrapped" from real failures. +func isMissingTable(err error) bool { + return matchesSQLState(err, pgUndefinedTable) +} + +func matchesSQLState(err error, code string) bool { + if err == nil { + return false + } + var pgxErr *pgconn.PgError + if errors.As(err, &pgxErr) { + return pgxErr.Code == code + } + type sqlStater interface{ SQLState() string } + var s sqlStater + if errors.As(err, &s) { + return s.SQLState() == code + } + return false +} diff --git a/store_migrate.go b/store_migrate.go new file mode 100644 index 0000000..db26237 --- /dev/null +++ b/store_migrate.go @@ -0,0 +1,122 @@ +package authkit + +import ( + "context" + "database/sql" + "embed" + "io/fs" + "log" + "sort" + "strings" + + "git.juancwu.dev/juancwu/errx" +) + +//go:embed migrations/*.sql +var migrationsFS embed.FS + +// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. +// Stable across rollouts and unlikely to clash with caller advisory locks. +const advisoryLockKey int64 = 0x617574686b6974 + +// Migrate applies every embedded migration not yet recorded in the +// schema-migrations table. Safe to call repeatedly and concurrently across +// processes; the advisory lock serialises rollouts. Each migration owns its +// own BEGIN/COMMIT. +// +// Embedded migrations hard-code the default authkit_* names. If the consumer +// has overridden any table name, Migrate is a no-op and the consumer is +// responsible for managing DDL out-of-band. +func Migrate(ctx context.Context, db *sql.DB, schema Schema) error { + const op = "authkit.Migrate" + if db == nil { + return errx.New(op, "db is required") + } + if err := schema.Validate(); err != nil { + return errx.Wrap(op, err) + } + + if !schema.isDefault() { + // Custom-named schemas: consumer owns DDL. The verifier still runs + // against the configured names (with default-name fallback) to + // confirm the tables exist and match the expected layout. + return nil + } + + conn, err := db.Conn(ctx) + if err != nil { + return errx.Wrap(op, err) + } + defer conn.Close() + + if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil { + return errx.Wrap(op, err) + } + defer func() { + if _, err := conn.ExecContext(context.Background(), + "SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil { + log.Printf("authkit: pg_advisory_unlock failed: %v", err) + } + }() + + q := buildQueries(schema.Tables) + if _, err := conn.ExecContext(ctx, q.createMigrationsTable); err != nil { + return errx.Wrap(op, err) + } + + applied, err := loadAppliedVersions(ctx, conn, q.selectAppliedVersions) + if err != nil { + return errx.Wrap(op, err) + } + + migs, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + return errx.Wrap(op, err) + } + files, err := fs.ReadDir(migs, ".") + if err != nil { + return errx.Wrap(op, err) + } + names := make([]string, 0, len(files)) + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + names = append(names, f.Name()) + } + } + sort.Strings(names) + + for _, name := range names { + version := strings.TrimSuffix(name, ".sql") + if _, ok := applied[version]; ok { + continue + } + body, err := fs.ReadFile(migs, name) + if err != nil { + return errx.Wrapf(op, err, "read %s", name) + } + if _, err := conn.ExecContext(ctx, string(body)); err != nil { + return errx.Wrapf(op, err, "apply %s", version) + } + } + return nil +} + +func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) { + rows, err := conn.QueryContext(ctx, q) + if err != nil { + if isMissingTable(err) { + return map[string]struct{}{}, nil + } + return nil, err + } + defer rows.Close() + out := make(map[string]struct{}) + for rows.Next() { + var v string + if err := rows.Scan(&v); err != nil { + return nil, err + } + out[v] = struct{}{} + } + return out, rows.Err() +} diff --git a/store_permissions.go b/store_permissions.go new file mode 100644 index 0000000..04b1aec --- /dev/null +++ b/store_permissions.go @@ -0,0 +1,166 @@ +package authkit + +import ( + "context" + "database/sql" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreatePermission(ctx context.Context, p *Permission) error { + const op = "authkit.storeCreatePermission" + if p.ID == uuid.Nil { + p.ID = uuid.New() + } + if p.CreatedAt.IsZero() { + p.CreatedAt = a.now() + } + if _, err := a.db.ExecContext(ctx, a.q.createPermission, + uuidArg(p.ID), p.Slug, nullableLabel(p.Label), p.CreatedAt); err != nil { + if isUniqueViolation(err) { + return errx.Wrap(op, ErrSlugTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error) { + const op = "authkit.storeGetPermissionByID" + p, err := scanPermission(a.db.QueryRowContext(ctx, a.q.getPermissionByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrPermissionNotFound)) + } + return p, nil +} + +func (a *Auth) storeGetPermissionBySlug(ctx context.Context, slug string) (*Permission, error) { + const op = "authkit.storeGetPermissionBySlug" + p, err := scanPermission(a.db.QueryRowContext(ctx, a.q.getPermissionBySlug, slug)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrPermissionNotFound)) + } + return p, nil +} + +func (a *Auth) storeListPermissions(ctx context.Context) ([]*Permission, error) { + const op = "authkit.storeListPermissions" + rows, err := a.db.QueryContext(ctx, a.q.listPermissions) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (a *Auth) storeDeletePermission(ctx context.Context, id uuid.UUID) error { + const op = "authkit.storeDeletePermission" + tag, err := a.db.ExecContext(ctx, a.q.deletePermission, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrPermissionNotFound) + } + return nil +} + +func (a *Auth) storeAssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.storeAssignPermissionToRole" + if _, err := a.db.ExecContext(ctx, a.q.assignPermissionToRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeRemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.storeRemovePermissionFromRole" + if _, err := a.db.ExecContext(ctx, a.q.removePermissionFromRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error) { + const op = "authkit.storeGetRolePermissions" + rows, err := a.db.QueryContext(ctx, a.q.getRolePermissions, uuidArg(roleID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (a *Auth) storeGetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error) { + const op = "authkit.storeGetUserPermissions" + rows, err := a.db.QueryContext(ctx, a.q.getUserPermissions, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (a *Auth) storeGrantPermissionToUser(ctx context.Context, userID, permID uuid.UUID) error { + const op = "authkit.storeGrantPermissionToUser" + if _, err := a.db.ExecContext(ctx, a.q.grantPermissionToUser, + uuidArg(userID), uuidArg(permID), a.now()); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeRevokePermissionFromUser(ctx context.Context, userID, permID uuid.UUID) error { + const op = "authkit.storeRevokePermissionFromUser" + if _, err := a.db.ExecContext(ctx, a.q.revokePermissionFromUser, + uuidArg(userID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func scanPermission(row rowScanner) (*Permission, error) { + var ( + p Permission + idStr string + label sql.NullString + ) + if err := row.Scan(&idStr, &p.Slug, &label, &p.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + p.ID = id + p.Label = scanNullString(label) + return &p, nil +} diff --git a/store_queries.go b/store_queries.go new file mode 100644 index 0000000..9e4fb32 --- /dev/null +++ b/store_queries.go @@ -0,0 +1,249 @@ +package authkit + +import ( + "fmt" + "strings" +) + +// queries holds every SQL string the store issues, with table identifiers +// already substituted from a validated Schema. Built once at New() to avoid +// per-call concatenation. Identifiers are interpolated via concatenation, +// safe because Schema.Validate gated them through identifierRE. +type queries struct { + // users + createUser string + getUserByID string + getUserByEmail string + updateUser string + deleteUser string + setPassword string + setEmailVerified string + bumpSessionVersion string + + // sessions + createSession string + getSession string + touchSession string + deleteSession string + deleteUserSessions string + deleteExpiredSessions string + + // tokens + createToken string + consumeToken string + getToken string + getOTPForUser string + decrementOTPAttempt string + consumeOTPByID string + deleteByChain string + deleteExpiredTokens string + + // service keys + createServiceKey string + getServiceKey string + listServiceKeys string + touchServiceKey string + revokeServiceKey string + getServiceKeyAbilities string + insertServiceKeyAbil string + + // roles + createRole string + getRoleByID string + getRoleBySlug string + listRoles string + deleteRole string + assignRoleToUser string + removeRoleFromUser string + getUserRoles string + hasAnyRolePrefix string + + // permissions + createPermission string + getPermissionByID string + getPermissionBySlug string + listPermissions string + deletePermission string + assignPermissionToRole string + removePermissionFromRole string + getRolePermissions string + getUserPermissions string + + // direct user permissions + grantPermissionToUser string + revokePermissionFromUser string + + // abilities + createAbility string + getAbilityByID string + getAbilityBySlug string + listAbilities string + deleteAbility string + + // migrations + createMigrationsTable string + selectAppliedVersions string +} + +func buildQueries(t Tables) queries { + return queries{ + // users + createUser: `INSERT INTO ` + t.Users + ` + (id, email, email_normalized, email_verified_at, password_hash, + session_version, last_login_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, + getUserByID: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, last_login_at, created_at, updated_at + FROM ` + t.Users + ` WHERE id = $1`, + getUserByEmail: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, last_login_at, created_at, updated_at + FROM ` + t.Users + ` WHERE email_normalized = $1`, + updateUser: `UPDATE ` + t.Users + ` SET + email = $1, email_normalized = $2, email_verified_at = $3, + password_hash = $4, session_version = $5, + last_login_at = $6, updated_at = $7 + WHERE id = $8`, + deleteUser: `DELETE FROM ` + t.Users + ` WHERE id = $1`, + setPassword: `UPDATE ` + t.Users + ` SET password_hash = $1, updated_at = $2 WHERE id = $3`, + setEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = $1, updated_at = $2 WHERE id = $3`, + bumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = $1 WHERE id = $2 RETURNING session_version`, + + // sessions + createSession: `INSERT INTO ` + t.Sessions + ` + (id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7)`, + getSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at + FROM ` + t.Sessions + ` WHERE id_hash = $1`, + touchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = $1, expires_at = $2 WHERE id_hash = $3`, + deleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = $1`, + deleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = $1`, + deleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= $1`, + + // tokens + createToken: `INSERT INTO ` + t.Tokens + ` + (hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + consumeToken: `UPDATE ` + t.Tokens + ` + SET consumed_at = $1 + WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $4 + RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`, + getToken: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at + FROM ` + t.Tokens + ` WHERE kind = $1 AND hash = $2`, + // getOTPForUser returns the most recent unconsumed, unexpired OTP for + // the user, used to verify a code by hash-comparing client input. + getOTPForUser: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at + FROM ` + t.Tokens + ` + WHERE kind = $1 AND user_id = $2 AND consumed_at IS NULL AND expires_at > $3 + ORDER BY created_at DESC LIMIT 1`, + // decrementOTPAttempt drops attempts_remaining by 1 and consumes the + // row when it hits zero. Used after a wrong-code submission. + decrementOTPAttempt: `UPDATE ` + t.Tokens + ` + SET attempts_remaining = GREATEST(COALESCE(attempts_remaining, 0) - 1, 0), + consumed_at = CASE WHEN COALESCE(attempts_remaining, 0) - 1 <= 0 THEN $1 ELSE consumed_at END + WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1 + RETURNING attempts_remaining`, + // consumeOTPByID is the success path: mark the matched OTP consumed. + consumeOTPByID: `UPDATE ` + t.Tokens + ` + SET consumed_at = $1 + WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1 + RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`, + deleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = $1`, + deleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= $1`, + + // service keys + createServiceKey: `INSERT INTO ` + t.ServiceKeys + ` + (id_hash, name, last_used_at, created_at, expires_at, revoked_at) + VALUES ($1, $2, $3, $4, $5, $6)`, + getServiceKey: `SELECT id_hash, name, last_used_at, created_at, expires_at, revoked_at + FROM ` + t.ServiceKeys + ` WHERE id_hash = $1`, + listServiceKeys: `SELECT id_hash, name, last_used_at, created_at, expires_at, revoked_at + FROM ` + t.ServiceKeys + ` ORDER BY created_at DESC`, + touchServiceKey: `UPDATE ` + t.ServiceKeys + ` SET last_used_at = $1 WHERE id_hash = $2`, + revokeServiceKey: `UPDATE ` + t.ServiceKeys + ` SET revoked_at = $1 WHERE id_hash = $2 AND revoked_at IS NULL`, + getServiceKeyAbilities: `SELECT a.slug FROM ` + t.Abilities + ` a + JOIN ` + t.ServiceKeyAbilities + ` ska ON ska.ability_id = a.id + WHERE ska.service_key_id_hash = $1 ORDER BY a.slug`, + insertServiceKeyAbil: `INSERT INTO ` + t.ServiceKeyAbilities + ` + (service_key_id_hash, ability_id, granted_at) VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING`, + + // roles + createRole: `INSERT INTO ` + t.Roles + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`, + getRoleByID: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` WHERE id = $1`, + getRoleBySlug: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` WHERE slug = $1`, + listRoles: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` ORDER BY slug`, + deleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = $1`, + assignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) + VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, + removeRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = $1 AND role_id = $2`, + getUserRoles: `SELECT r.id, r.slug, r.label, r.created_at + FROM ` + t.Roles + ` r JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id + WHERE ur.user_id = $1 ORDER BY r.slug`, + hasAnyRolePrefix: `SELECT EXISTS ( + SELECT 1 FROM ` + t.UserRoles + ` ur JOIN ` + t.Roles + ` r ON r.id = ur.role_id + WHERE ur.user_id = $1 AND r.slug IN (`, + + // permissions + createPermission: `INSERT INTO ` + t.Permissions + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`, + getPermissionByID: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` WHERE id = $1`, + getPermissionBySlug: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` WHERE slug = $1`, + listPermissions: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` ORDER BY slug`, + deletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = $1`, + assignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) + VALUES ($1, $2) ON CONFLICT DO NOTHING`, + removePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = $1 AND permission_id = $2`, + getRolePermissions: `SELECT p.id, p.slug, p.label, p.created_at + FROM ` + t.Permissions + ` p JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id + WHERE rp.role_id = $1 ORDER BY p.slug`, + // UNION of role-derived and direct user permissions. + getUserPermissions: `SELECT DISTINCT p.id, p.slug, p.label, p.created_at FROM ` + t.Permissions + ` p + WHERE p.id IN ( + SELECT rp.permission_id FROM ` + t.RolePermissions + ` rp + JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id + WHERE ur.user_id = $1 + UNION + SELECT up.permission_id FROM ` + t.UserPermissions + ` up + WHERE up.user_id = $1 + ) ORDER BY p.slug`, + + // direct user permissions + grantPermissionToUser: `INSERT INTO ` + t.UserPermissions + ` + (user_id, permission_id, granted_at) VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING`, + revokePermissionFromUser: `DELETE FROM ` + t.UserPermissions + ` + WHERE user_id = $1 AND permission_id = $2`, + + // abilities + createAbility: `INSERT INTO ` + t.Abilities + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`, + getAbilityByID: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` WHERE id = $1`, + getAbilityBySlug: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` WHERE slug = $1`, + listAbilities: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` ORDER BY slug`, + deleteAbility: `DELETE FROM ` + t.Abilities + ` WHERE id = $1`, + + // migrations + createMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL + )`, + selectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations, + } +} + +// hasAnyRoleSQL renders the dynamic IN-clause for HasAnyRole. Generated +// query is parameterized: $1 = user_id, $2..$N+1 = role slugs. +func (q queries) hasAnyRoleSQL(n int) string { + if n <= 0 { + return "" + } + var b strings.Builder + b.Grow(len(q.hasAnyRolePrefix) + 8*n + 4) + b.WriteString(q.hasAnyRolePrefix) + for i := 0; i < n; i++ { + if i > 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, "$%d", i+2) + } + b.WriteString("))") + return b.String() +} diff --git a/store_roles.go b/store_roles.go new file mode 100644 index 0000000..5f9386d --- /dev/null +++ b/store_roles.go @@ -0,0 +1,151 @@ +package authkit + +import ( + "context" + "database/sql" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreateRole(ctx context.Context, r *Role) error { + const op = "authkit.storeCreateRole" + if r.ID == uuid.Nil { + r.ID = uuid.New() + } + if r.CreatedAt.IsZero() { + r.CreatedAt = a.now() + } + if _, err := a.db.ExecContext(ctx, a.q.createRole, + uuidArg(r.ID), r.Slug, nullableLabel(r.Label), r.CreatedAt); err != nil { + if isUniqueViolation(err) { + return errx.Wrap(op, ErrSlugTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error) { + const op = "authkit.storeGetRoleByID" + r, err := scanRole(a.db.QueryRowContext(ctx, a.q.getRoleByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrRoleNotFound)) + } + return r, nil +} + +func (a *Auth) storeGetRoleBySlug(ctx context.Context, slug string) (*Role, error) { + const op = "authkit.storeGetRoleBySlug" + r, err := scanRole(a.db.QueryRowContext(ctx, a.q.getRoleBySlug, slug)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrRoleNotFound)) + } + return r, nil +} + +func (a *Auth) storeListRoles(ctx context.Context) ([]*Role, error) { + const op = "authkit.storeListRoles" + rows, err := a.db.QueryContext(ctx, a.q.listRoles) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (a *Auth) storeDeleteRole(ctx context.Context, id uuid.UUID) error { + const op = "authkit.storeDeleteRole" + tag, err := a.db.ExecContext(ctx, a.q.deleteRole, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrRoleNotFound) + } + return nil +} + +func (a *Auth) storeAssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.storeAssignRoleToUser" + if _, err := a.db.ExecContext(ctx, a.q.assignRoleToUser, + uuidArg(userID), uuidArg(roleID), a.now()); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeRemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.storeRemoveRoleFromUser" + if _, err := a.db.ExecContext(ctx, a.q.removeRoleFromUser, + uuidArg(userID), uuidArg(roleID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error) { + const op = "authkit.storeGetUserRoles" + rows, err := a.db.QueryContext(ctx, a.q.getUserRoles, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +// storeHasAnyRole builds the IN-clause at call time because the placeholder +// count depends on len(slugs). Identifier substitution comes from the +// validated Schema; values are bound, never interpolated. +func (a *Auth) storeHasAnyRole(ctx context.Context, userID uuid.UUID, slugs []string) (bool, error) { + const op = "authkit.storeHasAnyRole" + if len(slugs) == 0 { + return false, nil + } + q := a.q.hasAnyRoleSQL(len(slugs)) + args := make([]any, 0, 1+len(slugs)) + args = append(args, uuidArg(userID)) + for _, s := range slugs { + args = append(args, s) + } + var ok bool + if err := a.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +func scanRole(row rowScanner) (*Role, error) { + var ( + r Role + idStr string + label sql.NullString + ) + if err := row.Scan(&idStr, &r.Slug, &label, &r.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + r.ID = id + r.Label = scanNullString(label) + return &r, nil +} diff --git a/store_scan.go b/store_scan.go new file mode 100644 index 0000000..ee422b7 --- /dev/null +++ b/store_scan.go @@ -0,0 +1,101 @@ +package authkit + +import ( + "database/sql" + "net/netip" + "time" + + "github.com/google/uuid" +) + +// rowScanner is satisfied by both *sql.Row and *sql.Rows so scan helpers +// serve QueryRow and Query loops uniformly. +type rowScanner interface { + Scan(dest ...any) error +} + +func nullableTime(t *time.Time) any { + if t == nil { + return nil + } + return *t +} + +func nullableString(s string) any { + if s == "" { + return nil + } + return s +} + +func nullableLabel(s string) any { + // Labels are user-facing display strings — empty string means "no label", + // which we store NULL for clarity. + if s == "" { + return nil + } + return s +} + +func nullableAddrString(a netip.Addr) any { + if !a.IsValid() { + return nil + } + return a.String() +} + +func scanAddr(s *string) (netip.Addr, error) { + if s == nil || *s == "" { + return netip.Addr{}, nil + } + return netip.ParseAddr(*s) +} + +func uuidArg(id uuid.UUID) any { return id.String() } + +func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } + +func chainArg(c *string) any { + if c == nil { + return nil + } + return *c +} + +func nullableInt(n *int) any { + if n == nil { + return nil + } + return *n +} + +func scanNullStringPtr(ns sql.NullString) *string { + if !ns.Valid { + return nil + } + v := ns.String + return &v +} + +func scanNullString(ns sql.NullString) string { + if !ns.Valid { + return "" + } + return ns.String +} + +func scanNullTimePtr(nt sql.NullTime) *time.Time { + if !nt.Valid { + return nil + } + t := nt.Time + return &t +} + +func scanNullIntPtr(ni sql.NullInt32) *int { + if !ni.Valid { + return nil + } + v := int(ni.Int32) + return &v +} diff --git a/store_schema.go b/store_schema.go new file mode 100644 index 0000000..27b340d --- /dev/null +++ b/store_schema.go @@ -0,0 +1,140 @@ +package authkit + +import ( + "regexp" + + "git.juancwu.dev/juancwu/errx" +) + +// Schema lets consumers map authkit storage to their own table names. Column +// overrides are not exposed in v1 — the column set is fixed. +type Schema struct { + Tables Tables +} + +// Tables holds per-table identifier overrides. Every field must be a valid +// unquoted SQL identifier (matching identifierRE). Validation runs at +// New()/Migrate() time so SQL injection through Schema is impossible past +// that gate. +type Tables struct { + Users string + Sessions string + Tokens string + ServiceKeys string + ServiceKeyAbilities string + Roles string + Permissions string + Abilities string + UserRoles string + UserPermissions string + RolePermissions string + SchemaMigrations string +} + +// DefaultSchema returns the stock authkit_* names matching the embedded +// migration files. +func DefaultSchema() Schema { + return Schema{Tables: defaultTables()} +} + +func defaultTables() Tables { + return Tables{ + Users: "authkit_users", + Sessions: "authkit_sessions", + Tokens: "authkit_tokens", + ServiceKeys: "authkit_service_keys", + ServiceKeyAbilities: "authkit_service_key_abilities", + Roles: "authkit_roles", + Permissions: "authkit_permissions", + Abilities: "authkit_abilities", + UserRoles: "authkit_user_roles", + UserPermissions: "authkit_user_permissions", + RolePermissions: "authkit_role_permissions", + SchemaMigrations: "authkit_schema_migrations", + } +} + +// identifierRE matches the safe ASCII identifier subset shared by Postgres +// when not quoted. Anything outside this set is rejected at validation time. +var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// Validate ensures every Schema.Tables field is a non-empty, safe identifier. +func (s Schema) Validate() error { + const op = "authkit.Schema.Validate" + checks := []struct { + field, value string + }{ + {"Users", s.Tables.Users}, + {"Sessions", s.Tables.Sessions}, + {"Tokens", s.Tables.Tokens}, + {"ServiceKeys", s.Tables.ServiceKeys}, + {"ServiceKeyAbilities", s.Tables.ServiceKeyAbilities}, + {"Roles", s.Tables.Roles}, + {"Permissions", s.Tables.Permissions}, + {"Abilities", s.Tables.Abilities}, + {"UserRoles", s.Tables.UserRoles}, + {"UserPermissions", s.Tables.UserPermissions}, + {"RolePermissions", s.Tables.RolePermissions}, + {"SchemaMigrations", s.Tables.SchemaMigrations}, + } + for _, c := range checks { + if c.value == "" { + return errx.Newf(op, "Schema.Tables.%s is empty", c.field) + } + if !identifierRE.MatchString(c.value) { + return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value) + } + } + return nil +} + +// isDefault reports whether every Schema.Tables entry matches the default +// names. Embedded migrations hard-code the defaults, so they only run +// unmodified against the default schema. +func (s Schema) isDefault() bool { + return s.Tables == defaultTables() +} + +// mergeSchemaDefaults fills in any blank Tables field from the default set +// so callers can override one or two table names without having to copy +// the whole DefaultSchema structure. +func mergeSchemaDefaults(s Schema) Schema { + def := defaultTables() + if s.Tables.Users == "" { + s.Tables.Users = def.Users + } + if s.Tables.Sessions == "" { + s.Tables.Sessions = def.Sessions + } + if s.Tables.Tokens == "" { + s.Tables.Tokens = def.Tokens + } + if s.Tables.ServiceKeys == "" { + s.Tables.ServiceKeys = def.ServiceKeys + } + if s.Tables.ServiceKeyAbilities == "" { + s.Tables.ServiceKeyAbilities = def.ServiceKeyAbilities + } + if s.Tables.Roles == "" { + s.Tables.Roles = def.Roles + } + if s.Tables.Permissions == "" { + s.Tables.Permissions = def.Permissions + } + if s.Tables.Abilities == "" { + s.Tables.Abilities = def.Abilities + } + if s.Tables.UserRoles == "" { + s.Tables.UserRoles = def.UserRoles + } + if s.Tables.UserPermissions == "" { + s.Tables.UserPermissions = def.UserPermissions + } + if s.Tables.RolePermissions == "" { + s.Tables.RolePermissions = def.RolePermissions + } + if s.Tables.SchemaMigrations == "" { + s.Tables.SchemaMigrations = def.SchemaMigrations + } + return s +} diff --git a/store_service_keys.go b/store_service_keys.go new file mode 100644 index 0000000..d6461ce --- /dev/null +++ b/store_service_keys.go @@ -0,0 +1,136 @@ +package authkit + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreateServiceKey(ctx context.Context, k *ServiceKey, abilityIDs []uuid.UUID) error { + const op = "authkit.storeCreateServiceKey" + if k.CreatedAt.IsZero() { + k.CreatedAt = a.now() + } + tx, err := a.db.BeginTx(ctx, nil) + if err != nil { + return errx.Wrap(op, err) + } + defer func() { _ = tx.Rollback() }() + + if _, err := tx.ExecContext(ctx, a.q.createServiceKey, + k.IDHash, k.Name, nullableTime(k.LastUsedAt), k.CreatedAt, + nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt)); err != nil { + return errx.Wrap(op, err) + } + now := a.now() + for _, id := range abilityIDs { + if _, err := tx.ExecContext(ctx, a.q.insertServiceKeyAbil, + k.IDHash, uuidArg(id), now); err != nil { + return errx.Wrap(op, err) + } + } + if err := tx.Commit(); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetServiceKey(ctx context.Context, idHash []byte) (*ServiceKey, error) { + const op = "authkit.storeGetServiceKey" + k, err := scanServiceKey(a.db.QueryRowContext(ctx, a.q.getServiceKey, idHash)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrServiceKeyInvalid)) + } + abilities, err := a.storeServiceKeyAbilities(ctx, idHash) + if err != nil { + return nil, errx.Wrap(op, err) + } + k.Abilities = abilities + return k, nil +} + +func (a *Auth) storeListServiceKeys(ctx context.Context) ([]*ServiceKey, error) { + const op = "authkit.storeListServiceKeys" + rows, err := a.db.QueryContext(ctx, a.q.listServiceKeys) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*ServiceKey + for rows.Next() { + k, err := scanServiceKey(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, k) + } + if err := rows.Err(); err != nil { + return nil, errx.Wrap(op, err) + } + for _, k := range out { + abilities, err := a.storeServiceKeyAbilities(ctx, k.IDHash) + if err != nil { + return nil, errx.Wrap(op, err) + } + k.Abilities = abilities + } + return out, nil +} + +func (a *Auth) storeServiceKeyAbilities(ctx context.Context, idHash []byte) ([]string, error) { + rows, err := a.db.QueryContext(ctx, a.q.getServiceKeyAbilities, idHash) + if err != nil { + return nil, err + } + defer rows.Close() + out := []string{} + for rows.Next() { + var slug string + if err := rows.Scan(&slug); err != nil { + return nil, err + } + out = append(out, slug) + } + return out, rows.Err() +} + +func (a *Auth) storeTouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.storeTouchServiceKey" + if _, err := a.db.ExecContext(ctx, a.q.touchServiceKey, at, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeRevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.storeRevokeServiceKey" + tag, err := a.db.ExecContext(ctx, a.q.revokeServiceKey, at, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrServiceKeyInvalid) + } + return nil +} + +func scanServiceKey(row rowScanner) (*ServiceKey, error) { + var ( + k ServiceKey + lastUsed sql.NullTime + expires sql.NullTime + revoked sql.NullTime + ) + if err := row.Scan(&k.IDHash, &k.Name, &lastUsed, &k.CreatedAt, + &expires, &revoked); err != nil { + return nil, err + } + k.LastUsedAt = scanNullTimePtr(lastUsed) + k.ExpiresAt = scanNullTimePtr(expires) + k.RevokedAt = scanNullTimePtr(revoked) + return &k, nil +} diff --git a/store_sessions.go b/store_sessions.go new file mode 100644 index 0000000..c34b1db --- /dev/null +++ b/store_sessions.go @@ -0,0 +1,95 @@ +package authkit + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreateSession(ctx context.Context, s *Session) error { + const op = "authkit.storeCreateSession" + now := a.now() + if s.CreatedAt.IsZero() { + s.CreatedAt = now + } + if s.LastSeenAt.IsZero() { + s.LastSeenAt = s.CreatedAt + } + _, err := a.db.ExecContext(ctx, a.q.createSession, + s.IDHash, uuidArg(s.UserID), s.UserAgent, nullableAddrString(s.IP), + s.CreatedAt, s.LastSeenAt, s.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetSession(ctx context.Context, idHash []byte) (*Session, error) { + const op = "authkit.storeGetSession" + var ( + s Session + uidStr string + ipStr sql.NullString + ) + err := a.db.QueryRowContext(ctx, a.q.getSession, idHash).Scan( + &s.IDHash, &uidStr, &s.UserAgent, &ipStr, + &s.CreatedAt, &s.LastSeenAt, &s.ExpiresAt) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrSessionInvalid)) + } + uid, err := scanUUID(uidStr) + if err != nil { + return nil, errx.Wrap(op, err) + } + s.UserID = uid + if ipStr.Valid { + addr, err := scanAddr(&ipStr.String) + if err != nil { + return nil, errx.Wrap(op, err) + } + s.IP = addr + } + return &s, nil +} + +func (a *Auth) storeTouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error { + const op = "authkit.storeTouchSession" + tag, err := a.db.ExecContext(ctx, a.q.touchSession, lastSeenAt, newExpiresAt, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrSessionInvalid) + } + return nil +} + +func (a *Auth) storeDeleteSession(ctx context.Context, idHash []byte) error { + const op = "authkit.storeDeleteSession" + if _, err := a.db.ExecContext(ctx, a.q.deleteSession, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeDeleteUserSessions(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.storeDeleteUserSessions" + if _, err := a.db.ExecContext(ctx, a.q.deleteUserSessions, uuidArg(userID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeDeleteExpiredSessions(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.storeDeleteExpiredSessions" + tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredSessions, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} diff --git a/store_tokens.go b/store_tokens.go new file mode 100644 index 0000000..f9273d7 --- /dev/null +++ b/store_tokens.go @@ -0,0 +1,134 @@ +package authkit + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error { + const op = "authkit.storeCreateToken" + if t.CreatedAt.IsZero() { + t.CreatedAt = a.now() + } + _, err := a.db.ExecContext(ctx, a.q.createToken, + t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID), + nullableTime(t.ConsumedAt), nullableInt(t.AttemptsRemaining), + t.CreatedAt, t.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// storeConsumeToken atomically marks the matching unexpired, unconsumed token +// as consumed and returns it. Returns ErrTokenInvalid if no row matched. +// Implementations MUST do this in one statement to prevent double-spend +// under concurrent callers. +func (a *Auth) storeConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) { + const op = "authkit.storeConsumeToken" + row := a.db.QueryRowContext(ctx, a.q.consumeToken, now, string(kind), hash, now) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid)) + } + return t, nil +} + +func (a *Auth) storeGetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) { + const op = "authkit.storeGetToken" + row := a.db.QueryRowContext(ctx, a.q.getToken, string(kind), hash) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid)) + } + return t, nil +} + +// storeGetActiveOTPForUser returns the most recent unconsumed, unexpired OTP +// row for a user. Used by ConsumeEmailOTP to verify a code by hash-comparing +// client input. +func (a *Auth) storeGetActiveOTPForUser(ctx context.Context, kind TokenKind, userID any, now time.Time) (*Token, error) { + const op = "authkit.storeGetActiveOTPForUser" + row := a.db.QueryRowContext(ctx, a.q.getOTPForUser, string(kind), userID, now) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid)) + } + return t, nil +} + +// storeDecrementOTPAttempt drops attempts_remaining by 1 on the matched +// (kind, hash) row, consuming it when zero. Returns the new +// attempts_remaining (0 = consumed). ErrTokenInvalid when no row matched. +func (a *Auth) storeDecrementOTPAttempt(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (int, error) { + const op = "authkit.storeDecrementOTPAttempt" + var remaining sql.NullInt32 + if err := a.db.QueryRowContext(ctx, a.q.decrementOTPAttempt, + now, string(kind), hash).Scan(&remaining); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid)) + } + if !remaining.Valid { + return 0, nil + } + return int(remaining.Int32), nil +} + +// storeConsumeOTPByHash marks an OTP row consumed by direct hash match. Used +// on the success path of ConsumeEmailOTP. +func (a *Auth) storeConsumeOTPByHash(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) { + const op = "authkit.storeConsumeOTPByHash" + row := a.db.QueryRowContext(ctx, a.q.consumeOTPByID, now, string(kind), hash) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid)) + } + return t, nil +} + +func (a *Auth) storeDeleteByChain(ctx context.Context, chainID string) (int64, error) { + const op = "authkit.storeDeleteByChain" + tag, err := a.db.ExecContext(ctx, a.q.deleteByChain, chainID) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func (a *Auth) storeDeleteExpiredTokens(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.storeDeleteExpiredTokens" + tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredTokens, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func scanToken(row rowScanner) (*Token, error) { + var ( + t Token + kind string + userIDStr string + chainID sql.NullString + consumedAt sql.NullTime + attempts sql.NullInt32 + ) + if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, + &consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil { + return nil, err + } + t.Kind = TokenKind(kind) + uid, err := scanUUID(userIDStr) + if err != nil { + return nil, err + } + t.UserID = uid + t.ChainID = scanNullStringPtr(chainID) + t.ConsumedAt = scanNullTimePtr(consumedAt) + t.AttemptsRemaining = scanNullIntPtr(attempts) + return &t, nil +} diff --git a/store_users.go b/store_users.go new file mode 100644 index 0000000..63fa04a --- /dev/null +++ b/store_users.go @@ -0,0 +1,158 @@ +package authkit + +import ( + "context" + "database/sql" + "errors" + "strings" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +func (a *Auth) storeCreateUser(ctx context.Context, u *User) error { + const op = "authkit.storeCreateUser" + if u.ID == uuid.Nil { + u.ID = uuid.New() + } + if u.EmailNormalized == "" { + u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email)) + } + now := a.now() + if u.CreatedAt.IsZero() { + u.CreatedAt = now + } + if u.UpdatedAt.IsZero() { + u.UpdatedAt = now + } + _, err := a.db.ExecContext(ctx, a.q.createUser, + uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, + nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt) + if err != nil { + if isUniqueViolation(err) { + return errx.Wrap(op, ErrEmailTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (a *Auth) storeGetUserByID(ctx context.Context, id uuid.UUID) (*User, error) { + const op = "authkit.storeGetUserByID" + u, err := scanUser(a.db.QueryRowContext(ctx, a.q.getUserByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrUserNotFound)) + } + return u, nil +} + +func (a *Auth) storeGetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error) { + const op = "authkit.storeGetUserByEmail" + u, err := scanUser(a.db.QueryRowContext(ctx, a.q.getUserByEmail, normalizedEmail)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, ErrUserNotFound)) + } + return u, nil +} + +func (a *Auth) storeUpdateUser(ctx context.Context, u *User) error { + const op = "authkit.storeUpdateUser" + u.UpdatedAt = a.now() + tag, err := a.db.ExecContext(ctx, a.q.updateUser, + u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, + nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrUserNotFound) + } + return nil +} + +func (a *Auth) storeDeleteUser(ctx context.Context, id uuid.UUID) error { + const op = "authkit.storeDeleteUser" + tag, err := a.db.ExecContext(ctx, a.q.deleteUser, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrUserNotFound) + } + return nil +} + +func (a *Auth) storeSetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error { + const op = "authkit.storeSetPassword" + tag, err := a.db.ExecContext(ctx, a.q.setPassword, + nullableString(encodedHash), a.now(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrUserNotFound) + } + return nil +} + +func (a *Auth) storeSetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error { + const op = "authkit.storeSetEmailVerified" + tag, err := a.db.ExecContext(ctx, a.q.setEmailVerified, at, a.now(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, ErrUserNotFound) + } + return nil +} + +func (a *Auth) storeBumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) { + const op = "authkit.storeBumpSessionVersion" + var v int + if err := a.db.QueryRowContext(ctx, a.q.bumpSessionVersion, + a.now(), uuidArg(userID)).Scan(&v); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, ErrUserNotFound)) + } + return v, nil +} + +func scanUser(row rowScanner) (*User, error) { + var ( + u User + idStr string + passwordHash sql.NullString + emailVerified sql.NullTime + lastLogin sql.NullTime + ) + if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified, + &passwordHash, &u.SessionVersion, &lastLogin, + &u.CreatedAt, &u.UpdatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + u.ID = id + u.PasswordHash = scanNullString(passwordHash) + u.EmailVerifiedAt = scanNullTimePtr(emailVerified) + u.LastLoginAt = scanNullTimePtr(lastLogin) + return &u, nil +} + +// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so +// callers get reliable errors.Is targets through errx wrapping. +func mapNotFound(err error, sentinel error) error { + if errors.Is(err, sql.ErrNoRows) { + return sentinel + } + return err +} diff --git a/store_verify.go b/store_verify.go new file mode 100644 index 0000000..3310bde --- /dev/null +++ b/store_verify.go @@ -0,0 +1,278 @@ +package authkit + +import ( + "context" + "database/sql" + "fmt" + "sort" + "strings" + + "git.juancwu.dev/juancwu/errx" +) + +// columnSpec describes one expected column. dataType is matched against +// information_schema.columns.data_type (e.g. "uuid", "text", "bytea", +// "timestamp with time zone", "integer"). nullable is matched against +// is_nullable ("YES"/"NO"). Extra columns on the live table are allowed. +type columnSpec struct { + name string + dataType string + nullable bool +} + +// tableSpec is the expected layout for one logical table. +type tableSpec struct { + logicalKey string // matches a Tables field name; used for error messages + configured string + defaultNm string + columns []columnSpec +} + +// expectedSchema returns the full per-table column specification matching +// migrations/0001_init.sql. The verifier walks each tableSpec, looking up +// information_schema for the configured name first and falling back to the +// default name when the configured name has no rows. +func expectedSchema(s Schema) []tableSpec { + def := defaultTables() + t := s.Tables + return []tableSpec{ + { + logicalKey: "Users", + configured: t.Users, + defaultNm: def.Users, + columns: []columnSpec{ + {"id", "uuid", false}, + {"email", "text", false}, + {"email_normalized", "text", false}, + {"email_verified_at", "timestamp with time zone", true}, + {"password_hash", "text", true}, + {"session_version", "integer", false}, + {"last_login_at", "timestamp with time zone", true}, + {"created_at", "timestamp with time zone", false}, + {"updated_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "Sessions", + configured: t.Sessions, + defaultNm: def.Sessions, + columns: []columnSpec{ + {"id_hash", "bytea", false}, + {"user_id", "uuid", false}, + {"user_agent", "text", false}, + {"ip", "text", true}, + {"created_at", "timestamp with time zone", false}, + {"last_seen_at", "timestamp with time zone", false}, + {"expires_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "Tokens", + configured: t.Tokens, + defaultNm: def.Tokens, + columns: []columnSpec{ + {"hash", "bytea", false}, + {"kind", "text", false}, + {"user_id", "uuid", false}, + {"chain_id", "text", true}, + {"consumed_at", "timestamp with time zone", true}, + {"attempts_remaining", "integer", true}, + {"created_at", "timestamp with time zone", false}, + {"expires_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "ServiceKeys", + configured: t.ServiceKeys, + defaultNm: def.ServiceKeys, + columns: []columnSpec{ + {"id_hash", "bytea", false}, + {"name", "text", false}, + {"last_used_at", "timestamp with time zone", true}, + {"created_at", "timestamp with time zone", false}, + {"expires_at", "timestamp with time zone", true}, + {"revoked_at", "timestamp with time zone", true}, + }, + }, + { + logicalKey: "ServiceKeyAbilities", + configured: t.ServiceKeyAbilities, + defaultNm: def.ServiceKeyAbilities, + columns: []columnSpec{ + {"service_key_id_hash", "bytea", false}, + {"ability_id", "uuid", false}, + {"granted_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "Roles", + configured: t.Roles, + defaultNm: def.Roles, + columns: []columnSpec{ + {"id", "uuid", false}, + {"slug", "text", false}, + {"label", "text", true}, + {"created_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "Permissions", + configured: t.Permissions, + defaultNm: def.Permissions, + columns: []columnSpec{ + {"id", "uuid", false}, + {"slug", "text", false}, + {"label", "text", true}, + {"created_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "Abilities", + configured: t.Abilities, + defaultNm: def.Abilities, + columns: []columnSpec{ + {"id", "uuid", false}, + {"slug", "text", false}, + {"label", "text", true}, + {"created_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "UserRoles", + configured: t.UserRoles, + defaultNm: def.UserRoles, + columns: []columnSpec{ + {"user_id", "uuid", false}, + {"role_id", "uuid", false}, + {"granted_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "UserPermissions", + configured: t.UserPermissions, + defaultNm: def.UserPermissions, + columns: []columnSpec{ + {"user_id", "uuid", false}, + {"permission_id", "uuid", false}, + {"granted_at", "timestamp with time zone", false}, + }, + }, + { + logicalKey: "RolePermissions", + configured: t.RolePermissions, + defaultNm: def.RolePermissions, + columns: []columnSpec{ + {"role_id", "uuid", false}, + {"permission_id", "uuid", false}, + }, + }, + } +} + +// VerifySchema introspects the live database against the expected layout for +// the given schema. Returns a wrapped ErrSchemaDrift describing every +// missing/mismatched table or column. Extra columns on a table are allowed. +// +// For tables with non-default names, VerifySchema looks up the configured +// name first; if no rows are found, it falls back to the default name. This +// handles the case where a consumer migrated under custom names but later +// removed the overrides — drift is detected against whichever set of names +// actually exists. +func VerifySchema(ctx context.Context, db *sql.DB, schema Schema) error { + const op = "authkit.VerifySchema" + if db == nil { + return errx.New(op, "db is required") + } + if err := schema.Validate(); err != nil { + return errx.Wrap(op, err) + } + + specs := expectedSchema(schema) + var problems []string + + for _, spec := range specs { + live, foundUnder, err := loadTableColumns(ctx, db, spec.configured, spec.defaultNm) + if err != nil { + return errx.Wrap(op, err) + } + if foundUnder == "" { + problems = append(problems, fmt.Sprintf( + "table %q (%s): not found (also tried %q)", + spec.configured, spec.logicalKey, spec.defaultNm)) + continue + } + for _, want := range spec.columns { + got, ok := live[want.name] + if !ok { + problems = append(problems, fmt.Sprintf( + "table %q column %q: missing", foundUnder, want.name)) + continue + } + if got.dataType != want.dataType { + problems = append(problems, fmt.Sprintf( + "table %q column %q: data_type=%q, want %q", + foundUnder, want.name, got.dataType, want.dataType)) + } + if got.nullable != want.nullable { + problems = append(problems, fmt.Sprintf( + "table %q column %q: nullable=%v, want %v", + foundUnder, want.name, got.nullable, want.nullable)) + } + } + } + + if len(problems) > 0 { + sort.Strings(problems) + return errx.Wrapf(op, ErrSchemaDrift, + "%d issue(s):\n - %s", len(problems), strings.Join(problems, "\n - ")) + } + return nil +} + +// loadTableColumns queries information_schema for a table's columns. If the +// configured name has no rows AND defaultName differs, it falls back to the +// default. Returns the columns map and the name actually used (empty string +// when neither exists). +func loadTableColumns(ctx context.Context, db *sql.DB, configured, defaultName string) (map[string]columnSpec, string, error) { + live, err := queryColumns(ctx, db, configured) + if err != nil { + return nil, "", err + } + if len(live) > 0 { + return live, configured, nil + } + if defaultName != "" && defaultName != configured { + live, err = queryColumns(ctx, db, defaultName) + if err != nil { + return nil, "", err + } + if len(live) > 0 { + return live, defaultName, nil + } + } + return nil, "", nil +} + +func queryColumns(ctx context.Context, db *sql.DB, table string) (map[string]columnSpec, error) { + const q = `SELECT column_name, data_type, is_nullable + FROM information_schema.columns + WHERE table_schema = current_schema() AND table_name = $1` + rows, err := db.QueryContext(ctx, q, table) + if err != nil { + return nil, err + } + defer rows.Close() + out := make(map[string]columnSpec) + for rows.Next() { + var name, dataType, isNullable string + if err := rows.Scan(&name, &dataType, &isNullable); err != nil { + return nil, err + } + out[name] = columnSpec{ + name: name, + dataType: dataType, + nullable: isNullable == "YES", + } + } + return out, rows.Err() +} diff --git a/store_verify_test.go b/store_verify_test.go new file mode 100644 index 0000000..5ac96d3 --- /dev/null +++ b/store_verify_test.go @@ -0,0 +1,87 @@ +package authkit + +import ( + "context" + "database/sql" + "errors" + "testing" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +func TestIntegration_VerifySchemaPasses(t *testing.T) { + a := freshAuth(t) + if err := VerifySchema(context.Background(), a.DB(), DefaultSchema()); err != nil { + t.Fatalf("VerifySchema after Migrate should pass: %v", err) + } +} + +func TestIntegration_VerifyAllowsExtraColumns(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + // Add a column the verifier doesn't expect — should still pass. + if _, err := a.DB().ExecContext(ctx, + "ALTER TABLE authkit_users ADD COLUMN consumer_extra TEXT"); err != nil { + t.Fatalf("ALTER ADD: %v", err) + } + if err := VerifySchema(ctx, a.DB(), DefaultSchema()); err != nil { + t.Fatalf("VerifySchema should tolerate extra columns: %v", err) + } +} + +func TestIntegration_VerifyDetectsMissingColumn(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.DB().ExecContext(ctx, + "ALTER TABLE authkit_users DROP COLUMN last_login_at"); err != nil { + t.Fatalf("ALTER DROP: %v", err) + } + err := VerifySchema(ctx, a.DB(), DefaultSchema()) + if !errors.Is(err, ErrSchemaDrift) { + t.Fatalf("expected ErrSchemaDrift on missing column, got %v", err) + } +} + +func TestIntegration_VerifyDetectsMissingTable(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + if _, err := a.DB().ExecContext(ctx, + "DROP TABLE authkit_user_permissions"); err != nil { + t.Fatalf("DROP: %v", err) + } + err := VerifySchema(ctx, a.DB(), DefaultSchema()) + if !errors.Is(err, ErrSchemaDrift) { + t.Fatalf("expected ErrSchemaDrift on missing table, got %v", err) + } +} + +func TestIntegration_VerifyFallbackToDefaultName(t *testing.T) { + // Migrate created tables under default names. Construct a custom schema + // pointing at non-existent table names — verifier should fall back to + // the defaults and pass. + a := freshAuth(t) + ctx := context.Background() + + custom := DefaultSchema() + custom.Tables.Users = "renamed_users_does_not_exist" + if err := VerifySchema(ctx, a.DB(), custom); err != nil { + t.Fatalf("VerifySchema should fall back to default name when configured table is missing: %v", err) + } +} + +func TestIntegration_MigrateIdempotent(t *testing.T) { + url := dbURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + t.Cleanup(func() { dropAllAuthkitTables(t, db, DefaultSchema()) }) + + dropAllAuthkitTables(t, db, DefaultSchema()) + for i := 0; i < 3; i++ { + if err := Migrate(context.Background(), db, DefaultSchema()); err != nil { + t.Fatalf("Migrate iter %d: %v", i, err) + } + } +} diff --git a/stores.go b/stores.go deleted file mode 100644 index a3f1737..0000000 --- a/stores.go +++ /dev/null @@ -1,82 +0,0 @@ -package authkit - -import ( - "context" - "time" - - "github.com/google/uuid" -) - -type UserStore interface { - CreateUser(ctx context.Context, u *User) error - GetUserByID(ctx context.Context, id uuid.UUID) (*User, error) - GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error) - UpdateUser(ctx context.Context, u *User) error - DeleteUser(ctx context.Context, id uuid.UUID) error - SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error - SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error - BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) - IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) - ResetFailedLogins(ctx context.Context, userID uuid.UUID) error -} - -type SessionStore interface { - CreateSession(ctx context.Context, s *Session) error - GetSession(ctx context.Context, idHash []byte) (*Session, error) - TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error - DeleteSession(ctx context.Context, idHash []byte) error - DeleteUserSessions(ctx context.Context, userID uuid.UUID) error - DeleteExpired(ctx context.Context, now time.Time) (int64, error) -} - -type TokenStore interface { - CreateToken(ctx context.Context, t *Token) error - // ConsumeToken atomically marks the matching unexpired, unconsumed token - // as consumed and returns it. Returns ErrTokenInvalid if no row matched. - // Implementations MUST do this in one statement (UPDATE ... RETURNING) - // to prevent double-spend under concurrent callers. - ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) - // GetToken returns a token without consuming it. Used for refresh-token - // reuse detection: a token that exists with consumed_at != nil is a - // replay signal. - GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) - DeleteByChain(ctx context.Context, chainID string) (int64, error) - DeleteExpired(ctx context.Context, now time.Time) (int64, error) -} - -type ServiceKeyStore interface { - CreateServiceKey(ctx context.Context, k *ServiceKey) error - GetServiceKey(ctx context.Context, idHash []byte) (*ServiceKey, error) - ListServiceKeysByOwner(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*ServiceKey, error) - TouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error - RevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error -} - -type RoleStore interface { - CreateRole(ctx context.Context, r *Role) error - GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error) - GetRoleByName(ctx context.Context, name string) (*Role, error) - ListRoles(ctx context.Context) ([]*Role, error) - DeleteRole(ctx context.Context, id uuid.UUID) error - AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error - RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error - GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error) - HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) -} - -type PermissionStore interface { - CreatePermission(ctx context.Context, p *Permission) error - GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error) - GetPermissionByName(ctx context.Context, name string) (*Permission, error) - ListPermissions(ctx context.Context) ([]*Permission, error) - DeletePermission(ctx context.Context, id uuid.UUID) error - AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error - RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error - GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error) - GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error) -} - -type Hasher interface { - Hash(password string) (string, error) - Verify(password, encoded string) (ok bool, needsRehash bool, err error) -} diff --git a/testdb_test.go b/testdb_test.go new file mode 100644 index 0000000..85d0140 --- /dev/null +++ b/testdb_test.go @@ -0,0 +1,89 @@ +package authkit + +// Integration test infrastructure. Skipped when AUTHKIT_TEST_DATABASE_URL is +// unset so the unit-test suite remains usable without a database. + +import ( + "context" + "database/sql" + "fmt" + "net/netip" + "os" + "testing" + "time" + + "git.juancwu.dev/juancwu/authkit/hasher" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +// noIP returns the zero-value netip.Addr — used by tests that don't care +// about the originating IP. +func noIP() netip.Addr { return netip.Addr{} } + +func dbURL(t *testing.T) string { + t.Helper() + url := os.Getenv("AUTHKIT_TEST_DATABASE_URL") + if url == "" { + t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test") + } + return url +} + +// freshAuth returns a fully-initialized *Auth bound to a clean database. +// All authkit_* tables are dropped before Migrate runs, so each test sees +// an empty schema. +func freshAuth(t *testing.T) *Auth { + t.Helper() + url := dbURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("ping: %v", err) + } + + dropAllAuthkitTables(t, db, DefaultSchema()) + t.Cleanup(func() { dropAllAuthkitTables(t, db, DefaultSchema()) }) + + a, err := New(context.Background(), Deps{ + DB: db, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), + }, Config{ + JWTSecret: []byte("integration-secret-thirty-two!!!"), + JWTIssuer: "authkit-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + EmailOTPTTL: time.Minute, + EmailOTPMaxAttempts: 3, + }) + if err != nil { + t.Fatalf("authkit.New: %v", err) + } + return a +} + +func dropAllAuthkitTables(t *testing.T, db *sql.DB, s Schema) { + t.Helper() + tables := []string{ + s.Tables.ServiceKeyAbilities, s.Tables.UserPermissions, + s.Tables.UserRoles, s.Tables.RolePermissions, + s.Tables.ServiceKeys, s.Tables.Abilities, + s.Tables.Roles, s.Tables.Permissions, + s.Tables.Tokens, s.Tables.Sessions, s.Tables.Users, + s.Tables.SchemaMigrations, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, name := range tables { + _, _ = db.ExecContext(ctx, + fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name)) + } +} diff --git a/tokens.go b/tokens.go index f5bee47..7b8e5ea 100644 --- a/tokens.go +++ b/tokens.go @@ -3,7 +3,6 @@ package authkit import ( "crypto/rand" "crypto/sha256" - "crypto/subtle" "encoding/base64" "io" "strings" @@ -26,9 +25,7 @@ const ( // MintOpaqueSecret generates a fresh opaque secret with the given prefix. // Returns the plaintext (show once, never persist) and the SHA-256 lookup -// hash. A nil rng falls back to crypto/rand.Reader. Exposed so consumers -// building bespoke storage can produce secrets in the same shape authkit -// uses internally. +// hash. A nil rng falls back to crypto/rand.Reader. func MintOpaqueSecret(rng io.Reader, prefix string) (plaintext string, hash []byte, err error) { const op = "authkit.MintOpaqueSecret" if rng == nil { @@ -53,7 +50,7 @@ func HashOpaqueSecret(plaintext string) []byte { } // ParseOpaqueSecret validates that a plaintext begins with the expected -// prefix and returns the lookup hash. Returns ok=false on prefix mismatch. +// prefix and returns the lookup hash. func ParseOpaqueSecret(prefix, plaintext string) (hash []byte, ok bool) { want := prefix + "_" if !strings.HasPrefix(plaintext, want) { @@ -61,22 +58,3 @@ func ParseOpaqueSecret(prefix, plaintext string) (hash []byte, ok bool) { } return HashOpaqueSecret(plaintext), true } - -// mintSecret is the internal entry point; existing callers pass prefix -// first to match call-site readability ("mint a token of kind X"). -func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) { - return MintOpaqueSecret(rng, prefix) -} - -func hashSecret(plaintext string) []byte { - return HashOpaqueSecret(plaintext) -} - -func parseSecret(prefix, plaintext string) (hash []byte, ok bool) { - return ParseOpaqueSecret(prefix, plaintext) -} - -// constantTimeEqual is a thin wrapper for readability at call sites. -func constantTimeEqual(a, b []byte) bool { - return subtle.ConstantTimeCompare(a, b) == 1 -} diff --git a/tokens_test.go b/tokens_test.go index 28395ad..9de4776 100644 --- a/tokens_test.go +++ b/tokens_test.go @@ -7,46 +7,46 @@ import ( "testing" ) -func TestMintSecretRoundtrip(t *testing.T) { - plaintext, hash, err := mintSecret(prefixSession, nil) +func TestMintOpaqueSecretRoundtrip(t *testing.T) { + plaintext, hash, err := MintOpaqueSecret(nil, prefixSession) if err != nil { - t.Fatalf("mintSecret: %v", err) + t.Fatalf("MintOpaqueSecret: %v", err) } if !strings.HasPrefix(plaintext, prefixSession+"_") { t.Fatalf("missing prefix: %q", plaintext) } - parsed, ok := parseSecret(prefixSession, plaintext) + parsed, ok := ParseOpaqueSecret(prefixSession, plaintext) if !ok { - t.Fatalf("parseSecret rejected our own mint") + t.Fatalf("ParseOpaqueSecret rejected our own mint") } if !bytes.Equal(hash, parsed) { t.Fatalf("hash mismatch") } want := sha256.Sum256([]byte(plaintext)) if !bytes.Equal(hash, want[:]) { - t.Fatalf("hashSecret != sha256(plaintext)") + t.Fatalf("HashOpaqueSecret != sha256(plaintext)") } } -func TestParseSecretWrongPrefix(t *testing.T) { - plaintext, _, err := mintSecret(prefixSession, nil) +func TestParseOpaqueSecretWrongPrefix(t *testing.T) { + plaintext, _, err := MintOpaqueSecret(nil, prefixSession) if err != nil { - t.Fatalf("mintSecret: %v", err) + t.Fatalf("MintOpaqueSecret: %v", err) } - if _, ok := parseSecret(prefixServiceKey, plaintext); ok { - t.Fatalf("parseSecret should reject mismatched prefix") + if _, ok := ParseOpaqueSecret(prefixServiceKey, plaintext); ok { + t.Fatalf("ParseOpaqueSecret should reject mismatched prefix") } - if _, ok := parseSecret(prefixSession, "sessXXXX"); ok { - t.Fatalf("parseSecret should require trailing underscore") + if _, ok := ParseOpaqueSecret(prefixSession, "sessXXXX"); ok { + t.Fatalf("ParseOpaqueSecret should require trailing underscore") } } -func TestMintSecretUniqueness(t *testing.T) { +func TestMintOpaqueSecretUniqueness(t *testing.T) { seen := make(map[string]struct{}, 100) for i := 0; i < 100; i++ { - p, _, err := mintSecret(prefixServiceKey, nil) + p, _, err := MintOpaqueSecret(nil, prefixServiceKey) if err != nil { - t.Fatalf("mintSecret: %v", err) + t.Fatalf("MintOpaqueSecret: %v", err) } if _, dup := seen[p]; dup { t.Fatalf("duplicate mint: %s", p) diff --git a/userctx.go b/userctx.go new file mode 100644 index 0000000..942a30f --- /dev/null +++ b/userctx.go @@ -0,0 +1,105 @@ +package authkit + +import ( + "context" + "sync" + + "github.com/google/uuid" +) + +// userCtxKey is an unexported context key. The empty struct shape guarantees +// no collision with caller-defined keys. +type userCtxKey struct{} +type serviceKeyCtxKey struct{} + +// userBox holds the per-request lazy-loaded user. The box pointer is what's +// stored on the context, so RefreshUserInCtx can mutate the cache visible +// to every UserFromCtx call within the same request. +type userBox struct { + mu sync.Mutex + auth *Auth + userID uuid.UUID + cached *User +} + +func (b *userBox) get(ctx context.Context) (*User, error) { + b.mu.Lock() + if b.cached != nil { + u := b.cached + b.mu.Unlock() + return u, nil + } + b.mu.Unlock() + // Don't hold the lock across the DB call. + u, err := b.auth.storeGetUserByID(ctx, b.userID) + if err != nil { + return nil, err + } + b.mu.Lock() + if b.cached == nil { + b.cached = u + } + out := b.cached + b.mu.Unlock() + return out, nil +} + +func (b *userBox) refresh(ctx context.Context) (*User, error) { + b.mu.Lock() + b.cached = nil + b.mu.Unlock() + return b.get(ctx) +} + +// WithUserContext attaches a lazy user-context to ctx. Middleware uses this +// to record an authenticated user_id without paying for a DB read until a +// handler actually calls UserFromCtx. Custom middleware authors can use +// this directly to integrate hand-rolled auth flows. +func WithUserContext(ctx context.Context, a *Auth, userID uuid.UUID) context.Context { + return context.WithValue(ctx, userCtxKey{}, &userBox{auth: a, userID: userID}) +} + +// WithServiceKey attaches a *ServiceKey to ctx. Used by service-key middleware. +func WithServiceKey(ctx context.Context, k *ServiceKey) context.Context { + return context.WithValue(ctx, serviceKeyCtxKey{}, k) +} + +// UserIDFromCtx returns the authenticated user_id placed by middleware via +// WithUserContext. The boolean is false when no user-bound auth ran for +// this request (e.g. a service-key request). +func UserIDFromCtx(ctx context.Context) (uuid.UUID, bool) { + b, ok := ctx.Value(userCtxKey{}).(*userBox) + if !ok { + return uuid.Nil, false + } + return b.userID, true +} + +// UserFromCtx returns the authenticated *User, lazy-loading from the +// database on first call within this request and caching the result for +// subsequent calls. Returns ErrNoUserContext if no user-bound auth ran. +func UserFromCtx(ctx context.Context) (*User, error) { + b, ok := ctx.Value(userCtxKey{}).(*userBox) + if !ok { + return nil, ErrNoUserContext + } + return b.get(ctx) +} + +// RefreshUserInCtx invalidates the cached user and refetches. Use after an +// admin-side update that should be visible to the rest of the request. +func RefreshUserInCtx(ctx context.Context) (*User, error) { + b, ok := ctx.Value(userCtxKey{}).(*userBox) + if !ok { + return nil, ErrNoUserContext + } + return b.refresh(ctx) +} + +// ServiceKeyFromCtx returns the authenticated *ServiceKey placed by +// service-key middleware. The boolean is false when no service-key +// authentication ran for this request. +func ServiceKeyFromCtx(ctx context.Context) (*ServiceKey, bool) { + k, ok := ctx.Value(serviceKeyCtxKey{}).(*ServiceKey) + return k, ok +} diff --git a/userctx_test.go b/userctx_test.go new file mode 100644 index 0000000..798c101 --- /dev/null +++ b/userctx_test.go @@ -0,0 +1,64 @@ +package authkit + +import ( + "context" + "testing" +) + +func TestIntegration_UserCtxLazyAndRefresh(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "ctx@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + rctx := WithUserContext(ctx, a, u.ID) + + // UserIDFromCtx is non-loading. + id, ok := UserIDFromCtx(rctx) + if !ok || id != u.ID { + t.Fatalf("UserIDFromCtx mismatch") + } + + first, err := UserFromCtx(rctx) + if err != nil { + t.Fatalf("UserFromCtx (lazy load): %v", err) + } + + // Mutate the underlying user out-of-band so we can prove refresh sees + // the change. + if err := a.SetPassword(ctx, u.ID, "new-secret"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + + // Without RefreshUserInCtx, UserFromCtx returns the cached value (which + // has the empty password hash from the initial load). + second, err := UserFromCtx(rctx) + if err != nil { + t.Fatalf("UserFromCtx (cached): %v", err) + } + if second != first { + t.Fatalf("expected cached pointer identity, got distinct pointers") + } + + // Refresh: cache busts, next read sees the password hash. + refreshed, err := RefreshUserInCtx(rctx) + if err != nil { + t.Fatalf("RefreshUserInCtx: %v", err) + } + if refreshed.PasswordHash == "" { + t.Fatalf("refresh should observe the SetPassword side-effect") + } +} + +func TestUserCtxNoUser(t *testing.T) { + if _, ok := UserIDFromCtx(context.Background()); ok { + t.Fatalf("UserIDFromCtx should be false on a bare context") + } + if _, err := UserFromCtx(context.Background()); err != ErrNoUserContext { + t.Fatalf("UserFromCtx should return ErrNoUserContext, got %v", err) + } + if _, err := RefreshUserInCtx(context.Background()); err != ErrNoUserContext { + t.Fatalf("RefreshUserInCtx should return ErrNoUserContext, got %v", err) + } +}