diff --git a/README.md b/README.md index 4658760..7051c83 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,345 @@ # authkit -Just a concoction of auth stuff in one place. \ No newline at end of file +A pragmatic authentication and authorization toolkit for Go web services. + +`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and +permissions, plus default `database/sql` Postgres implementations and +framework-neutral HTTP middleware. It supports both opaque server-side +sessions and JWT access tokens with rotating refresh tokens, hashes passwords +with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux) +or any `net/http` stack. + +## Install + +``` +go get git.juancwu.dev/juancwu/authkit +``` + +`authkit` itself depends only on `database/sql` and the Go standard library +plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring +your own driver: `pgx`, `lib/pq`, or anything else that registers a +`database/sql` driver. + +```go +import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" +``` + +PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and +`pgcrypto` so no extensions are required. + +## What's included + +**Authentication flows** +- Email + password registration and login (Argon2id PHC-encoded hashes) +- Opaque server-side sessions with sliding TTL bounded by an absolute cap +- JWT access tokens (HS256) with rotating refresh tokens and reuse detection +- Email verification, password reset, and magic-link passwordless login + +**Authorization** +- Roles and permissions with many-to-many wiring +- API keys with custom abilities for per-endpoint scoping +- A unified `Principal` type so middleware works the same regardless of which + authentication method ran + +**Storage** +- Interfaces for every store so callers can plug in their own backends +- Default Postgres implementation built on `*sql.DB` (`sqlstore` package) +- Override table names via `Schema` without forking — useful when authkit + lives alongside an existing application schema +- A `Dialect` abstraction so future MySQL / SQLite implementations slot in + without changes to store code +- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)` + helper that takes a session-scoped advisory lock + +**HTTP** +- `middleware.RequireSession`, `RequireJWT`, `RequireAPIKey`, `RequireAny` +- `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`, + `RequireAbility` +- `middleware.PrincipalFrom(ctx)` to read the authenticated principal in + handlers + +**Errors** +- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`, + `ErrTokenReused`, `ErrSessionInvalid`, `ErrAPIKeyInvalid`, + `ErrPermissionDenied`, ...) compatible with `errors.Is` +- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx) + for op tags + +## Out of scope (v1) + +MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching, +pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite +dialects (architecture supports them; only Postgres ships in v1), and +column-name overrides in `Schema` (table-name overrides only). + +## Quick start + +### 1. Open a database and run migrations + +```go +import ( + "database/sql" + + "git.juancwu.dev/juancwu/authkit/sqlstore" + pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" + + _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq" +) + +db, err := sql.Open("pgx", os.Getenv("DATABASE_URL")) +if err != nil { /* ... */ } +defer db.Close() + +if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil { + log.Fatal(err) +} +``` + +`Migrate` is idempotent and safe to call from multiple processes — it takes +a session-scoped `pg_advisory_lock` to serialise rollouts. + +`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same +calls — the library only cares about `*sql.DB`. + +### 2. Wire the service + +```go +import ( + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" +) + +stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema()) +if err != nil { /* ... */ } + +auth := authkit.New(authkit.Deps{ + Users: stores.Users, + Sessions: stores.Sessions, + Tokens: stores.Tokens, + APIKeys: stores.APIKeys, + Roles: stores.Roles, + Permissions: stores.Permissions, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), +}, authkit.Config{ + JWTSecret: []byte(os.Getenv("JWT_SECRET")), + JWTIssuer: "myapp", + SessionCookieSecure: true, + SessionCookieHTTPOnly: true, +}) +``` + +`Config` zero values fall back to sensible defaults (24h idle / 30d absolute +session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h +password-reset, 15m magic-link). `JWTSecret` and the seven `Deps` fields are +required; `New` panics on a misconfiguration. + +### 3. Use the service + +```go +// Registration + password login +u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") +u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2") + +// Opaque session (cookie-friendly) +plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP) +http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt)) + +// JWT + rotating refresh +access, refresh, err := auth.IssueJWT(ctx, u.ID) +access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed + +// API key with abilities +plaintext, key, err := auth.IssueAPIKey(ctx, u.ID, "ci", + []string{"billing:read", "users:list"}, nil) + +// Email verification + password reset + magic link +tok, err := auth.RequestEmailVerification(ctx, u.ID) +_, err = auth.ConfirmEmail(ctx, tok) + +tok, err = auth.RequestPasswordReset(ctx, "alice@example.com") +err = auth.ConfirmPasswordReset(ctx, tok, "new-password") + +tok, err = auth.RequestMagicLink(ctx, "alice@example.com") +u, err = auth.ConsumeMagicLink(ctx, tok) +``` + +The plaintext returned by `IssueSession`, `IssueJWT`, `IssueAPIKey`, and the +token-minting flows is **show-once** — only its SHA-256 hash is stored. Show +it to the user immediately; you cannot recover it later. + +### 4. Wire middleware + +`authkit/middleware` returns standard `func(http.Handler) http.Handler` +values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any +`net/http` mux that accepts the same shape. + +```go +import ( + authkitmw "git.juancwu.dev/juancwu/authkit/middleware" + "git.juancwu.dev/juancwu/lightmux" +) + +mux := lightmux.New() + +cookieAuth := authkitmw.RequireSession(authkitmw.Options{ + Auth: auth, + Extractor: authkit.ChainExtractors( + authkit.CookieExtractor("authkit_session"), + authkit.BearerExtractor(), + ), +}) + +me := mux.Group("/me", cookieAuth) +me.Get("", func(w http.ResponseWriter, r *http.Request) { + p := authkitmw.MustPrincipal(r) + json.NewEncoder(w).Encode(p) +}) + +// RBAC: stack authz on top of any auth method +admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin")) + +// API-key-only route with a per-endpoint ability check +api := mux.Group("/api/v1", authkitmw.RequireAPIKey(authkitmw.Options{Auth: auth})) +api.Get("/billing", billingHandler, authkitmw.RequireAbility("billing:read")) +``` + +`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or +chain extractors) when reading session cookies. `Options.OnUnauth` and +`Options.OnForbidden` default to a JSON `401` / `403`; override them to match +your error envelope. + +## Custom table names + +Pass a non-default `Schema` to use your own table names. Identifiers must +match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and +`Migrate()` time, so SQL injection through the schema is impossible. + +```go +schema := sqlstore.DefaultSchema() +schema.Tables.Users = "accounts" +schema.Tables.APIKeys = "api_credentials" + +stores, _ := sqlstore.New(db, pgdialect.New(), schema) +``` + +The bundled migration files use the default `authkit_*` names. If you +override, you're responsible for matching DDL (most consumers with custom +naming already have their own DDL pipeline). + +Column-name overrides are not exposed in v1 — the column set is fixed for +each table. Adding column overrides later is purely additive. + +## How things work + +### Secret token format + +Sessions, refresh tokens, API keys, email-verify tokens, password-reset tokens, +and magic-link tokens all share one format: + +``` +plaintext = "_" + base64url(32 random bytes, no padding) +lookup = sha256(plaintext) +``` + +Plaintext is returned to the caller exactly once and never persisted; the +SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or +`Config.Random` for tests). + +### JWT revocation + +Access tokens carry `sv` (`session_version`) in their claims. When you call +`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version` +column increments and every outstanding access token fails the next +`AuthenticateJWT`. This is the only way to invalidate a JWT before its +`exp`. + +### Refresh token rotation + +Each `RefreshJWT` consumes the presented refresh token and issues a new one +on the same chain (the `chain_id` column on `authkit_tokens`). If a +*consumed* refresh token is ever presented again — a strong replay signal — +the entire chain is deleted via `TokenStore.DeleteByChain` and the call +returns `ErrTokenReused`. + +### Sliding session TTL + +Each authenticated request via `AuthenticateSession` slides `expires_at` to +`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`. +Long-lived idle sessions still hit the absolute boundary. + +### Schema and migrations + +`sqlstore.Migrate` applies every embedded `.sql` file under +`sqlstore/dialect/postgres/migrations/` whose version (filename without +`.sql`) is not in `authkit_schema_migrations`. The dialect's +`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises +concurrent migrators. Each migration owns its own transaction so future +migrations can use statements like `CREATE INDEX CONCURRENTLY`. + +Every default table is prefixed `authkit_` so the schema can live alongside +your application's own tables in a shared database. + +### Driver and dialect architecture + +The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour +lives behind a small `Dialect` interface: + +```go +type Dialect interface { + Name() string + BuildQueries(s Schema) Queries + Bootstrap(ctx context.Context, db *sql.DB) error + AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) + Migrations() fs.FS + IsUniqueViolation(err error) bool + Placeholder(n int) string + PlaceholderList(start, count int) string +} +``` + +v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new +implementation; no changes to store code. + +## Configuration reference + +| Field | Default | Notes | +|---|---|---| +| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request | +| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this | +| `SessionCookieName` | `authkit_session` | | +| `SessionCookieSameSite` | `Lax` | | +| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production | +| `JWTSecret` | — (required) | HS256 key | +| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them | +| `AccessTokenTTL` | 15m | | +| `RefreshTokenTTL` | 30d | | +| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | | +| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests | +| `Random` | `crypto/rand.Reader` | Override for deterministic tests | +| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit | + +## Implementing your own store + +Every store is a small interface with explicit semantics — see `stores.go`. +The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token +consumed and return it in a single statement (`UPDATE ... RETURNING` on +Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed. + +## Testing + +``` +go test ./... # unit tests, no DB +AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests +``` + +Unit tests cover token mint/parse, Argon2id encode/verify (including +`needsRehash` on parameter change), JWT issue/parse (incl. expired, +`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email +verification, password reset cascading session invalidation, magic-link +self-verification, API keys with abilities, and RBAC role-permission +resolution. Integration tests run the full `sqlstore` contract against a +real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set. + +## License + +MIT. See `LICENSE`. diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..b659f73 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,53 @@ +version: '3' + +tasks: + default: + desc: Run vet and unit tests + cmds: + - task: check + + install:tools: + desc: Install development tools (tparse for prettier test output) + cmds: + - go install github.com/mfridman/tparse@latest + + build: + desc: Build all packages + cmds: + - go build ./... + + test: + desc: Run unit tests with prettier output via tparse + cmds: + - set -o pipefail && go test ./... -json -cover | tparse -all + + test:race: + desc: Run unit tests under the race detector + cmds: + - set -o pipefail && go test ./... -race -json | tparse -all + + test:integration: + desc: Run sqlstore integration tests against a real Postgres (requires AUTHKIT_TEST_DATABASE_URL) + cmds: + - set -o pipefail && go test ./sqlstore/... -run Integration -count=1 -json | tparse -all + + vet: + desc: Run go vet + cmds: + - go vet ./... + + fmt: + desc: Format Go source files + cmds: + - gofmt -w . + + tidy: + desc: Tidy go.mod + cmds: + - go mod tidy + + check: + desc: Run vet and unit tests + cmds: + - task: vet + - task: test diff --git a/authkit.go b/authkit.go new file mode 100644 index 0000000..f23c2e7 --- /dev/null +++ b/authkit.go @@ -0,0 +1,121 @@ +package authkit + +import ( + "context" + "crypto/rand" + "io" + "net/http" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +// Deps bundles every backing store and the password hasher the Auth service +// depends on. All fields are required; New panics on a nil dep so misuse is +// caught at boot rather than under load. +type Deps struct { + Users UserStore + Sessions SessionStore + Tokens TokenStore + APIKeys APIKeyStore + Roles RoleStore + Permissions PermissionStore + Hasher Hasher +} + +// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material, +// and optional hooks. Any zero-valued duration is replaced with a sane +// default in New; required fields (notably JWTSecret) cause New to panic. +type Config struct { + // Session (opaque) cookies + DB-backed lifetime + SessionIdleTTL time.Duration + SessionAbsoluteTTL time.Duration + SessionCookieName string + SessionCookieDomain string + SessionCookiePath string + SessionCookieSecure bool + SessionCookieHTTPOnly bool + SessionCookieSameSite http.SameSite + + // JWT (HS256) + JWTSecret []byte + JWTIssuer string + JWTAudience string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + + // Single-use tokens + EmailVerifyTTL time.Duration + PasswordResetTTL time.Duration + MagicLinkTTL time.Duration + + // Hooks (optional) + Clock func() time.Time + Random io.Reader + LoginHook func(ctx context.Context, email string, success bool) error +} + +// Auth is the high-level service that composes the stores and hasher into the +// flows callers use: registration, login, sessions, JWTs, magic links, API +// keys, and authz checks. It is safe for concurrent use; method receivers +// never mutate Auth state after construction. +type Auth struct { + deps Deps + cfg Config +} + +// New validates Deps and Config, fills in defaults, and returns a ready +// service. It panics on missing deps or missing JWT secret rather than +// returning an error — these are programmer errors, not runtime ones. +func New(deps Deps, cfg Config) *Auth { + if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil || + deps.APIKeys == nil || deps.Roles == nil || deps.Permissions == nil || + deps.Hasher == nil { + panic(errx.New("authkit.New", "all Deps fields are required")) + } + if len(cfg.JWTSecret) == 0 { + panic(errx.New("authkit.New", "Config.JWTSecret is required")) + } + + if cfg.SessionIdleTTL == 0 { + cfg.SessionIdleTTL = 24 * time.Hour + } + if cfg.SessionAbsoluteTTL == 0 { + cfg.SessionAbsoluteTTL = 30 * 24 * time.Hour + } + if cfg.SessionCookieName == "" { + cfg.SessionCookieName = "authkit_session" + } + if cfg.SessionCookiePath == "" { + cfg.SessionCookiePath = "/" + } + if cfg.SessionCookieSameSite == 0 { + cfg.SessionCookieSameSite = http.SameSiteLaxMode + } + if cfg.AccessTokenTTL == 0 { + cfg.AccessTokenTTL = 15 * time.Minute + } + if cfg.RefreshTokenTTL == 0 { + cfg.RefreshTokenTTL = 30 * 24 * time.Hour + } + if cfg.EmailVerifyTTL == 0 { + cfg.EmailVerifyTTL = 48 * time.Hour + } + if cfg.PasswordResetTTL == 0 { + cfg.PasswordResetTTL = time.Hour + } + if cfg.MagicLinkTTL == 0 { + cfg.MagicLinkTTL = 15 * time.Minute + } + if cfg.Clock == nil { + cfg.Clock = func() time.Time { return time.Now().UTC() } + } + if cfg.Random == nil { + cfg.Random = rand.Reader + } + + return &Auth{deps: deps, cfg: cfg} +} + +// now returns the configured wall clock, defaulting to time.Now in UTC. +func (a *Auth) now() time.Time { return a.cfg.Clock() } diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3c49a23 --- /dev/null +++ b/doc.go @@ -0,0 +1,13 @@ +// Package authkit is an authentication and authorization toolkit for Go web +// services. It defines storage interfaces (UserStore, SessionStore, TokenStore, +// APIKeyStore, RoleStore, PermissionStore) and a high-level Auth service that +// composes them to support registration, password login, opaque server-side +// sessions, JWT access plus rotating refresh tokens, email verification, +// password resets, magic-link passwordless login, role-based access control, +// and API keys with custom abilities. +// +// Default Postgres implementations of every store live in the pgstore +// subpackage. Argon2id password hashing lives in hasher. Framework-neutral +// HTTP middleware (compatible with lightmux and any net/http stack) lives in +// middleware. +package authkit diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..a5acf10 --- /dev/null +++ b/errors.go @@ -0,0 +1,18 @@ +package authkit + +import "errors" + +var ( + ErrEmailTaken = errors.New("authkit: email already registered") + ErrUserNotFound = errors.New("authkit: user not found") + ErrInvalidCredentials = errors.New("authkit: invalid credentials") + ErrEmailNotVerified = errors.New("authkit: email not verified") + ErrTokenInvalid = errors.New("authkit: invalid or expired token") + ErrTokenReused = errors.New("authkit: token reuse detected") + ErrSessionInvalid = errors.New("authkit: invalid or expired session") + ErrAPIKeyInvalid = errors.New("authkit: invalid or expired api key") + ErrPermissionDenied = errors.New("authkit: permission denied") + ErrRoleNotFound = errors.New("authkit: role not found") + ErrPermissionNotFound = errors.New("authkit: permission not found") + ErrConfigInvalid = errors.New("authkit: invalid configuration") +) diff --git a/extractor.go b/extractor.go new file mode 100644 index 0000000..e1e0d73 --- /dev/null +++ b/extractor.go @@ -0,0 +1,64 @@ +package authkit + +import ( + "net/http" + "strings" +) + +// Extractor pulls a credential string out of an HTTP request. It returns +// (value, true) when a value was found, otherwise ("", false). +type Extractor func(r *http.Request) (string, bool) + +// BearerExtractor reads the value following "Bearer " in the Authorization +// header. Comparison is case-insensitive on the scheme. +func BearerExtractor() Extractor { + return func(r *http.Request) (string, bool) { + h := r.Header.Get("Authorization") + if h == "" { + return "", false + } + const prefix = "bearer " + if len(h) <= len(prefix) || !strings.EqualFold(h[:len(prefix)], prefix) { + return "", false + } + v := strings.TrimSpace(h[len(prefix):]) + if v == "" { + return "", false + } + return v, true + } +} + +// CookieExtractor reads the named cookie's value. +func CookieExtractor(name string) Extractor { + return func(r *http.Request) (string, bool) { + c, err := r.Cookie(name) + if err != nil || c.Value == "" { + return "", false + } + return c.Value, true + } +} + +// HeaderExtractor reads a custom header verbatim. +func HeaderExtractor(name string) Extractor { + return func(r *http.Request) (string, bool) { + v := strings.TrimSpace(r.Header.Get(name)) + if v == "" { + return "", false + } + return v, true + } +} + +// ChainExtractors tries each extractor in order, returning the first hit. +func ChainExtractors(es ...Extractor) Extractor { + return func(r *http.Request) (string, bool) { + for _, e := range es { + if v, ok := e(r); ok { + return v, true + } + } + return "", false + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ecd7a0c --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module git.juancwu.dev/juancwu/authkit + +go 1.26.2 + +require ( + git.juancwu.dev/juancwu/errx v0.1.0 + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.9.2 + golang.org/x/crypto v0.50.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..dc69ee1 --- /dev/null +++ b/go.sum @@ -0,0 +1,36 @@ +git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w= +git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hasher/argon2id.go b/hasher/argon2id.go new file mode 100644 index 0000000..5d5c57b --- /dev/null +++ b/hasher/argon2id.go @@ -0,0 +1,136 @@ +// Package hasher provides authkit.Hasher implementations. The default +// implementation, Argon2id, encodes hashes in the standard PHC string format +// (https://github.com/P-H-C/phc-string-format) so callers can introspect +// parameters and migrate. +package hasher + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "io" + "strings" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "golang.org/x/crypto/argon2" +) + +// Argon2idParams configures the Argon2id KDF. +type Argon2idParams struct { + Memory uint32 // KB; OWASP 2024 baseline: 19 MiB minimum, 64 MiB recommended + Iterations uint32 // time cost + Parallelism uint8 // lanes + SaltLen uint32 // bytes + KeyLen uint32 // bytes +} + +// DefaultArgon2idParams returns sensible defaults: 64 MiB memory, 3 +// iterations, 2 lanes, 16-byte salt, 32-byte key. Tune up Memory/Iterations +// over time and rely on Verify's needsRehash signal to migrate stored hashes. +func DefaultArgon2idParams() Argon2idParams { + return Argon2idParams{ + Memory: 64 * 1024, + Iterations: 3, + Parallelism: 2, + SaltLen: 16, + KeyLen: 32, + } +} + +type argon2idHasher struct { + params Argon2idParams + rng io.Reader +} + +// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the +// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand. +func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher { + if params == (Argon2idParams{}) { + params = DefaultArgon2idParams() + } + if rng == nil { + rng = rand.Reader + } + return &argon2idHasher{params: params, rng: rng} +} + +func (h *argon2idHasher) Hash(password string) (string, error) { + const op = "authkit.hasher.Argon2id.Hash" + if password == "" { + return "", errx.New(op, "password is empty") + } + salt := make([]byte, h.params.SaltLen) + if _, err := io.ReadFull(h.rng, salt); err != nil { + return "", errx.Wrap(op, err) + } + key := argon2.IDKey([]byte(password), salt, + h.params.Iterations, h.params.Memory, h.params.Parallelism, h.params.KeyLen) + return encodePHC(h.params, salt, key), nil +} + +func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) { + const op = "authkit.hasher.Argon2id.Verify" + got, salt, key, err := decodePHC(encoded) + if err != nil { + return false, false, errx.Wrap(op, err) + } + want := argon2.IDKey([]byte(password), salt, + got.Iterations, got.Memory, got.Parallelism, uint32(len(key))) + if subtle.ConstantTimeCompare(want, key) != 1 { + return false, false, nil + } + needsRehash := got.Memory != h.params.Memory || + got.Iterations != h.params.Iterations || + got.Parallelism != h.params.Parallelism || + uint32(len(salt)) != h.params.SaltLen || + uint32(len(key)) != h.params.KeyLen + return true, needsRehash, nil +} + +// PHC string format: +// +// $argon2id$v=19$m=,t=,p=$$ +// +// base64 here is the unpadded standard alphabet, per the spec. +func encodePHC(p Argon2idParams, salt, key []byte) string { + enc := base64.RawStdEncoding + return fmt.Sprintf( + "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, p.Memory, p.Iterations, p.Parallelism, + enc.EncodeToString(salt), enc.EncodeToString(key), + ) +} + +func decodePHC(s string) (Argon2idParams, []byte, []byte, error) { + parts := strings.Split(s, "$") + // Expect: ["", "argon2id", "v=19", "m=...,t=...,p=...", "", ""] + if len(parts) != 6 || parts[0] != "" || parts[1] != "argon2id" { + return Argon2idParams{}, nil, nil, errors.New("hasher: not an argon2id phc string") + } + var version int + if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad version segment: %w", err) + } + if version != argon2.Version { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: unsupported argon2 version %d", version) + } + var p Argon2idParams + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism); err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad params segment: %w", err) + } + enc := base64.RawStdEncoding + salt, err := enc.DecodeString(parts[4]) + if err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad salt: %w", err) + } + key, err := enc.DecodeString(parts[5]) + if err != nil { + return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad key: %w", err) + } + p.SaltLen = uint32(len(salt)) + p.KeyLen = uint32(len(key)) + return p, salt, key, nil +} diff --git a/hasher/argon2id_test.go b/hasher/argon2id_test.go new file mode 100644 index 0000000..008c453 --- /dev/null +++ b/hasher/argon2id_test.go @@ -0,0 +1,78 @@ +package hasher + +import ( + "strings" + "testing" +) + +func TestArgon2idHashVerifyRoundtrip(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + encoded, err := h.Hash("hunter2hunter2") + if err != nil { + t.Fatalf("Hash: %v", err) + } + if !strings.HasPrefix(encoded, "$argon2id$") { + t.Fatalf("encoded hash not in PHC form: %s", encoded) + } + ok, needsRehash, err := h.Verify("hunter2hunter2", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if !ok { + t.Fatalf("Verify rejected the original password") + } + if needsRehash { + t.Fatalf("Verify with default params should not signal rehash") + } +} + +func TestArgon2idVerifyWrongPassword(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + encoded, err := h.Hash("correct horse battery staple") + if err != nil { + t.Fatalf("Hash: %v", err) + } + ok, _, err := h.Verify("nope", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if ok { + t.Fatalf("Verify should reject wrong password") + } +} + +func TestArgon2idNeedsRehashOnParamChange(t *testing.T) { + // Hash with light params... + light := Argon2idParams{Memory: 8 * 1024, Iterations: 1, Parallelism: 1, SaltLen: 16, KeyLen: 32} + encoded, err := NewArgon2id(light, nil).Hash("hello world") + if err != nil { + t.Fatalf("Hash: %v", err) + } + // ...verify with stronger params should still match but flag rehash. + heavier := DefaultArgon2idParams() + ok, needsRehash, err := NewArgon2id(heavier, nil).Verify("hello world", encoded) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if !ok { + t.Fatalf("Verify rejected legitimate password across params") + } + if !needsRehash { + t.Fatalf("Verify should flag rehash when stored params differ from current") + } +} + +func TestArgon2idRejectsMalformed(t *testing.T) { + h := NewArgon2id(DefaultArgon2idParams(), nil) + cases := []string{ + "", + "not-a-phc", + "$argon2i$v=19$m=64,t=1,p=1$abc$def", + "$argon2id$v=99$m=64,t=1,p=1$YWJj$ZGVm", + } + for _, c := range cases { + if _, _, err := h.Verify("x", c); err == nil { + t.Fatalf("Verify should reject malformed encoding: %q", c) + } + } +} diff --git a/jwt.go b/jwt.go new file mode 100644 index 0000000..c9fb7b4 --- /dev/null +++ b/jwt.go @@ -0,0 +1,69 @@ +package authkit + +import ( + "git.juancwu.dev/juancwu/errx" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// accessClaims is the JWT shape issued by IssueJWT. The session_version +// field carries the User.SessionVersion at issue time so AuthenticateJWT +// can detect global revocations (logout-everywhere, password change). +type accessClaims struct { + jwt.RegisteredClaims + SessionVersion int `json:"sv"` + Method string `json:"m"` +} + +func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, error) { + const op = "authkit.signAccessToken" + now := a.now() + claims := accessClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: userID.String(), + Issuer: a.cfg.JWTIssuer, + Audience: jwt.ClaimStrings{a.cfg.JWTAudience}, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(a.cfg.AccessTokenTTL)), + ID: uuid.NewString(), + }, + SessionVersion: sessionVersion, + Method: string(AuthMethodJWT), + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString(a.cfg.JWTSecret) + if err != nil { + return "", errx.Wrap(op, err) + } + return signed, nil +} + +// parseAccessToken validates the signature and returns the parsed claims. +func (a *Auth) parseAccessToken(token string) (*accessClaims, error) { + const op = "authkit.parseAccessToken" + opts := []jwt.ParserOption{ + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithTimeFunc(a.cfg.Clock), + } + if a.cfg.JWTIssuer != "" { + opts = append(opts, jwt.WithIssuer(a.cfg.JWTIssuer)) + } + if a.cfg.JWTAudience != "" { + opts = append(opts, jwt.WithAudience(a.cfg.JWTAudience)) + } + parser := jwt.NewParser(opts...) + + parsed, err := parser.ParseWithClaims(token, &accessClaims{}, func(t *jwt.Token) (any, error) { + return a.cfg.JWTSecret, nil + }) + if err != nil { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + claims, ok := parsed.Claims.(*accessClaims) + if !ok || !parsed.Valid { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + return claims, nil +} diff --git a/jwt_test.go b/jwt_test.go new file mode 100644 index 0000000..5de34a2 --- /dev/null +++ b/jwt_test.go @@ -0,0 +1,99 @@ +package authkit + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestJWTIssueAndAuthenticate(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "alice@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, refresh, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if access == "" || refresh == "" { + t.Fatalf("empty tokens") + } + p, err := a.AuthenticateJWT(context.Background(), access) + if err != nil { + t.Fatalf("AuthenticateJWT: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("principal user id mismatch") + } + if p.Method != AuthMethodJWT { + t.Fatalf("principal method = %s, want jwt", p.Method) + } +} + +func TestJWTSessionVersionMismatchRejected(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "bob@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, _, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil { + t.Fatalf("RevokeAllUserSessions: %v", err) + } + if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err) + } +} + +func TestJWTRefreshRotationAndReuseDetection(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "carol@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + _, refresh1, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + _, refresh2, err := a.RefreshJWT(context.Background(), refresh1) + if err != nil { + t.Fatalf("first RefreshJWT: %v", err) + } + if refresh1 == refresh2 { + t.Fatalf("refresh token did not rotate") + } + + // Replaying refresh1 must surface ErrTokenReused and revoke the chain. + if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) { + t.Fatalf("expected ErrTokenReused on replay, got %v", err) + } + // After chain revocation, even refresh2 (the legitimate next one) must + // be rejected. + if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err) + } +} + +func TestJWTExpiredTokenRejected(t *testing.T) { + a := newTestAuth(t) + now := time.Now().UTC() + a.cfg.Clock = func() time.Time { return now } + u, err := a.Register(context.Background(), "dan@example.com", "hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + access, _, err := a.IssueJWT(context.Background(), u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + // Advance clock past TTL. + a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) } + if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err) + } +} diff --git a/memstore_test.go b/memstore_test.go new file mode 100644 index 0000000..4b16216 --- /dev/null +++ b/memstore_test.go @@ -0,0 +1,619 @@ +package authkit + +// In-memory store fakes used by service-level tests. Kept in a _test.go file +// so they don't ship in the public API. Each fake is intentionally minimal — +// it satisfies the interface and supports the flows the tests exercise; not +// every method is wired up (a few panic to flag accidental use). + +import ( + "bytes" + "context" + "errors" + "slices" + "sync" + "time" + + "github.com/google/uuid" +) + +type memUserStore struct { + mu sync.Mutex + byID map[uuid.UUID]*User + byEml map[string]uuid.UUID +} + +func newMemUserStore() *memUserStore { + return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}} +} + +func (s *memUserStore) CreateUser(_ context.Context, u *User) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.byEml[u.EmailNormalized]; exists { + return ErrEmailTaken + } + if u.ID == uuid.Nil { + u.ID = uuid.New() + } + cp := *u + s.byID[u.ID] = &cp + s.byEml[u.EmailNormalized] = u.ID + return nil +} + +func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return nil, ErrUserNotFound + } + cp := *u + return &cp, nil +} + +func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.byEml[ne] + if !ok { + return nil, ErrUserNotFound + } + u := *s.byID[id] + return &u, nil +} + +func (s *memUserStore) UpdateUser(_ context.Context, u *User) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.byID[u.ID]; !ok { + return ErrUserNotFound + } + cp := *u + s.byID[u.ID] = &cp + return nil +} +func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + delete(s.byID, id) + delete(s.byEml, u.EmailNormalized) + return nil +} + +func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.PasswordHash = h + return nil +} +func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.EmailVerifiedAt = &at + return nil +} +func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return 0, ErrUserNotFound + } + u.SessionVersion++ + return u.SessionVersion, nil +} +func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return 0, ErrUserNotFound + } + u.FailedLogins++ + return u.FailedLogins, nil +} +func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + u, ok := s.byID[id] + if !ok { + return ErrUserNotFound + } + u.FailedLogins = 0 + return nil +} + +type memSessionStore struct { + mu sync.Mutex + m map[string]*Session +} + +func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} } +func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *ses + s.m[string(ses.IDHash)] = &cp + return nil +} +func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[string(h)] + if !ok { + return nil, ErrSessionInvalid + } + cp := *v + return &cp, nil +} +func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[string(h)] + if !ok { + return ErrSessionInvalid + } + v.LastSeenAt = lastSeen + v.ExpiresAt = exp + return nil +} +func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, string(h)) + return nil +} +func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.m { + if v.UserID == uid { + delete(s.m, k) + } + } + return nil +} +func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, v := range s.m { + if !v.ExpiresAt.After(now) { + delete(s.m, k) + n++ + } + } + return n, nil +} + +type memTokenStore struct { + mu sync.Mutex + // Keyed by kind+hex(hash) so identical hashes across kinds don't collide. + m map[string]*Token +} + +func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} } +func tokKey(kind TokenKind, h []byte) string { + return string(kind) + ":" + string(h) +} +func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *t + s.m[tokKey(t.Kind, t.Hash)] = &cp + return nil +} +func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.m[tokKey(kind, h)] + if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) { + return nil, ErrTokenInvalid + } + t.ConsumedAt = &now + cp := *t + return &cp, nil +} +func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.m[tokKey(kind, h)] + if !ok { + return nil, ErrTokenInvalid + } + cp := *t + return &cp, nil +} +func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, t := range s.m { + if t.ChainID != nil && *t.ChainID == chainID { + delete(s.m, k) + n++ + } + } + return n, nil +} +func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + var n int64 + for k, t := range s.m { + if !t.ExpiresAt.After(now) { + delete(s.m, k) + n++ + } + } + return n, nil +} + +type memAPIKeyStore struct { + mu sync.Mutex + m map[string]*APIKey +} + +func newMemAPIKeyStore() *memAPIKeyStore { return &memAPIKeyStore{m: map[string]*APIKey{}} } +func (s *memAPIKeyStore) CreateAPIKey(_ context.Context, k *APIKey) error { + s.mu.Lock() + defer s.mu.Unlock() + cp := *k + cp.Abilities = append([]string(nil), k.Abilities...) + s.m[string(k.IDHash)] = &cp + return nil +} +func (s *memAPIKeyStore) GetAPIKey(_ context.Context, h []byte) (*APIKey, error) { + s.mu.Lock() + defer s.mu.Unlock() + k, ok := s.m[string(h)] + if !ok { + return nil, ErrAPIKeyInvalid + } + cp := *k + cp.Abilities = append([]string(nil), k.Abilities...) + return &cp, nil +} +func (s *memAPIKeyStore) ListAPIKeysByOwner(_ context.Context, owner uuid.UUID) ([]*APIKey, error) { + s.mu.Lock() + defer s.mu.Unlock() + var out []*APIKey + for _, k := range s.m { + if k.OwnerID == owner { + cp := *k + out = append(out, &cp) + } + } + return out, nil +} +func (s *memAPIKeyStore) TouchAPIKey(_ context.Context, h []byte, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + if k, ok := s.m[string(h)]; ok { + k.LastUsedAt = &at + } + return nil +} +func (s *memAPIKeyStore) RevokeAPIKey(_ context.Context, h []byte, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + k, ok := s.m[string(h)] + if !ok { + return ErrAPIKeyInvalid + } + if k.RevokedAt != nil { + return ErrAPIKeyInvalid + } + k.RevokedAt = &at + return nil +} +func (s *memAPIKeyStore) RevokeAPIKeysByOwner(_ context.Context, owner uuid.UUID, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, k := range s.m { + if k.OwnerID == owner && k.RevokedAt == nil { + k.RevokedAt = &at + } + } + return nil +} + +type memRoleStore struct { + mu sync.Mutex + roles map[uuid.UUID]*Role + rolesByNm map[string]uuid.UUID + userRoles map[uuid.UUID]map[uuid.UUID]struct{} +} + +func newMemRoleStore() *memRoleStore { + return &memRoleStore{ + roles: map[uuid.UUID]*Role{}, + rolesByNm: map[string]uuid.UUID{}, + userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{}, + } +} +func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, dup := s.rolesByNm[r.Name]; dup { + return errors.New("role exists") + } + if r.ID == uuid.Nil { + r.ID = uuid.New() + } + cp := *r + s.roles[r.ID] = &cp + s.rolesByNm[r.Name] = r.ID + return nil +} +func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.roles[id] + if !ok { + return nil, ErrRoleNotFound + } + cp := *r + return &cp, nil +} +func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.rolesByNm[n] + if !ok { + return nil, ErrRoleNotFound + } + r := *s.roles[id] + return &r, nil +} +func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*Role, 0, len(s.roles)) + for _, r := range s.roles { + cp := *r + out = append(out, &cp) + } + return out, nil +} +func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.roles[id] + if !ok { + return ErrRoleNotFound + } + delete(s.roles, id) + delete(s.rolesByNm, r.Name) + for _, m := range s.userRoles { + delete(m, id) + } + return nil +} +func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + m, ok := s.userRoles[uid] + if !ok { + m = map[uuid.UUID]struct{}{} + s.userRoles[uid] = m + } + m[rid] = struct{}{} + return nil +} +func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + if m, ok := s.userRoles[uid]; ok { + delete(m, rid) + } + return nil +} +func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := []*Role{} + for rid := range s.userRoles[uid] { + if r, ok := s.roles[rid]; ok { + cp := *r + out = append(out, &cp) + } + } + return out, nil +} +func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for rid := range s.userRoles[uid] { + r := s.roles[rid] + if slices.Contains(names, r.Name) { + return true, nil + } + } + return false, nil +} + +type memPermStore struct { + mu sync.Mutex + perms map[uuid.UUID]*Permission + permsByNm map[string]uuid.UUID + rolePerms map[uuid.UUID]map[uuid.UUID]struct{} + roles *memRoleStore +} + +func newMemPermStore(rs *memRoleStore) *memPermStore { + return &memPermStore{ + perms: map[uuid.UUID]*Permission{}, + permsByNm: map[string]uuid.UUID{}, + rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{}, + roles: rs, + } +} +func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, dup := s.permsByNm[p.Name]; dup { + return errors.New("perm exists") + } + if p.ID == uuid.Nil { + p.ID = uuid.New() + } + cp := *p + s.perms[p.ID] = &cp + s.permsByNm[p.Name] = p.ID + return nil +} +func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + p, ok := s.perms[id] + if !ok { + return nil, ErrPermissionNotFound + } + cp := *p + return &cp, nil +} +func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + id, ok := s.permsByNm[n] + if !ok { + return nil, ErrPermissionNotFound + } + p := *s.perms[id] + return &p, nil +} +func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*Permission, 0, len(s.perms)) + for _, p := range s.perms { + cp := *p + out = append(out, &cp) + } + return out, nil +} +func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + p, ok := s.perms[id] + if !ok { + return ErrPermissionNotFound + } + delete(s.perms, id) + delete(s.permsByNm, p.Name) + for _, m := range s.rolePerms { + delete(m, id) + } + return nil +} +func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + m, ok := s.rolePerms[rid] + if !ok { + m = map[uuid.UUID]struct{}{} + s.rolePerms[rid] = m + } + m[pid] = struct{}{} + return nil +} +func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error { + s.mu.Lock() + defer s.mu.Unlock() + if m, ok := s.rolePerms[rid]; ok { + delete(m, pid) + } + return nil +} +func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := []*Permission{} + for pid := range s.rolePerms[rid] { + if p, ok := s.perms[pid]; ok { + cp := *p + out = append(out, &cp) + } + } + return out, nil +} +func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) { + s.roles.mu.Lock() + roleIDs := make([]uuid.UUID, 0) + for rid := range s.roles.userRoles[uid] { + roleIDs = append(roleIDs, rid) + } + s.roles.mu.Unlock() + + s.mu.Lock() + defer s.mu.Unlock() + seen := map[uuid.UUID]struct{}{} + out := []*Permission{} + for _, rid := range roleIDs { + for pid := range s.rolePerms[rid] { + if _, dup := seen[pid]; dup { + continue + } + seen[pid] = struct{}{} + if p, ok := s.perms[pid]; ok { + cp := *p + out = append(out, &cp) + } + } + } + return out, nil +} + +// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in +// the hasher package itself; we just need *something* callable here. +type stubHasher struct{} + +func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil } +func (stubHasher) Verify(p, encoded string) (bool, bool, error) { + want := "stub:" + p + if !bytes.Equal([]byte(want), []byte(encoded)) { + return false, false, nil + } + return true, false, nil +} + +// newTestAuth wires the fakes into Auth with deterministic config. +func newTestAuth(t interface{ Helper() }) *Auth { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + roles := newMemRoleStore() + return New(Deps{ + Users: newMemUserStore(), + Sessions: newMemSessionStore(), + Tokens: newMemTokenStore(), + APIKeys: newMemAPIKeyStore(), + Roles: roles, + Permissions: newMemPermStore(roles), + Hasher: stubHasher{}, + }, Config{ + JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"), + JWTIssuer: "authkit-test", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: 1 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + }) +} diff --git a/middleware/authz.go b/middleware/authz.go new file mode 100644 index 0000000..79d7297 --- /dev/null +++ b/middleware/authz.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// authzGuard wraps the common pattern of "look up the Principal, run a +// predicate, succeed or 403". onForbidden defaults to JSON 403. +func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler { + if onForbidden == nil { + onForbidden = defaultJSONError(http.StatusForbidden) + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p, ok := PrincipalFrom(r.Context()) + if !ok { + // No auth middleware ran upstream; treat as forbidden + // rather than crashing — composition is the caller's + // responsibility but a 403 is the safer default. + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + if !pred(p) { + onForbidden(w, r, authkit.ErrPermissionDenied) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// RequireRole permits requests whose Principal holds the named role. +func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasRole(name) + }) +} + +// RequireAnyRole permits requests whose Principal holds at least one of the +// named roles. +func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasAnyRole(names...) + }) +} + +// RequirePermission permits requests whose Principal holds the named +// permission (resolved via roles at auth time). +func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasPermission(name) + }) +} + +// RequireAbility permits requests whose Principal carries the named ability. +// Abilities are populated only for API-key authentication; this middleware +// will reject session/JWT-authenticated requests by design. +func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler { + return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool { + return p.HasAbility(name) + }) +} + +func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) { + if len(s) == 0 { + return nil + } + return s[0] +} diff --git a/middleware/context.go b/middleware/context.go new file mode 100644 index 0000000..7de6da0 --- /dev/null +++ b/middleware/context.go @@ -0,0 +1,39 @@ +// Package middleware provides framework-neutral HTTP middleware for authkit. +// Every middleware function returns the standard func(http.Handler) +// http.Handler type, so it composes with lightmux's Use/Group/Handle as well +// as any net/http stack that uses the same signature. +package middleware + +import ( + "context" + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// principalKey is an unexported context key. Using a distinct empty struct +// type guarantees no collision with caller-defined keys. +type principalKey struct{} + +// withPrincipal stashes p on the request context for downstream handlers. +func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context { + return context.WithValue(ctx, principalKey{}, p) +} + +// PrincipalFrom retrieves the authenticated Principal placed by RequireSession, +// RequireJWT, or RequireAPIKey. The boolean is false if no auth middleware +// ran for this request. +func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) { + p, ok := ctx.Value(principalKey{}).(*authkit.Principal) + return p, ok +} + +// MustPrincipal panics if no Principal is on the context. Use only on +// handlers known to be behind a Require* middleware. +func MustPrincipal(r *http.Request) *authkit.Principal { + p, ok := PrincipalFrom(r.Context()) + if !ok { + panic("authkit/middleware: no principal on request context") + } + return p +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..1fbc4dd --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,138 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "git.juancwu.dev/juancwu/authkit" +) + +// Options configures auth middleware. Auth is required; the rest fall back +// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403 +// on authz failure. +type Options struct { + Auth *authkit.Auth + Extractor authkit.Extractor + OnUnauth func(w http.ResponseWriter, r *http.Request, err error) + OnForbidden func(w http.ResponseWriter, r *http.Request, err error) +} + +func (o Options) extractor() authkit.Extractor { + if o.Extractor != nil { + return o.Extractor + } + return authkit.BearerExtractor() +} + +func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) { + if o.OnUnauth != nil { + return o.OnUnauth + } + return defaultJSONError(http.StatusUnauthorized) +} + +func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) { + if o.OnForbidden != nil { + return o.OnForbidden + } + return defaultJSONError(http.StatusForbidden) +} + +func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) { + return func(w http.ResponseWriter, _ *http.Request, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": http.StatusText(status), + }) + } +} + +// RequireSession authenticates the request via an opaque session string. The +// extractor is consulted first; if no extractor is set the default Bearer +// extractor is used. For cookie-based session lookup, set +// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName). +func RequireSession(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateSession(r.Context(), raw) + }) +} + +// RequireJWT authenticates the request via an HS256 JWT. +func RequireJWT(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateJWT(r.Context(), raw) + }) +} + +// RequireAPIKey authenticates the request via an opaque API secret. +func RequireAPIKey(opts Options) func(http.Handler) http.Handler { + return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) { + return opts.Auth.AuthenticateAPIKey(r.Context(), raw) + }) +} + +// RequireAny tries each method in order until one succeeds. Useful for routes +// that accept either a session cookie or an API key. +func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler { + if len(methods) == 0 { + methods = []authkit.AuthMethod{ + authkit.AuthMethodSession, + authkit.AuthMethodJWT, + authkit.AuthMethodAPIKey, + } + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, ok := opts.extractor()(r) + if !ok || raw == "" { + opts.onUnauth()(w, r, authkit.ErrSessionInvalid) + return + } + var ( + p *authkit.Principal + lastErr error + ) + for _, m := range methods { + switch m { + case authkit.AuthMethodSession: + p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw) + case authkit.AuthMethodJWT: + p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw) + case authkit.AuthMethodAPIKey: + p, lastErr = opts.Auth.AuthenticateAPIKey(r.Context(), raw) + } + if lastErr == nil && p != nil { + next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) + return + } + } + opts.onUnauth()(w, r, lastErr) + }) + } +} + +// requireWith is the shared scaffolding for the single-method Require* +// middlewares. +func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler { + if opts.Auth == nil { + panic("authkit/middleware: Options.Auth is required") + } + extractor := opts.extractor() + onUnauth := opts.onUnauth() + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, ok := extractor(r) + if !ok || raw == "" { + onUnauth(w, r, authkit.ErrSessionInvalid) + return + } + p, err := authn(r, raw) + if err != nil { + onUnauth(w, r, err) + return + } + next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p))) + }) + } +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..9b31616 --- /dev/null +++ b/models.go @@ -0,0 +1,75 @@ +package authkit + +import ( + "net/netip" + "time" + + "github.com/google/uuid" +) + +type User struct { + ID uuid.UUID + Email string + EmailNormalized string + EmailVerifiedAt *time.Time + PasswordHash string + SessionVersion int + FailedLogins int + LastLoginAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type Session struct { + IDHash []byte + UserID uuid.UUID + UserAgent string + IP netip.Addr + CreatedAt time.Time + LastSeenAt time.Time + ExpiresAt time.Time +} + +type TokenKind string + +const ( + TokenEmailVerify TokenKind = "email_verify" + TokenPasswordReset TokenKind = "password_reset" + TokenMagicLink TokenKind = "magic_link" + TokenRefresh TokenKind = "refresh" +) + +type Token struct { + Hash []byte + Kind TokenKind + UserID uuid.UUID + ChainID *string + ConsumedAt *time.Time + CreatedAt time.Time + ExpiresAt time.Time +} + +type APIKey struct { + IDHash []byte + OwnerID uuid.UUID + Name string + Abilities []string + LastUsedAt *time.Time + CreatedAt time.Time + ExpiresAt *time.Time + RevokedAt *time.Time +} + +type Role struct { + ID uuid.UUID + Name string + Description string + CreatedAt time.Time +} + +type Permission struct { + ID uuid.UUID + Name string + Description string + CreatedAt time.Time +} diff --git a/principal.go b/principal.go new file mode 100644 index 0000000..590dd9b --- /dev/null +++ b/principal.go @@ -0,0 +1,63 @@ +package authkit + +import ( + "time" + + "github.com/google/uuid" +) + +type AuthMethod string + +const ( + AuthMethodSession AuthMethod = "session" + AuthMethodJWT AuthMethod = "jwt" + AuthMethodAPIKey AuthMethod = "api_key" +) + +type Principal struct { + UserID uuid.UUID + Method AuthMethod + SessionID []byte + APIKeyID []byte + Roles []string + Permissions []string + Abilities []string + IssuedAt time.Time + ExpiresAt time.Time +} + +func (p *Principal) HasRole(name string) bool { + for _, r := range p.Roles { + if r == name { + return true + } + } + return false +} + +func (p *Principal) HasAnyRole(names ...string) bool { + for _, n := range names { + if p.HasRole(n) { + return true + } + } + return false +} + +func (p *Principal) HasPermission(name string) bool { + for _, perm := range p.Permissions { + if perm == name { + return true + } + } + return false +} + +func (p *Principal) HasAbility(name string) bool { + for _, a := range p.Abilities { + if a == name { + return true + } + } + return false +} diff --git a/service_apikey.go b/service_apikey.go new file mode 100644 index 0000000..045e063 --- /dev/null +++ b/service_apikey.go @@ -0,0 +1,92 @@ +package authkit + +import ( + "context" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueAPIKey mints a fresh API secret with the given abilities and an +// optional TTL. The plaintext is returned to the caller (show-once) and the +// SHA-256 lookup hash is stored. Pass ttl=nil for a non-expiring key. +func (a *Auth) IssueAPIKey(ctx context.Context, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *APIKey, error) { + const op = "authkit.Auth.IssueAPIKey" + plaintext, hash, err := mintSecret(prefixAPIKey, a.cfg.Random) + if err != nil { + return "", nil, errx.Wrap(op, err) + } + now := a.now() + k := &APIKey{ + IDHash: hash, + OwnerID: ownerID, + Name: name, + Abilities: append([]string(nil), abilities...), + CreatedAt: now, + } + if ttl != nil { + exp := now.Add(*ttl) + k.ExpiresAt = &exp + } + if err := a.deps.APIKeys.CreateAPIKey(ctx, k); err != nil { + return "", nil, errx.Wrap(op, err) + } + return plaintext, k, nil +} + +// AuthenticateAPIKey validates an API secret string, touches last_used_at +// (best-effort), and returns a Principal carrying the key's abilities. The +// owning user's roles+permissions are also resolved so the same Principal +// can satisfy RequireRole / RequirePermission middleware. +func (a *Auth) AuthenticateAPIKey(ctx context.Context, plaintext string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateAPIKey" + hash, ok := parseSecret(prefixAPIKey, plaintext) + if !ok { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + k, err := a.deps.APIKeys.GetAPIKey(ctx, hash) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + if k.RevokedAt != nil { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + if k.ExpiresAt != nil && !k.ExpiresAt.After(now) { + return nil, errx.Wrap(op, ErrAPIKeyInvalid) + } + _ = a.deps.APIKeys.TouchAPIKey(ctx, hash, now) + + roles, perms, err := a.resolveRolesAndPermissions(ctx, k.OwnerID) + if err != nil { + return nil, errx.Wrap(op, err) + } + expires := now + if k.ExpiresAt != nil { + expires = *k.ExpiresAt + } + return &Principal{ + UserID: k.OwnerID, + Method: AuthMethodAPIKey, + APIKeyID: hash, + Roles: roles, + Permissions: perms, + Abilities: append([]string(nil), k.Abilities...), + IssuedAt: k.CreatedAt, + ExpiresAt: expires, + }, nil +} + +// RevokeAPIKey marks a key revoked. Idempotent on already-revoked keys. +func (a *Auth) RevokeAPIKey(ctx context.Context, plaintext string) error { + const op = "authkit.Auth.RevokeAPIKey" + hash, ok := parseSecret(prefixAPIKey, plaintext) + if !ok { + return errx.Wrap(op, ErrAPIKeyInvalid) + } + if err := a.deps.APIKeys.RevokeAPIKey(ctx, hash, a.now()); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_authz.go b/service_authz.go new file mode 100644 index 0000000..7a5582c --- /dev/null +++ b/service_authz.go @@ -0,0 +1,85 @@ +package authkit + +import ( + "context" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// UserPermissions returns the union of permission names a user holds via +// their assigned roles. Resolved at call time; v1 does not cache. +func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) { + const op = "authkit.Auth.UserPermissions" + perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) + if err != nil { + return nil, errx.Wrap(op, err) + } + out := make([]string, len(perms)) + for i, p := range perms { + out[i] = p.Name + } + return out, nil +} + +// HasPermission checks whether a user holds the named permission via any +// assigned role. +func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) { + const op = "authkit.Auth.HasPermission" + perms, err := a.UserPermissions(ctx, userID) + if err != nil { + return false, errx.Wrap(op, err) + } + for _, p := range perms { + if p == name { + return true, nil + } + } + return false, nil +} + +// HasRole checks whether a user is assigned the named role. +func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) { + const op = "authkit.Auth.HasRole" + ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name}) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// HasAnyRole checks whether a user holds at least one of the named roles. +func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { + const op = "authkit.Auth.HasAnyRole" + ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names) + if err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// AssignRole is a convenience that looks up a role by name and assigns it. +func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error { + const op = "authkit.Auth.AssignRole" + r, err := a.deps.Roles.GetRoleByName(ctx, roleName) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RemoveRole is the symmetric helper for AssignRole. +func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error { + const op = "authkit.Auth.RemoveRole" + r, err := a.deps.Roles.GetRoleByName(ctx, roleName) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_jwt.go b/service_jwt.go new file mode 100644 index 0000000..d3133b2 --- /dev/null +++ b/service_jwt.go @@ -0,0 +1,147 @@ +package authkit + +import ( + "context" + "errors" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueJWT issues a fresh access JWT and a rotating opaque refresh token. +// The refresh token is bound to a chain via Token.ChainID; rotation +// preserves that chain so reuse-detection can revoke the whole family. +func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) { + const op = "authkit.Auth.IssueJWT" + u, err := a.deps.Users.GetUserByID(ctx, userID) + if err != nil { + return "", "", errx.Wrap(op, err) + } + access, err = a.signAccessToken(u.ID, u.SessionVersion) + if err != nil { + return "", "", errx.Wrap(op, err) + } + refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString()) + if err != nil { + return "", "", errx.Wrap(op, err) + } + return access, refresh, nil +} + +// AuthenticateJWT validates the access JWT, cross-checks the user's +// session_version (instant revocation), and resolves a Principal. +func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateJWT" + claims, err := a.parseAccessToken(access) + if err != nil { + return nil, errx.Wrap(op, err) + } + uid, err := uuid.Parse(claims.Subject) + if err != nil { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + u, err := a.deps.Users.GetUserByID(ctx, uid) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + return nil, errx.Wrap(op, err) + } + if u.SessionVersion != claims.SessionVersion { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + roles, perms, err := a.resolveRolesAndPermissions(ctx, u.ID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return &Principal{ + UserID: u.ID, + Method: AuthMethodJWT, + Roles: roles, + Permissions: perms, + IssuedAt: claims.IssuedAt.Time, + ExpiresAt: claims.ExpiresAt.Time, + }, nil +} + +// RefreshJWT consumes the presented refresh token and mints a new access + +// refresh pair. Reuse of an already-consumed refresh token deletes the +// entire chain (logout-everywhere on that device family) and returns +// ErrTokenReused. +func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) { + const op = "authkit.Auth.RefreshJWT" + hash, ok := parseSecret(prefixRefresh, plaintextRefresh) + if !ok { + return "", "", errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + + consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now) + if err != nil { + // Differentiate plain-invalid (never existed / expired) from + // reuse (existed, already consumed). The presence-check below is + // the reuse signal. + if errors.Is(err, ErrTokenInvalid) { + if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil { + if existing.ChainID != nil && *existing.ChainID != "" { + _, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID) + } + return "", "", errx.Wrap(op, ErrTokenReused) + } + } + return "", "", errx.Wrap(op, err) + } + + var chainID string + if consumed.ChainID != nil { + chainID = *consumed.ChainID + } + if chainID == "" { + // Defensive: every refresh token should be chain-bound. Fall back + // to a fresh chain so we never throw on missing metadata. + chainID = uuid.NewString() + } + + access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID)) + if err != nil { + return "", "", errx.Wrap(op, err) + } + refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID) + if err != nil { + return "", "", errx.Wrap(op, err) + } + return access, refresh, nil +} + +// mintRefreshToken stores a fresh refresh token bound to chainID and returns +// the plaintext. +func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) { + const op = "authkit.Auth.mintRefreshToken" + plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenRefresh, + UserID: userID, + ChainID: &chainID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.RefreshTokenTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// userSessionVersion fetches the current session_version. Errors collapse to +// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly +// — but we still need a value to embed in the freshly-minted access token. +func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int { + if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil { + return u.SessionVersion + } + return 0 +} diff --git a/service_magic.go b/service_magic.go new file mode 100644 index 0000000..d200ca5 --- /dev/null +++ b/service_magic.go @@ -0,0 +1,62 @@ +package authkit + +import ( + "context" + + "git.juancwu.dev/juancwu/errx" +) + +// RequestMagicLink mints a single-use magic-link token for the email and +// returns the plaintext for delivery. ErrUserNotFound is returned for +// unregistered emails. +func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) { + const op = "authkit.Auth.RequestMagicLink" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + return "", errx.Wrap(op, err) + } + plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenMagicLink, + UserID: u.ID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.MagicLinkTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConsumeMagicLink consumes the magic-link token and returns the +// authenticated user. Callers typically follow this with IssueSession or +// IssueJWT to actually log the user in. +func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) { + const op = "authkit.Auth.ConsumeMagicLink" + hash, ok := parseSecret(prefixMagicLink, plaintextToken) + if !ok { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now) + if err != nil { + return nil, errx.Wrap(op, err) + } + u, err := a.deps.Users.GetUserByID(ctx, t.UserID) + if err != nil { + return nil, errx.Wrap(op, err) + } + // A successful magic-link login also implicitly verifies the email + // (the user demonstrably controls the inbox). + if u.EmailVerifiedAt == nil { + if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil { + u.EmailVerifiedAt = &now + } + } + return u, nil +} diff --git a/service_reset.go b/service_reset.go new file mode 100644 index 0000000..a2ebbb6 --- /dev/null +++ b/service_reset.go @@ -0,0 +1,69 @@ +package authkit + +import ( + "context" + "errors" + + "git.juancwu.dev/juancwu/errx" +) + +// RequestPasswordReset mints a single-use password-reset token for the user +// behind email and returns the plaintext for the caller to deliver via email. +// Returns ErrUserNotFound when the email isn't registered (per project +// policy of distinct errors over anti-enumeration). +func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) { + const op = "authkit.Auth.RequestPasswordReset" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + return "", errx.Wrap(op, err) + } + plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenPasswordReset, + UserID: u.ID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.PasswordResetTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConfirmPasswordReset consumes the reset token, sets the new password, +// bumps the user's session_version, and revokes outstanding sessions so the +// reset constitutes a global logout. +func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error { + const op = "authkit.Auth.ConfirmPasswordReset" + hash, ok := parseSecret(prefixPasswordRset, plaintextToken) + if !ok { + return errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now) + if err != nil { + return errx.Wrap(op, err) + } + newHash, err := a.deps.Hasher.Hash(newPassword) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil { + if errors.Is(err, ErrUserNotFound) { + return errx.Wrap(op, ErrUserNotFound) + } + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil { + return errx.Wrap(op, err) + } + return nil +} diff --git a/service_session.go b/service_session.go new file mode 100644 index 0000000..a6573d6 --- /dev/null +++ b/service_session.go @@ -0,0 +1,154 @@ +package authkit + +import ( + "context" + "net/http" + "net/netip" + "time" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// IssueSession mints an opaque session ID, persists the session record, and +// returns the plaintext (for the cookie) plus the stored Session. +func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) { + const op = "authkit.Auth.IssueSession" + plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random) + if err != nil { + return "", nil, errx.Wrap(op, err) + } + now := a.now() + expires := now.Add(a.cfg.SessionIdleTTL) + if cap := now.Add(a.cfg.SessionAbsoluteTTL); expires.After(cap) { + expires = cap + } + s := &Session{ + IDHash: hash, + UserID: userID, + UserAgent: userAgent, + IP: ip, + CreatedAt: now, + LastSeenAt: now, + ExpiresAt: expires, + } + if err := a.deps.Sessions.CreateSession(ctx, s); err != nil { + return "", nil, errx.Wrap(op, err) + } + return plaintext, s, nil +} + +// AuthenticateSession validates an opaque session string, slides the TTL, +// resolves the user's roles+permissions, and returns a Principal. Expired or +// unknown sessions return ErrSessionInvalid. +func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) { + const op = "authkit.Auth.AuthenticateSession" + hash, ok := parseSecret(prefixSession, plaintext) + if !ok { + return nil, errx.Wrap(op, ErrSessionInvalid) + } + s, err := a.deps.Sessions.GetSession(ctx, hash) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + if !s.ExpiresAt.After(now) { + _ = a.deps.Sessions.DeleteSession(ctx, hash) + return nil, errx.Wrap(op, ErrSessionInvalid) + } + + // Slide the idle TTL, capped at created_at + AbsoluteTTL so an active + // session still expires at the absolute boundary. + newExpires := now.Add(a.cfg.SessionIdleTTL) + if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) { + newExpires = cap + } + if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil { + return nil, errx.Wrap(op, err) + } + + roles, perms, err := a.resolveRolesAndPermissions(ctx, s.UserID) + if err != nil { + return nil, errx.Wrap(op, err) + } + return &Principal{ + UserID: s.UserID, + Method: AuthMethodSession, + SessionID: hash, + Roles: roles, + Permissions: perms, + IssuedAt: s.CreatedAt, + ExpiresAt: newExpires, + }, nil +} + +// RevokeSession deletes a single session by its plaintext id. Idempotent: +// missing sessions are not an error (logout twice should not 500). +func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error { + const op = "authkit.Auth.RevokeSession" + hash, ok := parseSecret(prefixSession, plaintext) + if !ok { + return nil + } + if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RevokeAllUserSessions kills every active session for the user and bumps +// the user's session_version (invalidating outstanding JWT access tokens). +func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.Auth.RevokeAllUserSessions" + if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the +// plaintext returned by IssueSession; pass the matching ExpiresAt from the +// returned *Session as `expires`. To clear a cookie at logout, pass an empty +// plaintext and a past expiry. +func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie { + c := &http.Cookie{ + Name: a.cfg.SessionCookieName, + Value: plaintext, + Path: a.cfg.SessionCookiePath, + Domain: a.cfg.SessionCookieDomain, + Secure: a.cfg.SessionCookieSecure, + HttpOnly: a.cfg.SessionCookieHTTPOnly, + SameSite: a.cfg.SessionCookieSameSite, + Expires: expires, + } + if plaintext == "" { + c.MaxAge = -1 + } + return c +} + +// resolveRolesAndPermissions fetches the user's role names and the union of +// their permission names. Both are returned as flat string slices for cheap +// containment checks on the Principal. +func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) { + roles, err := a.deps.Roles.GetUserRoles(ctx, userID) + if err != nil { + return nil, nil, err + } + perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID) + if err != nil { + return nil, nil, err + } + rNames := make([]string, len(roles)) + for i, r := range roles { + rNames[i] = r.Name + } + pNames := make([]string, len(perms)) + for i, p := range perms { + pNames[i] = p.Name + } + return rNames, pNames, nil +} diff --git a/service_test.go b/service_test.go new file mode 100644 index 0000000..868ec1b --- /dev/null +++ b/service_test.go @@ -0,0 +1,220 @@ +package authkit + +import ( + "context" + "errors" + "net/netip" + "testing" +) + +func TestRegisterAndLogin(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + if u.EmailNormalized != "alice@example.com" { + t.Fatalf("email_normalized = %q", u.EmailNormalized) + } + got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("LoginPassword: %v", err) + } + if got.ID != u.ID { + t.Fatalf("login user mismatch") + } +} + +func TestRegisterDuplicateEmail(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil { + t.Fatalf("Register: %v", err) + } + _, err := a.Register(context.Background(), "X@Y.COM", "abc") + if !errors.Is(err, ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken, got %v", err) + } +} + +func TestLoginWrongPassword(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil { + t.Fatalf("Register: %v", err) + } + if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestSessionIssueAuthenticateRevoke(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "s@s.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + if sess == nil || plain == "" { + t.Fatalf("missing session or plaintext") + } + + p, err := a.AuthenticateSession(context.Background(), plain) + if err != nil { + t.Fatalf("AuthenticateSession: %v", err) + } + if p.UserID != u.ID { + t.Fatalf("principal user id mismatch") + } + if err := a.RevokeSession(context.Background(), plain); err != nil { + t.Fatalf("RevokeSession: %v", err) + } + if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) + } +} + +func TestEmailVerificationFlow(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "ev@e.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + tok, err := a.RequestEmailVerification(context.Background(), u.ID) + if err != nil { + t.Fatalf("RequestEmailVerification: %v", err) + } + confirmed, err := a.ConfirmEmail(context.Background(), tok) + if err != nil { + t.Fatalf("ConfirmEmail: %v", err) + } + if confirmed.EmailVerifiedAt == nil { + t.Fatalf("email_verified_at not set") + } + // Re-using the token must fail. + if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err) + } +} + +func TestPasswordResetFlow(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "r@r.com", "old") + if err != nil { + t.Fatalf("Register: %v", err) + } + // Issue a session that should be invalidated by the reset. + plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{}) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + tok, err := a.RequestPasswordReset(context.Background(), "r@r.com") + if err != nil { + t.Fatalf("RequestPasswordReset: %v", err) + } + if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil { + t.Fatalf("ConfirmPasswordReset: %v", err) + } + if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) { + t.Fatalf("old password should fail post-reset, got %v", err) + } + if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil { + t.Fatalf("new password should work post-reset, got %v", err) + } + if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) { + t.Fatalf("session should be invalidated by reset, got %v", err) + } +} + +func TestMagicLinkFlow(t *testing.T) { + a := newTestAuth(t) + if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil { + t.Fatalf("Register: %v", err) + } + tok, err := a.RequestMagicLink(context.Background(), "m@m.com") + if err != nil { + t.Fatalf("RequestMagicLink: %v", err) + } + u, err := a.ConsumeMagicLink(context.Background(), tok) + if err != nil { + t.Fatalf("ConsumeMagicLink: %v", err) + } + if u.EmailVerifiedAt == nil { + t.Fatalf("magic link should imply email verification") + } + if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err) + } +} + +func TestAPIKeyFlowWithAbilities(t *testing.T) { + a := newTestAuth(t) + u, err := a.Register(context.Background(), "k@k.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plaintext, k, err := a.IssueAPIKey(context.Background(), u.ID, "ci", []string{"billing:read", "users:list"}, nil) + if err != nil { + t.Fatalf("IssueAPIKey: %v", err) + } + if k == nil || plaintext == "" { + t.Fatalf("missing api key") + } + p, err := a.AuthenticateAPIKey(context.Background(), plaintext) + if err != nil { + t.Fatalf("AuthenticateAPIKey: %v", err) + } + if !p.HasAbility("billing:read") || !p.HasAbility("users:list") { + t.Fatalf("abilities missing on principal: %+v", p.Abilities) + } + if p.HasAbility("admin:nuke") { + t.Fatalf("unexpected ability granted") + } + if err := a.RevokeAPIKey(context.Background(), plaintext); err != nil { + t.Fatalf("RevokeAPIKey: %v", err) + } + if _, err := a.AuthenticateAPIKey(context.Background(), plaintext); !errors.Is(err, ErrAPIKeyInvalid) { + t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err) + } +} + +func TestRBACRolesAndPermissions(t *testing.T) { + ctx := context.Background() + a := newTestAuth(t) + u, err := a.Register(ctx, "rb@a.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + + // Create role + permission, hook them up. + role := &Role{Name: "editor"} + if err := a.deps.Roles.CreateRole(ctx, role); err != nil { + t.Fatalf("CreateRole: %v", err) + } + perm := &Permission{Name: "posts:write"} + if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil { + t.Fatalf("CreatePermission: %v", err) + } + if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil { + t.Fatalf("AssignPermissionToRole: %v", err) + } + if err := a.AssignRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + ok, err := a.HasPermission(ctx, u.ID, "posts:write") + if err != nil || !ok { + t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err) + } + ok, err = a.HasRole(ctx, u.ID, "editor") + if err != nil || !ok { + t.Fatalf("HasRole editor should be true, got %v %v", ok, err) + } + if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("RemoveRole: %v", err) + } + ok, _ = a.HasPermission(ctx, u.ID, "posts:write") + if ok { + t.Fatalf("HasPermission should be false after RemoveRole") + } +} diff --git a/service_user.go b/service_user.go new file mode 100644 index 0000000..b888897 --- /dev/null +++ b/service_user.go @@ -0,0 +1,178 @@ +package authkit + +import ( + "context" + "errors" + "strings" + + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail +// and the email_normalized column. Trim + lowercase is intentional; we do +// not collapse Gmail-style "+" addressing or strip dots — that's a policy +// decision callers can layer on top. +func normalizeEmail(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +// Register creates a new user with an Argon2id-hashed password. Returns +// ErrEmailTaken if the normalized email is already registered. +func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) { + const op = "authkit.Auth.Register" + if email == "" || password == "" { + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + hash, err := a.deps.Hasher.Hash(password) + if err != nil { + return nil, errx.Wrap(op, err) + } + now := a.now() + u := &User{ + ID: uuid.New(), + Email: email, + EmailNormalized: normalizeEmail(email), + PasswordHash: hash, + CreatedAt: now, + UpdatedAt: now, + } + if err := a.deps.Users.CreateUser(ctx, u); err != nil { + return nil, errx.Wrap(op, err) + } + return u, nil +} + +// LoginPassword verifies the password and returns the authenticated user. +// Failure increments failed_logins; success resets it and stamps last_login_at. +// LoginHook (if configured) is invoked with the success outcome — use this to +// hook in rate limiting or audit logging. +func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) { + const op = "authkit.Auth.LoginPassword" + u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email)) + if err != nil { + _ = a.fireLoginHook(ctx, email, false) + if errors.Is(err, ErrUserNotFound) { + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + return nil, errx.Wrap(op, err) + } + if u.PasswordHash == "" { + _ = a.fireLoginHook(ctx, email, false) + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash) + if err != nil { + return nil, errx.Wrap(op, err) + } + if !ok { + _, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID) + _ = a.fireLoginHook(ctx, email, false) + return nil, errx.Wrap(op, ErrInvalidCredentials) + } + + now := a.now() + u.LastLoginAt = &now + u.FailedLogins = 0 + if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil { + return nil, errx.Wrap(op, err) + } + if err := a.deps.Users.UpdateUser(ctx, u); err != nil { + return nil, errx.Wrap(op, err) + } + + if needsRehash { + if newHash, herr := a.deps.Hasher.Hash(password); herr == nil { + _ = a.deps.Users.SetPassword(ctx, u.ID, newHash) + u.PasswordHash = newHash + } + } + _ = a.fireLoginHook(ctx, email, true) + return u, nil +} + +// ChangePassword verifies the current password, sets the new one, and bumps +// the user's session_version so all outstanding JWT access tokens are +// instantly invalidated. Outstanding opaque sessions are also revoked. +func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error { + const op = "authkit.Auth.ChangePassword" + u, err := a.deps.Users.GetUserByID(ctx, userID) + if err != nil { + return errx.Wrap(op, err) + } + if u.PasswordHash == "" { + return errx.Wrap(op, ErrInvalidCredentials) + } + ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash) + if err != nil { + return errx.Wrap(op, err) + } + if !ok { + return errx.Wrap(op, ErrInvalidCredentials) + } + newHash, err := a.deps.Hasher.Hash(newPassword) + if err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil { + return errx.Wrap(op, err) + } + if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// RequestEmailVerification mints a single-use email-verify token for the +// user. Return the plaintext to the caller so they can put it in an email +// link; the lookup hash is stored in TokenStore. +func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) { + const op = "authkit.Auth.RequestEmailVerification" + plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random) + if err != nil { + return "", errx.Wrap(op, err) + } + now := a.now() + t := &Token{ + Hash: hash, + Kind: TokenEmailVerify, + UserID: userID, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.EmailVerifyTTL), + } + if err := a.deps.Tokens.CreateToken(ctx, t); err != nil { + return "", errx.Wrap(op, err) + } + return plaintext, nil +} + +// ConfirmEmail consumes the verification token and marks the user's email +// verified. Returns ErrTokenInvalid if the token is missing/expired/used. +func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) { + const op = "authkit.Auth.ConfirmEmail" + hash, ok := parseSecret(prefixEmailVerify, plaintextToken) + if !ok { + return nil, errx.Wrap(op, ErrTokenInvalid) + } + now := a.now() + t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now) + if err != nil { + return nil, errx.Wrap(op, err) + } + if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil { + return nil, errx.Wrap(op, err) + } + return a.deps.Users.GetUserByID(ctx, t.UserID) +} + +// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied +// hooks; we never want a misbehaving telemetry hook to break login. +func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error { + if a.cfg.LoginHook == nil { + return nil + } + return a.cfg.LoginHook(ctx, email, success) +} diff --git a/sqlstore/apikeys.go b/sqlstore/apikeys.go new file mode 100644 index 0000000..1212c20 --- /dev/null +++ b/sqlstore/apikeys.go @@ -0,0 +1,124 @@ +package sqlstore + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type apiKeyStore struct{ storeBase } + +func (s *apiKeyStore) CreateAPIKey(ctx context.Context, k *authkit.APIKey) error { + const op = "authkit.sqlstore.APIKeyStore.CreateAPIKey" + if k.CreatedAt.IsZero() { + k.CreatedAt = time.Now().UTC() + } + if k.Abilities == nil { + k.Abilities = []string{} + } + abilities, err := json.Marshal(k.Abilities) + if err != nil { + return errx.Wrap(op, err) + } + _, err = s.db.ExecContext(ctx, s.q.CreateAPIKey, + k.IDHash, uuidArg(k.OwnerID), k.Name, abilities, + nullableTime(k.LastUsedAt), k.CreatedAt, + nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt)) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *apiKeyStore) GetAPIKey(ctx context.Context, idHash []byte) (*authkit.APIKey, error) { + const op = "authkit.sqlstore.APIKeyStore.GetAPIKey" + k, err := scanAPIKey(s.db.QueryRowContext(ctx, s.q.GetAPIKey, idHash)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrAPIKeyInvalid)) + } + return k, nil +} + +func (s *apiKeyStore) ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*authkit.APIKey, error) { + const op = "authkit.sqlstore.APIKeyStore.ListAPIKeysByOwner" + rows, err := s.db.QueryContext(ctx, s.q.ListAPIKeysByOwner, uuidArg(ownerID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.APIKey + for rows.Next() { + k, err := scanAPIKey(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, k) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *apiKeyStore) TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.TouchAPIKey" + if _, err := s.db.ExecContext(ctx, s.q.TouchAPIKey, at, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *apiKeyStore) RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKey" + tag, err := s.db.ExecContext(ctx, s.q.RevokeAPIKey, at, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrAPIKeyInvalid) + } + return nil +} + +func (s *apiKeyStore) RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error { + const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKeysByOwner" + if _, err := s.db.ExecContext(ctx, s.q.RevokeAPIKeysByOwner, at, uuidArg(ownerID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func scanAPIKey(row rowScanner) (*authkit.APIKey, error) { + var ( + k authkit.APIKey + ownerIDStr string + abilitiesRaw []byte + lastUsed sql.NullTime + expires sql.NullTime + revoked sql.NullTime + ) + if err := row.Scan(&k.IDHash, &ownerIDStr, &k.Name, &abilitiesRaw, + &lastUsed, &k.CreatedAt, &expires, &revoked); err != nil { + return nil, err + } + owner, err := scanUUID(ownerIDStr) + if err != nil { + return nil, err + } + k.OwnerID = owner + if len(abilitiesRaw) > 0 { + if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil { + return nil, err + } + } + if k.Abilities == nil { + k.Abilities = []string{} + } + k.LastUsedAt = scanNullTimePtr(lastUsed) + k.ExpiresAt = scanNullTimePtr(expires) + k.RevokedAt = scanNullTimePtr(revoked) + return &k, nil +} diff --git a/sqlstore/dialect.go b/sqlstore/dialect.go new file mode 100644 index 0000000..4b6c3d8 --- /dev/null +++ b/sqlstore/dialect.go @@ -0,0 +1,118 @@ +package sqlstore + +import ( + "context" + "database/sql" + "io/fs" +) + +// Dialect describes how a particular SQL backend renders the queries authkit +// runs at runtime, bootstraps its session for migrations, and reports +// driver-specific errors. v1 ships the Postgres dialect; future MySQL or +// SQLite implementations satisfy the same interface without changes to the +// store code. +type Dialect interface { + // Name returns a stable short identifier ("postgres", "mysql", ...). + Name() string + + // BuildQueries renders every templated SQL string from a validated + // Schema. Called once at New() and at Migrate() so identifiers and + // placeholder styles are baked in by the time stores execute queries. + BuildQueries(s Schema) Queries + + // Bootstrap runs once per Migrate() before the migration lock is taken. + // It is the place for things that must happen outside any transaction + // (CREATE EXTENSION on Postgres, for example). Must be idempotent and + // may be a no-op. + Bootstrap(ctx context.Context, db *sql.DB) error + + // AcquireMigrationLock takes a session-scoped lock on conn so concurrent + // migrators serialise. The returned release function never returns an + // error — implementations may log internally. + AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error) + + // Migrations returns the dialect's embedded .sql files, rooted such + // that fs.ReadDir(".") lists them lex-sorted by filename. + Migrations() fs.FS + + // IsUniqueViolation maps a duplicate-key error from this driver to true + // so insert paths can return the matching authkit sentinel without + // depending on driver internals. + IsUniqueViolation(err error) bool + + // Placeholder returns the placeholder for parameter index n (1-based) + // in this dialect — "$1" for Postgres, "?" for MySQL/SQLite. + Placeholder(n int) string + + // PlaceholderList returns a comma-separated placeholder list for `count` + // parameters starting at position `start` (1-based), suitable for + // dynamic IN-clauses. The second return is the same length, prefilled + // with nil so the caller can append actual values. + PlaceholderList(start, count int) string +} + +// Queries is the full set of statement templates authkit issues. Field names +// match the store method that consumes the query. +type Queries struct { + // users + CreateUser string + GetUserByID string + GetUserByEmail string + UpdateUser string + DeleteUser string + SetPassword string + SetEmailVerified string + BumpSessionVersion string + IncrementFailedLogins string + ResetFailedLogins string + + // sessions + CreateSession string + GetSession string + TouchSession string + DeleteSession string + DeleteUserSessions string + DeleteExpiredSessions string + + // tokens + CreateToken string + ConsumeToken string + GetToken string + DeleteByChain string + DeleteExpiredTokens string + + // api keys + CreateAPIKey string + GetAPIKey string + ListAPIKeysByOwner string + TouchAPIKey string + RevokeAPIKey string + RevokeAPIKeysByOwner string + + // roles + CreateRole string + GetRoleByID string + GetRoleByName string + ListRoles string + DeleteRole string + AssignRoleToUser string + RemoveRoleFromUser string + GetUserRoles string + // HasAnyRole is built at call time because the placeholder count varies. + + // permissions + CreatePermission string + GetPermissionByID string + GetPermissionByName string + ListPermissions string + DeletePermission string + AssignPermissionToRole string + RemovePermissionFromRole string + GetRolePermissions string + GetUserPermissions string + + // migrations + CreateMigrationsTable string + SelectAppliedVersions string + InsertAppliedVersion string +} diff --git a/sqlstore/dialect/postgres/errors.go b/sqlstore/dialect/postgres/errors.go new file mode 100644 index 0000000..d4e93ad --- /dev/null +++ b/sqlstore/dialect/postgres/errors.go @@ -0,0 +1,33 @@ +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5/pgconn" +) + +// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib +// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError. +// lib/pq uses *pq.Error which has a Code field of the same value. +const pgUniqueViolation = "23505" + +// isUniqueViolation inspects err for a Postgres unique-violation, regardless +// of which driver registered the connection. We match on either the pgx +// error type or any error implementing a Code() string method (lib/pq's +// pq.Error has SQLState and Code fields; we check via reflection-free +// duck-typing through an interface). +func isUniqueViolation(err error) bool { + if err == nil { + return false + } + var pgxErr *pgconn.PgError + if errors.As(err, &pgxErr) { + return pgxErr.Code == pgUniqueViolation + } + type sqlStater interface{ SQLState() string } + var s sqlStater + if errors.As(err, &s) { + return s.SQLState() == pgUniqueViolation + } + return false +} diff --git a/sqlstore/dialect/postgres/migrations/0001_init.sql b/sqlstore/dialect/postgres/migrations/0001_init.sql new file mode 100644 index 0000000..e8f751d --- /dev/null +++ b/sqlstore/dialect/postgres/migrations/0001_init.sql @@ -0,0 +1,99 @@ +-- 0001_init.sql +-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the +-- library can be embedded in an existing application database. Each +-- migration owns its own transaction and inserts its version row at the +-- bottom; the runner only orchestrates file discovery and concurrency. + +BEGIN; + +CREATE TABLE IF NOT EXISTS authkit_schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_users ( + id UUID PRIMARY KEY, + email TEXT NOT NULL, + email_normalized TEXT NOT NULL, + email_verified_at TIMESTAMPTZ, + password_hash TEXT, + session_version INTEGER NOT NULL DEFAULT 0, + failed_logins INTEGER NOT NULL DEFAULT 0, + last_login_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq + ON authkit_users (email_normalized); + +CREATE TABLE IF NOT EXISTS authkit_sessions ( + id_hash BYTEA PRIMARY KEY, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + user_agent TEXT NOT NULL DEFAULT '', + ip TEXT, + created_at TIMESTAMPTZ NOT NULL, + last_seen_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL +); +CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id); +CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at); + +CREATE TABLE IF NOT EXISTS authkit_tokens ( + hash BYTEA NOT NULL, + kind TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + chain_id TEXT, + consumed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (kind, hash) +); +CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id); +CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at); +CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx + ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL; + +CREATE TABLE IF NOT EXISTS authkit_api_keys ( + id_hash BYTEA PRIMARY KEY, + owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + name TEXT NOT NULL, + abilities JSONB NOT NULL DEFAULT '[]'::jsonb, + last_used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL, + expires_at TIMESTAMPTZ, + revoked_at TIMESTAMPTZ +); +CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id); + +CREATE TABLE IF NOT EXISTS authkit_roles ( + id UUID PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_permissions ( + id UUID PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL +); + +CREATE TABLE IF NOT EXISTS authkit_role_permissions ( + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE, + PRIMARY KEY (role_id, permission_id) +); + +CREATE TABLE IF NOT EXISTS authkit_user_roles ( + user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, + role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE, + granted_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (user_id, role_id) +); +CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id); + +INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now()) +ON CONFLICT (version) DO NOTHING; + +COMMIT; diff --git a/sqlstore/dialect/postgres/postgres.go b/sqlstore/dialect/postgres/postgres.go new file mode 100644 index 0000000..56ec880 --- /dev/null +++ b/sqlstore/dialect/postgres/postgres.go @@ -0,0 +1,272 @@ +// Package postgres is the Postgres dialect for authkit/sqlstore. Importing +// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"` +// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`. +package postgres + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log" + "strings" + + "git.juancwu.dev/juancwu/authkit/sqlstore" +) + +//go:embed migrations/*.sql +var migrationsFS embed.FS + +// Dialect implements sqlstore.Dialect for Postgres. The zero value is the +// only required form; New() returns a pointer for clarity at call sites. +type Dialect struct{} + +// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate. +func New() *Dialect { return &Dialect{} } + +func (Dialect) Name() string { return "postgres" } + +// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable +// across rollouts and unlikely to clash with caller advisory locks. +const advisoryLockKey int64 = 0x617574686b6974 + +func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error { + // Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a + // hook for future migration prerequisites. + return nil +} + +func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) { + if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil { + return nil, err + } + release := func() { + if _, err := conn.ExecContext(context.Background(), + "SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil { + log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err) + } + } + return release, nil +} + +func (Dialect) Migrations() fs.FS { + sub, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + // migrationsFS is statically populated; this can only fail if the + // embed directive is removed, which would be a build-time error. + panic(err) + } + return sub +} + +func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) } + +func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) } + +// PlaceholderList renders a comma-separated `$start,$start+1,...` list of +// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion. +func (Dialect) PlaceholderList(start, count int) string { + if count <= 0 { + return "" + } + var b strings.Builder + for i := 0; i < count; i++ { + if i > 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, "$%d", start+i) + } + return b.String() +} + +// BuildQueries renders every query authkit issues, with table identifiers +// taken from s and `?` placeholders rewritten to `$N`. Identifiers are +// already validated by Schema.Validate — this is interpolated with +// fmt.Sprintf, so the validation gate is load-bearing. +func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries { + t := s.Tables + q := sqlstore.Queries{ + // users + CreateUser: `INSERT INTO ` + t.Users + ` + (id, email, email_normalized, email_verified_at, password_hash, + session_version, failed_logins, last_login_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + GetUserByID: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, failed_logins, last_login_at, + created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`, + GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at, + password_hash, session_version, failed_logins, last_login_at, + created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`, + UpdateUser: `UPDATE ` + t.Users + ` SET + email = ?, email_normalized = ?, email_verified_at = ?, + password_hash = ?, session_version = ?, failed_logins = ?, + last_login_at = ?, updated_at = ? + WHERE id = ?`, + DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`, + SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`, + SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`, + BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`, + IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`, + ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`, + + // sessions + CreateSession: `INSERT INTO ` + t.Sessions + ` + (id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at + FROM ` + t.Sessions + ` WHERE id_hash = ?`, + TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`, + DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`, + DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`, + DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`, + + // tokens + CreateToken: `INSERT INTO ` + t.Tokens + ` + (hash, kind, user_id, chain_id, consumed_at, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + ConsumeToken: `UPDATE ` + t.Tokens + ` + SET consumed_at = ? + WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ? + RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`, + GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at + FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`, + DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`, + DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`, + + // api keys + CreateAPIKey: `INSERT INTO ` + t.APIKeys + ` + (id_hash, owner_id, name, abilities, last_used_at, created_at, expires_at, revoked_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + GetAPIKey: `SELECT id_hash, owner_id, name, abilities, last_used_at, + created_at, expires_at, revoked_at + FROM ` + t.APIKeys + ` WHERE id_hash = ?`, + ListAPIKeysByOwner: `SELECT id_hash, owner_id, name, abilities, last_used_at, + created_at, expires_at, revoked_at + FROM ` + t.APIKeys + ` WHERE owner_id = ? ORDER BY created_at DESC`, + TouchAPIKey: `UPDATE ` + t.APIKeys + ` SET last_used_at = ? WHERE id_hash = ?`, + RevokeAPIKey: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`, + RevokeAPIKeysByOwner: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE owner_id = ? AND revoked_at IS NULL`, + + // roles + CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, + GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`, + GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`, + ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`, + DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`, + AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`, + RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`, + GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at + FROM ` + t.Roles + ` r + JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id + WHERE ur.user_id = ? ORDER BY r.name`, + + // permissions + CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`, + GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`, + GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`, + ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`, + DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`, + AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?) + ON CONFLICT DO NOTHING`, + RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`, + GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at + FROM ` + t.Permissions + ` p + JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id + WHERE rp.role_id = ? ORDER BY p.name`, + GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at + FROM ` + t.Permissions + ` p + JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id + JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id + WHERE ur.user_id = ? ORDER BY p.name`, + + // migrations + CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL + )`, + SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations, + InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`, + } + + // Rewrite `?` placeholders to `$N`. Each query is independent; numbering + // resets per query. + q.CreateUser = rebind(q.CreateUser) + q.GetUserByID = rebind(q.GetUserByID) + q.GetUserByEmail = rebind(q.GetUserByEmail) + q.UpdateUser = rebind(q.UpdateUser) + q.DeleteUser = rebind(q.DeleteUser) + q.SetPassword = rebind(q.SetPassword) + q.SetEmailVerified = rebind(q.SetEmailVerified) + q.BumpSessionVersion = rebind(q.BumpSessionVersion) + q.IncrementFailedLogins = rebind(q.IncrementFailedLogins) + q.ResetFailedLogins = rebind(q.ResetFailedLogins) + + q.CreateSession = rebind(q.CreateSession) + q.GetSession = rebind(q.GetSession) + q.TouchSession = rebind(q.TouchSession) + q.DeleteSession = rebind(q.DeleteSession) + q.DeleteUserSessions = rebind(q.DeleteUserSessions) + q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions) + + q.CreateToken = rebind(q.CreateToken) + q.ConsumeToken = rebind(q.ConsumeToken) + q.GetToken = rebind(q.GetToken) + q.DeleteByChain = rebind(q.DeleteByChain) + q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens) + + q.CreateAPIKey = rebind(q.CreateAPIKey) + q.GetAPIKey = rebind(q.GetAPIKey) + q.ListAPIKeysByOwner = rebind(q.ListAPIKeysByOwner) + q.TouchAPIKey = rebind(q.TouchAPIKey) + q.RevokeAPIKey = rebind(q.RevokeAPIKey) + q.RevokeAPIKeysByOwner = rebind(q.RevokeAPIKeysByOwner) + + q.CreateRole = rebind(q.CreateRole) + q.GetRoleByID = rebind(q.GetRoleByID) + q.GetRoleByName = rebind(q.GetRoleByName) + q.ListRoles = rebind(q.ListRoles) + q.DeleteRole = rebind(q.DeleteRole) + q.AssignRoleToUser = rebind(q.AssignRoleToUser) + q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser) + q.GetUserRoles = rebind(q.GetUserRoles) + + q.CreatePermission = rebind(q.CreatePermission) + q.GetPermissionByID = rebind(q.GetPermissionByID) + q.GetPermissionByName = rebind(q.GetPermissionByName) + q.ListPermissions = rebind(q.ListPermissions) + q.DeletePermission = rebind(q.DeletePermission) + q.AssignPermissionToRole = rebind(q.AssignPermissionToRole) + q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole) + q.GetRolePermissions = rebind(q.GetRolePermissions) + q.GetUserPermissions = rebind(q.GetUserPermissions) + + q.SelectAppliedVersions = rebind(q.SelectAppliedVersions) + q.InsertAppliedVersion = rebind(q.InsertAppliedVersion) + // CreateMigrationsTable contains no parameters. + + return q +} + +// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order. +// Our query strings contain no string literals that include `?` (verified +// by inspection of every query in BuildQueries); a literal-aware rewriter +// would be more defensive but is not needed for v1. +func rebind(s string) string { + var b strings.Builder + b.Grow(len(s) + 16) + n := 1 + for i := 0; i < len(s); i++ { + c := s[i] + if c == '?' { + fmt.Fprintf(&b, "$%d", n) + n++ + continue + } + b.WriteByte(c) + } + return b.String() +} + +// Compile-time interface compliance check. +var _ sqlstore.Dialect = (*Dialect)(nil) diff --git a/sqlstore/migrate.go b/sqlstore/migrate.go new file mode 100644 index 0000000..48906a0 --- /dev/null +++ b/sqlstore/migrate.go @@ -0,0 +1,116 @@ +package sqlstore + +import ( + "context" + "database/sql" + "io/fs" + "sort" + "strings" + "time" + + "git.juancwu.dev/juancwu/errx" +) + +// Migrate applies every embedded migration the dialect ships that has not +// yet been recorded in the schema-migrations table. It is safe to call +// repeatedly and concurrently across processes — the dialect's session +// lock serialises rollouts. +// +// Each migration .sql file is responsible for owning its own +// BEGIN/COMMIT and inserting a row into the schema-migrations table on +// success. The runner only handles file discovery, version tracking, and +// concurrency. +func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error { + const op = "authkit.sqlstore.Migrate" + if db == nil { + return errx.New(op, "db is required") + } + if dialect == nil { + return errx.New(op, "dialect is required") + } + if err := schema.Validate(); err != nil { + return errx.Wrap(op, err) + } + + if err := dialect.Bootstrap(ctx, db); err != nil { + return errx.Wrap(op, err) + } + + conn, err := db.Conn(ctx) + if err != nil { + return errx.Wrap(op, err) + } + defer conn.Close() + + release, err := dialect.AcquireMigrationLock(ctx, conn) + if err != nil { + return errx.Wrap(op, err) + } + defer release() + + q := dialect.BuildQueries(schema) + if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil { + return errx.Wrap(op, err) + } + + applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions) + if err != nil { + return errx.Wrap(op, err) + } + + migs := dialect.Migrations() + files, err := fs.ReadDir(migs, ".") + if err != nil { + return errx.Wrap(op, err) + } + names := make([]string, 0, len(files)) + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + names = append(names, f.Name()) + } + } + sort.Strings(names) + + for _, name := range names { + version := strings.TrimSuffix(name, ".sql") + if _, ok := applied[version]; ok { + continue + } + body, err := fs.ReadFile(migs, name) + if err != nil { + return errx.Wrapf(op, err, "read %s", name) + } + if _, err := conn.ExecContext(ctx, string(body)); err != nil { + return errx.Wrapf(op, err, "apply %s", version) + } + } + return nil +} + +// applyVersionRow is intentionally not exposed: migration files own their +// own version-row insert. We keep this helper around in case a dialect ever +// needs to backfill versions from outside a migration body — currently +// unused. +var _ = applyVersionRow + +func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error { + _, err := conn.ExecContext(ctx, insertQ, version, at) + return err +} + +func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) { + rows, err := conn.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + out := make(map[string]struct{}) + for rows.Next() { + var v string + if err := rows.Scan(&v); err != nil { + return nil, err + } + out[v] = struct{}{} + } + return out, rows.Err() +} diff --git a/sqlstore/rbac.go b/sqlstore/rbac.go new file mode 100644 index 0000000..767440e --- /dev/null +++ b/sqlstore/rbac.go @@ -0,0 +1,301 @@ +package sqlstore + +import ( + "context" + "fmt" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type roleStore struct{ storeBase } +type permissionStore struct{ storeBase } + +// ----- roleStore ------------------------------------------------------------ + +func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error { + const op = "authkit.sqlstore.RoleStore.CreateRole" + if r.ID == uuid.Nil { + r.ID = uuid.New() + } + if r.CreatedAt.IsZero() { + r.CreatedAt = time.Now().UTC() + } + if _, err := s.db.ExecContext(ctx, s.q.CreateRole, + uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrapf(op, err, "role %q already exists", r.Name) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetRoleByID" + r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) + } + return r, nil +} + +func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetRoleByName" + r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound)) + } + return r, nil +} + +func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.ListRoles" + rows, err := s.db.QueryContext(ctx, s.q.ListRoles) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.DeleteRole" + tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrRoleNotFound) + } + return nil +} + +func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.AssignRoleToUser" + if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser, + uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error { + const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser" + if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser, + uuidArg(userID), uuidArg(roleID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) { + const op = "authkit.sqlstore.RoleStore.GetUserRoles" + rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Role + for rows.Next() { + r, err := scanRole(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, r) + } + return out, errx.Wrap(op, rows.Err()) +} + +// HasAnyRole builds the IN-clause at call time because the placeholder count +// depends on len(names). Identifier substitution comes from the validated +// Schema; values are bound, never interpolated. +func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) { + const op = "authkit.sqlstore.RoleStore.HasAnyRole" + if len(names) == 0 { + return false, nil + } + // Placeholder $1 is the user_id; $2..$N+1 cover the names slice. + listSQL := s.d.PlaceholderList(2, len(names)) + q := fmt.Sprintf(`SELECT EXISTS ( + SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id + WHERE ur.user_id = %s AND r.name IN (%s) + )`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL) + + args := make([]any, 0, 1+len(names)) + args = append(args, uuidArg(userID)) + for _, n := range names { + args = append(args, n) + } + var ok bool + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil { + return false, errx.Wrap(op, err) + } + return ok, nil +} + +// ----- permissionStore ------------------------------------------------------ + +func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error { + const op = "authkit.sqlstore.PermissionStore.CreatePermission" + if p.ID == uuid.Nil { + p.ID = uuid.New() + } + if p.CreatedAt.IsZero() { + p.CreatedAt = time.Now().UTC() + } + if _, err := s.db.ExecContext(ctx, s.q.CreatePermission, + uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrapf(op, err, "permission %q already exists", p.Name) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetPermissionByID" + p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) + } + return p, nil +} + +func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetPermissionByName" + p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound)) + } + return p, nil +} + +func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.ListPermissions" + rows, err := s.db.QueryContext(ctx, s.q.ListPermissions) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.DeletePermission" + tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrPermissionNotFound) + } + return nil +} + +func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole" + if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error { + const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole" + if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole, + uuidArg(roleID), uuidArg(permID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetRolePermissions" + rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) { + const op = "authkit.sqlstore.PermissionStore.GetUserPermissions" + rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID)) + if err != nil { + return nil, errx.Wrap(op, err) + } + defer rows.Close() + var out []*authkit.Permission + for rows.Next() { + p, err := scanPermission(rows) + if err != nil { + return nil, errx.Wrap(op, err) + } + out = append(out, p) + } + return out, errx.Wrap(op, rows.Err()) +} + +func scanRole(row rowScanner) (*authkit.Role, error) { + var ( + r authkit.Role + idStr string + ) + if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + r.ID = id + return &r, nil +} + +func scanPermission(row rowScanner) (*authkit.Permission, error) { + var ( + p authkit.Permission + idStr string + ) + if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + p.ID = id + return &p, nil +} diff --git a/sqlstore/scan.go b/sqlstore/scan.go new file mode 100644 index 0000000..a962535 --- /dev/null +++ b/sqlstore/scan.go @@ -0,0 +1,86 @@ +package sqlstore + +import ( + "database/sql" + "net/netip" + "time" + + "github.com/google/uuid" +) + +// rowScanner is the lowest-common-denominator interface satisfied by both +// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops +// uniformly. +type rowScanner interface { + Scan(dest ...any) error +} + +// nullableTime returns nil when t is the zero time, otherwise &t. Callers +// should usually accept *time.Time on the model side and bind via this. +func nullableTime(t *time.Time) any { + if t == nil { + return nil + } + return *t +} + +// nullableString turns "" into nil so columns store NULL rather than an +// empty string. Used for password hashes and similar optional text columns. +func nullableString(s string) any { + if s == "" { + return nil + } + return s +} + +// nullableAddrString returns the string form of a netip.Addr when valid, or +// nil to bind as SQL NULL. Pairs with scanAddr. +func nullableAddrString(a netip.Addr) any { + if !a.IsValid() { + return nil + } + return a.String() +} + +// scanAddr parses a *string column into a netip.Addr. The zero Addr value is +// returned when the column was NULL or empty. +func scanAddr(s *string) (netip.Addr, error) { + if s == nil || *s == "" { + return netip.Addr{}, nil + } + return netip.ParseAddr(*s) +} + +// uuidArg returns the canonical string form of a UUID for binding. Every +// store uses this rather than passing uuid.UUID directly to keep behaviour +// identical across drivers (some accept driver.Valuer, some don't). +func uuidArg(id uuid.UUID) any { return id.String() } + +// scanUUID reads a string column and parses it back to a uuid.UUID. +func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) } + +// chainArg returns either a *string or nil for binding the chain_id column. +func chainArg(c *string) any { + if c == nil { + return nil + } + return *c +} + +// scanNullStringPtr converts sql.NullString to *string for the model. +func scanNullStringPtr(ns sql.NullString) *string { + if !ns.Valid { + return nil + } + v := ns.String + return &v +} + +// scanNullTimePtr converts sql.NullTime to *time.Time for the model. +func scanNullTimePtr(nt sql.NullTime) *time.Time { + if !nt.Valid { + return nil + } + t := nt.Time + return &t +} diff --git a/sqlstore/schema.go b/sqlstore/schema.go new file mode 100644 index 0000000..cb254e4 --- /dev/null +++ b/sqlstore/schema.go @@ -0,0 +1,78 @@ +package sqlstore + +import ( + "regexp" + + "git.juancwu.dev/juancwu/errx" +) + +// Schema lets consumers map authkit storage to their own table names. +// Column overrides are intentionally not present in v1 — adding them later +// is purely additive. +type Schema struct { + Tables Tables +} + +// Tables is the per-table identifier override set. Every field must be a +// valid unquoted SQL identifier (matching identifierRE). Validation runs at +// New() and Migrate() time so SQL injection through Schema is impossible +// past that gate. +type Tables struct { + Users string + Sessions string + Tokens string + APIKeys string + Roles string + Permissions string + UserRoles string + RolePermissions string + SchemaMigrations string +} + +// DefaultSchema returns the stock authkit_* names used by the embedded +// migration files. +func DefaultSchema() Schema { + return Schema{Tables: Tables{ + Users: "authkit_users", + Sessions: "authkit_sessions", + Tokens: "authkit_tokens", + APIKeys: "authkit_api_keys", + Roles: "authkit_roles", + Permissions: "authkit_permissions", + UserRoles: "authkit_user_roles", + RolePermissions: "authkit_role_permissions", + SchemaMigrations: "authkit_schema_migrations", + }} +} + +// identifierRE matches the safe ASCII identifier subset shared by Postgres, +// MySQL and SQLite when not quoted. Anything outside this set is rejected +// rather than escaped — Schema is not the place to support exotic names. +var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// Validate ensures every Schema.Tables field is a non-empty, safe identifier. +func (s Schema) Validate() error { + const op = "authkit.sqlstore.Schema.Validate" + checks := []struct { + field, value string + }{ + {"Users", s.Tables.Users}, + {"Sessions", s.Tables.Sessions}, + {"Tokens", s.Tables.Tokens}, + {"APIKeys", s.Tables.APIKeys}, + {"Roles", s.Tables.Roles}, + {"Permissions", s.Tables.Permissions}, + {"UserRoles", s.Tables.UserRoles}, + {"RolePermissions", s.Tables.RolePermissions}, + {"SchemaMigrations", s.Tables.SchemaMigrations}, + } + for _, c := range checks { + if c.value == "" { + return errx.Newf(op, "Schema.Tables.%s is empty", c.field) + } + if !identifierRE.MatchString(c.value) { + return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value) + } + } + return nil +} diff --git a/sqlstore/sessions.go b/sqlstore/sessions.go new file mode 100644 index 0000000..8241bb5 --- /dev/null +++ b/sqlstore/sessions.go @@ -0,0 +1,98 @@ +package sqlstore + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type sessionStore struct{ storeBase } + +func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error { + const op = "authkit.sqlstore.SessionStore.CreateSession" + now := time.Now().UTC() + if ses.CreatedAt.IsZero() { + ses.CreatedAt = now + } + if ses.LastSeenAt.IsZero() { + ses.LastSeenAt = ses.CreatedAt + } + _, err := s.db.ExecContext(ctx, s.q.CreateSession, + ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP), + ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) { + const op = "authkit.sqlstore.SessionStore.GetSession" + var ( + ses authkit.Session + uidStr string + ipStr sql.NullString + ) + err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan( + &ses.IDHash, &uidStr, &ses.UserAgent, &ipStr, + &ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid)) + } + uid, err := scanUUID(uidStr) + if err != nil { + return nil, errx.Wrap(op, err) + } + ses.UserID = uid + if ipStr.Valid { + addr, err := scanAddr(&ipStr.String) + if err != nil { + return nil, errx.Wrap(op, err) + } + ses.IP = addr + } + return &ses, nil +} + +func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error { + const op = "authkit.sqlstore.SessionStore.TouchSession" + tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrSessionInvalid) + } + return nil +} + +func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error { + const op = "authkit.sqlstore.SessionStore.DeleteSession" + if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.sqlstore.SessionStore.DeleteUserSessions" + if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil { + return errx.Wrap(op, err) + } + return nil +} + +func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.sqlstore.SessionStore.DeleteExpired" + tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} diff --git a/sqlstore/sqlstore.go b/sqlstore/sqlstore.go new file mode 100644 index 0000000..a87d9ca --- /dev/null +++ b/sqlstore/sqlstore.go @@ -0,0 +1,59 @@ +// Package sqlstore provides database/sql-backed implementations of every +// authkit store interface. It works with any driver (pgx-stdlib, lib/pq, +// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via +// Schema. Driver-specific behaviour lives in a Dialect; v1 ships +// dialect/postgres. +package sqlstore + +import ( + "database/sql" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" +) + +// Stores bundles every store implementation produced by New, ready to drop +// into authkit.Deps. +type Stores struct { + Users authkit.UserStore + Sessions authkit.SessionStore + Tokens authkit.TokenStore + APIKeys authkit.APIKeyStore + Roles authkit.RoleStore + Permissions authkit.PermissionStore +} + +// New validates the Schema, builds the dialect's templated queries, and +// returns store implementations bound to db. +func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) { + const op = "authkit.sqlstore.New" + if db == nil { + return nil, errx.New(op, "db is required") + } + if dialect == nil { + return nil, errx.New(op, "dialect is required") + } + if err := schema.Validate(); err != nil { + return nil, errx.Wrap(op, err) + } + q := dialect.BuildQueries(schema) + + base := storeBase{db: db, q: q, d: dialect, s: schema} + return &Stores{ + Users: &userStore{storeBase: base}, + Sessions: &sessionStore{storeBase: base}, + Tokens: &tokenStore{storeBase: base}, + APIKeys: &apiKeyStore{storeBase: base}, + Roles: &roleStore{storeBase: base}, + Permissions: &permissionStore{storeBase: base}, + }, nil +} + +// storeBase carries the shared dependencies every store needs. Embedded into +// each concrete store struct. +type storeBase struct { + db *sql.DB + q Queries + d Dialect + s Schema +} diff --git a/sqlstore/sqlstore_test.go b/sqlstore/sqlstore_test.go new file mode 100644 index 0000000..8c444b6 --- /dev/null +++ b/sqlstore/sqlstore_test.go @@ -0,0 +1,258 @@ +package sqlstore_test + +// Integration tests against a real Postgres. Skipped unless +// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by +// running Migrate against a randomly-named set of tables — no cleanup +// fixture, no external Docker dependency. + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/netip" + "os" + "testing" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/authkit/hasher" + "git.juancwu.dev/juancwu/authkit/sqlstore" + pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +func envURL(t *testing.T) string { + t.Helper() + url := os.Getenv("AUTHKIT_TEST_DATABASE_URL") + if url == "" { + t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test") + } + return url +} + +// freshDB opens a connection, runs Migrate against the default schema, and +// schedules a teardown that drops every authkit_* table. Tests must run +// sequentially against a single database (the package's tests do, by +// default — go test serialises within a package unless t.Parallel is +// called). +func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) { + t.Helper() + url := envURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("ping: %v", err) + } + + schema := sqlstore.DefaultSchema() + // Drop first to start clean — previous failed tests may have left rows. + dropAuthkitTables(t, db, schema) + if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { + t.Fatalf("Migrate: %v", err) + } + t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) + + stores, err := sqlstore.New(db, pgdialect.New(), schema) + if err != nil { + t.Fatalf("sqlstore.New: %v", err) + } + auth := authkit.New(authkit.Deps{ + Users: stores.Users, + Sessions: stores.Sessions, + Tokens: stores.Tokens, + APIKeys: stores.APIKeys, + Roles: stores.Roles, + Permissions: stores.Permissions, + Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), + }, authkit.Config{ + JWTSecret: []byte("integration-secret-thirty-two!!!!"), + JWTIssuer: "authkit-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: 1 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + }) + return auth, db, schema +} + +func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) { + t.Helper() + tables := []string{ + s.Tables.UserRoles, s.Tables.RolePermissions, + s.Tables.Roles, s.Tables.Permissions, + s.Tables.APIKeys, s.Tables.Tokens, + s.Tables.Sessions, s.Tables.Users, + s.Tables.SchemaMigrations, + } + for _, name := range tables { + _, _ = db.ExecContext(context.Background(), + fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name)) + } +} + +func TestIntegration_MigrateIdempotent(t *testing.T) { + url := envURL(t) + db, err := sql.Open("pgx", url) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + schema := sqlstore.DefaultSchema() + t.Cleanup(func() { dropAuthkitTables(t, db, schema) }) + for i := 0; i < 3; i++ { + if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil { + t.Fatalf("Migrate iter %d: %v", i, err) + } + } +} + +func TestIntegration_RegisterAndLogin(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("Register: %v", err) + } + got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2") + if err != nil { + t.Fatalf("LoginPassword: %v", err) + } + if got.ID != u.ID { + t.Fatalf("login user id mismatch") + } + + if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) { + t.Fatalf("expected ErrEmailTaken, got %v", err) + } + if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) { + t.Fatalf("expected ErrInvalidCredentials, got %v", err) + } +} + +func TestIntegration_SessionLifecycle(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "s@s.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1")) + if err != nil { + t.Fatalf("IssueSession: %v", err) + } + if sess.ExpiresAt.Before(time.Now()) { + t.Fatalf("session already expired at issue") + } + if _, err := auth.AuthenticateSession(ctx, plain); err != nil { + t.Fatalf("AuthenticateSession: %v", err) + } + if err := auth.RevokeSession(ctx, plain); err != nil { + t.Fatalf("RevokeSession: %v", err) + } + if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) { + t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err) + } +} + +func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "j@j.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + _, refresh1, err := auth.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + _, refresh2, err := auth.RefreshJWT(ctx, refresh1) + if err != nil { + t.Fatalf("first RefreshJWT: %v", err) + } + if refresh1 == refresh2 { + t.Fatalf("refresh token did not rotate") + } + + if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) { + t.Fatalf("expected ErrTokenReused on replay, got %v", err) + } + if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err) + } +} + +func TestIntegration_APIKeyWithAbilities(t *testing.T) { + auth, _, _ := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "k@k.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + plain, _, err := auth.IssueAPIKey(ctx, u.ID, "ci", + []string{"billing:read", "users:list"}, nil) + if err != nil { + t.Fatalf("IssueAPIKey: %v", err) + } + p, err := auth.AuthenticateAPIKey(ctx, plain) + if err != nil { + t.Fatalf("AuthenticateAPIKey: %v", err) + } + if !p.HasAbility("billing:read") || !p.HasAbility("users:list") { + t.Fatalf("abilities missing: %+v", p.Abilities) + } + if err := auth.RevokeAPIKey(ctx, plain); err != nil { + t.Fatalf("RevokeAPIKey: %v", err) + } + if _, err := auth.AuthenticateAPIKey(ctx, plain); !errors.Is(err, authkit.ErrAPIKeyInvalid) { + t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err) + } +} + +func TestIntegration_RBAC(t *testing.T) { + auth, db, schema := freshDB(t) + ctx := context.Background() + u, err := auth.Register(ctx, "rb@b.com", "pw") + if err != nil { + t.Fatalf("Register: %v", err) + } + + stores, _ := sqlstore.New(db, pgdialect.New(), schema) + r := &authkit.Role{Name: "editor"} + if err := stores.Roles.CreateRole(ctx, r); err != nil { + t.Fatalf("CreateRole: %v", err) + } + p := &authkit.Permission{Name: "posts:write"} + if err := stores.Permissions.CreatePermission(ctx, p); err != nil { + t.Fatalf("CreatePermission: %v", err) + } + if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil { + t.Fatalf("AssignPermissionToRole: %v", err) + } + if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("AssignRole: %v", err) + } + ok, err := auth.HasPermission(ctx, u.ID, "posts:write") + if err != nil || !ok { + t.Fatalf("HasPermission: %v %v", ok, err) + } + ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"}) + if err != nil || !ok { + t.Fatalf("HasAnyRole: %v %v", ok, err) + } + if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil { + t.Fatalf("RemoveRole: %v", err) + } + ok, _ = auth.HasPermission(ctx, u.ID, "posts:write") + if ok { + t.Fatalf("HasPermission should be false after RemoveRole") + } +} diff --git a/sqlstore/tokens.go b/sqlstore/tokens.go new file mode 100644 index 0000000..99841fc --- /dev/null +++ b/sqlstore/tokens.go @@ -0,0 +1,92 @@ +package sqlstore + +import ( + "context" + "database/sql" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" +) + +type tokenStore struct{ storeBase } + +func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error { + const op = "authkit.sqlstore.TokenStore.CreateToken" + if t.CreatedAt.IsZero() { + t.CreatedAt = time.Now().UTC() + } + _, err := s.db.ExecContext(ctx, s.q.CreateToken, + t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID), + nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt) + if err != nil { + return errx.Wrap(op, err) + } + return nil +} + +// ConsumeToken does the find-and-mark-consumed in a single statement so two +// concurrent callers cannot both successfully consume the same token. The +// row is returned for inspection (e.g. ChainID for refresh rotation). +func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) { + const op = "authkit.sqlstore.TokenStore.ConsumeToken" + row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) + } + return t, nil +} + +func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) { + const op = "authkit.sqlstore.TokenStore.GetToken" + row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash) + t, err := scanToken(row) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid)) + } + return t, nil +} + +func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) { + const op = "authkit.sqlstore.TokenStore.DeleteByChain" + tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { + const op = "authkit.sqlstore.TokenStore.DeleteExpired" + tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now) + if err != nil { + return 0, errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + return n, nil +} + +func scanToken(row rowScanner) (*authkit.Token, error) { + var ( + t authkit.Token + kind string + userIDStr string + chainID sql.NullString + consumedAt sql.NullTime + ) + if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, + &consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil { + return nil, err + } + t.Kind = authkit.TokenKind(kind) + uid, err := scanUUID(userIDStr) + if err != nil { + return nil, err + } + t.UserID = uid + t.ChainID = scanNullStringPtr(chainID) + t.ConsumedAt = scanNullTimePtr(consumedAt) + return &t, nil +} diff --git a/sqlstore/users.go b/sqlstore/users.go new file mode 100644 index 0000000..b03c379 --- /dev/null +++ b/sqlstore/users.go @@ -0,0 +1,186 @@ +package sqlstore + +import ( + "context" + "database/sql" + "errors" + "strings" + "time" + + "git.juancwu.dev/juancwu/authkit" + "git.juancwu.dev/juancwu/errx" + "github.com/google/uuid" +) + +type userStore struct{ storeBase } + +func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error { + const op = "authkit.sqlstore.UserStore.CreateUser" + if u.ID == uuid.Nil { + u.ID = uuid.New() + } + if u.EmailNormalized == "" { + u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email)) + } + now := time.Now().UTC() + if u.CreatedAt.IsZero() { + u.CreatedAt = now + } + if u.UpdatedAt.IsZero() { + u.UpdatedAt = now + } + _, err := s.db.ExecContext(ctx, s.q.CreateUser, + uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, + nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt) + if err != nil { + if s.d.IsUniqueViolation(err) { + return errx.Wrap(op, authkit.ErrEmailTaken) + } + return errx.Wrap(op, err) + } + return nil +} + +func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) { + const op = "authkit.sqlstore.UserStore.GetUserByID" + u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id))) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return u, nil +} + +func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) { + const op = "authkit.sqlstore.UserStore.GetUserByEmail" + u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail)) + if err != nil { + return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return u, nil +} + +func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error { + const op = "authkit.sqlstore.UserStore.UpdateUser" + u.UpdatedAt = time.Now().UTC() + tag, err := s.db.ExecContext(ctx, s.q.UpdateUser, + u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt), + nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins, + nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error { + const op = "authkit.sqlstore.UserStore.DeleteUser" + tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error { + const op = "authkit.sqlstore.UserStore.SetPassword" + tag, err := s.db.ExecContext(ctx, s.q.SetPassword, + nullableString(encodedHash), time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error { + const op = "authkit.sqlstore.UserStore.SetEmailVerified" + tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) { + const op = "authkit.sqlstore.UserStore.BumpSessionVersion" + var v int + if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion, + time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return v, nil +} + +func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) { + const op = "authkit.sqlstore.UserStore.IncrementFailedLogins" + var n int + if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins, + time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil { + return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound)) + } + return n, nil +} + +func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error { + const op = "authkit.sqlstore.UserStore.ResetFailedLogins" + tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID)) + if err != nil { + return errx.Wrap(op, err) + } + n, _ := tag.RowsAffected() + if n == 0 { + return errx.Wrap(op, authkit.ErrUserNotFound) + } + return nil +} + +func scanUser(row rowScanner) (*authkit.User, error) { + var ( + u authkit.User + idStr string + passwordHash sql.NullString + emailVerified sql.NullTime + lastLogin sql.NullTime + ) + if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified, + &passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin, + &u.CreatedAt, &u.UpdatedAt); err != nil { + return nil, err + } + id, err := scanUUID(idStr) + if err != nil { + return nil, err + } + u.ID = id + if passwordHash.Valid { + u.PasswordHash = passwordHash.String + } + u.EmailVerifiedAt = scanNullTimePtr(emailVerified) + u.LastLoginAt = scanNullTimePtr(lastLogin) + return &u, nil +} + +// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so +// callers get reliable errors.Is targets through errx wrapping. +func mapNotFound(err error, sentinel error) error { + if errors.Is(err, sql.ErrNoRows) { + return sentinel + } + return err +} diff --git a/stores.go b/stores.go new file mode 100644 index 0000000..36f2374 --- /dev/null +++ b/stores.go @@ -0,0 +1,83 @@ +package authkit + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +type UserStore interface { + CreateUser(ctx context.Context, u *User) error + GetUserByID(ctx context.Context, id uuid.UUID) (*User, error) + GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error) + UpdateUser(ctx context.Context, u *User) error + DeleteUser(ctx context.Context, id uuid.UUID) error + SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error + SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error + BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) + IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) + ResetFailedLogins(ctx context.Context, userID uuid.UUID) error +} + +type SessionStore interface { + CreateSession(ctx context.Context, s *Session) error + GetSession(ctx context.Context, idHash []byte) (*Session, error) + TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error + DeleteSession(ctx context.Context, idHash []byte) error + DeleteUserSessions(ctx context.Context, userID uuid.UUID) error + DeleteExpired(ctx context.Context, now time.Time) (int64, error) +} + +type TokenStore interface { + CreateToken(ctx context.Context, t *Token) error + // ConsumeToken atomically marks the matching unexpired, unconsumed token + // as consumed and returns it. Returns ErrTokenInvalid if no row matched. + // Implementations MUST do this in one statement (UPDATE ... RETURNING) + // to prevent double-spend under concurrent callers. + ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) + // GetToken returns a token without consuming it. Used for refresh-token + // reuse detection: a token that exists with consumed_at != nil is a + // replay signal. + GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) + DeleteByChain(ctx context.Context, chainID string) (int64, error) + DeleteExpired(ctx context.Context, now time.Time) (int64, error) +} + +type APIKeyStore interface { + CreateAPIKey(ctx context.Context, k *APIKey) error + GetAPIKey(ctx context.Context, idHash []byte) (*APIKey, error) + ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*APIKey, error) + TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error + RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error + RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error +} + +type RoleStore interface { + CreateRole(ctx context.Context, r *Role) error + GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error) + GetRoleByName(ctx context.Context, name string) (*Role, error) + ListRoles(ctx context.Context) ([]*Role, error) + DeleteRole(ctx context.Context, id uuid.UUID) error + AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error + RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error + GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error) + HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) +} + +type PermissionStore interface { + CreatePermission(ctx context.Context, p *Permission) error + GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error) + GetPermissionByName(ctx context.Context, name string) (*Permission, error) + ListPermissions(ctx context.Context) ([]*Permission, error) + DeletePermission(ctx context.Context, id uuid.UUID) error + AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error + RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error + GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error) + GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error) +} + +type Hasher interface { + Hash(password string) (string, error) + Verify(password, encoded string) (ok bool, needsRehash bool, err error) +} diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..78d5e22 --- /dev/null +++ b/tokens.go @@ -0,0 +1,67 @@ +package authkit + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "io" + "strings" + + "git.juancwu.dev/juancwu/errx" +) + +const secretRandomBytes = 32 + +// Secret prefixes. The leading token namespaces the secret so callers can +// route them to the right verifier without trial-and-error. +const ( + prefixSession = "sess" + prefixRefresh = "rfr" + prefixAPIKey = "ak" + prefixEmailVerify = "evr" + prefixPasswordRset = "pwr" + prefixMagicLink = "mlnk" +) + +// mintSecret generates a new opaque secret of the given prefix kind and +// returns the plaintext (to be returned to the user, never stored) and the +// SHA-256 lookup hash (to be stored). +func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) { + const op = "authkit.mintSecret" + if rng == nil { + rng = rand.Reader + } + buf := make([]byte, secretRandomBytes) + if _, err := io.ReadFull(rng, buf); err != nil { + return "", nil, errx.Wrap(op, err) + } + body := base64.RawURLEncoding.EncodeToString(buf) + plaintext = prefix + "_" + body + hash = hashSecret(plaintext) + return plaintext, hash, nil +} + +// hashSecret returns sha256(plaintext) — the lookup key for opaque secrets. +// Plaintexts have full entropy from a CSPRNG so a plain hash is sufficient +// (no per-record salt needed; the random body is the salt). +func hashSecret(plaintext string) []byte { + sum := sha256.Sum256([]byte(plaintext)) + return sum[:] +} + +// parseSecret validates that a plaintext starts with the expected prefix +// and returns the lookup hash. The prefix check is constant-time relative +// to a fixed-length comparison. +func parseSecret(prefix, plaintext string) (hash []byte, ok bool) { + want := prefix + "_" + if !strings.HasPrefix(plaintext, want) { + return nil, false + } + return hashSecret(plaintext), true +} + +// constantTimeEqual is a thin wrapper for readability at call sites. +func constantTimeEqual(a, b []byte) bool { + return subtle.ConstantTimeCompare(a, b) == 1 +} diff --git a/tokens_test.go b/tokens_test.go new file mode 100644 index 0000000..391437a --- /dev/null +++ b/tokens_test.go @@ -0,0 +1,56 @@ +package authkit + +import ( + "bytes" + "crypto/sha256" + "strings" + "testing" +) + +func TestMintSecretRoundtrip(t *testing.T) { + plaintext, hash, err := mintSecret(prefixSession, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if !strings.HasPrefix(plaintext, prefixSession+"_") { + t.Fatalf("missing prefix: %q", plaintext) + } + parsed, ok := parseSecret(prefixSession, plaintext) + if !ok { + t.Fatalf("parseSecret rejected our own mint") + } + if !bytes.Equal(hash, parsed) { + t.Fatalf("hash mismatch") + } + want := sha256.Sum256([]byte(plaintext)) + if !bytes.Equal(hash, want[:]) { + t.Fatalf("hashSecret != sha256(plaintext)") + } +} + +func TestParseSecretWrongPrefix(t *testing.T) { + plaintext, _, err := mintSecret(prefixSession, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if _, ok := parseSecret(prefixAPIKey, plaintext); ok { + t.Fatalf("parseSecret should reject mismatched prefix") + } + if _, ok := parseSecret(prefixSession, "sessXXXX"); ok { + t.Fatalf("parseSecret should require trailing underscore") + } +} + +func TestMintSecretUniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + p, _, err := mintSecret(prefixAPIKey, nil) + if err != nil { + t.Fatalf("mintSecret: %v", err) + } + if _, dup := seen[p]; dup { + t.Fatalf("duplicate mint: %s", p) + } + seen[p] = struct{}{} + } +}