Rebuild for v1.0.0: postgres-only, slug-keyed authz, predicate API

Drops the Dialect/Queries abstraction in favor of a single PostgreSQL 16+
implementation collapsed into the root authkit package, removes the
public store interfaces, and reshapes the authorization model around
seeded slugs (roles, permissions, abilities) with optional labels.

Schema is now squashed into one migrations/0001_init.sql and applied
automatically on authkit.New (opt-out via Config.SkipAutoMigrate). A
schema verifier checks tables/columns/types/nullability on startup,
tolerates extra columns, and falls back to default table names when a
configured override is missing.

Auth API: CreateUser + SetPassword replace Register; password is
nullable. Email OTP (RequestEmailOTP/ConsumeEmailOTP) joins magic links
and password reset, all with anti-enumeration silent-success defaults
and a Config.RevealUnknownEmail opt-in. Service tokens drop owner
columns and validate ability slugs against authkit_abilities at issue.
Direct user permissions live alongside role-derived ones; queries
return their UNION.

Predicate API: HasRole/HasPermission/HasAbility leaves with
AnyLogin/AllLogin/AnyServiceKey/AllServiceKey combinators. Validate
runs at middleware construction, panicking on unknown slugs.

Middleware collapses to RequireLogin (cookie + JWT), RequireGuest
(configurable OnAuthenticated), and RequireServiceKey. UserIDFromCtx /
UserFromCtx (lazy) / RefreshUserInCtx provide request-lifetime user
caching. Cookie defaults flip to Secure=true and HttpOnly=true via
*bool with BoolPtr opt-out.

CLIs ship under cmd/perms, cmd/roles, cmd/abilities for seeding the
authorization vocabulary; the library never seeds rows itself.

Tests cover unit-level (slug validation + fuzz, opaque secrets, email
normalization, extractors, predicates, OTP generator) and integration
flows gated on AUTHKIT_TEST_DATABASE_URL (every Auth method, schema
drift detection, migration idempotency, lazy user cache, all middleware
paths).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
juancwu 2026-04-26 23:27:30 +00:00
commit d3c5367492
80 changed files with 5605 additions and 4565 deletions

487
README.md
View file

@ -1,408 +1,283 @@
# authkit
A pragmatic authentication and authorization toolkit for Go web services.
A pragmatic authentication and authorization toolkit for Go web services
on PostgreSQL 16+.
`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and
permissions, plus default `database/sql` Postgres implementations and
framework-neutral HTTP middleware. It supports both opaque server-side
sessions and JWT access tokens with rotating refresh tokens, hashes passwords
with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux)
or any `net/http` stack.
`authkit` is a library, not a service. Drop it into a `net/http` stack and
get registration, password login, opaque server-side sessions, JWT access
tokens with rotating refresh, email verification, password reset,
magic-link login, email OTP, and machine-targeted service tokens with
consumer-defined abilities. Authorization is flat RBAC with both
role-derived and direct user permissions.
> **Status:** v1.0.0 development. The API is being stabilised; expect
> breaking changes until the v1.0.0 tag.
## Install
```
go get git.juancwu.dev/juancwu/authkit@v0.1.0
```sh
go get git.juancwu.dev/juancwu/authkit
```
`authkit` itself depends only on `database/sql` and the Go standard library
plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring
your own driver: `pgx`, `lib/pq`, or anything else that registers a
`database/sql` driver.
`authkit` depends only on the Go standard library, `golang-jwt`,
`google/uuid`, `golang.org/x/crypto`, and `errx`. Bring your own driver:
`pgx`, `lib/pq`, or anything else that registers a `database/sql` driver.
```go
import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
```
PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and
`pgcrypto` so no extensions are required.
PostgreSQL 16 or newer is required.
## What's included
**Authentication flows**
- Email + password registration and login (Argon2id PHC-encoded hashes)
**Authentication**
- Email-only registration (`CreateUser`); password is optional and can be
set later via `SetPassword`
- Password login with Argon2id PHC-encoded hashes
- Opaque server-side sessions with sliding TTL bounded by an absolute cap
- JWT access tokens (HS256) with rotating refresh tokens and reuse detection
- Email verification, password reset, and magic-link passwordless login
- HS256 JWT access tokens with rotating refresh tokens and reuse
detection
- Email verification, password reset, magic-link login, email OTP
**Authorization**
- Roles and permissions with many-to-many wiring (resolved on user-bound
`Principal`s)
- Owner-agnostic service tokens with custom abilities for server-to-server
auth (no FK on owner; cascade-on-delete is the consumer's responsibility)
- A `Principal` for user-bound auth (sessions, JWTs) and a `ServiceKey` for
service-token auth — middleware composes around both subject types
- Roles and permissions (flat RBAC)
- Direct user-permission grants in addition to role-derived ones —
`UserPermissions` returns the UNION
- Service tokens with consumer-defined abilities (machine credentials, no
user owner)
**Predicate API for middleware authz**
- Leaves: `HasRole(slug)`, `HasPermission(slug)`, `HasAbility(slug)`
- Combinators: `AnyLogin`, `AllLogin`, `AnyServiceKey`, `AllServiceKey`
- Compose freely:
`AnyLogin(HasRole("admin"), AllLogin(HasRole("manager"), HasRole("ads_manager")))`
**HTTP middleware**
- `RequireLogin` — accept session cookie OR JWT, optionally constrain by
`LoginAuthz`
- `RequireGuest` — block authenticated requests (with a configurable
`OnAuthenticated` callback for redirects)
- `RequireServiceKey` — accept a service token, optionally constrain by
`ServiceKeyAuthz`
**Storage**
- Interfaces for every store so callers can plug in their own backends
- Default Postgres implementation built on `*sql.DB` (`sqlstore` package)
- Override table names via `Schema` without forking — useful when authkit
lives alongside an existing application schema
- A `Dialect` abstraction so future MySQL / SQLite implementations slot in
without changes to store code
- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)`
helper that takes a session-scoped advisory lock
**HTTP**
- User-bound: `middleware.RequireSession`, `RequireJWT`, `RequireAny`
- Service-bound: `middleware.RequireServiceKey`
- Either: `middleware.RequireAnyOrServiceKey` (Session/JWT, falling through to
ServiceKey)
- Authz: `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`
(operate on `*Principal`); `middleware.RequireAbility` (operates on
`*ServiceKey`)
- `middleware.PrincipalFrom(ctx)` and `middleware.ServiceKeyFrom(ctx)` to
read the authenticated subject in handlers
- PostgreSQL 16+ only
- Migrations and schema verification run on startup (opt-out via
`Config.SkipAutoMigrate` / `Config.SkipSchemaVerify`)
- Override individual table names via `Schema.Tables`
- Schema verifier tolerates extra columns; flags missing tables, missing
columns, type drift, and nullability drift
**Errors**
- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`,
`ErrTokenReused`, `ErrSessionInvalid`, `ErrServiceKeyInvalid`,
`ErrPermissionDenied`, ...) compatible with `errors.Is`
- Sentinel errors compatible with `errors.Is`
- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx)
for op tags
## Out of scope (v1)
MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching,
pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite
dialects (architecture supports them; only Postgres ships in v1), and
column-name overrides in `Schema` (table-name overrides only).
## Quick start
### 1. Open a database and run migrations
### 1. Open a database
```go
import (
"database/sql"
"git.juancwu.dev/juancwu/authkit/sqlstore"
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
_ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
_ "github.com/jackc/pgx/v5/stdlib"
)
db, err := sql.Open("pgx", os.Getenv("DATABASE_URL"))
if err != nil { /* ... */ }
defer db.Close()
if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil {
log.Fatal(err)
}
```
`Migrate` is idempotent and safe to call from multiple processes — it takes
a session-scoped `pg_advisory_lock` to serialise rollouts.
`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same
calls — the library only cares about `*sql.DB`.
### 2. Wire the service
### 2. Construct Auth
```go
import (
"context"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
)
stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema())
if err != nil { /* ... */ }
auth := authkit.New(authkit.Deps{
Users: stores.Users,
Sessions: stores.Sessions,
Tokens: stores.Tokens,
ServiceKeys: stores.ServiceKeys,
Roles: stores.Roles,
Permissions: stores.Permissions,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
auth, err := authkit.New(ctx, authkit.Deps{
DB: db,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
JWTIssuer: "myapp",
SessionCookieSecure: true,
SessionCookieHTTPOnly: true,
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
JWTIssuer: "myapp",
})
if err != nil { log.Fatal(err) }
```
`Config` zero values fall back to sensible defaults (24h idle / 30d absolute
session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h
password-reset, 15m magic-link). `JWTSecret` and all seven `Deps` fields are
required; `New` panics on a misconfiguration.
`New` runs migrations and verifies the schema. `Config` zero values fall
back to sane defaults: 24h idle / 30d absolute session TTL, 15m access /
30d refresh, 48h email-verify, 1h password-reset, 15m magic-link, 10m
email OTP with 5 attempts. Cookie defaults: `Secure=true`, `HttpOnly=true`,
`SameSite=Lax`. Pass `authkit.BoolPtr(false)` to opt out for local dev.
### 3. Use the service
### 3. Seed roles, permissions, and abilities
`authkit` does not seed any rows. Use the bundled CLIs:
```sh
go install git.juancwu.dev/juancwu/authkit/cmd/perms@latest
go install git.juancwu.dev/juancwu/authkit/cmd/roles@latest
go install git.juancwu.dev/juancwu/authkit/cmd/abilities@latest
export AUTHKIT_DATABASE_URL=postgres://...
perms create posts:write --label "Write posts"
perms create posts:read --label "Read posts"
roles create editor --label "Editor"
roles grant editor posts:write
roles grant editor posts:read
abilities create events:write --label "Events ingest"
```
Or call the equivalent methods on `*authkit.Auth` from your own seed
script. Slugs match `^[a-z][a-z0-9_:-]*$` (max 64 bytes); invalid slugs
return `ErrSlugInvalid`.
### 4. User flows
```go
// Registration + password login
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2")
// Email-only account, password set later.
u, _ := auth.CreateUser(ctx, "alice@example.com")
_ = auth.SetPassword(ctx, u.ID, "hunter2hunter2")
u, _ = auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2") // case-insensitive
// Opaque session (cookie-friendly)
plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP)
// Opaque session.
plaintext, sess, _ := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP)
http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt))
// JWT + rotating refresh
access, refresh, err := auth.IssueJWT(ctx, u.ID)
access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed
// JWT + rotating refresh.
access, refresh, _ := auth.IssueJWT(ctx, u.ID)
access, refresh, _ = auth.RefreshJWT(ctx, refresh) // old refresh is consumed
// Service token (owner-agnostic; ownerKind labels the namespace).
// Service tokens are the only credential type that carries free-form abilities.
plaintext, sk, err := auth.IssueServiceKey(ctx,
"application", appID, "events-ingest",
[]string{"events:write"}, nil)
got, err := auth.AuthenticateServiceKey(ctx, plaintext)
// got.OwnerKind == "application"; got.OwnerID == appID
err = auth.RevokeServiceKey(ctx, plaintext)
// Magic link / OTP / password reset (anti-enumeration: silent on unknown email).
linkToken, _ := auth.RequestMagicLink(ctx, "alice@example.com")
otpCode, _ := auth.RequestEmailOTP(ctx, "alice@example.com")
resetToken, _ := auth.RequestPasswordReset(ctx, "alice@example.com")
// Email verification + password reset + magic link
tok, err := auth.RequestEmailVerification(ctx, u.ID)
_, err = auth.ConfirmEmail(ctx, tok)
tok, err = auth.RequestPasswordReset(ctx, "alice@example.com")
err = auth.ConfirmPasswordReset(ctx, tok, "new-password")
tok, err = auth.RequestMagicLink(ctx, "alice@example.com")
u, err = auth.ConsumeMagicLink(ctx, tok)
// Service token with abilities.
plaintext, sk, _ := auth.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
Name: "events-ingest",
Abilities: []string{"events:write"},
})
got, _ := auth.AuthenticateServiceKey(ctx, plaintext)
```
The plaintext returned by `IssueSession`, `IssueJWT`, `IssueServiceKey`, and
the token-minting flows is **show-once** — only its SHA-256 hash is stored.
Show it to the user immediately; you cannot recover it later.
The plaintext returned by every issue/mint flow is **show-once** — only
its SHA-256 hash is stored. Show it to the user immediately; you cannot
recover it later.
### 4. Wire middleware
`authkit/middleware` returns standard `func(http.Handler) http.Handler`
values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any
`net/http` mux that accepts the same shape.
### 5. Wire middleware
```go
import (
authkitmw "git.juancwu.dev/juancwu/authkit/middleware"
"git.juancwu.dev/juancwu/lightmux"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/middleware"
)
mux := lightmux.New()
// Default RequireLogin reads the session cookie and falls through to a
// Bearer JWT.
loginMW := middleware.RequireLogin(middleware.LoginOptions{Auth: auth})
cookieAuth := authkitmw.RequireSession(authkitmw.Options{
Auth: auth,
Extractor: authkit.ChainExtractors(
authkit.CookieExtractor("authkit_session"),
authkit.BearerExtractor(),
// Constrain on roles/permissions:
adminMW := middleware.RequireLogin(middleware.LoginOptions{
Auth: auth,
Authz: authkit.AnyLogin(
authkit.HasRole("admin"),
authkit.AllLogin(authkit.HasRole("manager"), authkit.HasRole("ads_manager")),
),
})
me := mux.Group("/me", cookieAuth)
me.Get("", func(w http.ResponseWriter, r *http.Request) {
p := authkitmw.MustPrincipal(r)
json.NewEncoder(w).Encode(p)
// Login/register pages: block if already authenticated.
guestMW := middleware.RequireGuest(middleware.GuestOptions{
Auth: auth,
OnAuthenticated: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
},
})
// RBAC: stack authz on top of any auth method
admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin"))
// Service-token route with a per-endpoint ability check
api := mux.Group("/api/v1", authkitmw.RequireServiceKey(authkitmw.Options{Auth: auth}))
api.Get("/events", eventsHandler, authkitmw.RequireAbility("events:write"))
// Mixed route — accept either a session cookie or a service token
mixed := mux.Group("/v1", authkitmw.RequireAnyOrServiceKey(authkitmw.Options{Auth: auth}))
mixed.Get("/profile", func(w http.ResponseWriter, r *http.Request) {
if p, ok := authkitmw.PrincipalFrom(r.Context()); ok {
// user request
_ = p
} else if k, ok := authkitmw.ServiceKeyFrom(r.Context()); ok {
// service request
_ = k
}
// Service tokens with an ability gate.
apiMW := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: auth,
Authz: authkit.AllServiceKey(authkit.HasAbility("events:write")),
})
```
`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or
chain extractors) when reading session cookies. `Options.OnUnauth` and
`Options.OnForbidden` default to a JSON `401` / `403`; override them to match
your error envelope.
`RequireLogin` and `RequireServiceKey` panic at construction if any slug
referenced by the predicate isn't registered in the database — typos fail
at boot, not at request time.
## Custom table names
### 6. Read the user in handlers
Pass a non-default `Schema` to use your own table names. Identifiers must
match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and
`Migrate()` time, so SQL injection through the schema is impossible.
Middleware attaches the `user_id` to the request context. Handlers fetch
the full user lazily:
```go
schema := sqlstore.DefaultSchema()
schema.Tables.Users = "accounts"
schema.Tables.ServiceKeys = "service_credentials"
func handle(w http.ResponseWriter, r *http.Request) {
id, _ := authkit.UserIDFromCtx(r.Context()) // never queries the DB
u, err := authkit.UserFromCtx(r.Context()) // lazy-load + per-request cache
if err != nil { /* handle */ }
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
```
The bundled migration files use the default `authkit_*` names. If you
override, you're responsible for matching DDL (most consumers with custom
naming already have their own DDL pipeline).
Column-name overrides are not exposed in v1 — the column set is fixed for
each table. Adding column overrides later is purely additive.
## How things work
### Secret token format
Sessions, refresh tokens, service tokens, email-verify tokens,
password-reset tokens, and magic-link tokens all share one format:
```
plaintext = "<prefix>_" + base64url(32 random bytes, no padding)
lookup = sha256(plaintext)
```
Plaintext is returned to the caller exactly once and never persisted; the
SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or
`Config.Random` for tests). The mint/parse/hash helpers are exported as
`MintOpaqueSecret`, `ParseOpaqueSecret`, and `HashOpaqueSecret` for callers
building bespoke token storage on top of the same shape.
### User credentials vs. service tokens
`authkit` exposes two distinct subject types, and middleware composes around
them differently.
**User credentials** — sessions and JWTs — prove **identity**. They are
produced by `IssueSession` / `IssueJWT` and authenticate via
`AuthenticateSession` / `AuthenticateJWT`, which return a `*Principal`
carrying `UserID`, `Method`, and the user's roles + permissions resolved
through RBAC. Authorization on these requests is **role/permission-based**
via `RequireRole` / `RequirePermission`. User credentials carry no abilities;
"what this user may do" is answered by the user's RBAC, not by anything
embedded on the credential itself.
**Service tokens** — `IssueServiceKey` — prove **"this caller may do X"**.
They are owner-agnostic: `OwnerKind` labels the namespace ("application",
"tenant", whatever) and `OwnerID` identifies the entity within it. The
database column has **no foreign key** on purpose — `authkit` makes no
assumption about what the owner is, and cascade-on-delete is the consumer's
responsibility. `AuthenticateServiceKey` returns a `*ServiceKey` directly
(no `*Principal`, no role/permission resolution). Authorization on these
requests is **ability-based** via `RequireAbility`; the abilities slice is
free-form and not linked to `authkit_roles` / `authkit_permissions`.
```go
plaintext, key, err := auth.IssueServiceKey(ctx,
"application", appID, "events-ingest",
[]string{"events:write"}, nil)
k, err := auth.AuthenticateServiceKey(ctx, plaintext)
// k.OwnerKind == "application"; k.OwnerID == appID; k.HasAbility("events:write")
err = auth.RevokeServiceKey(ctx, plaintext)
```
When a consumer-owned entity (an application, a tenant) is deleted, the
consumer must revoke the associated service tokens itself — typically by
iterating `ListServiceKeys(ctx, ownerKind, ownerID)`.
### JWT revocation
Access tokens carry `sv` (`session_version`) in their claims. When you call
`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version`
column increments and every outstanding access token fails the next
`AuthenticateJWT`. This is the only way to invalidate a JWT before its
`exp`.
### Refresh token rotation
Each `RefreshJWT` consumes the presented refresh token and issues a new one
on the same chain (the `chain_id` column on `authkit_tokens`). If a
*consumed* refresh token is ever presented again — a strong replay signal —
the entire chain is deleted via `TokenStore.DeleteByChain` and the call
returns `ErrTokenReused`.
### Sliding session TTL
Each authenticated request via `AuthenticateSession` slides `expires_at` to
`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`.
Long-lived idle sessions still hit the absolute boundary.
### Schema and migrations
`sqlstore.Migrate` applies every embedded `.sql` file under
`sqlstore/dialect/postgres/migrations/` whose version (filename without
`.sql`) is not in `authkit_schema_migrations`. The dialect's
`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises
concurrent migrators. Each migration owns its own transaction so future
migrations can use statements like `CREATE INDEX CONCURRENTLY`.
Every default table is prefixed `authkit_` so the schema can live alongside
your application's own tables in a shared database.
### Driver and dialect architecture
The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour
lives behind a small `Dialect` interface:
```go
type Dialect interface {
Name() string
BuildQueries(s Schema) Queries
Bootstrap(ctx context.Context, db *sql.DB) error
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
Migrations() fs.FS
IsUniqueViolation(err error) bool
Placeholder(n int) string
PlaceholderList(start, count int) string
// After an admin-side update that should be visible:
u, err = authkit.RefreshUserInCtx(r.Context())
_ = u; _ = id
}
```
v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new
implementation; no changes to store code.
The cache lives only for the request lifetime — nothing persists across
requests. For service-token routes, use `authkit.ServiceKeyFromCtx`.
## Schema verification and drift
On `New`, `authkit` introspects `information_schema.columns` and verifies
the live database matches the expected layout (table presence, column
names, `data_type`, `is_nullable`). Extra columns are tolerated; missing
tables/columns and type drift fail with `ErrSchemaDrift`.
When a table cannot be found under the configured name, the verifier
falls back to the default `authkit_*` name. This handles migrations from
custom names back to defaults without manual intervention.
## Configuration reference
| Field | Default | Notes |
|---|---|---|
| `Schema` | `DefaultSchema()` | Override individual `Tables` fields; missing fields fall back to defaults |
| `SkipAutoMigrate` | `false` | Disables migration run inside `New` |
| `SkipSchemaVerify` | `false` | Disables schema check inside `New` |
| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request |
| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this |
| `SessionCookieName` | `authkit_session` | |
| `SessionCookieSecure` | `*true` | Pass `BoolPtr(false)` for local HTTP dev |
| `SessionCookieHTTPOnly` | `*true` | Pass `BoolPtr(false)` if JS must read it (rarely correct) |
| `SessionCookieSameSite` | `Lax` | |
| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production |
| `JWTSecret` | — (required) | HS256 key |
| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them |
| `AccessTokenTTL` | 15m | |
| `RefreshTokenTTL` | 30d | |
| `AccessTokenTTL` / `RefreshTokenTTL` | 15m / 30d | |
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests |
| `EmailOTPTTL` / `EmailOTPDigits` / `EmailOTPMaxAttempts` | 10m / 6 / 5 | |
| `RevealUnknownEmail` | `false` | Default anti-enumeration: silent success on unknown email |
| `Clock` | `time.Now().UTC` | Override for deterministic tests |
| `Random` | `crypto/rand.Reader` | Override for deterministic tests |
| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit |
## Implementing your own store
Every store is a small interface with explicit semantics — see `stores.go`.
The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token
consumed and return it in a single statement (`UPDATE ... RETURNING` on
Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed.
| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit. Panics in the hook are recovered. |
## Testing
```
go test ./... # unit tests, no DB
AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests
```sh
go test ./... # unit tests, no DB required
AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./... -run Integration
```
Unit tests cover token mint/parse, Argon2id encode/verify (including
`needsRehash` on parameter change), JWT issue/parse (incl. expired,
`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email
verification, password reset cascading session invalidation, magic-link
self-verification, API keys with abilities, and RBAC role-permission
resolution. Integration tests run the full `sqlstore` contract against a
real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set.
The unit suite covers slug validation (incl. fuzz), opaque-secret
roundtrip, email normalization, HTTP extractors, predicate combinators,
and OTP code generation. Integration tests cover every database-bound
flow: registration, login, sessions, JWT refresh + reuse, magic link,
email OTP (incl. attempt cap), password reset, service tokens, RBAC,
direct user permissions, schema verification (drift cases + fallback),
migration idempotency, lazy user-context cache, and middleware behavior.
## License

View file

@ -3,6 +3,7 @@ package authkit
import (
"context"
"crypto/rand"
"database/sql"
"io"
"net/http"
"time"
@ -10,31 +11,53 @@ import (
"git.juancwu.dev/juancwu/errx"
)
// Deps bundles every backing store and the password hasher the Auth service
// depends on. All fields are required; New panics on a nil dep so misuse is
// caught at boot rather than under load.
// Hasher is the password hashing interface. The default implementation is
// hasher.Argon2id; consumers can swap in alternative KDFs as long as the
// encoded form lets Verify roundtrip and report needsRehash on parameter
// drift.
type Hasher interface {
Hash(password string) (string, error)
Verify(password, encoded string) (ok bool, needsRehash bool, err error)
}
// Deps bundles the runtime dependencies the Auth service requires. DB and
// Hasher are required; New panics on either being nil.
type Deps struct {
Users UserStore
Sessions SessionStore
Tokens TokenStore
ServiceKeys ServiceKeyStore
Roles RoleStore
Permissions PermissionStore
Hasher Hasher
DB *sql.DB
Hasher Hasher
}
// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material,
// and optional hooks. Any zero-valued duration is replaced with a sane
// default in New; required fields (notably JWTSecret) cause New to panic.
// schema overrides, and optional hooks. Zero-valued durations are replaced
// with sane defaults in New; required fields (notably JWTSecret) cause New
// to panic.
type Config struct {
// Session (opaque) cookies + DB-backed lifetime
SessionIdleTTL time.Duration
SessionAbsoluteTTL time.Duration
SessionCookieName string
SessionCookieDomain string
SessionCookiePath string
SessionCookieSecure bool
SessionCookieHTTPOnly bool
// Schema lets consumers override table names. Zero value uses
// DefaultSchema().
Schema Schema
// SkipAutoMigrate disables the migration run inside New. The verifier
// still runs; consumers running their own migrate pipeline should set
// this and call authkit.Migrate manually before New (or skip it
// entirely if they manage DDL out-of-band).
SkipAutoMigrate bool
// SkipSchemaVerify disables the startup schema check. Recommended only
// for tests that expect schema drift; production callers should let the
// verifier run.
SkipSchemaVerify bool
// Sessions
SessionIdleTTL time.Duration
SessionAbsoluteTTL time.Duration
SessionCookieName string
SessionCookieDomain string
SessionCookiePath string
// SessionCookieSecure / SessionCookieHTTPOnly use *bool so a nil value
// means "fall back to the secure default (true)" while *bool(false) is
// an explicit opt-out for local dev. BoolPtr is a one-line constructor.
SessionCookieSecure *bool
SessionCookieHTTPOnly *bool
SessionCookieSameSite http.SameSite
// JWT (HS256)
@ -45,38 +68,82 @@ type Config struct {
RefreshTokenTTL time.Duration
// Single-use tokens
EmailVerifyTTL time.Duration
PasswordResetTTL time.Duration
MagicLinkTTL time.Duration
EmailVerifyTTL time.Duration
PasswordResetTTL time.Duration
MagicLinkTTL time.Duration
EmailOTPTTL time.Duration
EmailOTPMaxAttempts int
EmailOTPDigits int
// Hooks (optional)
// RevealUnknownEmail flips request flows (RequestPasswordReset,
// RequestMagicLink, RequestEmailOTP) from anti-enumeration silent
// success to returning ErrUserNotFound when the email isn't
// registered. Default false (silent).
RevealUnknownEmail bool
// Hooks
Clock func() time.Time
Random io.Reader
LoginHook func(ctx context.Context, email string, success bool) error
}
// Auth is the high-level service that composes the stores and hasher into the
// flows callers use: registration, login, sessions, JWTs, magic links, API
// keys, and authz checks. It is safe for concurrent use; method receivers
// Auth is the high-level service. Safe for concurrent use; method receivers
// never mutate Auth state after construction.
type Auth struct {
deps Deps
cfg Config
db *sql.DB
hasher Hasher
cfg Config
q queries
schema Schema
}
// New validates Deps and Config, fills in defaults, and returns a ready
// service. It panics on missing deps or missing JWT secret rather than
// returning an error — these are programmer errors, not runtime ones.
func New(deps Deps, cfg Config) *Auth {
if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil ||
deps.ServiceKeys == nil || deps.Roles == nil || deps.Permissions == nil ||
deps.Hasher == nil {
panic(errx.New("authkit.New", "all Deps fields are required"))
// New validates Deps and Config, fills in defaults, runs migrations
// (unless SkipAutoMigrate), verifies the schema (unless SkipSchemaVerify),
// and returns a ready service.
//
// Panics on missing deps, missing JWT secret, invalid schema, or schema
// drift — these are programmer/operator errors, not runtime failures.
func New(ctx context.Context, deps Deps, cfg Config) (*Auth, error) {
const op = "authkit.New"
if deps.DB == nil {
panic(errx.New(op, "Deps.DB is required"))
}
if deps.Hasher == nil {
panic(errx.New(op, "Deps.Hasher is required"))
}
if len(cfg.JWTSecret) == 0 {
panic(errx.New("authkit.New", "Config.JWTSecret is required"))
panic(errx.New(op, "Config.JWTSecret is required"))
}
cfg.Schema = mergeSchemaDefaults(cfg.Schema)
if err := cfg.Schema.Validate(); err != nil {
panic(errx.Wrap(op, err))
}
cfg = applyDefaults(cfg)
a := &Auth{
db: deps.DB,
hasher: deps.Hasher,
cfg: cfg,
q: buildQueries(cfg.Schema.Tables),
schema: cfg.Schema,
}
if !cfg.SkipAutoMigrate {
if err := Migrate(ctx, deps.DB, cfg.Schema); err != nil {
return nil, errx.Wrap(op, err)
}
}
if !cfg.SkipSchemaVerify {
if err := VerifySchema(ctx, deps.DB, cfg.Schema); err != nil {
return nil, errx.Wrap(op, err)
}
}
return a, nil
}
func applyDefaults(cfg Config) Config {
if cfg.SessionIdleTTL == 0 {
cfg.SessionIdleTTL = 24 * time.Hour
}
@ -92,6 +159,14 @@ func New(deps Deps, cfg Config) *Auth {
if cfg.SessionCookieSameSite == 0 {
cfg.SessionCookieSameSite = http.SameSiteLaxMode
}
// Secure & HTTPOnly default to true. Consumers wanting plain HTTP for
// local dev must pass an explicit *false via BoolPtr.
if cfg.SessionCookieSecure == nil {
cfg.SessionCookieSecure = BoolPtr(true)
}
if cfg.SessionCookieHTTPOnly == nil {
cfg.SessionCookieHTTPOnly = BoolPtr(true)
}
if cfg.AccessTokenTTL == 0 {
cfg.AccessTokenTTL = 15 * time.Minute
}
@ -107,15 +182,34 @@ func New(deps Deps, cfg Config) *Auth {
if cfg.MagicLinkTTL == 0 {
cfg.MagicLinkTTL = 15 * time.Minute
}
if cfg.EmailOTPTTL == 0 {
cfg.EmailOTPTTL = 10 * time.Minute
}
if cfg.EmailOTPMaxAttempts == 0 {
cfg.EmailOTPMaxAttempts = 5
}
if cfg.EmailOTPDigits == 0 {
cfg.EmailOTPDigits = 6
}
if cfg.Clock == nil {
cfg.Clock = func() time.Time { return time.Now().UTC() }
}
if cfg.Random == nil {
cfg.Random = rand.Reader
}
return &Auth{deps: deps, cfg: cfg}
return cfg
}
// BoolPtr is a one-line helper for Config fields that take *bool. Use it to
// opt out of the secure cookie defaults: cfg.SessionCookieSecure = BoolPtr(false).
func BoolPtr(b bool) *bool { return &b }
// now returns the configured wall clock, defaulting to time.Now in UTC.
func (a *Auth) now() time.Time { return a.cfg.Clock() }
// DB exposes the underlying *sql.DB. Useful for callers that want to run
// admin queries on the same pool.
func (a *Auth) DB() *sql.DB { return a.db }
// Schema returns the configured schema.
func (a *Auth) Schema() Schema { return a.schema }

176
authz.go Normal file
View file

@ -0,0 +1,176 @@
package authkit
import (
"context"
"fmt"
)
// LoginAuthz is a predicate over a *Principal. Used by middleware that
// gates handlers on a user's roles or permissions.
type LoginAuthz interface {
// Match reports whether the principal satisfies the predicate.
Match(p *Principal) bool
// Validate verifies that every slug referenced by this predicate exists
// in the database. Called at middleware-construction time so typos fail
// at boot rather than at request time.
Validate(ctx context.Context, a *Auth) error
}
// ServiceKeyAuthz is the analogous predicate type for service-token
// authorization.
type ServiceKeyAuthz interface {
Match(k *ServiceKey) bool
Validate(ctx context.Context, a *Auth) error
}
// HasRole returns a leaf predicate satisfied when the principal carries the
// given role slug.
func HasRole(slug string) LoginAuthz { return roleLeaf{slug: slug} }
// HasPermission returns a leaf predicate satisfied when the principal
// carries the given permission slug (resolved through any combination of
// roles and direct grants).
func HasPermission(slug string) LoginAuthz { return permLeaf{slug: slug} }
// HasAbility returns a leaf predicate satisfied when the service key
// carries the given ability slug.
func HasAbility(slug string) ServiceKeyAuthz { return abilityLeaf{slug: slug} }
// AnyLogin returns a predicate satisfied when at least one child predicate
// matches. With no children, AnyLogin matches nothing (returns false).
func AnyLogin(preds ...LoginAuthz) LoginAuthz { return anyLogin{preds: preds} }
// AllLogin returns a predicate satisfied when every child predicate
// matches. With no children, AllLogin matches everything (returns true).
func AllLogin(preds ...LoginAuthz) LoginAuthz { return allLogin{preds: preds} }
// AnyServiceKey returns a service-key predicate satisfied when at least one
// child matches.
func AnyServiceKey(preds ...ServiceKeyAuthz) ServiceKeyAuthz {
return anyService{preds: preds}
}
// AllServiceKey returns a service-key predicate satisfied when every child
// matches.
func AllServiceKey(preds ...ServiceKeyAuthz) ServiceKeyAuthz {
return allService{preds: preds}
}
// ─── leaves ────────────────────────────────────────────────────────────────
type roleLeaf struct{ slug string }
func (l roleLeaf) Match(p *Principal) bool { return p != nil && p.HasRole(l.slug) }
func (l roleLeaf) Validate(ctx context.Context, a *Auth) error {
if err := validateSlug("authkit.HasRole", l.slug); err != nil {
return err
}
if _, err := a.storeGetRoleBySlug(ctx, l.slug); err != nil {
return fmt.Errorf("authkit.HasRole(%q): %w", l.slug, err)
}
return nil
}
type permLeaf struct{ slug string }
func (l permLeaf) Match(p *Principal) bool { return p != nil && p.HasPermission(l.slug) }
func (l permLeaf) Validate(ctx context.Context, a *Auth) error {
if err := validateSlug("authkit.HasPermission", l.slug); err != nil {
return err
}
if _, err := a.storeGetPermissionBySlug(ctx, l.slug); err != nil {
return fmt.Errorf("authkit.HasPermission(%q): %w", l.slug, err)
}
return nil
}
type abilityLeaf struct{ slug string }
func (l abilityLeaf) Match(k *ServiceKey) bool { return k != nil && k.HasAbility(l.slug) }
func (l abilityLeaf) Validate(ctx context.Context, a *Auth) error {
if err := validateSlug("authkit.HasAbility", l.slug); err != nil {
return err
}
if _, err := a.storeGetAbilityBySlug(ctx, l.slug); err != nil {
return fmt.Errorf("authkit.HasAbility(%q): %w", l.slug, err)
}
return nil
}
// ─── combinators ───────────────────────────────────────────────────────────
type anyLogin struct{ preds []LoginAuthz }
func (a anyLogin) Match(p *Principal) bool {
for _, pr := range a.preds {
if pr.Match(p) {
return true
}
}
return false
}
func (a anyLogin) Validate(ctx context.Context, auth *Auth) error {
for _, p := range a.preds {
if err := p.Validate(ctx, auth); err != nil {
return err
}
}
return nil
}
type allLogin struct{ preds []LoginAuthz }
func (a allLogin) Match(p *Principal) bool {
for _, pr := range a.preds {
if !pr.Match(p) {
return false
}
}
return true
}
func (a allLogin) Validate(ctx context.Context, auth *Auth) error {
for _, p := range a.preds {
if err := p.Validate(ctx, auth); err != nil {
return err
}
}
return nil
}
type anyService struct{ preds []ServiceKeyAuthz }
func (a anyService) Match(k *ServiceKey) bool {
for _, pr := range a.preds {
if pr.Match(k) {
return true
}
}
return false
}
func (a anyService) Validate(ctx context.Context, auth *Auth) error {
for _, p := range a.preds {
if err := p.Validate(ctx, auth); err != nil {
return err
}
}
return nil
}
type allService struct{ preds []ServiceKeyAuthz }
func (a allService) Match(k *ServiceKey) bool {
for _, pr := range a.preds {
if !pr.Match(k) {
return false
}
}
return true
}
func (a allService) Validate(ctx context.Context, auth *Auth) error {
for _, p := range a.preds {
if err := p.Validate(ctx, auth); err != nil {
return err
}
}
return nil
}

90
authz_test.go Normal file
View file

@ -0,0 +1,90 @@
package authkit
import (
"testing"
"github.com/google/uuid"
)
func TestPredicateLeavesAndCombinators(t *testing.T) {
p := &Principal{
UserID: uuid.New(),
Roles: []string{"admin", "manager"},
Permissions: []string{"posts:read", "posts:write"},
}
if !HasRole("admin").Match(p) {
t.Fatalf("HasRole admin should match")
}
if HasRole("nope").Match(p) {
t.Fatalf("HasRole nope should not match")
}
if !HasPermission("posts:write").Match(p) {
t.Fatalf("HasPermission posts:write should match")
}
// AnyLogin: short-circuit on first match.
any1 := AnyLogin(HasRole("nope"), HasRole("admin"))
if !any1.Match(p) {
t.Fatalf("AnyLogin with one match should match")
}
any2 := AnyLogin(HasRole("nope"), HasRole("missing"))
if any2.Match(p) {
t.Fatalf("AnyLogin with no matches should not match")
}
// AnyLogin with no children: vacuously false.
if AnyLogin().Match(p) {
t.Fatalf("AnyLogin() should be false (no candidates can satisfy)")
}
// AllLogin: every child must match.
all1 := AllLogin(HasRole("admin"), HasRole("manager"))
if !all1.Match(p) {
t.Fatalf("AllLogin with all matches should match")
}
all2 := AllLogin(HasRole("admin"), HasRole("missing"))
if all2.Match(p) {
t.Fatalf("AllLogin with one missing should not match")
}
if !AllLogin().Match(p) {
t.Fatalf("AllLogin() should be true (vacuous truth)")
}
// Nested: Admin OR (Manager AND AdsManager). Without ads_manager, the
// AND-arm fails but the Admin-arm succeeds.
expr := AnyLogin(
HasRole("admin"),
AllLogin(HasRole("manager"), HasRole("ads_manager")),
)
if !expr.Match(p) {
t.Fatalf("Admin OR (Manager AND AdsManager) should match: admin alone qualifies")
}
// Same expression against a non-admin manager who lacks ads_manager:
pNonAdmin := &Principal{Roles: []string{"manager"}}
if expr.Match(pNonAdmin) {
t.Fatalf("manager without ads_manager should not match the compound")
}
pBoth := &Principal{Roles: []string{"manager", "ads_manager"}}
if !expr.Match(pBoth) {
t.Fatalf("manager+ads_manager should match the AND-arm")
}
}
func TestServiceKeyPredicates(t *testing.T) {
k := &ServiceKey{Abilities: []string{"events:write", "events:read"}}
if !HasAbility("events:write").Match(k) {
t.Fatalf("HasAbility events:write should match")
}
if HasAbility("admin:nuke").Match(k) {
t.Fatalf("HasAbility admin:nuke should not match")
}
if !AllServiceKey(HasAbility("events:write"), HasAbility("events:read")).Match(k) {
t.Fatalf("AllServiceKey should match when key carries both")
}
if AnyServiceKey(HasAbility("admin:nuke"), HasAbility("missing")).Match(k) {
t.Fatalf("AnyServiceKey should not match with no candidates")
}
}

109
cmd/abilities/main.go Normal file
View file

@ -0,0 +1,109 @@
// Command abilities is the seeding CLI for service-token abilities.
//
// abilities create <slug> [--label "..."]
// abilities list
// abilities delete <slug>
package main
import (
"context"
"errors"
"fmt"
"os"
"git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp"
)
func main() {
if len(os.Args) < 2 {
usage()
os.Exit(2)
}
sub := os.Args[1]
args := os.Args[2:]
switch sub {
case "create":
runCreate(args)
case "list":
runList(args)
case "delete", "rm":
runDelete(args)
case "-h", "--help", "help":
usage()
default:
fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub)
usage()
os.Exit(2)
}
}
func usage() {
fmt.Fprintln(os.Stderr, `usage: abilities <subcommand> [args]
Subcommands:
create <slug> [--label "..."] create an ability
list list every ability
delete <slug> delete an ability
Common flags:
--dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`)
}
func runCreate(args []string) {
fs, dsn := clihelp.DSNFlag("abilities create")
label := fs.String("label", "", "optional human label")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("create takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
ab, err := a.CreateAbility(ctx, rest[0], *label)
if err != nil {
clihelp.Fail(err)
}
fmt.Printf("created ability %s (id=%s, label=%q)\n", ab.Slug, ab.ID, ab.Label)
}
func runList(args []string) {
fs, dsn := clihelp.DSNFlag("abilities list")
_ = fs.Parse(args)
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
abilities, err := a.ListAbilities(ctx)
if err != nil {
clihelp.Fail(err)
}
for _, ab := range abilities {
fmt.Printf("%s\t%s\n", ab.Slug, ab.Label)
}
}
func runDelete(args []string) {
fs, dsn := clihelp.DSNFlag("abilities delete")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("delete takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
if err := a.DeleteAbility(ctx, rest[0]); err != nil {
clihelp.Fail(err)
}
fmt.Printf("deleted ability %s\n", rest[0])
}

View file

@ -0,0 +1,70 @@
// Package clihelp is a small helper used by the cmd/perms, cmd/roles, and
// cmd/abilities seeding CLIs to dial Postgres, build an *authkit.Auth, and
// share argument-parsing scaffolding.
package clihelp
import (
"context"
"database/sql"
"errors"
"flag"
"fmt"
"os"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
_ "github.com/jackc/pgx/v5/stdlib"
)
// DSNFlag returns a flag.FlagSet pre-populated with --dsn. Callers add their
// own flags and call Parse on it.
func DSNFlag(name string) (*flag.FlagSet, *string) {
fs := flag.NewFlagSet(name, flag.ExitOnError)
dsn := fs.String("dsn", "", "PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)")
return fs, dsn
}
// Dial opens a database connection using either the supplied DSN or the
// AUTHKIT_DATABASE_URL env var, then constructs an *authkit.Auth ready to
// run seed operations. Migrations and schema verification both run as part
// of New.
//
// The CLIs never sign JWTs or hash passwords, but Auth.New requires a JWT
// secret and a hasher — we supply a dummy secret and the default Argon2id
// hasher so the constructor passes.
func Dial(ctx context.Context, dsn string) (*authkit.Auth, *sql.DB, error) {
if dsn == "" {
dsn = os.Getenv("AUTHKIT_DATABASE_URL")
}
if dsn == "" {
return nil, nil, errors.New("no DSN: pass --dsn or set AUTHKIT_DATABASE_URL")
}
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, nil, fmt.Errorf("sql.Open: %w", err)
}
if err := db.PingContext(ctx); err != nil {
_ = db.Close()
return nil, nil, fmt.Errorf("ping: %w", err)
}
a, err := authkit.New(ctx, authkit.Deps{
DB: db,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
// JWT secret is unused by seed flows but required by New.
JWTSecret: []byte("authkit-cli-not-used-for-anything-real"),
})
if err != nil {
_ = db.Close()
return nil, nil, err
}
return a, db, nil
}
// Fail prints err to stderr and exits with status 1. Used by every CLI's
// top-level dispatch.
func Fail(err error) {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}

114
cmd/perms/main.go Normal file
View file

@ -0,0 +1,114 @@
// Command perms is the seeding CLI for authkit permissions.
//
// perms create <slug> [--label "..."]
// perms list
// perms delete <slug>
//
// Database connection comes from --dsn or $AUTHKIT_DATABASE_URL.
package main
import (
"context"
"errors"
"fmt"
"os"
"git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp"
)
func main() {
if len(os.Args) < 2 {
usage()
os.Exit(2)
}
sub := os.Args[1]
args := os.Args[2:]
switch sub {
case "create":
runCreate(args)
case "list":
runList(args)
case "delete", "rm":
runDelete(args)
case "-h", "--help", "help":
usage()
default:
fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub)
usage()
os.Exit(2)
}
}
func usage() {
fmt.Fprintln(os.Stderr, `usage: perms <subcommand> [args]
Subcommands:
create <slug> [--label "..."] create a permission
list list every permission
delete <slug> delete a permission
Common flags:
--dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`)
}
func runCreate(args []string) {
fs, dsn := clihelp.DSNFlag("perms create")
label := fs.String("label", "", "optional human label")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("create takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
p, err := a.CreatePermission(ctx, rest[0], *label)
if err != nil {
clihelp.Fail(err)
}
fmt.Printf("created permission %s (id=%s, label=%q)\n", p.Slug, p.ID, p.Label)
}
func runList(args []string) {
fs, dsn := clihelp.DSNFlag("perms list")
_ = fs.Parse(args)
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
perms, err := a.ListPermissions(ctx)
if err != nil {
clihelp.Fail(err)
}
for _, p := range perms {
fmt.Printf("%s\t%s\n", p.Slug, p.Label)
}
}
func runDelete(args []string) {
fs, dsn := clihelp.DSNFlag("perms delete")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("delete takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
if err := a.DeletePermission(ctx, rest[0]); err != nil {
clihelp.Fail(err)
}
fmt.Printf("deleted permission %s\n", rest[0])
}

182
cmd/roles/main.go Normal file
View file

@ -0,0 +1,182 @@
// Command roles is the seeding CLI for authkit roles, plus role↔permission
// linking.
//
// roles create <slug> [--label "..."]
// roles list
// roles delete <slug>
// roles grant <role-slug> <perm-slug>
// roles revoke <role-slug> <perm-slug>
// roles permissions <role-slug>
package main
import (
"context"
"errors"
"fmt"
"os"
"git.juancwu.dev/juancwu/authkit/cmd/internal/clihelp"
)
func main() {
if len(os.Args) < 2 {
usage()
os.Exit(2)
}
sub := os.Args[1]
args := os.Args[2:]
switch sub {
case "create":
runCreate(args)
case "list":
runList(args)
case "delete", "rm":
runDelete(args)
case "grant":
runGrant(args)
case "revoke":
runRevoke(args)
case "permissions", "perms":
runPermissions(args)
case "-h", "--help", "help":
usage()
default:
fmt.Fprintf(os.Stderr, "unknown subcommand %q\n\n", sub)
usage()
os.Exit(2)
}
}
func usage() {
fmt.Fprintln(os.Stderr, `usage: roles <subcommand> [args]
Subcommands:
create <slug> [--label "..."] create a role
list list every role
delete <slug> delete a role
grant <role-slug> <perm-slug> grant a permission to a role
revoke <role-slug> <perm-slug> revoke a permission from a role
permissions <role-slug> list permissions granted to a role
Common flags:
--dsn PostgreSQL DSN (defaults to $AUTHKIT_DATABASE_URL)`)
}
func runCreate(args []string) {
fs, dsn := clihelp.DSNFlag("roles create")
label := fs.String("label", "", "optional human label")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("create takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
r, err := a.CreateRole(ctx, rest[0], *label)
if err != nil {
clihelp.Fail(err)
}
fmt.Printf("created role %s (id=%s, label=%q)\n", r.Slug, r.ID, r.Label)
}
func runList(args []string) {
fs, dsn := clihelp.DSNFlag("roles list")
_ = fs.Parse(args)
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
roles, err := a.ListRoles(ctx)
if err != nil {
clihelp.Fail(err)
}
for _, r := range roles {
fmt.Printf("%s\t%s\n", r.Slug, r.Label)
}
}
func runDelete(args []string) {
fs, dsn := clihelp.DSNFlag("roles delete")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("delete takes exactly one slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
if err := a.DeleteRole(ctx, rest[0]); err != nil {
clihelp.Fail(err)
}
fmt.Printf("deleted role %s\n", rest[0])
}
func runGrant(args []string) {
fs, dsn := clihelp.DSNFlag("roles grant")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 2 {
clihelp.Fail(errors.New("grant takes <role-slug> <perm-slug>"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
if err := a.GrantPermissionToRole(ctx, rest[0], rest[1]); err != nil {
clihelp.Fail(err)
}
fmt.Printf("granted %s to role %s\n", rest[1], rest[0])
}
func runRevoke(args []string) {
fs, dsn := clihelp.DSNFlag("roles revoke")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 2 {
clihelp.Fail(errors.New("revoke takes <role-slug> <perm-slug>"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
if err := a.RevokePermissionFromRole(ctx, rest[0], rest[1]); err != nil {
clihelp.Fail(err)
}
fmt.Printf("revoked %s from role %s\n", rest[1], rest[0])
}
func runPermissions(args []string) {
fs, dsn := clihelp.DSNFlag("roles permissions")
_ = fs.Parse(args)
rest := fs.Args()
if len(rest) != 1 {
clihelp.Fail(errors.New("permissions takes exactly one role-slug argument"))
}
ctx := context.Background()
a, db, err := clihelp.Dial(ctx, *dsn)
if err != nil {
clihelp.Fail(err)
}
defer db.Close()
perms, err := a.ListRolePermissions(ctx, rest[0])
if err != nil {
clihelp.Fail(err)
}
for _, p := range perms {
fmt.Printf("%s\t%s\n", p.Slug, p.Label)
}
}

34
doc.go
View file

@ -1,14 +1,24 @@
// Package authkit is an authentication and authorization toolkit for Go web
// services. It defines storage interfaces (UserStore, SessionStore, TokenStore,
// ServiceKeyStore, RoleStore, PermissionStore) and a high-level Auth service
// that composes them to support registration, password login, opaque
// server-side sessions, JWT access plus rotating refresh tokens, email
// verification, password resets, magic-link passwordless login, role-based
// access control, and owner-agnostic service tokens with custom abilities for
// server-to-server auth.
// Package authkit is a pragmatic authentication and authorization toolkit
// for Go web services on PostgreSQL 16+.
//
// Default Postgres implementations of every store live in the pgstore
// subpackage. Argon2id password hashing lives in hasher. Framework-neutral
// HTTP middleware (compatible with lightmux and any net/http stack) lives in
// middleware.
// Drop authkit into a net/http stack and get registration, password login,
// opaque server-side sessions, JWT access tokens with rotating refresh,
// email verification, password reset, magic-link login, email OTP, and
// owner-agnostic service tokens with consumer-defined abilities.
// Authorization is flat RBAC with both role-derived and direct user
// permissions.
//
// Roles, permissions, and abilities are seeded by the consumer (typically
// via the cmd/perms, cmd/roles, and cmd/abilities CLIs that ship with this
// repo). The library does not seed any rows automatically — applications
// own their authorization vocabulary.
//
// Migrations and schema verification run at startup. Set
// Config.SkipAutoMigrate to disable.
//
// The library does not send email or otherwise reach out to users.
// Token-minting flows (RequestEmailVerification, RequestPasswordReset,
// RequestMagicLink, RequestEmailOTP, IssueServiceKey, IssueSession,
// IssueJWT) return the plaintext to the caller exactly once — show it to
// the user immediately; only its SHA-256 hash is persisted.
package authkit

11
email.go Normal file
View file

@ -0,0 +1,11 @@
package authkit
import "strings"
// normalizeEmail produces the lookup form used by GetUserByEmail and the
// email_normalized column. Trim + lowercase is intentional; we do not
// collapse Gmail-style "+" addressing or strip dots — that's a policy
// decision callers can layer on top.
func normalizeEmail(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}

21
email_test.go Normal file
View file

@ -0,0 +1,21 @@
package authkit
import "testing"
func TestNormalizeEmail(t *testing.T) {
cases := []struct {
in, want string
}{
{"alice@example.com", "alice@example.com"},
{"Alice@Example.com", "alice@example.com"},
{"ALICE@EXAMPLE.COM", "alice@example.com"},
{" alice@example.com ", "alice@example.com"},
{"\talice@EXAMPLE.com\n", "alice@example.com"},
}
for _, c := range cases {
got := normalizeEmail(c.in)
if got != c.want {
t.Fatalf("normalizeEmail(%q) = %q, want %q", c.in, got, c.want)
}
}
}

View file

@ -2,11 +2,12 @@ package authkit
import "errors"
// Sentinel errors. Internal call sites wrap these via errx so callers using
// errors.Is(err, authkit.ErrFoo) get reliable matching across wrap chains.
var (
ErrEmailTaken = errors.New("authkit: email already registered")
ErrUserNotFound = errors.New("authkit: user not found")
ErrInvalidCredentials = errors.New("authkit: invalid credentials")
ErrEmailNotVerified = errors.New("authkit: email not verified")
ErrTokenInvalid = errors.New("authkit: invalid or expired token")
ErrTokenReused = errors.New("authkit: token reuse detected")
ErrSessionInvalid = errors.New("authkit: invalid or expired session")
@ -14,5 +15,10 @@ var (
ErrPermissionDenied = errors.New("authkit: permission denied")
ErrRoleNotFound = errors.New("authkit: role not found")
ErrPermissionNotFound = errors.New("authkit: permission not found")
ErrConfigInvalid = errors.New("authkit: invalid configuration")
ErrAbilityNotFound = errors.New("authkit: ability not found")
ErrSlugInvalid = errors.New("authkit: invalid slug")
ErrSlugTaken = errors.New("authkit: slug already in use")
ErrOTPInvalid = errors.New("authkit: invalid or expired OTP")
ErrNoUserContext = errors.New("authkit: no user on request context")
ErrSchemaDrift = errors.New("authkit: database schema does not match expected layout")
)

View file

@ -5,8 +5,8 @@ import (
"strings"
)
// Extractor pulls a credential string out of an HTTP request. It returns
// (value, true) when a value was found, otherwise ("", false).
// Extractor pulls a credential string out of an HTTP request. Returns
// (value, true) when found, ("", false) otherwise.
type Extractor func(r *http.Request) (string, bool)
// BearerExtractor reads the value following "Bearer " in the Authorization

86
extractor_test.go Normal file
View file

@ -0,0 +1,86 @@
package authkit
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestBearerExtractor(t *testing.T) {
ex := BearerExtractor()
cases := []struct {
name string
header string
want string
ok bool
}{
{"plain bearer", "Bearer abc", "abc", true},
{"lowercase bearer", "bearer abc", "abc", true},
{"mixed case", "BeArEr abc", "abc", true},
{"no header", "", "", false},
{"non-bearer scheme", "Basic abc", "", false},
{"bearer with no token", "Bearer ", "", false},
{"bearer with whitespace", "Bearer abc ", "abc", true},
{"too short", "Bearer", "", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if c.header != "" {
r.Header.Set("Authorization", c.header)
}
got, ok := ex(r)
if ok != c.ok || got != c.want {
t.Fatalf("got (%q, %v), want (%q, %v)", got, ok, c.want, c.ok)
}
})
}
}
func TestCookieExtractor(t *testing.T) {
ex := CookieExtractor("session")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.AddCookie(&http.Cookie{Name: "session", Value: "abc"})
got, ok := ex(r)
if !ok || got != "abc" {
t.Fatalf("got (%q, %v), want (\"abc\", true)", got, ok)
}
r2 := httptest.NewRequest(http.MethodGet, "/", nil)
if _, ok := ex(r2); ok {
t.Fatalf("missing cookie should not extract")
}
r3 := httptest.NewRequest(http.MethodGet, "/", nil)
r3.AddCookie(&http.Cookie{Name: "session", Value: ""})
if _, ok := ex(r3); ok {
t.Fatalf("empty cookie value should not extract")
}
}
func TestHeaderExtractor(t *testing.T) {
ex := HeaderExtractor("X-API-Token")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("X-API-Token", " abc ")
got, ok := ex(r)
if !ok || got != "abc" {
t.Fatalf("got (%q, %v), want (\"abc\", true)", got, ok)
}
}
func TestChainExtractors(t *testing.T) {
a := func(r *http.Request) (string, bool) { return "", false }
b := func(r *http.Request) (string, bool) { return "from-b", true }
c := func(r *http.Request) (string, bool) { return "from-c", true }
chain := ChainExtractors(a, b, c)
r := httptest.NewRequest(http.MethodGet, "/", nil)
got, ok := chain(r)
if !ok || got != "from-b" {
t.Fatalf("chain should return first hit; got (%q, %v)", got, ok)
}
none := ChainExtractors(a, a)
if _, ok := none(r); ok {
t.Fatalf("chain of misses should not extract")
}
}

View file

@ -1,7 +1,11 @@
// Package hasher provides authkit.Hasher implementations. The default
// implementation, Argon2id, encodes hashes in the standard PHC string format
// (https://github.com/P-H-C/phc-string-format) so callers can introspect
// parameters and migrate.
// Package hasher provides password-hashing primitives that satisfy the
// authkit.Hasher interface. The default implementation, Argon2id, encodes
// hashes in the standard PHC string format (https://github.com/P-H-C/phc-string-format)
// so callers can introspect parameters and migrate.
//
// This package intentionally does not import authkit — the Hasher interface
// is structurally satisfied. That keeps the dependency arrow one-way and
// lets test code in the authkit package itself import this package.
package hasher
import (
@ -13,7 +17,6 @@ import (
"io"
"strings"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"golang.org/x/crypto/argon2"
)
@ -40,24 +43,27 @@ func DefaultArgon2idParams() Argon2idParams {
}
}
type argon2idHasher struct {
// Argon2idHasher implements password hashing via Argon2id. It satisfies
// authkit.Hasher through structural typing — pass *Argon2idHasher into
// authkit.Deps.Hasher.
type Argon2idHasher struct {
params Argon2idParams
rng io.Reader
}
// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the
// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand.
func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher {
// NewArgon2id builds an *Argon2idHasher. If params is the zero value,
// DefaultArgon2idParams() is used. rng defaults to crypto/rand.
func NewArgon2id(params Argon2idParams, rng io.Reader) *Argon2idHasher {
if params == (Argon2idParams{}) {
params = DefaultArgon2idParams()
}
if rng == nil {
rng = rand.Reader
}
return &argon2idHasher{params: params, rng: rng}
return &Argon2idHasher{params: params, rng: rng}
}
func (h *argon2idHasher) Hash(password string) (string, error) {
func (h *Argon2idHasher) Hash(password string) (string, error) {
const op = "authkit.hasher.Argon2id.Hash"
if password == "" {
return "", errx.New(op, "password is empty")
@ -71,7 +77,7 @@ func (h *argon2idHasher) Hash(password string) (string, error) {
return encodePHC(h.params, salt, key), nil
}
func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) {
func (h *Argon2idHasher) Verify(password, encoded string) (bool, bool, error) {
const op = "authkit.hasher.Argon2id.Verify"
got, salt, key, err := decodePHC(encoded)
if err != nil {

7
jwt.go
View file

@ -6,9 +6,9 @@ import (
"github.com/google/uuid"
)
// accessClaims is the JWT shape issued by IssueJWT. The session_version
// field carries the User.SessionVersion at issue time so AuthenticateJWT
// can detect global revocations (logout-everywhere, password change).
// accessClaims is the JWT shape issued by IssueJWT. session_version carries
// User.SessionVersion at issue time so AuthenticateJWT can detect global
// revocations (logout-everywhere, password change).
type accessClaims struct {
jwt.RegisteredClaims
SessionVersion int `json:"sv"`
@ -39,6 +39,7 @@ func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, er
}
// parseAccessToken validates the signature and returns the parsed claims.
// Strictly enforces HS256 — alg=none and asymmetric algorithms are rejected.
func (a *Auth) parseAccessToken(token string) (*accessClaims, error) {
const op = "authkit.parseAccessToken"
opts := []jwt.ParserOption{

View file

@ -1,99 +0,0 @@
package authkit
import (
"context"
"errors"
"testing"
"time"
)
func TestJWTIssueAndAuthenticate(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "alice@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, refresh, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if access == "" || refresh == "" {
t.Fatalf("empty tokens")
}
p, err := a.AuthenticateJWT(context.Background(), access)
if err != nil {
t.Fatalf("AuthenticateJWT: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("principal user id mismatch")
}
if p.Method != AuthMethodJWT {
t.Fatalf("principal method = %s, want jwt", p.Method)
}
}
func TestJWTSessionVersionMismatchRejected(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "bob@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, _, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil {
t.Fatalf("RevokeAllUserSessions: %v", err)
}
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err)
}
}
func TestJWTRefreshRotationAndReuseDetection(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "carol@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
_, refresh1, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
_, refresh2, err := a.RefreshJWT(context.Background(), refresh1)
if err != nil {
t.Fatalf("first RefreshJWT: %v", err)
}
if refresh1 == refresh2 {
t.Fatalf("refresh token did not rotate")
}
// Replaying refresh1 must surface ErrTokenReused and revoke the chain.
if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) {
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
}
// After chain revocation, even refresh2 (the legitimate next one) must
// be rejected.
if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err)
}
}
func TestJWTExpiredTokenRejected(t *testing.T) {
a := newTestAuth(t)
now := time.Now().UTC()
a.cfg.Clock = func() time.Time { return now }
u, err := a.Register(context.Background(), "dan@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, _, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
// Advance clock past TTL.
a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) }
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err)
}
}

View file

@ -1,612 +0,0 @@
package authkit
// In-memory store fakes used by service-level tests. Kept in a _test.go file
// so they don't ship in the public API. Each fake is intentionally minimal —
// it satisfies the interface and supports the flows the tests exercise; not
// every method is wired up (a few panic to flag accidental use).
import (
"bytes"
"context"
"errors"
"slices"
"sync"
"time"
"github.com/google/uuid"
)
type memUserStore struct {
mu sync.Mutex
byID map[uuid.UUID]*User
byEml map[string]uuid.UUID
}
func newMemUserStore() *memUserStore {
return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}}
}
func (s *memUserStore) CreateUser(_ context.Context, u *User) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.byEml[u.EmailNormalized]; exists {
return ErrEmailTaken
}
if u.ID == uuid.Nil {
u.ID = uuid.New()
}
cp := *u
s.byID[u.ID] = &cp
s.byEml[u.EmailNormalized] = u.ID
return nil
}
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return nil, ErrUserNotFound
}
cp := *u
return &cp, nil
}
func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.byEml[ne]
if !ok {
return nil, ErrUserNotFound
}
u := *s.byID[id]
return &u, nil
}
func (s *memUserStore) UpdateUser(_ context.Context, u *User) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.byID[u.ID]; !ok {
return ErrUserNotFound
}
cp := *u
s.byID[u.ID] = &cp
return nil
}
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
delete(s.byID, id)
delete(s.byEml, u.EmailNormalized)
return nil
}
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.PasswordHash = h
return nil
}
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.EmailVerifiedAt = &at
return nil
}
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return 0, ErrUserNotFound
}
u.SessionVersion++
return u.SessionVersion, nil
}
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return 0, ErrUserNotFound
}
u.FailedLogins++
return u.FailedLogins, nil
}
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.FailedLogins = 0
return nil
}
type memSessionStore struct {
mu sync.Mutex
m map[string]*Session
}
func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} }
func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *ses
s.m[string(ses.IDHash)] = &cp
return nil
}
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.m[string(h)]
if !ok {
return nil, ErrSessionInvalid
}
cp := *v
return &cp, nil
}
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.m[string(h)]
if !ok {
return ErrSessionInvalid
}
v.LastSeenAt = lastSeen
v.ExpiresAt = exp
return nil
}
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.m, string(h))
return nil
}
func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.m {
if v.UserID == uid {
delete(s.m, k)
}
}
return nil
}
func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, v := range s.m {
if !v.ExpiresAt.After(now) {
delete(s.m, k)
n++
}
}
return n, nil
}
type memTokenStore struct {
mu sync.Mutex
// Keyed by kind+hex(hash) so identical hashes across kinds don't collide.
m map[string]*Token
}
func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} }
func tokKey(kind TokenKind, h []byte) string {
return string(kind) + ":" + string(h)
}
func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *t
s.m[tokKey(t.Kind, t.Hash)] = &cp
return nil
}
func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.m[tokKey(kind, h)]
if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) {
return nil, ErrTokenInvalid
}
t.ConsumedAt = &now
cp := *t
return &cp, nil
}
func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.m[tokKey(kind, h)]
if !ok {
return nil, ErrTokenInvalid
}
cp := *t
return &cp, nil
}
func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, t := range s.m {
if t.ChainID != nil && *t.ChainID == chainID {
delete(s.m, k)
n++
}
}
return n, nil
}
func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, t := range s.m {
if !t.ExpiresAt.After(now) {
delete(s.m, k)
n++
}
}
return n, nil
}
type memServiceKeyStore struct {
mu sync.Mutex
m map[string]*ServiceKey
}
func newMemServiceKeyStore() *memServiceKeyStore {
return &memServiceKeyStore{m: map[string]*ServiceKey{}}
}
func (s *memServiceKeyStore) CreateServiceKey(_ context.Context, k *ServiceKey) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
s.m[string(k.IDHash)] = &cp
return nil
}
func (s *memServiceKeyStore) GetServiceKey(_ context.Context, h []byte) (*ServiceKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return nil, ErrServiceKeyInvalid
}
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
return &cp, nil
}
func (s *memServiceKeyStore) ListServiceKeysByOwner(_ context.Context, ownerKind string, owner uuid.UUID) ([]*ServiceKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
var out []*ServiceKey
for _, k := range s.m {
if k.OwnerKind == ownerKind && k.OwnerID == owner {
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
out = append(out, &cp)
}
}
return out, nil
}
func (s *memServiceKeyStore) TouchServiceKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
if k, ok := s.m[string(h)]; ok {
k.LastUsedAt = &at
}
return nil
}
func (s *memServiceKeyStore) RevokeServiceKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return ErrServiceKeyInvalid
}
if k.RevokedAt != nil {
return ErrServiceKeyInvalid
}
k.RevokedAt = &at
return nil
}
type memRoleStore struct {
mu sync.Mutex
roles map[uuid.UUID]*Role
rolesByNm map[string]uuid.UUID
userRoles map[uuid.UUID]map[uuid.UUID]struct{}
}
func newMemRoleStore() *memRoleStore {
return &memRoleStore{
roles: map[uuid.UUID]*Role{},
rolesByNm: map[string]uuid.UUID{},
userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{},
}
}
func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.rolesByNm[r.Name]; dup {
return errors.New("role exists")
}
if r.ID == uuid.Nil {
r.ID = uuid.New()
}
cp := *r
s.roles[r.ID] = &cp
s.rolesByNm[r.Name] = r.ID
return nil
}
func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.roles[id]
if !ok {
return nil, ErrRoleNotFound
}
cp := *r
return &cp, nil
}
func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.rolesByNm[n]
if !ok {
return nil, ErrRoleNotFound
}
r := *s.roles[id]
return &r, nil
}
func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*Role, 0, len(s.roles))
for _, r := range s.roles {
cp := *r
out = append(out, &cp)
}
return out, nil
}
func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.roles[id]
if !ok {
return ErrRoleNotFound
}
delete(s.roles, id)
delete(s.rolesByNm, r.Name)
for _, m := range s.userRoles {
delete(m, id)
}
return nil
}
func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
m, ok := s.userRoles[uid]
if !ok {
m = map[uuid.UUID]struct{}{}
s.userRoles[uid] = m
}
m[rid] = struct{}{}
return nil
}
func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
if m, ok := s.userRoles[uid]; ok {
delete(m, rid)
}
return nil
}
func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := []*Role{}
for rid := range s.userRoles[uid] {
if r, ok := s.roles[rid]; ok {
cp := *r
out = append(out, &cp)
}
}
return out, nil
}
func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
for rid := range s.userRoles[uid] {
r := s.roles[rid]
if slices.Contains(names, r.Name) {
return true, nil
}
}
return false, nil
}
type memPermStore struct {
mu sync.Mutex
perms map[uuid.UUID]*Permission
permsByNm map[string]uuid.UUID
rolePerms map[uuid.UUID]map[uuid.UUID]struct{}
roles *memRoleStore
}
func newMemPermStore(rs *memRoleStore) *memPermStore {
return &memPermStore{
perms: map[uuid.UUID]*Permission{},
permsByNm: map[string]uuid.UUID{},
rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{},
roles: rs,
}
}
func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.permsByNm[p.Name]; dup {
return errors.New("perm exists")
}
if p.ID == uuid.Nil {
p.ID = uuid.New()
}
cp := *p
s.perms[p.ID] = &cp
s.permsByNm[p.Name] = p.ID
return nil
}
func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.perms[id]
if !ok {
return nil, ErrPermissionNotFound
}
cp := *p
return &cp, nil
}
func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.permsByNm[n]
if !ok {
return nil, ErrPermissionNotFound
}
p := *s.perms[id]
return &p, nil
}
func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*Permission, 0, len(s.perms))
for _, p := range s.perms {
cp := *p
out = append(out, &cp)
}
return out, nil
}
func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.perms[id]
if !ok {
return ErrPermissionNotFound
}
delete(s.perms, id)
delete(s.permsByNm, p.Name)
for _, m := range s.rolePerms {
delete(m, id)
}
return nil
}
func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
m, ok := s.rolePerms[rid]
if !ok {
m = map[uuid.UUID]struct{}{}
s.rolePerms[rid] = m
}
m[pid] = struct{}{}
return nil
}
func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
if m, ok := s.rolePerms[rid]; ok {
delete(m, pid)
}
return nil
}
func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := []*Permission{}
for pid := range s.rolePerms[rid] {
if p, ok := s.perms[pid]; ok {
cp := *p
out = append(out, &cp)
}
}
return out, nil
}
func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) {
s.roles.mu.Lock()
roleIDs := make([]uuid.UUID, 0)
for rid := range s.roles.userRoles[uid] {
roleIDs = append(roleIDs, rid)
}
s.roles.mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
seen := map[uuid.UUID]struct{}{}
out := []*Permission{}
for _, rid := range roleIDs {
for pid := range s.rolePerms[rid] {
if _, dup := seen[pid]; dup {
continue
}
seen[pid] = struct{}{}
if p, ok := s.perms[pid]; ok {
cp := *p
out = append(out, &cp)
}
}
}
return out, nil
}
// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in
// the hasher package itself; we just need *something* callable here.
type stubHasher struct{}
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
want := "stub:" + p
if !bytes.Equal([]byte(want), []byte(encoded)) {
return false, false, nil
}
return true, false, nil
}
// newTestAuth wires the fakes into Auth with deterministic config.
func newTestAuth(t interface{ Helper() }) *Auth {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
}
roles := newMemRoleStore()
return New(Deps{
Users: newMemUserStore(),
Sessions: newMemSessionStore(),
Tokens: newMemTokenStore(),
ServiceKeys: newMemServiceKeyStore(),
Roles: roles,
Permissions: newMemPermStore(roles),
Hasher: stubHasher{},
}, Config{
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
JWTIssuer: "authkit-test",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: 1 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
})
}

View file

@ -1,84 +0,0 @@
package middleware
import (
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// authzGuard wraps the common pattern of "look up the Principal, run a
// predicate, succeed or 403". onForbidden defaults to JSON 403.
func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler {
if onForbidden == nil {
onForbidden = defaultJSONError(http.StatusForbidden)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p, ok := PrincipalFrom(r.Context())
if !ok {
// No auth middleware ran upstream; treat as forbidden
// rather than crashing — composition is the caller's
// responsibility but a 403 is the safer default.
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
if !pred(p) {
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireRole permits requests whose Principal holds the named role.
func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasRole(name)
})
}
// RequireAnyRole permits requests whose Principal holds at least one of the
// named roles.
func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasAnyRole(names...)
})
}
// RequirePermission permits requests whose Principal holds the named
// permission (resolved via roles at auth time).
func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasPermission(name)
})
}
// RequireAbility permits requests whose ServiceKey carries the named ability.
// Abilities live only on service tokens — this middleware reads
// *authkit.ServiceKey from the request context (placed by RequireServiceKey
// or RequireAnyOrServiceKey) and 403s any request authenticated as a user
// (session or JWT), which by definition has no abilities.
func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
onForb := firstOrNil(onForbidden)
if onForb == nil {
onForb = defaultJSONError(http.StatusForbidden)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
k, ok := ServiceKeyFrom(r.Context())
if !ok || !k.HasAbility(name) {
onForb(w, r, authkit.ErrPermissionDenied)
return
}
next.ServeHTTP(w, r)
})
}
}
func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) {
if len(s) == 0 {
return nil
}
return s[0]
}

View file

@ -1,64 +0,0 @@
// Package middleware provides framework-neutral HTTP middleware for authkit.
// Every middleware function returns the standard func(http.Handler)
// http.Handler type, so it composes with lightmux's Use/Group/Handle as well
// as any net/http stack that uses the same signature.
package middleware
import (
"context"
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// principalKey and serviceKeyKey are unexported context keys. Using distinct
// empty struct types guarantees no collision with caller-defined keys.
type principalKey struct{}
type serviceKeyKey struct{}
// withPrincipal stashes p on the request context for downstream handlers.
func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context {
return context.WithValue(ctx, principalKey{}, p)
}
// PrincipalFrom retrieves the authenticated Principal placed by RequireSession
// or RequireJWT. The boolean is false if no user-bound auth middleware ran for
// this request (e.g. the request was authenticated via service key instead).
func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) {
p, ok := ctx.Value(principalKey{}).(*authkit.Principal)
return p, ok
}
// MustPrincipal panics if no Principal is on the context. Use only on
// handlers known to be behind a Require* middleware that authenticates a
// user (RequireSession or RequireJWT).
func MustPrincipal(r *http.Request) *authkit.Principal {
p, ok := PrincipalFrom(r.Context())
if !ok {
panic("authkit/middleware: no principal on request context")
}
return p
}
// withServiceKey stashes k on the request context for downstream handlers.
func withServiceKey(ctx context.Context, k *authkit.ServiceKey) context.Context {
return context.WithValue(ctx, serviceKeyKey{}, k)
}
// ServiceKeyFrom retrieves the authenticated ServiceKey placed by
// RequireServiceKey. The boolean is false if no service-key middleware ran
// for this request.
func ServiceKeyFrom(ctx context.Context) (*authkit.ServiceKey, bool) {
k, ok := ctx.Value(serviceKeyKey{}).(*authkit.ServiceKey)
return k, ok
}
// MustServiceKey panics if no ServiceKey is on the context. Use only on
// handlers known to be behind RequireServiceKey.
func MustServiceKey(r *http.Request) *authkit.ServiceKey {
k, ok := ServiceKeyFrom(r.Context())
if !ok {
panic("authkit/middleware: no service key on request context")
}
return k
}

View file

@ -1,45 +1,236 @@
// Package middleware provides framework-neutral HTTP middleware for authkit.
// Every middleware function returns the standard func(http.Handler)
// http.Handler shape so it composes with lightmux.Mux.Use/Group/Handle, with
// chi/gorilla, or with any net/http stack accepting that signature.
//
// Three primitives:
// - RequireLogin — accept session OR JWT, optionally constrain by Authz
// - RequireGuest — reject authenticated requests
// - RequireServiceKey — accept a service token, optionally constrain by Authz
//
// All three attach the relevant subject to the request context via
// authkit.WithUserContext or authkit.WithServiceKey, so handlers can read
// it via authkit.UserIDFromCtx / authkit.UserFromCtx /
// authkit.ServiceKeyFromCtx.
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// Options configures auth middleware. Auth is required; the rest fall back
// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403
// on authz failure.
type Options struct {
Auth *authkit.Auth
Extractor authkit.Extractor
// LoginOptions configures RequireLogin.
type LoginOptions struct {
// Auth is required.
Auth *authkit.Auth
// SessionExtractor reads the session plaintext from the request.
// Defaults to a cookie extractor using Auth.SessionCookieName().
SessionExtractor authkit.Extractor
// JWTExtractor reads the JWT access token from the request. Defaults
// to BearerExtractor.
JWTExtractor authkit.Extractor
// Authz, if non-nil, gates the request on a predicate over the
// resolved *Principal. Validate is called once at construction; an
// invalid predicate (unknown slug) panics.
Authz authkit.LoginAuthz
// OnUnauth handles "no credential / bad credential" failures (HTTP
// 401). Default: JSON {"error":"Unauthorized"}.
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
// OnForbidden handles "credential ok but Authz failed" (HTTP 403).
// Default: JSON {"error":"Forbidden"}.
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
}
// RequireLogin returns middleware that authenticates the request via either
// a session cookie or a JWT (in that order) and, if Authz is set, gates the
// resolved Principal against the predicate.
//
// Panics at construction time if Auth is nil or Authz references unknown
// slugs.
func RequireLogin(opts LoginOptions) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: LoginOptions.Auth is required")
}
if opts.Authz != nil {
if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil {
panic(fmt.Sprintf("authkit/middleware: %v", err))
}
}
sessionEx := opts.SessionExtractor
if sessionEx == nil {
sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName())
}
jwtEx := opts.JWTExtractor
if jwtEx == nil {
jwtEx = authkit.BearerExtractor()
}
onUnauth := opts.OnUnauth
if onUnauth == nil {
onUnauth = defaultJSONError(http.StatusUnauthorized)
}
onForbidden := opts.OnForbidden
if onForbidden == nil {
onForbidden = defaultJSONError(http.StatusForbidden)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx)
if err != nil {
onUnauth(w, r, err)
return
}
if opts.Authz != nil && !opts.Authz.Match(p) {
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
ctx := authkit.WithUserContext(r.Context(), opts.Auth, p.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// authenticatePrincipal tries the session extractor first, then the JWT
// extractor. Returns the first successful Principal or the last error.
func authenticatePrincipal(r *http.Request, a *authkit.Auth, sessionEx, jwtEx authkit.Extractor) (*authkit.Principal, error) {
if v, ok := sessionEx(r); ok && v != "" {
if p, err := a.AuthenticateSession(r.Context(), v); err == nil {
return p, nil
}
}
if v, ok := jwtEx(r); ok && v != "" {
if p, err := a.AuthenticateJWT(r.Context(), v); err == nil {
return p, nil
}
}
return nil, authkit.ErrSessionInvalid
}
// GuestOptions configures RequireGuest.
type GuestOptions struct {
// Auth is required.
Auth *authkit.Auth
SessionExtractor authkit.Extractor
JWTExtractor authkit.Extractor
// OnAuthenticated handles requests that present a valid credential
// (where a guest was expected). Default: JSON 403.
OnAuthenticated func(w http.ResponseWriter, r *http.Request)
}
// RequireGuest returns middleware that rejects requests carrying a valid
// session or JWT. Useful for /login or /register pages where authenticated
// users should be redirected away.
//
// Default rejection is HTTP 403 JSON. Pass Options.OnAuthenticated to
// implement a redirect or custom response.
func RequireGuest(opts GuestOptions) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: GuestOptions.Auth is required")
}
sessionEx := opts.SessionExtractor
if sessionEx == nil {
sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName())
}
jwtEx := opts.JWTExtractor
if jwtEx == nil {
jwtEx = authkit.BearerExtractor()
}
onAuthenticated := opts.OnAuthenticated
if onAuthenticated == nil {
onAuthenticated = func(w http.ResponseWriter, r *http.Request) {
defaultJSONError(http.StatusForbidden)(w, r, authkit.ErrPermissionDenied)
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx); err == nil {
onAuthenticated(w, r)
return
}
next.ServeHTTP(w, r)
})
}
}
// ServiceKeyOptions configures RequireServiceKey.
type ServiceKeyOptions struct {
Auth *authkit.Auth
// Extractor reads the service token plaintext. Defaults to
// BearerExtractor.
Extractor authkit.Extractor
// Authz, if non-nil, gates the request on a predicate over the
// resolved *ServiceKey.
Authz authkit.ServiceKeyAuthz
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
}
func (o Options) extractor() authkit.Extractor {
if o.Extractor != nil {
return o.Extractor
// RequireServiceKey returns middleware that authenticates the request via a
// service token and, if Authz is set, gates on the predicate.
//
// Panics at construction time if Auth is nil or Authz references unknown
// ability slugs.
func RequireServiceKey(opts ServiceKeyOptions) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: ServiceKeyOptions.Auth is required")
}
if opts.Authz != nil {
if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil {
panic(fmt.Sprintf("authkit/middleware: %v", err))
}
}
ex := opts.Extractor
if ex == nil {
ex = authkit.BearerExtractor()
}
onUnauth := opts.OnUnauth
if onUnauth == nil {
onUnauth = defaultJSONError(http.StatusUnauthorized)
}
onForbidden := opts.OnForbidden
if onForbidden == nil {
onForbidden = defaultJSONError(http.StatusForbidden)
}
return authkit.BearerExtractor()
}
func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) {
if o.OnUnauth != nil {
return o.OnUnauth
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := ex(r)
if !ok || raw == "" {
onUnauth(w, r, authkit.ErrServiceKeyInvalid)
return
}
k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw)
if err != nil {
onUnauth(w, r, err)
return
}
if opts.Authz != nil && !opts.Authz.Match(k) {
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
ctx := authkit.WithServiceKey(r.Context(), k)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
return defaultJSONError(http.StatusUnauthorized)
}
func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) {
if o.OnForbidden != nil {
return o.OnForbidden
}
return defaultJSONError(http.StatusForbidden)
}
func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) {
return func(w http.ResponseWriter, _ *http.Request, err error) {
return func(w http.ResponseWriter, _ *http.Request, _ error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(map[string]string{
@ -47,169 +238,3 @@ func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, e
})
}
}
// RequireSession authenticates the request via an opaque session string. The
// extractor is consulted first; if no extractor is set the default Bearer
// extractor is used. For cookie-based session lookup, set
// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName).
func RequireSession(opts Options) func(http.Handler) http.Handler {
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
return opts.Auth.AuthenticateSession(r.Context(), raw)
})
}
// RequireJWT authenticates the request via an HS256 JWT.
func RequireJWT(opts Options) func(http.Handler) http.Handler {
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
return opts.Auth.AuthenticateJWT(r.Context(), raw)
})
}
// RequireServiceKey authenticates the request via an opaque service token
// secret. On success the resolved *authkit.ServiceKey is placed on the
// request context; downstream handlers retrieve it via ServiceKeyFrom. Note
// that this middleware does NOT place a *Principal on the context — service
// tokens have no user — so user-bound authz middleware (RequireRole,
// RequirePermission) will reject service-key requests with 403.
func RequireServiceKey(opts Options) func(http.Handler) http.Handler {
return requireWithServiceKey(opts, func(r *http.Request, raw string) (*authkit.ServiceKey, error) {
return opts.Auth.AuthenticateServiceKey(r.Context(), raw)
})
}
// RequireAny tries each user-bound method in order until one succeeds. The
// default set is [Session, JWT]; service tokens are NOT included because
// they yield a different subject type. For routes that accept either a user
// credential or a service token, use RequireAnyOrServiceKey.
func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
if len(methods) == 0 {
methods = []authkit.AuthMethod{
authkit.AuthMethodSession,
authkit.AuthMethodJWT,
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := opts.extractor()(r)
if !ok || raw == "" {
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
return
}
var (
p *authkit.Principal
lastErr error
)
for _, m := range methods {
switch m {
case authkit.AuthMethodSession:
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
case authkit.AuthMethodJWT:
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
}
if lastErr == nil && p != nil {
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
return
}
}
opts.onUnauth()(w, r, lastErr)
})
}
}
// RequireAnyOrServiceKey tries the user-bound methods first (default
// [Session, JWT]); on failure, falls through to a service-key lookup. The
// downstream handler sees either a *Principal or a *ServiceKey on context —
// retrieve via PrincipalFrom or ServiceKeyFrom and dispatch accordingly.
func RequireAnyOrServiceKey(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: Options.Auth is required")
}
if len(methods) == 0 {
methods = []authkit.AuthMethod{
authkit.AuthMethodSession,
authkit.AuthMethodJWT,
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := opts.extractor()(r)
if !ok || raw == "" {
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
return
}
var lastErr error
for _, m := range methods {
var p *authkit.Principal
switch m {
case authkit.AuthMethodSession:
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
case authkit.AuthMethodJWT:
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
}
if lastErr == nil && p != nil {
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
return
}
}
k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw)
if err == nil && k != nil {
next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k)))
return
}
if lastErr == nil {
lastErr = err
}
opts.onUnauth()(w, r, lastErr)
})
}
}
// requireWith is the shared scaffolding for the single-method user-bound
// Require* middlewares.
func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: Options.Auth is required")
}
extractor := opts.extractor()
onUnauth := opts.onUnauth()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := extractor(r)
if !ok || raw == "" {
onUnauth(w, r, authkit.ErrSessionInvalid)
return
}
p, err := authn(r, raw)
if err != nil {
onUnauth(w, r, err)
return
}
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
})
}
}
// requireWithServiceKey is the service-key analogue of requireWith. It places
// a *ServiceKey (not a *Principal) on the request context.
func requireWithServiceKey(opts Options, authn func(r *http.Request, raw string) (*authkit.ServiceKey, error)) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: Options.Auth is required")
}
extractor := opts.extractor()
onUnauth := opts.onUnauth()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := extractor(r)
if !ok || raw == "" {
onUnauth(w, r, authkit.ErrServiceKeyInvalid)
return
}
k, err := authn(r, raw)
if err != nil {
onUnauth(w, r, err)
return
}
next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k)))
})
}
}

View file

@ -1,319 +1,92 @@
package middleware_test
// Integration tests for the middleware package. Skipped when
// AUTHKIT_TEST_DATABASE_URL is not set.
import (
"context"
"errors"
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"sync"
"os"
"testing"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
"git.juancwu.dev/juancwu/authkit/middleware"
"github.com/google/uuid"
_ "github.com/jackc/pgx/v5/stdlib"
)
// ─── minimal in-memory stores ──────────────────────────────────────────────
//
// The middleware package can't import the parent's _test stores, so we wire
// up a fresh-but-minimal set here. Only the methods actually exercised by
// the middleware tests below have meaningful bodies; unused store methods
// panic to surface unexpected call paths.
type memUserStore struct {
mu sync.Mutex
m map[uuid.UUID]*authkit.User
}
func newMemUserStore() *memUserStore { return &memUserStore{m: map[uuid.UUID]*authkit.User{}} }
func (s *memUserStore) CreateUser(_ context.Context, u *authkit.User) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, existing := range s.m {
if existing.EmailNormalized == u.EmailNormalized {
return authkit.ErrEmailTaken
}
}
cp := *u
s.m[u.ID] = &cp
return nil
}
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*authkit.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.m[id]
if !ok {
return nil, authkit.ErrUserNotFound
}
cp := *u
return &cp, nil
}
func (s *memUserStore) GetUserByEmail(_ context.Context, normalized string) (*authkit.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
for _, u := range s.m {
if u.EmailNormalized == normalized {
cp := *u
return &cp, nil
}
}
return nil, authkit.ErrUserNotFound
}
func (s *memUserStore) UpdateUser(_ context.Context, u *authkit.User) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *u
s.m[u.ID] = &cp
return nil
}
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.m, id)
return nil
}
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, encoded string) error {
s.mu.Lock()
defer s.mu.Unlock()
if u, ok := s.m[id]; ok {
u.PasswordHash = encoded
}
return nil
}
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
if u, ok := s.m[id]; ok {
u.EmailVerifiedAt = &at
}
return nil
}
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
if u, ok := s.m[id]; ok {
u.SessionVersion++
return u.SessionVersion, nil
}
return 0, authkit.ErrUserNotFound
}
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
if u, ok := s.m[id]; ok {
u.FailedLogins++
return u.FailedLogins, nil
}
return 0, authkit.ErrUserNotFound
}
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
if u, ok := s.m[id]; ok {
u.FailedLogins = 0
}
return nil
}
type memSessionStore struct {
mu sync.Mutex
m map[string]*authkit.Session
}
func newMemSessionStore() *memSessionStore {
return &memSessionStore{m: map[string]*authkit.Session{}}
}
func (s *memSessionStore) CreateSession(_ context.Context, sess *authkit.Session) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *sess
s.m[string(sess.IDHash)] = &cp
return nil
}
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*authkit.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.m[string(h)]
if !ok {
return nil, authkit.ErrSessionInvalid
}
cp := *sess
return &cp, nil
}
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, newExp time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
if sess, ok := s.m[string(h)]; ok {
sess.LastSeenAt = lastSeen
sess.ExpiresAt = newExp
}
return nil
}
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.m, string(h))
return nil
}
func (s *memSessionStore) DeleteUserSessions(_ context.Context, _ uuid.UUID) error { return nil }
func (s *memSessionStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) {
return 0, nil
}
type memTokenStore struct{}
func (memTokenStore) CreateToken(_ context.Context, _ *authkit.Token) error { return nil }
func (memTokenStore) ConsumeToken(_ context.Context, _ authkit.TokenKind, _ []byte, _ time.Time) (*authkit.Token, error) {
return nil, authkit.ErrTokenInvalid
}
func (memTokenStore) GetToken(_ context.Context, _ authkit.TokenKind, _ []byte) (*authkit.Token, error) {
return nil, authkit.ErrTokenInvalid
}
func (memTokenStore) DeleteByChain(_ context.Context, _ string) (int64, error) { return 0, nil }
func (memTokenStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) {
return 0, nil
}
type memServiceKeyStore struct {
mu sync.Mutex
m map[string]*authkit.ServiceKey
}
func newMemServiceKeyStore() *memServiceKeyStore {
return &memServiceKeyStore{m: map[string]*authkit.ServiceKey{}}
}
func (s *memServiceKeyStore) CreateServiceKey(_ context.Context, k *authkit.ServiceKey) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
s.m[string(k.IDHash)] = &cp
return nil
}
func (s *memServiceKeyStore) GetServiceKey(_ context.Context, h []byte) (*authkit.ServiceKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return nil, authkit.ErrServiceKeyInvalid
}
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
return &cp, nil
}
func (s *memServiceKeyStore) ListServiceKeysByOwner(_ context.Context, kind string, owner uuid.UUID) ([]*authkit.ServiceKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
var out []*authkit.ServiceKey
for _, k := range s.m {
if k.OwnerKind == kind && k.OwnerID == owner {
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
out = append(out, &cp)
}
}
return out, nil
}
func (s *memServiceKeyStore) TouchServiceKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
if k, ok := s.m[string(h)]; ok {
k.LastUsedAt = &at
}
return nil
}
func (s *memServiceKeyStore) RevokeServiceKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return authkit.ErrServiceKeyInvalid
}
if k.RevokedAt != nil {
return authkit.ErrServiceKeyInvalid
}
k.RevokedAt = &at
return nil
}
type memRoleStore struct{}
func (memRoleStore) CreateRole(_ context.Context, _ *authkit.Role) error { return nil }
func (memRoleStore) GetRoleByID(_ context.Context, _ uuid.UUID) (*authkit.Role, error) {
return nil, authkit.ErrRoleNotFound
}
func (memRoleStore) GetRoleByName(_ context.Context, _ string) (*authkit.Role, error) {
return nil, authkit.ErrRoleNotFound
}
func (memRoleStore) ListRoles(_ context.Context) ([]*authkit.Role, error) { return nil, nil }
func (memRoleStore) DeleteRole(_ context.Context, _ uuid.UUID) error { return nil }
func (memRoleStore) AssignRoleToUser(_ context.Context, _, _ uuid.UUID) error { return nil }
func (memRoleStore) RemoveRoleFromUser(_ context.Context, _, _ uuid.UUID) error { return nil }
func (memRoleStore) GetUserRoles(_ context.Context, _ uuid.UUID) ([]*authkit.Role, error) {
return nil, nil
}
func (memRoleStore) HasAnyRole(_ context.Context, _ uuid.UUID, _ []string) (bool, error) {
return false, nil
}
type memPermStore struct{}
func (memPermStore) CreatePermission(_ context.Context, _ *authkit.Permission) error { return nil }
func (memPermStore) GetPermissionByID(_ context.Context, _ uuid.UUID) (*authkit.Permission, error) {
return nil, authkit.ErrPermissionNotFound
}
func (memPermStore) GetPermissionByName(_ context.Context, _ string) (*authkit.Permission, error) {
return nil, authkit.ErrPermissionNotFound
}
func (memPermStore) ListPermissions(_ context.Context) ([]*authkit.Permission, error) {
return nil, nil
}
func (memPermStore) DeletePermission(_ context.Context, _ uuid.UUID) error { return nil }
func (memPermStore) AssignPermissionToRole(_ context.Context, _, _ uuid.UUID) error { return nil }
func (memPermStore) RemovePermissionFromRole(_ context.Context, _, _ uuid.UUID) error { return nil }
func (memPermStore) GetRolePermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) {
return nil, nil
}
func (memPermStore) GetUserPermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) {
return nil, nil
}
type stubHasher struct{}
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
return encoded == "stub:"+p, false, nil
}
func newTestAuth(t *testing.T) *authkit.Auth {
func freshAuth(t *testing.T) *authkit.Auth {
t.Helper()
return authkit.New(authkit.Deps{
Users: newMemUserStore(),
Sessions: newMemSessionStore(),
Tokens: memTokenStore{},
ServiceKeys: newMemServiceKeyStore(),
Roles: memRoleStore{},
Permissions: memPermStore{},
Hasher: stubHasher{},
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
if url == "" {
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
}
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("ping: %v", err)
}
dropAuthkitTables(t, db)
t.Cleanup(func() { dropAuthkitTables(t, db) })
a, err := authkit.New(context.Background(), authkit.Deps{
DB: db,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
JWTIssuer: "mw-test",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: 1 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
JWTSecret: []byte("integration-secret-thirty-two!!!"),
JWTIssuer: "authkit-mw-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
EmailOTPTTL: time.Minute,
EmailOTPMaxAttempts: 3,
// Plain HTTP for tests so secure-cookie defaults don't interfere
// with httptest's HTTP server.
SessionCookieSecure: authkit.BoolPtr(false),
})
if err != nil {
t.Fatalf("authkit.New: %v", err)
}
return a
}
// Bearer-style request helper.
func req(token string) *http.Request {
func dropAuthkitTables(t *testing.T, db *sql.DB) {
t.Helper()
tables := []string{
"authkit_service_key_abilities",
"authkit_user_permissions",
"authkit_user_roles",
"authkit_role_permissions",
"authkit_service_keys",
"authkit_abilities",
"authkit_roles",
"authkit_permissions",
"authkit_tokens",
"authkit_sessions",
"authkit_users",
"authkit_schema_migrations",
}
ctx := context.Background()
for _, name := range tables {
_, _ = db.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
}
}
// reqWithBearer issues a request carrying Authorization: Bearer <token>.
func reqWithBearer(token string) *http.Request {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if token != "" {
r.Header.Set("Authorization", "Bearer "+token)
@ -323,191 +96,263 @@ func req(token string) *http.Request {
func ok200(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }
// ─── tests ─────────────────────────────────────────────────────────────────
// ─── RequireLogin ──────────────────────────────────────────────────────────
func TestRequireServiceKey_Authenticates(t *testing.T) {
a := newTestAuth(t)
plain, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", []string{"events:write"}, nil)
func TestRequireLogin_AcceptsSessionCookie(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "alice@example.com")
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
t.Fatalf("CreateUser: %v", err)
}
plain, _, err := a.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
var seen *authkit.ServiceKey
handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
k, ok := middleware.ServiceKeyFrom(r.Context())
if !ok {
t.Fatalf("no ServiceKey on context")
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := authkit.UserIDFromCtx(r.Context())
if !ok || uid != u.ID {
t.Fatalf("user_id missing or wrong on context: ok=%v id=%v", ok, uid)
}
seen = k
w.WriteHeader(http.StatusOK)
}))
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.AddCookie(a.SessionCookie(plain, time.Now().Add(time.Hour)))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req(plain))
handler.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rr.Code)
}
if seen == nil || !seen.HasAbility("events:write") {
t.Fatalf("expected ServiceKey with events:write ability; got %+v", seen)
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireServiceKey_RejectsRevoked(t *testing.T) {
a := newTestAuth(t)
plain, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", nil, nil)
func TestRequireLogin_AcceptsJWT(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "j@j.com")
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
t.Fatalf("CreateUser: %v", err)
}
if err := a.RevokeServiceKey(context.Background(), plain); err != nil {
t.Fatalf("RevokeServiceKey: %v", err)
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireLogin_RejectsUnauthenticated(t *testing.T) {
a := freshAuth(t)
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rr.Code)
}
}
func TestRequireLogin_AuthzRoleGate(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateRole(ctx, "admin", ""); err != nil {
t.Fatalf("CreateRole: %v", err)
}
u, err := a.CreateUser(ctx, "noadmin@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireLogin(middleware.LoginOptions{
Auth: a,
Authz: authkit.HasRole("admin"),
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusForbidden {
t.Fatalf("non-admin should get 403, got %d", rr.Code)
}
// Promote the user to admin and retry.
if err := a.AssignRole(ctx, u.ID, "admin"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
access2, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access2))
if rr.Code != http.StatusOK {
t.Fatalf("admin should get 200, got %d", rr.Code)
}
}
func TestRequireLogin_PanicsOnUnknownSlug(t *testing.T) {
a := freshAuth(t)
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic on unknown role slug")
}
}()
middleware.RequireLogin(middleware.LoginOptions{
Auth: a,
Authz: authkit.HasRole("never-registered"),
})
}
// ─── RequireGuest ──────────────────────────────────────────────────────────
func TestRequireGuest_LetsUnauthenticatedThrough(t *testing.T) {
a := freshAuth(t)
called := false
handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req(plain))
if rr.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want 401", rr.Code)
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if !called {
t.Fatalf("guest middleware should pass through unauthenticated request")
}
if called {
t.Fatalf("handler should not have been invoked for revoked key")
}
}
func TestRequireAbility_AcceptsServiceKeyWithAbility(t *testing.T) {
a := newTestAuth(t)
plain, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", []string{"events:write"}, nil)
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
chain := middleware.RequireServiceKey(middleware.Options{Auth: a})(
middleware.RequireAbility("events:write")(http.HandlerFunc(ok200)))
rr := httptest.NewRecorder()
chain.ServeHTTP(rr, req(plain))
if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rr.Code)
}
// Same chain but ability the key does not carry → 403.
chainBad := middleware.RequireServiceKey(middleware.Options{Auth: a})(
middleware.RequireAbility("admin:nuke")(http.HandlerFunc(ok200)))
rr = httptest.NewRecorder()
chainBad.ServeHTTP(rr, req(plain))
if rr.Code != http.StatusForbidden {
t.Fatalf("missing-ability status = %d, want 403", rr.Code)
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireAbility_RejectsUserPrincipal(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2")
func TestRequireGuest_BlocksAuthenticated(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "g@g.com")
if err != nil {
t.Fatalf("Register: %v", err)
t.Fatalf("CreateUser: %v", err)
}
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueSession: %v", err)
t.Fatalf("IssueJWT: %v", err)
}
chain := middleware.RequireSession(middleware.Options{Auth: a})(
middleware.RequireAbility("events:write")(http.HandlerFunc(ok200)))
rr := httptest.NewRecorder()
chain.ServeHTTP(rr, req(plain))
if rr.Code != http.StatusForbidden {
t.Fatalf("status = %d, want 403 (user principal carries no abilities)", rr.Code)
}
}
func TestRequireRole_RejectsServiceKey(t *testing.T) {
a := newTestAuth(t)
plain, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", nil, nil)
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
chain := middleware.RequireServiceKey(middleware.Options{Auth: a})(
middleware.RequireRole("admin")(http.HandlerFunc(ok200)))
rr := httptest.NewRecorder()
chain.ServeHTTP(rr, req(plain))
if rr.Code != http.StatusForbidden {
t.Fatalf("status = %d, want 403 (service key carries no Principal/role)", rr.Code)
}
}
func TestRequireAnyOrServiceKey(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
sessionPlain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
servicePlain, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", nil, nil)
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
type subject struct {
hasPrincipal bool
hasServiceKey bool
}
var got subject
handler := middleware.RequireAnyOrServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hp := middleware.PrincipalFrom(r.Context())
_, hs := middleware.ServiceKeyFrom(r.Context())
got = subject{hp, hs}
w.WriteHeader(http.StatusOK)
handlerCalled := false
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
}))
// Session token → Principal in context, no ServiceKey.
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req(sessionPlain))
if rr.Code != http.StatusOK {
t.Fatalf("session: status = %d, want 200", rr.Code)
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rr.Code)
}
if !got.hasPrincipal || got.hasServiceKey {
t.Fatalf("session: ctx subject = %+v, want principal-only", got)
}
// Service token → ServiceKey in context, no Principal.
rr = httptest.NewRecorder()
got = subject{}
handler.ServeHTTP(rr, req(servicePlain))
if rr.Code != http.StatusOK {
t.Fatalf("service: status = %d, want 200", rr.Code)
}
if got.hasPrincipal || !got.hasServiceKey {
t.Fatalf("service: ctx subject = %+v, want servicekey-only", got)
}
// Garbage token → 401, neither subject set.
rr = httptest.NewRecorder()
got = subject{}
handler.ServeHTTP(rr, req(strings.Repeat("x", 50)))
if rr.Code != http.StatusUnauthorized {
t.Fatalf("garbage: status = %d, want 401", rr.Code)
if handlerCalled {
t.Fatalf("handler should not run for authenticated request")
}
}
// Sanity check: the constructed *authkit.Auth should satisfy errors.Is on the
// canonical sentinels — ensures our minimal stores are wired correctly.
func TestSentinelsReachable(t *testing.T) {
a := newTestAuth(t)
_, err := a.AuthenticateServiceKey(context.Background(), "sk_not-real")
if !errors.Is(err, authkit.ErrServiceKeyInvalid) {
t.Fatalf("expected ErrServiceKeyInvalid, got %v", err)
func TestRequireGuest_CustomOnAuthenticated(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "custom@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireGuest(middleware.GuestOptions{
Auth: a,
OnAuthenticated: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
},
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusSeeOther {
t.Fatalf("expected 303, got %d", rr.Code)
}
if got := rr.Header().Get("Location"); got != "/dashboard" {
t.Fatalf("expected Location=/dashboard, got %q", got)
}
}
// ─── RequireServiceKey ─────────────────────────────────────────────────────
func TestRequireServiceKey_AbilityGate(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
t.Fatalf("CreateAbility: %v", err)
}
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
Name: "ci",
Abilities: []string{"events:write"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("events:write"),
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
k, ok := authkit.ServiceKeyFromCtx(r.Context())
if !ok || !k.HasAbility("events:write") {
t.Fatalf("expected ServiceKey with events:write on context")
}
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(plain))
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireServiceKey_AbilityGateRejectsMissing(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
t.Fatalf("CreateAbility events:write: %v", err)
}
if _, err := a.CreateAbility(ctx, "admin:nuke", ""); err != nil {
t.Fatalf("CreateAbility admin:nuke: %v", err)
}
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
Name: "ci",
Abilities: []string{"events:write"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("admin:nuke"),
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(plain))
if rr.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rr.Code)
}
}
func TestRequireServiceKey_PanicsOnUnknownAbility(t *testing.T) {
a := freshAuth(t)
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic on unknown ability slug")
}
}()
middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("never-registered"),
})
}

136
migrations/0001_init.sql Normal file
View file

@ -0,0 +1,136 @@
-- 0001_init.sql
-- Initial authkit schema for PostgreSQL 16+. All tables prefixed authkit_ so
-- the library can be embedded in an existing application database. Each
-- migration owns its transaction and inserts its version row at the bottom;
-- the runner only orchestrates file discovery and concurrency.
BEGIN;
CREATE TABLE IF NOT EXISTS authkit_schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
);
-- Users. Password is nullable so accounts can be created without a credential
-- and have one set later (invite flows, magic-link-only accounts, etc.).
CREATE TABLE IF NOT EXISTS authkit_users (
id UUID PRIMARY KEY,
email TEXT NOT NULL,
email_normalized TEXT NOT NULL,
email_verified_at TIMESTAMPTZ,
password_hash TEXT,
session_version INTEGER NOT NULL DEFAULT 0,
last_login_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq
ON authkit_users (email_normalized);
-- Opaque server-side sessions.
CREATE TABLE IF NOT EXISTS authkit_sessions (
id_hash BYTEA PRIMARY KEY,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
user_agent TEXT NOT NULL DEFAULT '',
ip TEXT,
created_at TIMESTAMPTZ NOT NULL,
last_seen_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id);
CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at);
-- Single-use tokens (refresh, email-verify, password-reset, magic-link, email-otp).
-- attempts_remaining is non-null only for tokens that allow retries (email_otp);
-- ConsumeToken decrements and zeroes-out on exhaustion.
CREATE TABLE IF NOT EXISTS authkit_tokens (
hash BYTEA NOT NULL,
kind TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
chain_id TEXT,
consumed_at TIMESTAMPTZ,
attempts_remaining INTEGER,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (kind, hash)
);
CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id);
CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at);
CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx
ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL;
-- Service tokens. No owner column: these are machine credentials, intended to
-- be created by applications for outbound API calls or inbound automation.
-- Consumers tag them with whatever metadata they need via Name.
CREATE TABLE IF NOT EXISTS authkit_service_keys (
id_hash BYTEA PRIMARY KEY,
name TEXT NOT NULL,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ
);
-- Roles, permissions, and abilities are seeded by the consumer (typically via
-- the cmd/roles, cmd/perms, cmd/abilities CLIs). They share the same shape:
-- normalised slug as the unique business key, optional human label.
CREATE TABLE IF NOT EXISTS authkit_roles (
id UUID PRIMARY KEY,
slug TEXT NOT NULL UNIQUE,
label TEXT,
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_permissions (
id UUID PRIMARY KEY,
slug TEXT NOT NULL UNIQUE,
label TEXT,
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_abilities (
id UUID PRIMARY KEY,
slug TEXT NOT NULL UNIQUE,
label TEXT,
created_at TIMESTAMPTZ NOT NULL
);
-- Role ↔ Permission (defines what permissions a role grants).
CREATE TABLE IF NOT EXISTS authkit_role_permissions (
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
PRIMARY KEY (role_id, permission_id)
);
-- User ↔ Role (which roles a user holds).
CREATE TABLE IF NOT EXISTS authkit_user_roles (
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
granted_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (user_id, role_id)
);
CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id);
-- User ↔ Permission (direct grants, in addition to permissions resolved
-- through roles). GetUserPermissions returns the UNION of both paths.
CREATE TABLE IF NOT EXISTS authkit_user_permissions (
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
granted_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (user_id, permission_id)
);
CREATE INDEX IF NOT EXISTS authkit_user_permissions_perm_id_idx ON authkit_user_permissions(permission_id);
-- ServiceKey ↔ Ability (which abilities a service key carries).
CREATE TABLE IF NOT EXISTS authkit_service_key_abilities (
service_key_id_hash BYTEA NOT NULL REFERENCES authkit_service_keys(id_hash) ON DELETE CASCADE,
ability_id UUID NOT NULL REFERENCES authkit_abilities(id) ON DELETE CASCADE,
granted_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (service_key_id_hash, ability_id)
);
CREATE INDEX IF NOT EXISTS authkit_service_key_abilities_ability_idx ON authkit_service_key_abilities(ability_id);
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now())
ON CONFLICT (version) DO NOTHING;
COMMIT;

View file

@ -7,6 +7,9 @@ import (
"github.com/google/uuid"
)
// User is the canonical account record. Password hash is empty (and stored
// NULL in the DB) when no credential has been set — accounts created via
// invite or magic-link-only flows live in this state until SetPassword runs.
type User struct {
ID uuid.UUID
Email string
@ -14,12 +17,12 @@ type User struct {
EmailVerifiedAt *time.Time
PasswordHash string
SessionVersion int
FailedLogins int
LastLoginAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// Session is an opaque server-side credential bound to one user.
type Session struct {
IDHash []byte
UserID uuid.UUID
@ -30,35 +33,36 @@ type Session struct {
ExpiresAt time.Time
}
// TokenKind enumerates the single-use credentials persisted in authkit_tokens.
type TokenKind string
const (
TokenEmailVerify TokenKind = "email_verify"
TokenPasswordReset TokenKind = "password_reset"
TokenMagicLink TokenKind = "magic_link"
TokenEmailOTP TokenKind = "email_otp"
TokenRefresh TokenKind = "refresh"
)
// Token is one row in authkit_tokens. AttemptsRemaining is non-nil only for
// tokens that allow retry on incorrect input (email OTPs); other kinds are
// strictly one-shot via ConsumeToken.
type Token struct {
Hash []byte
Kind TokenKind
UserID uuid.UUID
ChainID *string
ConsumedAt *time.Time
CreatedAt time.Time
ExpiresAt time.Time
Hash []byte
Kind TokenKind
UserID uuid.UUID
ChainID *string
ConsumedAt *time.Time
AttemptsRemaining *int
CreatedAt time.Time
ExpiresAt time.Time
}
// ServiceKey is an owner-agnostic credential for server-to-server auth.
// OwnerID is not constrained to authkit_users — OwnerKind labels the owner
// namespace (e.g. "application", "tenant") and consumers manage their own
// cascade-on-delete. It is the only credential type that carries free-form
// abilities; user-bound credentials (sessions, JWTs) prove identity and
// resolve permissions through RBAC instead.
// ServiceKey is a machine credential. It carries no identity — service tokens
// are produced by applications for outbound API access or inbound automation,
// and authorize via Abilities resolved through the join table.
type ServiceKey struct {
IDHash []byte
OwnerID uuid.UUID
OwnerKind string
Name string
Abilities []string
LastUsedAt *time.Time
@ -67,26 +71,41 @@ type ServiceKey struct {
RevokedAt *time.Time
}
// HasAbility reports whether the service key carries the named ability.
func (k *ServiceKey) HasAbility(name string) bool {
// HasAbility reports whether the service key carries the named ability slug.
func (k *ServiceKey) HasAbility(slug string) bool {
for _, a := range k.Abilities {
if a == name {
if a == slug {
return true
}
}
return false
}
// Role groups permissions for assignment to users. Slug is the immutable
// business key; Label is an optional human-readable name.
type Role struct {
ID uuid.UUID
Name string
Description string
CreatedAt time.Time
ID uuid.UUID
Slug string
Label string
CreatedAt time.Time
}
// Permission is a unit of authorization. Granted to users either through a
// role or directly via authkit_user_permissions.
type Permission struct {
ID uuid.UUID
Name string
Description string
CreatedAt time.Time
ID uuid.UUID
Slug string
Label string
CreatedAt time.Time
}
// Ability is a unit of authorization for service tokens. Abilities are a
// separate vocabulary from Permissions because they target machines, not
// users — keep them distinct so middleware predicates remain clear about
// which subject they're authorizing.
type Ability struct {
ID uuid.UUID
Slug string
Label string
CreatedAt time.Time
}

View file

@ -6,6 +6,7 @@ import (
"github.com/google/uuid"
)
// AuthMethod tags how a Principal was authenticated.
type AuthMethod string
const (
@ -13,10 +14,10 @@ const (
AuthMethodJWT AuthMethod = "jwt"
)
// Principal represents an authenticated user. It is produced only by
// user-bound auth methods (session, JWT) and carries identity plus
// RBAC-resolved roles/permissions. Service-token auth produces a
// *ServiceKey instead — those credentials carry abilities, not identity.
// Principal represents an authenticated user. Produced only by user-bound
// auth methods (session, JWT) and carries identity plus RBAC-resolved roles
// and permissions. Service-token auth produces a *ServiceKey instead — those
// credentials carry abilities, not identity.
type Principal struct {
UserID uuid.UUID
Method AuthMethod
@ -27,27 +28,32 @@ type Principal struct {
ExpiresAt time.Time
}
func (p *Principal) HasRole(name string) bool {
// HasRole reports whether the principal holds the named role slug.
func (p *Principal) HasRole(slug string) bool {
for _, r := range p.Roles {
if r == name {
if r == slug {
return true
}
}
return false
}
func (p *Principal) HasAnyRole(names ...string) bool {
for _, n := range names {
if p.HasRole(n) {
// HasAnyRole reports whether the principal holds at least one of the named
// role slugs.
func (p *Principal) HasAnyRole(slugs ...string) bool {
for _, s := range slugs {
if p.HasRole(s) {
return true
}
}
return false
}
func (p *Principal) HasPermission(name string) bool {
// HasPermission reports whether the principal holds the named permission
// slug, resolved through any combination of roles and direct grants.
func (p *Principal) HasPermission(slug string) bool {
for _, perm := range p.Permissions {
if perm == name {
if perm == slug {
return true
}
}

View file

@ -7,79 +7,122 @@ import (
"github.com/google/uuid"
)
// UserPermissions returns the union of permission names a user holds via
// their assigned roles. Resolved at call time; v1 does not cache.
// AssignRole assigns roleSlug to userID. Idempotent — a duplicate insert
// is a no-op via ON CONFLICT.
func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleSlug string) error {
const op = "authkit.Auth.AssignRole"
r, err := a.storeGetRoleBySlug(ctx, roleSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeAssignRoleToUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RemoveRole removes roleSlug from userID. Idempotent on missing
// assignments.
func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleSlug string) error {
const op = "authkit.Auth.RemoveRole"
r, err := a.storeGetRoleBySlug(ctx, roleSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeRemoveRoleFromUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// UserRoles returns the role slugs assigned to a user.
func (a *Auth) UserRoles(ctx context.Context, userID uuid.UUID) ([]string, error) {
const op = "authkit.Auth.UserRoles"
roles, err := a.storeGetUserRoles(ctx, userID)
if err != nil {
return nil, errx.Wrap(op, err)
}
out := make([]string, len(roles))
for i, r := range roles {
out[i] = r.Slug
}
return out, nil
}
// HasRole reports whether the user holds the named role.
func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, slug string) (bool, error) {
const op = "authkit.Auth.HasRole"
ok, err := a.storeHasAnyRole(ctx, userID, []string{slug})
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// HasAnyRole reports whether the user holds at least one of the named roles.
func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, slugs []string) (bool, error) {
const op = "authkit.Auth.HasAnyRole"
ok, err := a.storeHasAnyRole(ctx, userID, slugs)
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// GrantPermissionToUser adds a direct permission grant (not through any
// role). Idempotent.
func (a *Auth) GrantPermissionToUser(ctx context.Context, userID uuid.UUID, permSlug string) error {
const op = "authkit.Auth.GrantPermissionToUser"
p, err := a.storeGetPermissionBySlug(ctx, permSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeGrantPermissionToUser(ctx, userID, p.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RevokePermissionFromUser removes a direct permission grant.
func (a *Auth) RevokePermissionFromUser(ctx context.Context, userID uuid.UUID, permSlug string) error {
const op = "authkit.Auth.RevokePermissionFromUser"
p, err := a.storeGetPermissionBySlug(ctx, permSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeRevokePermissionFromUser(ctx, userID, p.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// UserPermissions returns the union of permission slugs the user holds via
// roles and direct grants.
func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
const op = "authkit.Auth.UserPermissions"
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
perms, err := a.storeGetUserPermissions(ctx, userID)
if err != nil {
return nil, errx.Wrap(op, err)
}
out := make([]string, len(perms))
for i, p := range perms {
out[i] = p.Name
out[i] = p.Slug
}
return out, nil
}
// HasPermission checks whether a user holds the named permission via any
// assigned role.
func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
// HasPermission reports whether the user holds the named permission, via
// any combination of role-derived and direct grants.
func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, permSlug string) (bool, error) {
const op = "authkit.Auth.HasPermission"
perms, err := a.UserPermissions(ctx, userID)
if err != nil {
return false, errx.Wrap(op, err)
}
for _, p := range perms {
if p == name {
if p == permSlug {
return true, nil
}
}
return false, nil
}
// HasRole checks whether a user is assigned the named role.
func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
const op = "authkit.Auth.HasRole"
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name})
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// HasAnyRole checks whether a user holds at least one of the named roles.
func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
const op = "authkit.Auth.HasAnyRole"
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names)
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// AssignRole is a convenience that looks up a role by name and assigns it.
func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error {
const op = "authkit.Auth.AssignRole"
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RemoveRole is the symmetric helper for AssignRole.
func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error {
const op = "authkit.Auth.RemoveRole"
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}

View file

@ -13,7 +13,7 @@ import (
// preserves that chain so reuse-detection can revoke the whole family.
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
const op = "authkit.Auth.IssueJWT"
u, err := a.deps.Users.GetUserByID(ctx, userID)
u, err := a.storeGetUserByID(ctx, userID)
if err != nil {
return "", "", errx.Wrap(op, err)
}
@ -40,7 +40,7 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal,
if err != nil {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
u, err := a.deps.Users.GetUserByID(ctx, uid)
u, err := a.storeGetUserByID(ctx, uid)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, errx.Wrap(op, ErrTokenInvalid)
@ -64,27 +64,27 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal,
}, nil
}
// RefreshJWT consumes the presented refresh token and mints a new access +
// refresh pair. Reuse of an already-consumed refresh token deletes the
// entire chain (logout-everywhere on that device family) and returns
// RefreshJWT consumes the presented refresh token and mints a new
// access+refresh pair. Reuse of an already-consumed refresh token deletes
// the entire chain (logout-everywhere on that device family) and returns
// ErrTokenReused.
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
const op = "authkit.Auth.RefreshJWT"
hash, ok := parseSecret(prefixRefresh, plaintextRefresh)
hash, ok := ParseOpaqueSecret(prefixRefresh, plaintextRefresh)
if !ok {
return "", "", errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now)
consumed, err := a.storeConsumeToken(ctx, TokenRefresh, hash, now)
if err != nil {
// Differentiate plain-invalid (never existed / expired) from
// reuse (existed, already consumed). The presence-check below is
// the reuse signal.
// Differentiate plain-invalid (never existed / expired) from reuse
// (existed, already consumed). Existence-with-consumed is the
// reuse signal.
if errors.Is(err, ErrTokenInvalid) {
if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil {
if existing, gerr := a.storeGetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil {
if existing.ChainID != nil && *existing.ChainID != "" {
_, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID)
_, _ = a.storeDeleteByChain(ctx, *existing.ChainID)
}
return "", "", errx.Wrap(op, ErrTokenReused)
}
@ -98,7 +98,7 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
}
if chainID == "" {
// Defensive: every refresh token should be chain-bound. Fall back
// to a fresh chain so we never throw on missing metadata.
// to a fresh chain rather than throwing on missing metadata.
chainID = uuid.NewString()
}
@ -113,11 +113,9 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
return access, refresh, nil
}
// mintRefreshToken stores a fresh refresh token bound to chainID and returns
// the plaintext.
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
const op = "authkit.Auth.mintRefreshToken"
plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random)
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixRefresh)
if err != nil {
return "", errx.Wrap(op, err)
}
@ -130,17 +128,16 @@ func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID s
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
if err := a.storeCreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// userSessionVersion fetches the current session_version. Errors collapse to
// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly
// — but we still need a value to embed in the freshly-minted access token.
// userSessionVersion fetches the current session_version. Errors collapse
// to 0 — a stale token will fail AuthenticateJWT cleanly anyway.
func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int {
if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil {
if u, err := a.storeGetUserByID(ctx, userID); err == nil {
return u.SessionVersion
}
return 0

67
service_jwt_test.go Normal file
View file

@ -0,0 +1,67 @@
package authkit
import (
"context"
"errors"
"testing"
)
func TestIntegration_JWTIssueAuthenticate(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "j@j.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, refresh, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if access == "" || refresh == "" {
t.Fatalf("missing access/refresh")
}
p, err := a.AuthenticateJWT(ctx, access)
if err != nil {
t.Fatalf("AuthenticateJWT: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("user id mismatch")
}
if p.Method != AuthMethodJWT {
t.Fatalf("method = %s, want jwt", p.Method)
}
}
func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "ref@r.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
_, refresh1, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
_, refresh2, err := a.RefreshJWT(ctx, refresh1)
if err != nil {
t.Fatalf("RefreshJWT: %v", err)
}
if refresh1 == refresh2 {
t.Fatalf("refresh did not rotate")
}
if _, _, err := a.RefreshJWT(ctx, refresh1); !errors.Is(err, ErrTokenReused) {
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
}
if _, _, err := a.RefreshJWT(ctx, refresh2); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err)
}
}
func TestIntegration_JWTInvalidPrefix(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, _, err := a.RefreshJWT(ctx, "not-a-refresh-token"); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid for malformed input, got %v", err)
}
}

View file

@ -2,20 +2,27 @@ package authkit
import (
"context"
"errors"
"git.juancwu.dev/juancwu/errx"
)
// RequestMagicLink mints a single-use magic-link token for the email and
// returns the plaintext for delivery. ErrUserNotFound is returned for
// unregistered emails.
// returns the plaintext for delivery.
//
// Default behavior is anti-enumeration: if the email is not registered,
// returns ("", nil) — the caller cannot distinguish "exists" from "doesn't
// exist". Set Config.RevealUnknownEmail = true to surface ErrUserNotFound.
func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) {
const op = "authkit.Auth.RequestMagicLink"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail {
return "", nil
}
return "", errx.Wrap(op, err)
}
plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random)
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixMagicLink)
if err != nil {
return "", errx.Wrap(op, err)
}
@ -27,34 +34,33 @@ func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, erro
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.MagicLinkTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
if err := a.storeCreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConsumeMagicLink consumes the magic-link token and returns the
// authenticated user. Callers typically follow this with IssueSession or
// IssueJWT to actually log the user in.
// authenticated user. Callers typically follow with IssueSession or
// IssueJWT to actually log the user in. A successful consume implicitly
// verifies the email (the user demonstrably controls the inbox).
func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) {
const op = "authkit.Auth.ConsumeMagicLink"
hash, ok := parseSecret(prefixMagicLink, plaintextToken)
hash, ok := ParseOpaqueSecret(prefixMagicLink, plaintextToken)
if !ok {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now)
t, err := a.storeConsumeToken(ctx, TokenMagicLink, hash, now)
if err != nil {
return nil, errx.Wrap(op, err)
}
u, err := a.deps.Users.GetUserByID(ctx, t.UserID)
u, err := a.storeGetUserByID(ctx, t.UserID)
if err != nil {
return nil, errx.Wrap(op, err)
}
// A successful magic-link login also implicitly verifies the email
// (the user demonstrably controls the inbox).
if u.EmailVerifiedAt == nil {
if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil {
if err := a.storeSetEmailVerified(ctx, u.ID, now); err == nil {
u.EmailVerifiedAt = &now
}
}

41
service_magic_test.go Normal file
View file

@ -0,0 +1,41 @@
package authkit
import (
"context"
"errors"
"testing"
)
func TestIntegration_MagicLinkFlow(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateUser(ctx, "m@m.com"); err != nil {
t.Fatalf("CreateUser: %v", err)
}
tok, err := a.RequestMagicLink(ctx, "m@m.com")
if err != nil {
t.Fatalf("RequestMagicLink: %v", err)
}
u, err := a.ConsumeMagicLink(ctx, tok)
if err != nil {
t.Fatalf("ConsumeMagicLink: %v", err)
}
if u.EmailVerifiedAt == nil {
t.Fatalf("magic link should imply email verification")
}
if _, err := a.ConsumeMagicLink(ctx, tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on reuse, got %v", err)
}
}
func TestIntegration_MagicLinkUnknownEmailIsSilent(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
tok, err := a.RequestMagicLink(ctx, "nobody@example.com")
if err != nil {
t.Fatalf("expected silent success, got %v", err)
}
if tok != "" {
t.Fatalf("expected empty token for unknown email, got %q", tok)
}
}

136
service_otp.go Normal file
View file

@ -0,0 +1,136 @@
package authkit
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"git.juancwu.dev/juancwu/errx"
)
// RequestEmailOTP mints a numeric one-time code for the email and returns
// the plaintext for delivery. Anti-enumeration: unknown email returns
// ("", nil) unless Config.RevealUnknownEmail is set.
//
// Code length is Config.EmailOTPDigits (default 6). Brute-force resistance
// comes from Config.EmailOTPMaxAttempts (default 5): after N wrong tries
// the code is invalidated, forcing the caller to request a new one.
func (a *Auth) RequestEmailOTP(ctx context.Context, email string) (string, error) {
const op = "authkit.Auth.RequestEmailOTP"
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail {
return "", nil
}
return "", errx.Wrap(op, err)
}
code, err := generateOTPCode(a.cfg.Random, a.cfg.EmailOTPDigits)
if err != nil {
return "", errx.Wrap(op, err)
}
now := a.now()
attempts := a.cfg.EmailOTPMaxAttempts
t := &Token{
Hash: hashOTPCode(code),
Kind: TokenEmailOTP,
UserID: u.ID,
AttemptsRemaining: &attempts,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.EmailOTPTTL),
}
if err := a.storeCreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return code, nil
}
// ConsumeEmailOTP verifies a code against the most recent active OTP for
// the user behind email. Successful match consumes the row. A wrong code
// decrements attempts_remaining and returns ErrOTPInvalid; reaching zero
// attempts invalidates the OTP. A successful consume implicitly verifies
// the email.
func (a *Auth) ConsumeEmailOTP(ctx context.Context, email, code string) (*User, error) {
const op = "authkit.Auth.ConsumeEmailOTP"
if code == "" {
return nil, errx.Wrap(op, ErrOTPInvalid)
}
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Don't leak account existence — same error shape as wrong code.
return nil, errx.Wrap(op, ErrOTPInvalid)
}
return nil, errx.Wrap(op, err)
}
now := a.now()
active, err := a.storeGetActiveOTPForUser(ctx, TokenEmailOTP, uuidArg(u.ID), now)
if err != nil {
return nil, errx.Wrap(op, err)
}
if !bytes.Equal(active.Hash, hashOTPCode(code)) {
// Wrong code: decrement attempts on the active OTP. If this
// drives attempts_remaining to 0, the row is consumed atomically
// inside the same UPDATE.
_, derr := a.storeDecrementOTPAttempt(ctx, TokenEmailOTP, active.Hash, now)
if derr != nil && !errors.Is(derr, ErrTokenInvalid) {
return nil, errx.Wrap(op, derr)
}
return nil, errx.Wrap(op, ErrOTPInvalid)
}
if _, err := a.storeConsumeOTPByHash(ctx, TokenEmailOTP, active.Hash, now); err != nil {
return nil, errx.Wrap(op, err)
}
if u.EmailVerifiedAt == nil {
if err := a.storeSetEmailVerified(ctx, u.ID, now); err == nil {
u.EmailVerifiedAt = &now
}
}
return u, nil
}
// generateOTPCode produces a numeric code of length digits using a CSPRNG
// (defaults to crypto/rand if rng is nil). Uniformly distributed; rejects
// rolls that would bias the mod operator.
func generateOTPCode(rng io.Reader, digits int) (string, error) {
if digits <= 0 || digits > 12 {
return "", fmt.Errorf("invalid OTP digits: %d", digits)
}
if rng == nil {
rng = rand.Reader
}
max := uint64(1)
for i := 0; i < digits; i++ {
max *= 10
}
// Reject rolls that fall in the "leftover" partial keyspace at the top
// of uint64 to keep the distribution uniform.
limit := (^uint64(0)) - ((^uint64(0)) % max)
var buf [8]byte
for {
if _, err := io.ReadFull(rng, buf[:]); err != nil {
return "", err
}
v := binary.BigEndian.Uint64(buf[:])
if v >= limit {
continue
}
return fmt.Sprintf("%0*d", digits, v%max), nil
}
}
// hashOTPCode returns sha256(code). Codes are short and low-entropy, so the
// hash is purely a database lookup key — the brute-force defense is
// attempts_remaining, not hash strength.
func hashOTPCode(code string) []byte {
sum := sha256.Sum256([]byte(code))
return sum[:]
}

85
service_otp_test.go Normal file
View file

@ -0,0 +1,85 @@
package authkit
import (
"context"
"errors"
"testing"
)
func TestIntegration_OTPHappyPath(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateUser(ctx, "otp@example.com"); err != nil {
t.Fatalf("CreateUser: %v", err)
}
code, err := a.RequestEmailOTP(ctx, "otp@example.com")
if err != nil {
t.Fatalf("RequestEmailOTP: %v", err)
}
if len(code) != 6 {
t.Fatalf("expected 6-digit OTP, got %q", code)
}
u, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code)
if err != nil {
t.Fatalf("ConsumeEmailOTP: %v", err)
}
if u.EmailVerifiedAt == nil {
t.Fatalf("OTP consume should imply email verification")
}
// Re-using the same code must fail.
if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code); !errors.Is(err, ErrOTPInvalid) {
t.Fatalf("expected ErrOTPInvalid on reuse, got %v", err)
}
}
func TestIntegration_OTPWrongCodeDecrementsAndExhausts(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateUser(ctx, "otp@example.com"); err != nil {
t.Fatalf("CreateUser: %v", err)
}
code, err := a.RequestEmailOTP(ctx, "otp@example.com")
if err != nil {
t.Fatalf("RequestEmailOTP: %v", err)
}
// freshAuth configures EmailOTPMaxAttempts=3.
for i := 0; i < 3; i++ {
if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", "000000"); !errors.Is(err, ErrOTPInvalid) {
t.Fatalf("attempt %d: expected ErrOTPInvalid, got %v", i, err)
}
}
// After exhausting attempts, even the correct code must fail (the OTP
// row was consumed when attempts hit zero).
if _, err := a.ConsumeEmailOTP(ctx, "otp@example.com", code); !errors.Is(err, ErrOTPInvalid) {
t.Fatalf("expected ErrOTPInvalid after exhausting attempts, got %v", err)
}
}
func TestIntegration_OTPUnknownEmailIsSilent(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
code, err := a.RequestEmailOTP(ctx, "nobody@example.com")
if err != nil {
t.Fatalf("expected silent success, got %v", err)
}
if code != "" {
t.Fatalf("expected empty code, got %q", code)
}
}
func TestIntegration_OTPGeneratorProducesDigitsOnly(t *testing.T) {
for i := 0; i < 50; i++ {
code, err := generateOTPCode(nil, 6)
if err != nil {
t.Fatalf("generateOTPCode: %v", err)
}
if len(code) != 6 {
t.Fatalf("expected 6 chars, got %d (%q)", len(code), code)
}
for _, c := range code {
if c < '0' || c > '9' {
t.Fatalf("non-digit %q in code %q", c, code)
}
}
}
}

View file

@ -8,16 +8,20 @@ import (
)
// RequestPasswordReset mints a single-use password-reset token for the user
// behind email and returns the plaintext for the caller to deliver via email.
// Returns ErrUserNotFound when the email isn't registered (per project
// policy of distinct errors over anti-enumeration).
// behind email and returns the plaintext for delivery.
//
// Default behavior is anti-enumeration: unknown email returns ("", nil).
// Set Config.RevealUnknownEmail = true to surface ErrUserNotFound.
func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) {
const op = "authkit.Auth.RequestPasswordReset"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
if errors.Is(err, ErrUserNotFound) && !a.cfg.RevealUnknownEmail {
return "", nil
}
return "", errx.Wrap(op, err)
}
plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random)
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixPasswordRset)
if err != nil {
return "", errx.Wrap(op, err)
}
@ -29,40 +33,37 @@ func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.PasswordResetTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
if err := a.storeCreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConfirmPasswordReset consumes the reset token, sets the new password,
// bumps the user's session_version, and revokes outstanding sessions so the
// reset constitutes a global logout.
// bumps session_version, and revokes outstanding sessions so the reset is
// a global logout.
func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error {
const op = "authkit.Auth.ConfirmPasswordReset"
hash, ok := parseSecret(prefixPasswordRset, plaintextToken)
hash, ok := ParseOpaqueSecret(prefixPasswordRset, plaintextToken)
if !ok {
return errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now)
t, err := a.storeConsumeToken(ctx, TokenPasswordReset, hash, now)
if err != nil {
return errx.Wrap(op, err)
}
newHash, err := a.deps.Hasher.Hash(newPassword)
newHash, err := a.hasher.Hash(newPassword)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil {
if errors.Is(err, ErrUserNotFound) {
return errx.Wrap(op, ErrUserNotFound)
}
if err := a.storeSetPassword(ctx, t.UserID, newHash); err != nil {
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil {
if _, err := a.storeBumpSessionVersion(ctx, t.UserID); err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil {
if err := a.storeDeleteUserSessions(ctx, t.UserID); err != nil {
return errx.Wrap(op, err)
}
return nil

201
service_seed.go Normal file
View file

@ -0,0 +1,201 @@
package authkit
import (
"context"
"git.juancwu.dev/juancwu/errx"
)
// CreateRole inserts a new role. Slug must be a valid normalized slug;
// returns ErrSlugInvalid otherwise. Returns ErrSlugTaken if the slug is
// already in use.
func (a *Auth) CreateRole(ctx context.Context, slug, label string) (*Role, error) {
const op = "authkit.Auth.CreateRole"
if err := validateSlug(op, slug); err != nil {
return nil, err
}
r := &Role{Slug: slug, Label: label}
if err := a.storeCreateRole(ctx, r); err != nil {
return nil, errx.Wrap(op, err)
}
return r, nil
}
// GetRoleBySlug fetches a role by its slug.
func (a *Auth) GetRoleBySlug(ctx context.Context, slug string) (*Role, error) {
const op = "authkit.Auth.GetRoleBySlug"
r, err := a.storeGetRoleBySlug(ctx, slug)
if err != nil {
return nil, errx.Wrap(op, err)
}
return r, nil
}
// ListRoles returns every role ordered by slug.
func (a *Auth) ListRoles(ctx context.Context) ([]*Role, error) {
const op = "authkit.Auth.ListRoles"
out, err := a.storeListRoles(ctx)
if err != nil {
return nil, errx.Wrap(op, err)
}
return out, nil
}
// DeleteRole removes a role by its slug. Cascades to user_roles and
// role_permissions. Returns ErrRoleNotFound if absent.
func (a *Auth) DeleteRole(ctx context.Context, slug string) error {
const op = "authkit.Auth.DeleteRole"
r, err := a.storeGetRoleBySlug(ctx, slug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeDeleteRole(ctx, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// CreatePermission inserts a new permission.
func (a *Auth) CreatePermission(ctx context.Context, slug, label string) (*Permission, error) {
const op = "authkit.Auth.CreatePermission"
if err := validateSlug(op, slug); err != nil {
return nil, err
}
p := &Permission{Slug: slug, Label: label}
if err := a.storeCreatePermission(ctx, p); err != nil {
return nil, errx.Wrap(op, err)
}
return p, nil
}
// GetPermissionBySlug fetches a permission by its slug.
func (a *Auth) GetPermissionBySlug(ctx context.Context, slug string) (*Permission, error) {
const op = "authkit.Auth.GetPermissionBySlug"
p, err := a.storeGetPermissionBySlug(ctx, slug)
if err != nil {
return nil, errx.Wrap(op, err)
}
return p, nil
}
// ListPermissions returns every permission ordered by slug.
func (a *Auth) ListPermissions(ctx context.Context) ([]*Permission, error) {
const op = "authkit.Auth.ListPermissions"
out, err := a.storeListPermissions(ctx)
if err != nil {
return nil, errx.Wrap(op, err)
}
return out, nil
}
// DeletePermission removes a permission by its slug. Cascades to
// role_permissions and user_permissions.
func (a *Auth) DeletePermission(ctx context.Context, slug string) error {
const op = "authkit.Auth.DeletePermission"
p, err := a.storeGetPermissionBySlug(ctx, slug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeDeletePermission(ctx, p.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// CreateAbility inserts a new ability for service tokens.
func (a *Auth) CreateAbility(ctx context.Context, slug, label string) (*Ability, error) {
const op = "authkit.Auth.CreateAbility"
if err := validateSlug(op, slug); err != nil {
return nil, err
}
ab := &Ability{Slug: slug, Label: label}
if err := a.storeCreateAbility(ctx, ab); err != nil {
return nil, errx.Wrap(op, err)
}
return ab, nil
}
// GetAbilityBySlug fetches an ability by its slug.
func (a *Auth) GetAbilityBySlug(ctx context.Context, slug string) (*Ability, error) {
const op = "authkit.Auth.GetAbilityBySlug"
ab, err := a.storeGetAbilityBySlug(ctx, slug)
if err != nil {
return nil, errx.Wrap(op, err)
}
return ab, nil
}
// ListAbilities returns every ability ordered by slug.
func (a *Auth) ListAbilities(ctx context.Context) ([]*Ability, error) {
const op = "authkit.Auth.ListAbilities"
out, err := a.storeListAbilities(ctx)
if err != nil {
return nil, errx.Wrap(op, err)
}
return out, nil
}
// DeleteAbility removes an ability by its slug. Cascades to
// service_key_abilities.
func (a *Auth) DeleteAbility(ctx context.Context, slug string) error {
const op = "authkit.Auth.DeleteAbility"
ab, err := a.storeGetAbilityBySlug(ctx, slug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeDeleteAbility(ctx, ab.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// GrantPermissionToRole adds permSlug to roleSlug's permission set.
// Idempotent.
func (a *Auth) GrantPermissionToRole(ctx context.Context, roleSlug, permSlug string) error {
const op = "authkit.Auth.GrantPermissionToRole"
r, err := a.storeGetRoleBySlug(ctx, roleSlug)
if err != nil {
return errx.Wrap(op, err)
}
p, err := a.storeGetPermissionBySlug(ctx, permSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeAssignPermissionToRole(ctx, r.ID, p.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RevokePermissionFromRole removes permSlug from roleSlug's permission set.
// Idempotent.
func (a *Auth) RevokePermissionFromRole(ctx context.Context, roleSlug, permSlug string) error {
const op = "authkit.Auth.RevokePermissionFromRole"
r, err := a.storeGetRoleBySlug(ctx, roleSlug)
if err != nil {
return errx.Wrap(op, err)
}
p, err := a.storeGetPermissionBySlug(ctx, permSlug)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeRemovePermissionFromRole(ctx, r.ID, p.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// ListRolePermissions returns permissions granted to a role through the
// role-permission link only (not direct user grants).
func (a *Auth) ListRolePermissions(ctx context.Context, roleSlug string) ([]*Permission, error) {
const op = "authkit.Auth.ListRolePermissions"
r, err := a.storeGetRoleBySlug(ctx, roleSlug)
if err != nil {
return nil, errx.Wrap(op, err)
}
out, err := a.storeGetRolePermissions(ctx, r.ID)
if err != nil {
return nil, errx.Wrap(op, err)
}
return out, nil
}

136
service_seed_test.go Normal file
View file

@ -0,0 +1,136 @@
package authkit
import (
"context"
"errors"
"testing"
)
func TestIntegration_SeedRolesAndPermissions(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
r, err := a.CreateRole(ctx, "editor", "Editor")
if err != nil {
t.Fatalf("CreateRole: %v", err)
}
if r.Slug != "editor" || r.Label != "Editor" {
t.Fatalf("role round-trip mismatch: %+v", r)
}
if _, err := a.CreateRole(ctx, "editor", ""); !errors.Is(err, ErrSlugTaken) {
t.Fatalf("expected ErrSlugTaken on duplicate, got %v", err)
}
if _, err := a.CreateRole(ctx, "Editor", ""); !errors.Is(err, ErrSlugInvalid) {
t.Fatalf("expected ErrSlugInvalid on uppercase, got %v", err)
}
if _, err := a.CreatePermission(ctx, "posts:write", "Write posts"); err != nil {
t.Fatalf("CreatePermission: %v", err)
}
if err := a.GrantPermissionToRole(ctx, "editor", "posts:write"); err != nil {
t.Fatalf("GrantPermissionToRole: %v", err)
}
roles, err := a.ListRoles(ctx)
if err != nil {
t.Fatalf("ListRoles: %v", err)
}
if len(roles) != 1 || roles[0].Slug != "editor" {
t.Fatalf("ListRoles unexpected: %+v", roles)
}
perms, err := a.ListRolePermissions(ctx, "editor")
if err != nil {
t.Fatalf("ListRolePermissions: %v", err)
}
if len(perms) != 1 || perms[0].Slug != "posts:write" {
t.Fatalf("ListRolePermissions unexpected: %+v", perms)
}
}
func TestIntegration_DirectUserPermissionsUnion(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "u@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
// Permission via role.
if _, err := a.CreateRole(ctx, "editor", ""); err != nil {
t.Fatalf("CreateRole: %v", err)
}
if _, err := a.CreatePermission(ctx, "posts:read", ""); err != nil {
t.Fatalf("CreatePermission posts:read: %v", err)
}
if err := a.GrantPermissionToRole(ctx, "editor", "posts:read"); err != nil {
t.Fatalf("GrantPermissionToRole: %v", err)
}
if err := a.AssignRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
// Direct permission grant.
if _, err := a.CreatePermission(ctx, "billing:view", ""); err != nil {
t.Fatalf("CreatePermission billing:view: %v", err)
}
if err := a.GrantPermissionToUser(ctx, u.ID, "billing:view"); err != nil {
t.Fatalf("GrantPermissionToUser: %v", err)
}
got, err := a.UserPermissions(ctx, u.ID)
if err != nil {
t.Fatalf("UserPermissions: %v", err)
}
if len(got) != 2 {
t.Fatalf("expected 2 perms (UNION), got %v", got)
}
want := map[string]bool{"posts:read": true, "billing:view": true}
for _, p := range got {
if !want[p] {
t.Fatalf("unexpected permission %q", p)
}
delete(want, p)
}
if len(want) != 0 {
t.Fatalf("missing permissions: %v", want)
}
// Revoke direct grant; only role-derived remains.
if err := a.RevokePermissionFromUser(ctx, u.ID, "billing:view"); err != nil {
t.Fatalf("RevokePermissionFromUser: %v", err)
}
got2, err := a.UserPermissions(ctx, u.ID)
if err != nil {
t.Fatalf("UserPermissions post-revoke: %v", err)
}
if len(got2) != 1 || got2[0] != "posts:read" {
t.Fatalf("expected only posts:read after revoke, got %v", got2)
}
}
func TestIntegration_RolePermissionMembershipQueries(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "h@h.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if _, err := a.CreateRole(ctx, "manager", ""); err != nil {
t.Fatalf("CreateRole: %v", err)
}
if err := a.AssignRole(ctx, u.ID, "manager"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
ok, err := a.HasRole(ctx, u.ID, "manager")
if err != nil || !ok {
t.Fatalf("HasRole: ok=%v err=%v", ok, err)
}
ok, err = a.HasAnyRole(ctx, u.ID, []string{"admin", "manager"})
if err != nil || !ok {
t.Fatalf("HasAnyRole: ok=%v err=%v", ok, err)
}
ok, err = a.HasAnyRole(ctx, u.ID, []string{"admin", "ops"})
if err != nil || ok {
t.Fatalf("HasAnyRole should be false: ok=%v err=%v", ok, err)
}
}

View file

@ -2,19 +2,50 @@ package authkit
import (
"context"
"errors"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// IssueServiceKey mints a fresh owner-agnostic service token. ownerKind is a
// consumer-defined namespace label (e.g. "application", "tenant") and ownerID
// is the owning entity's id; authkit makes no assumption about either. The
// plaintext is returned (show-once) and the SHA-256 lookup hash is stored.
// Pass ttl=nil for a non-expiring key.
func (a *Auth) IssueServiceKey(ctx context.Context, ownerKind string, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *ServiceKey, error) {
// IssueServiceKeyParams is the input shape for IssueServiceKey. Abilities
// are slugs that must already exist in authkit_abilities — issue fails with
// ErrAbilityNotFound if any slug is unknown. TTL is optional; nil means
// non-expiring.
type IssueServiceKeyParams struct {
Name string
Abilities []string
TTL *time.Duration
}
// IssueServiceKey mints a fresh service token. Plaintext is returned exactly
// once (show-once); only the SHA-256 hash is persisted. Each ability slug is
// resolved to its row before insertion, so the service key carries a
// well-defined set of abilities even after later slug renames or deletes.
func (a *Auth) IssueServiceKey(ctx context.Context, params IssueServiceKeyParams) (string, *ServiceKey, error) {
const op = "authkit.Auth.IssueServiceKey"
if params.Name == "" {
return "", nil, errx.New(op, "Name is required")
}
abilityIDs := make([]uuid.UUID, 0, len(params.Abilities))
resolved := make([]string, 0, len(params.Abilities))
seen := map[string]struct{}{}
for _, slug := range params.Abilities {
if _, dup := seen[slug]; dup {
continue
}
seen[slug] = struct{}{}
ab, err := a.storeGetAbilityBySlug(ctx, slug)
if err != nil {
if errors.Is(err, ErrAbilityNotFound) {
return "", nil, errx.Wrapf(op, ErrAbilityNotFound, "ability %q is not registered", slug)
}
return "", nil, errx.Wrap(op, err)
}
abilityIDs = append(abilityIDs, ab.ID)
resolved = append(resolved, ab.Slug)
}
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixServiceKey)
if err != nil {
return "", nil, errx.Wrap(op, err)
@ -22,34 +53,29 @@ func (a *Auth) IssueServiceKey(ctx context.Context, ownerKind string, ownerID uu
now := a.now()
k := &ServiceKey{
IDHash: hash,
OwnerID: ownerID,
OwnerKind: ownerKind,
Name: name,
Abilities: append([]string(nil), abilities...),
Name: params.Name,
Abilities: resolved,
CreatedAt: now,
}
if ttl != nil {
exp := now.Add(*ttl)
if params.TTL != nil {
exp := now.Add(*params.TTL)
k.ExpiresAt = &exp
}
if err := a.deps.ServiceKeys.CreateServiceKey(ctx, k); err != nil {
if err := a.storeCreateServiceKey(ctx, k, abilityIDs); err != nil {
return "", nil, errx.Wrap(op, err)
}
return plaintext, k, nil
}
// AuthenticateServiceKey validates a service token, touches last_used_at
// (best-effort), and returns the stored *ServiceKey. Unlike API keys, no
// Principal is returned — service tokens have no owning user, so the
// Principal abstraction does not fit. Consumers needing a Principal can
// build one from the returned key.
// AuthenticateServiceKey validates a service token and returns the stored
// *ServiceKey with its abilities resolved.
func (a *Auth) AuthenticateServiceKey(ctx context.Context, plaintext string) (*ServiceKey, error) {
const op = "authkit.Auth.AuthenticateServiceKey"
hash, ok := ParseOpaqueSecret(prefixServiceKey, plaintext)
if !ok {
return nil, errx.Wrap(op, ErrServiceKeyInvalid)
}
k, err := a.deps.ServiceKeys.GetServiceKey(ctx, hash)
k, err := a.storeGetServiceKey(ctx, hash)
if err != nil {
return nil, errx.Wrap(op, err)
}
@ -60,10 +86,8 @@ func (a *Auth) AuthenticateServiceKey(ctx context.Context, plaintext string) (*S
if k.ExpiresAt != nil && !k.ExpiresAt.After(now) {
return nil, errx.Wrap(op, ErrServiceKeyInvalid)
}
_ = a.deps.ServiceKeys.TouchServiceKey(ctx, hash, now)
out := *k
out.Abilities = append([]string(nil), k.Abilities...)
return &out, nil
_ = a.storeTouchServiceKey(ctx, hash, now)
return k, nil
}
// RevokeServiceKey marks a service token revoked. Idempotent on
@ -74,17 +98,17 @@ func (a *Auth) RevokeServiceKey(ctx context.Context, plaintext string) error {
if !ok {
return errx.Wrap(op, ErrServiceKeyInvalid)
}
if err := a.deps.ServiceKeys.RevokeServiceKey(ctx, hash, a.now()); err != nil {
if err := a.storeRevokeServiceKey(ctx, hash, a.now()); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// ListServiceKeys returns every service token issued for the given
// (ownerKind, ownerID) pair, including revoked and expired keys.
func (a *Auth) ListServiceKeys(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*ServiceKey, error) {
// ListServiceKeys returns every service token, including revoked and
// expired ones, ordered by creation time descending.
func (a *Auth) ListServiceKeys(ctx context.Context) ([]*ServiceKey, error) {
const op = "authkit.Auth.ListServiceKeys"
out, err := a.deps.ServiceKeys.ListServiceKeysByOwner(ctx, ownerKind, ownerID)
out, err := a.storeListServiceKeys(ctx)
if err != nil {
return nil, errx.Wrap(op, err)
}

View file

@ -2,183 +2,123 @@ package authkit
import (
"context"
"encoding/base64"
"errors"
"strings"
"testing"
"time"
"github.com/google/uuid"
)
func TestServiceKeyRoundtrip(t *testing.T) {
a := newTestAuth(t)
appID := uuid.New()
plaintext, k, err := a.IssueServiceKey(context.Background(),
"application", appID, "events-ingest",
[]string{"events:write", "events:read"}, nil)
func TestIntegration_ServiceKeyRoundtrip(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "events:write", "Write events"); err != nil {
t.Fatalf("CreateAbility events:write: %v", err)
}
if _, err := a.CreateAbility(ctx, "events:read", "Read events"); err != nil {
t.Fatalf("CreateAbility events:read: %v", err)
}
plain, k, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{
Name: "events-ingest",
Abilities: []string{"events:write", "events:read"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
if plaintext == "" || k == nil {
t.Fatalf("missing plaintext or key")
if !strings.HasPrefix(plain, "sk_") {
t.Fatalf("plaintext should start with sk_: %q", plain)
}
got, err := a.AuthenticateServiceKey(context.Background(), plaintext)
if k.Name != "events-ingest" {
t.Fatalf("name mismatch: %q", k.Name)
}
if len(k.Abilities) != 2 {
t.Fatalf("expected 2 abilities, got %v", k.Abilities)
}
got, err := a.AuthenticateServiceKey(ctx, plain)
if err != nil {
t.Fatalf("AuthenticateServiceKey: %v", err)
}
if got.OwnerKind != "application" || got.OwnerID != appID {
t.Fatalf("owner mismatch: kind=%q id=%v", got.OwnerKind, got.OwnerID)
}
if got.Name != "events-ingest" {
t.Fatalf("name mismatch: %q", got.Name)
}
if len(got.Abilities) != 2 || got.Abilities[0] != "events:write" || got.Abilities[1] != "events:read" {
t.Fatalf("abilities mismatch: %+v", got.Abilities)
}
got.Abilities[0] = "tampered"
again, err := a.AuthenticateServiceKey(context.Background(), plaintext)
if err != nil {
t.Fatalf("AuthenticateServiceKey (re-auth): %v", err)
}
if again.Abilities[0] != "events:write" {
t.Fatalf("returned slice was not deep-copied; saw mutation: %+v", again.Abilities)
if !got.HasAbility("events:write") || !got.HasAbility("events:read") {
t.Fatalf("missing expected abilities: %+v", got.Abilities)
}
}
func TestServiceKeyPlaintextShape(t *testing.T) {
a := newTestAuth(t)
plaintext, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "name", nil, nil)
func TestIntegration_ServiceKeyRejectsUnknownAbility(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
_, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{
Name: "x",
Abilities: []string{"never-registered"},
})
if !errors.Is(err, ErrAbilityNotFound) {
t.Fatalf("expected ErrAbilityNotFound, got %v", err)
}
}
func TestIntegration_ServiceKeyRevoke(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "ops", ""); err != nil {
t.Fatalf("CreateAbility: %v", err)
}
plain, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{
Name: "ci",
Abilities: []string{"ops"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
if !strings.HasPrefix(plaintext, "sk_") {
t.Fatalf("plaintext missing sk_ prefix: %q", plaintext)
}
body := strings.TrimPrefix(plaintext, "sk_")
raw, err := base64.RawURLEncoding.DecodeString(body)
if err != nil {
t.Fatalf("base64 decode: %v", err)
}
if len(raw) != 32 {
t.Fatalf("body decoded to %d bytes, want 32", len(raw))
}
}
func TestServiceKeyWrongPrefix(t *testing.T) {
a := newTestAuth(t)
_, err := a.AuthenticateServiceKey(context.Background(), "ak_not-a-service-key")
if !errors.Is(err, ErrServiceKeyInvalid) {
t.Fatalf("expected ErrServiceKeyInvalid for wrong prefix, got %v", err)
}
}
func TestServiceKeyAfterRevoke(t *testing.T) {
a := newTestAuth(t)
plaintext, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ci", nil, nil)
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
if err := a.RevokeServiceKey(context.Background(), plaintext); err != nil {
if err := a.RevokeServiceKey(ctx, plain); err != nil {
t.Fatalf("RevokeServiceKey: %v", err)
}
if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); !errors.Is(err, ErrServiceKeyInvalid) {
if _, err := a.AuthenticateServiceKey(ctx, plain); !errors.Is(err, ErrServiceKeyInvalid) {
t.Fatalf("expected ErrServiceKeyInvalid post-revoke, got %v", err)
}
}
func TestServiceKeyAfterExpiry(t *testing.T) {
a := newTestAuth(t)
func TestIntegration_ServiceKeyExpiry(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "ops", ""); err != nil {
t.Fatalf("CreateAbility: %v", err)
}
now := time.Now().UTC()
a.cfg.Clock = func() time.Time { return now }
ttl := time.Minute
plaintext, _, err := a.IssueServiceKey(context.Background(),
"application", uuid.New(), "ephemeral", nil, &ttl)
plain, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{
Name: "ephemeral",
Abilities: []string{"ops"},
TTL: &ttl,
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
a.cfg.Clock = func() time.Time { return now.Add(2 * time.Minute) }
if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); !errors.Is(err, ErrServiceKeyInvalid) {
if _, err := a.AuthenticateServiceKey(ctx, plain); !errors.Is(err, ErrServiceKeyInvalid) {
t.Fatalf("expected ErrServiceKeyInvalid post-expiry, got %v", err)
}
}
func TestServiceKeyListByOwner(t *testing.T) {
a := newTestAuth(t)
appA := uuid.New()
appB := uuid.New()
for i := 0; i < 2; i++ {
if _, _, err := a.IssueServiceKey(context.Background(), "application", appA, "k", nil, nil); err != nil {
t.Fatalf("Issue appA #%d: %v", i, err)
func TestIntegration_ServiceKeyList(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "ops", ""); err != nil {
t.Fatalf("CreateAbility: %v", err)
}
for i := 0; i < 3; i++ {
if _, _, err := a.IssueServiceKey(ctx, IssueServiceKeyParams{
Name: "k",
Abilities: []string{"ops"},
}); err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
}
if _, _, err := a.IssueServiceKey(context.Background(), "application", appB, "k", nil, nil); err != nil {
t.Fatalf("Issue appB: %v", err)
}
gotA, err := a.ListServiceKeys(context.Background(), "application", appA)
out, err := a.ListServiceKeys(ctx)
if err != nil {
t.Fatalf("ListServiceKeys appA: %v", err)
t.Fatalf("ListServiceKeys: %v", err)
}
if len(gotA) != 2 {
t.Fatalf("ListServiceKeys appA = %d keys, want 2", len(gotA))
}
gotB, err := a.ListServiceKeys(context.Background(), "application", appB)
if err != nil {
t.Fatalf("ListServiceKeys appB: %v", err)
}
if len(gotB) != 1 {
t.Fatalf("ListServiceKeys appB = %d keys, want 1", len(gotB))
}
gotTenantA, err := a.ListServiceKeys(context.Background(), "tenant", appA)
if err != nil {
t.Fatalf("ListServiceKeys tenant/appA: %v", err)
}
if len(gotTenantA) != 0 {
t.Fatalf("ListServiceKeys tenant/appA = %d, want 0 (different owner_kind)", len(gotTenantA))
}
}
func TestServiceKeyHasAbility(t *testing.T) {
k := &ServiceKey{Abilities: []string{"events:write", "events:read"}}
if !k.HasAbility("events:write") {
t.Fatalf("expected HasAbility(events:write) = true")
}
if !k.HasAbility("events:read") {
t.Fatalf("expected HasAbility(events:read) = true")
}
if k.HasAbility("admin:nuke") {
t.Fatalf("expected HasAbility(admin:nuke) = false")
}
empty := &ServiceKey{}
if empty.HasAbility("anything") {
t.Fatalf("HasAbility on empty Abilities must be false")
}
}
func TestServiceKeyTouchUpdatesLastUsedAt(t *testing.T) {
a := newTestAuth(t)
appID := uuid.New()
plaintext, _, err := a.IssueServiceKey(context.Background(), "application", appID, "k", nil, nil)
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
keys, err := a.ListServiceKeys(context.Background(), "application", appID)
if err != nil || len(keys) != 1 {
t.Fatalf("pre-touch list: err=%v len=%d", err, len(keys))
}
if keys[0].LastUsedAt != nil {
t.Fatalf("expected LastUsedAt=nil before authenticate, got %v", *keys[0].LastUsedAt)
}
if _, err := a.AuthenticateServiceKey(context.Background(), plaintext); err != nil {
t.Fatalf("AuthenticateServiceKey: %v", err)
}
keys, err = a.ListServiceKeys(context.Background(), "application", appID)
if err != nil || len(keys) != 1 {
t.Fatalf("post-touch list: err=%v len=%d", err, len(keys))
}
if keys[0].LastUsedAt == nil {
t.Fatalf("expected LastUsedAt to be set after authenticate")
if len(out) != 3 {
t.Fatalf("expected 3 keys, got %d", len(out))
}
}

View file

@ -14,7 +14,7 @@ import (
// returns the plaintext (for the cookie) plus the stored Session.
func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) {
const op = "authkit.Auth.IssueSession"
plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random)
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixSession)
if err != nil {
return "", nil, errx.Wrap(op, err)
}
@ -32,38 +32,35 @@ func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent str
LastSeenAt: now,
ExpiresAt: expires,
}
if err := a.deps.Sessions.CreateSession(ctx, s); err != nil {
if err := a.storeCreateSession(ctx, s); err != nil {
return "", nil, errx.Wrap(op, err)
}
return plaintext, s, nil
}
// AuthenticateSession validates an opaque session string, slides the TTL,
// resolves the user's roles+permissions, and returns a Principal. Expired or
// unknown sessions return ErrSessionInvalid.
// resolves the user's roles+permissions, and returns a Principal.
func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) {
const op = "authkit.Auth.AuthenticateSession"
hash, ok := parseSecret(prefixSession, plaintext)
hash, ok := ParseOpaqueSecret(prefixSession, plaintext)
if !ok {
return nil, errx.Wrap(op, ErrSessionInvalid)
}
s, err := a.deps.Sessions.GetSession(ctx, hash)
s, err := a.storeGetSession(ctx, hash)
if err != nil {
return nil, errx.Wrap(op, err)
}
now := a.now()
if !s.ExpiresAt.After(now) {
_ = a.deps.Sessions.DeleteSession(ctx, hash)
_ = a.storeDeleteSession(ctx, hash)
return nil, errx.Wrap(op, ErrSessionInvalid)
}
// Slide the idle TTL, capped at created_at + AbsoluteTTL so an active
// session still expires at the absolute boundary.
newExpires := now.Add(a.cfg.SessionIdleTTL)
if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) {
newExpires = cap
}
if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil {
if err := a.storeTouchSession(ctx, hash, now, newExpires); err != nil {
return nil, errx.Wrap(op, err)
}
@ -82,73 +79,85 @@ func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Prin
}, nil
}
// RevokeSession deletes a single session by its plaintext id. Idempotent:
// missing sessions are not an error (logout twice should not 500).
// RevokeSession deletes a single session by its plaintext id. Idempotent
// missing sessions are not an error.
func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error {
const op = "authkit.Auth.RevokeSession"
hash, ok := parseSecret(prefixSession, plaintext)
hash, ok := ParseOpaqueSecret(prefixSession, plaintext)
if !ok {
return nil
}
if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil {
if err := a.storeDeleteSession(ctx, hash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RevokeAllUserSessions kills every active session for the user and bumps
// the user's session_version (invalidating outstanding JWT access tokens).
// session_version (invalidating outstanding JWT access tokens).
func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.Auth.RevokeAllUserSessions"
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
if err := a.storeDeleteUserSessions(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
if _, err := a.storeBumpSessionVersion(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the
// plaintext returned by IssueSession; pass the matching ExpiresAt from the
// returned *Session as `expires`. To clear a cookie at logout, pass an empty
// plaintext and a past expiry.
// plaintext returned by IssueSession and the matching ExpiresAt from the
// returned *Session.
func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie {
c := &http.Cookie{
return &http.Cookie{
Name: a.cfg.SessionCookieName,
Value: plaintext,
Path: a.cfg.SessionCookiePath,
Domain: a.cfg.SessionCookieDomain,
Secure: a.cfg.SessionCookieSecure,
HttpOnly: a.cfg.SessionCookieHTTPOnly,
Secure: *a.cfg.SessionCookieSecure,
HttpOnly: *a.cfg.SessionCookieHTTPOnly,
SameSite: a.cfg.SessionCookieSameSite,
Expires: expires,
}
if plaintext == "" {
c.MaxAge = -1
}
return c
}
// resolveRolesAndPermissions fetches the user's role names and the union of
// their permission names. Both are returned as flat string slices for cheap
// containment checks on the Principal.
func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) {
roles, err := a.deps.Roles.GetUserRoles(ctx, userID)
if err != nil {
return nil, nil, err
// ClearSessionCookie returns a cookie that, when set on the response, tells
// the browser to delete the session cookie. Use on logout.
func (a *Auth) ClearSessionCookie() *http.Cookie {
return &http.Cookie{
Name: a.cfg.SessionCookieName,
Value: "",
Path: a.cfg.SessionCookiePath,
Domain: a.cfg.SessionCookieDomain,
Secure: *a.cfg.SessionCookieSecure,
HttpOnly: *a.cfg.SessionCookieHTTPOnly,
SameSite: a.cfg.SessionCookieSameSite,
MaxAge: -1,
}
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
if err != nil {
return nil, nil, err
}
rNames := make([]string, len(roles))
for i, r := range roles {
rNames[i] = r.Name
}
pNames := make([]string, len(perms))
for i, p := range perms {
pNames[i] = p.Name
}
return rNames, pNames, nil
}
// SessionCookieName returns the configured cookie name. Useful for callers
// wiring extractors without reaching into Config.
func (a *Auth) SessionCookieName() string { return a.cfg.SessionCookieName }
// resolveRolesAndPermissions fetches the user's role and permission slugs.
func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) {
roles, err := a.storeGetUserRoles(ctx, userID)
if err != nil {
return nil, nil, err
}
perms, err := a.storeGetUserPermissions(ctx, userID)
if err != nil {
return nil, nil, err
}
rSlugs := make([]string, len(roles))
for i, r := range roles {
rSlugs[i] = r.Slug
}
pSlugs := make([]string, len(perms))
for i, p := range perms {
pSlugs[i] = p.Slug
}
return rSlugs, pSlugs, nil
}

81
service_session_test.go Normal file
View file

@ -0,0 +1,81 @@
package authkit
import (
"context"
"errors"
"testing"
"time"
)
func TestIntegration_SessionLifecycle(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "s@s.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
plain, sess, err := a.IssueSession(ctx, u.ID, "ua", noIP())
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
if sess.ExpiresAt.Before(time.Now()) {
t.Fatalf("session already expired at issue")
}
p, err := a.AuthenticateSession(ctx, plain)
if err != nil {
t.Fatalf("AuthenticateSession: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("principal user id mismatch")
}
if p.Method != AuthMethodSession {
t.Fatalf("method = %s, want session", p.Method)
}
if err := a.RevokeSession(ctx, plain); err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
}
}
func TestIntegration_SessionCookieDefaultsSecure(t *testing.T) {
a := freshAuth(t)
c := a.SessionCookie("plaintext", time.Now().Add(time.Hour))
if !c.Secure {
t.Fatalf("Secure should default to true")
}
if !c.HttpOnly {
t.Fatalf("HttpOnly should default to true")
}
clear := a.ClearSessionCookie()
if clear.MaxAge != -1 || clear.Value != "" {
t.Fatalf("ClearSessionCookie should be MaxAge=-1 and Value=\"\"")
}
}
func TestIntegration_RevokeAllSessions(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "ra@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
plain, _, err := a.IssueSession(ctx, u.ID, "ua", noIP())
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if err := a.RevokeAllUserSessions(ctx, u.ID); err != nil {
t.Fatalf("RevokeAllUserSessions: %v", err)
}
if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("session should be revoked, got %v", err)
}
if _, err := a.AuthenticateJWT(ctx, access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("JWT should be invalidated by session_version bump, got %v", err)
}
}

View file

@ -1,189 +0,0 @@
package authkit
import (
"context"
"errors"
"net/netip"
"testing"
)
func TestRegisterAndLogin(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
if u.EmailNormalized != "alice@example.com" {
t.Fatalf("email_normalized = %q", u.EmailNormalized)
}
got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("LoginPassword: %v", err)
}
if got.ID != u.ID {
t.Fatalf("login user mismatch")
}
}
func TestRegisterDuplicateEmail(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil {
t.Fatalf("Register: %v", err)
}
_, err := a.Register(context.Background(), "X@Y.COM", "abc")
if !errors.Is(err, ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken, got %v", err)
}
}
func TestLoginWrongPassword(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil {
t.Fatalf("Register: %v", err)
}
if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestSessionIssueAuthenticateRevoke(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "s@s.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
if sess == nil || plain == "" {
t.Fatalf("missing session or plaintext")
}
p, err := a.AuthenticateSession(context.Background(), plain)
if err != nil {
t.Fatalf("AuthenticateSession: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("principal user id mismatch")
}
if err := a.RevokeSession(context.Background(), plain); err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
}
}
func TestEmailVerificationFlow(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "ev@e.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
tok, err := a.RequestEmailVerification(context.Background(), u.ID)
if err != nil {
t.Fatalf("RequestEmailVerification: %v", err)
}
confirmed, err := a.ConfirmEmail(context.Background(), tok)
if err != nil {
t.Fatalf("ConfirmEmail: %v", err)
}
if confirmed.EmailVerifiedAt == nil {
t.Fatalf("email_verified_at not set")
}
// Re-using the token must fail.
if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err)
}
}
func TestPasswordResetFlow(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "r@r.com", "old")
if err != nil {
t.Fatalf("Register: %v", err)
}
// Issue a session that should be invalidated by the reset.
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
tok, err := a.RequestPasswordReset(context.Background(), "r@r.com")
if err != nil {
t.Fatalf("RequestPasswordReset: %v", err)
}
if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil {
t.Fatalf("ConfirmPasswordReset: %v", err)
}
if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("old password should fail post-reset, got %v", err)
}
if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil {
t.Fatalf("new password should work post-reset, got %v", err)
}
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("session should be invalidated by reset, got %v", err)
}
}
func TestMagicLinkFlow(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil {
t.Fatalf("Register: %v", err)
}
tok, err := a.RequestMagicLink(context.Background(), "m@m.com")
if err != nil {
t.Fatalf("RequestMagicLink: %v", err)
}
u, err := a.ConsumeMagicLink(context.Background(), tok)
if err != nil {
t.Fatalf("ConsumeMagicLink: %v", err)
}
if u.EmailVerifiedAt == nil {
t.Fatalf("magic link should imply email verification")
}
if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err)
}
}
func TestRBACRolesAndPermissions(t *testing.T) {
ctx := context.Background()
a := newTestAuth(t)
u, err := a.Register(ctx, "rb@a.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
// Create role + permission, hook them up.
role := &Role{Name: "editor"}
if err := a.deps.Roles.CreateRole(ctx, role); err != nil {
t.Fatalf("CreateRole: %v", err)
}
perm := &Permission{Name: "posts:write"}
if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil {
t.Fatalf("CreatePermission: %v", err)
}
if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil {
t.Fatalf("AssignPermissionToRole: %v", err)
}
if err := a.AssignRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
ok, err := a.HasPermission(ctx, u.ID, "posts:write")
if err != nil || !ok {
t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err)
}
ok, err = a.HasRole(ctx, u.ID, "editor")
if err != nil || !ok {
t.Fatalf("HasRole editor should be true, got %v %v", ok, err)
}
if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("RemoveRole: %v", err)
}
ok, _ = a.HasPermission(ctx, u.ID, "posts:write")
if ok {
t.Fatalf("HasPermission should be false after RemoveRole")
}
}

View file

@ -3,53 +3,61 @@ package authkit
import (
"context"
"errors"
"strings"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail
// and the email_normalized column. Trim + lowercase is intentional; we do
// not collapse Gmail-style "+" addressing or strip dots — that's a policy
// decision callers can layer on top.
func normalizeEmail(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
// Register creates a new user with an Argon2id-hashed password. Returns
// ErrEmailTaken if the normalized email is already registered.
func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) {
const op = "authkit.Auth.Register"
if email == "" || password == "" {
// CreateUser registers a new account with the given email. Password is
// optional — accounts can be created without a credential and have one set
// later via SetPassword. Returns ErrEmailTaken if the normalized email is
// already registered.
func (a *Auth) CreateUser(ctx context.Context, email string) (*User, error) {
const op = "authkit.Auth.CreateUser"
if email == "" {
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
hash, err := a.deps.Hasher.Hash(password)
if err != nil {
return nil, errx.Wrap(op, err)
}
now := a.now()
u := &User{
ID: uuid.New(),
Email: email,
EmailNormalized: normalizeEmail(email),
PasswordHash: hash,
CreatedAt: now,
UpdatedAt: now,
}
if err := a.deps.Users.CreateUser(ctx, u); err != nil {
if err := a.storeCreateUser(ctx, u); err != nil {
return nil, errx.Wrap(op, err)
}
return u, nil
}
// SetPassword stores a password hash for the user. Use for the
// initial-credential flow and for administrative password changes.
// Bumping session_version is the caller's responsibility; SetPassword does
// not invalidate existing sessions on its own. ChangePassword is the
// safer wrapper for end-user-driven changes.
func (a *Auth) SetPassword(ctx context.Context, userID uuid.UUID, password string) error {
const op = "authkit.Auth.SetPassword"
if password == "" {
return errx.Wrap(op, ErrInvalidCredentials)
}
hash, err := a.hasher.Hash(password)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.storeSetPassword(ctx, userID, hash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// LoginPassword verifies the password and returns the authenticated user.
// Failure increments failed_logins; success resets it and stamps last_login_at.
// LoginHook (if configured) is invoked with the success outcome — use this to
// hook in rate limiting or audit logging.
// Failure does not increment any counter — consumers wanting lockout should
// implement it via LoginHook (see README). Success resets nothing and stamps
// last_login_at. LoginHook is invoked with the success outcome.
func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) {
const op = "authkit.Auth.LoginPassword"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
_ = a.fireLoginHook(ctx, email, false)
if errors.Is(err, ErrUserNotFound) {
@ -58,32 +66,29 @@ func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User
return nil, errx.Wrap(op, err)
}
if u.PasswordHash == "" {
// Password-less account (invite-only / magic-link-only). Treat the
// same as wrong password to avoid leaking account state.
_ = a.fireLoginHook(ctx, email, false)
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash)
ok, needsRehash, err := a.hasher.Verify(password, u.PasswordHash)
if err != nil {
return nil, errx.Wrap(op, err)
}
if !ok {
_, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID)
_ = a.fireLoginHook(ctx, email, false)
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
now := a.now()
u.LastLoginAt = &now
u.FailedLogins = 0
if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil {
return nil, errx.Wrap(op, err)
}
if err := a.deps.Users.UpdateUser(ctx, u); err != nil {
if err := a.storeUpdateUser(ctx, u); err != nil {
return nil, errx.Wrap(op, err)
}
if needsRehash {
if newHash, herr := a.deps.Hasher.Hash(password); herr == nil {
_ = a.deps.Users.SetPassword(ctx, u.ID, newHash)
if newHash, herr := a.hasher.Hash(password); herr == nil {
_ = a.storeSetPassword(ctx, u.ID, newHash)
u.PasswordHash = newHash
}
}
@ -92,46 +97,76 @@ func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User
}
// ChangePassword verifies the current password, sets the new one, and bumps
// the user's session_version so all outstanding JWT access tokens are
// instantly invalidated. Outstanding opaque sessions are also revoked.
// the user's session_version (invalidating outstanding JWTs). Outstanding
// opaque sessions are also revoked.
func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error {
const op = "authkit.Auth.ChangePassword"
u, err := a.deps.Users.GetUserByID(ctx, userID)
u, err := a.storeGetUserByID(ctx, userID)
if err != nil {
return errx.Wrap(op, err)
}
if u.PasswordHash == "" {
return errx.Wrap(op, ErrInvalidCredentials)
}
ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash)
ok, _, err := a.hasher.Verify(oldPassword, u.PasswordHash)
if err != nil {
return errx.Wrap(op, err)
}
if !ok {
return errx.Wrap(op, ErrInvalidCredentials)
}
newHash, err := a.deps.Hasher.Hash(newPassword)
newHash, err := a.hasher.Hash(newPassword)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil {
if err := a.storeSetPassword(ctx, userID, newHash); err != nil {
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
if _, err := a.storeBumpSessionVersion(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
if err := a.storeDeleteUserSessions(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// GetUser fetches the user by ID. Returns ErrUserNotFound if absent.
func (a *Auth) GetUser(ctx context.Context, userID uuid.UUID) (*User, error) {
const op = "authkit.Auth.GetUser"
u, err := a.storeGetUserByID(ctx, userID)
if err != nil {
return nil, errx.Wrap(op, err)
}
return u, nil
}
// GetUserByEmail fetches the user by email (input is normalized internally).
func (a *Auth) GetUserByEmail(ctx context.Context, email string) (*User, error) {
const op = "authkit.Auth.GetUserByEmail"
u, err := a.storeGetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
return nil, errx.Wrap(op, err)
}
return u, nil
}
// DeleteUser removes the user. Cascades to sessions, tokens, role
// assignments, and direct permission grants via FK ON DELETE CASCADE.
func (a *Auth) DeleteUser(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.Auth.DeleteUser"
if err := a.storeDeleteUser(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RequestEmailVerification mints a single-use email-verify token for the
// user. Return the plaintext to the caller so they can put it in an email
// link; the lookup hash is stored in TokenStore.
// user. Return the plaintext to the caller for delivery; the lookup hash
// is what's stored.
func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) {
const op = "authkit.Auth.RequestEmailVerification"
plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random)
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixEmailVerify)
if err != nil {
return "", errx.Wrap(op, err)
}
@ -143,36 +178,42 @@ func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.EmailVerifyTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
if err := a.storeCreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConfirmEmail consumes the verification token and marks the user's email
// verified. Returns ErrTokenInvalid if the token is missing/expired/used.
// ConfirmEmail consumes a verification token and marks the user's email
// verified. Returns ErrTokenInvalid for missing/expired/already-used tokens.
func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) {
const op = "authkit.Auth.ConfirmEmail"
hash, ok := parseSecret(prefixEmailVerify, plaintextToken)
hash, ok := ParseOpaqueSecret(prefixEmailVerify, plaintextToken)
if !ok {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now)
t, err := a.storeConsumeToken(ctx, TokenEmailVerify, hash, now)
if err != nil {
return nil, errx.Wrap(op, err)
}
if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil {
if err := a.storeSetEmailVerified(ctx, t.UserID, now); err != nil {
return nil, errx.Wrap(op, err)
}
return a.deps.Users.GetUserByID(ctx, t.UserID)
return a.storeGetUserByID(ctx, t.UserID)
}
// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied
// hooks; we never want a misbehaving telemetry hook to break login.
func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error {
// fireLoginHook runs Config.LoginHook if configured. Returned errors are
// surfaced for the caller to log; they never break login. The hook is
// wrapped in recover() so a misbehaving hook can't take down the auth path.
func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) (err error) {
if a.cfg.LoginHook == nil {
return nil
}
defer func() {
if r := recover(); r != nil {
err = errx.Newf("authkit.fireLoginHook", "login hook panicked: %v", r)
}
}()
return a.cfg.LoginHook(ctx, email, success)
}

149
service_user_test.go Normal file
View file

@ -0,0 +1,149 @@
package authkit
import (
"context"
"errors"
"testing"
)
func TestIntegration_CreateUserNoPasswordThenLoginFails(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "alice@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if u.PasswordHash != "" {
t.Fatalf("password should be empty for fresh user")
}
// Login against the password-less user must fail with
// ErrInvalidCredentials, not leak account existence.
if _, err := a.LoginPassword(ctx, "alice@example.com", "anything"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestIntegration_CreateUserSetPasswordThenLogin(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "bob@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if err := a.SetPassword(ctx, u.ID, "hunter2hunter2"); err != nil {
t.Fatalf("SetPassword: %v", err)
}
got, err := a.LoginPassword(ctx, "Bob@Example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("LoginPassword (case-insensitive email): %v", err)
}
if got.ID != u.ID {
t.Fatalf("user id mismatch")
}
if _, err := a.LoginPassword(ctx, "bob@example.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestIntegration_CreateUserDuplicateEmail(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateUser(ctx, "dup@example.com"); err != nil {
t.Fatalf("CreateUser: %v", err)
}
if _, err := a.CreateUser(ctx, "DUP@example.com"); !errors.Is(err, ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken on case-folded duplicate, got %v", err)
}
}
func TestIntegration_EmailVerificationFlow(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "ev@e.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
tok, err := a.RequestEmailVerification(ctx, u.ID)
if err != nil {
t.Fatalf("RequestEmailVerification: %v", err)
}
confirmed, err := a.ConfirmEmail(ctx, tok)
if err != nil {
t.Fatalf("ConfirmEmail: %v", err)
}
if confirmed.EmailVerifiedAt == nil {
t.Fatalf("email_verified_at not set")
}
if _, err := a.ConfirmEmail(ctx, tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err)
}
}
func TestIntegration_PasswordResetCascadesSessionInvalidation(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "r@r.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if err := a.SetPassword(ctx, u.ID, "old-password"); err != nil {
t.Fatalf("SetPassword: %v", err)
}
plain, _, err := a.IssueSession(ctx, u.ID, "ua", noIP())
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
tok, err := a.RequestPasswordReset(ctx, "r@r.com")
if err != nil {
t.Fatalf("RequestPasswordReset: %v", err)
}
if tok == "" {
t.Fatalf("expected token for known email")
}
if err := a.ConfirmPasswordReset(ctx, tok, "new-password"); err != nil {
t.Fatalf("ConfirmPasswordReset: %v", err)
}
if _, err := a.LoginPassword(ctx, "r@r.com", "old-password"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("old password should fail, got %v", err)
}
if _, err := a.LoginPassword(ctx, "r@r.com", "new-password"); err != nil {
t.Fatalf("new password should work: %v", err)
}
if _, err := a.AuthenticateSession(ctx, plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("session should be invalidated by reset: got %v", err)
}
}
func TestIntegration_PasswordResetUnknownEmailIsSilent(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
tok, err := a.RequestPasswordReset(ctx, "nobody@example.com")
if err != nil {
t.Fatalf("expected silent success, got err %v", err)
}
if tok != "" {
t.Fatalf("expected empty token for unknown email, got %q", tok)
}
}
func TestIntegration_ChangePasswordRevokesEverything(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "cp@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if err := a.SetPassword(ctx, u.ID, "old-password"); err != nil {
t.Fatalf("SetPassword: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if err := a.ChangePassword(ctx, u.ID, "old-password", "new-password"); err != nil {
t.Fatalf("ChangePassword: %v", err)
}
if _, err := a.AuthenticateJWT(ctx, access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("JWT should be invalidated by password change, got %v", err)
}
}

32
slug.go Normal file
View file

@ -0,0 +1,32 @@
package authkit
import (
"regexp"
"git.juancwu.dev/juancwu/errx"
)
// MaxSlugLength is the upper bound on slug length, in bytes. Slugs are ASCII
// so this also bounds character count.
const MaxSlugLength = 64
// slugRE matches the accepted slug shape: a lowercase ASCII letter followed
// by any number of lowercase letters, digits, or one of `_`, `:`, `-`.
// Common valid forms: "admin", "ads-manager", "ads_manager", "posts:write".
var slugRE = regexp.MustCompile(`^[a-z][a-z0-9_:-]*$`)
// validateSlug returns nil when s is a syntactically valid slug. Strict
// validation, no transformation: the caller must pre-normalize before
// passing in. Wrapped with op for call-site context.
func validateSlug(op, s string) error {
if s == "" {
return errx.Wrap(op, ErrSlugInvalid)
}
if len(s) > MaxSlugLength {
return errx.Wrapf(op, ErrSlugInvalid, "slug exceeds %d bytes", MaxSlugLength)
}
if !slugRE.MatchString(s) {
return errx.Wrapf(op, ErrSlugInvalid, "slug %q does not match %s", s, slugRE.String())
}
return nil
}

87
slug_test.go Normal file
View file

@ -0,0 +1,87 @@
package authkit
import (
"errors"
"strings"
"testing"
)
func TestValidateSlug(t *testing.T) {
cases := []struct {
name string
slug string
wantErr bool
}{
{"plain lowercase", "admin", false},
{"snake", "ads_manager", false},
{"kebab", "ads-manager", false},
{"colon namespaced", "posts:write", false},
{"with digits", "v2:posts", false},
{"single char start", "a", false},
{"empty", "", true},
{"uppercase", "Admin", true},
{"starts with digit", "1admin", true},
{"starts with underscore", "_admin", true},
{"starts with hyphen", "-admin", true},
{"contains space", "ads manager", true},
{"contains slash", "posts/write", true},
{"contains dot", "posts.write", true},
{"only digits", "12345", true},
{"too long", strings.Repeat("a", MaxSlugLength+1), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := validateSlug("test", c.slug)
if c.wantErr {
if err == nil {
t.Fatalf("expected ErrSlugInvalid for %q, got nil", c.slug)
}
if !errors.Is(err, ErrSlugInvalid) {
t.Fatalf("expected error wrapping ErrSlugInvalid, got %v", err)
}
return
}
if err != nil {
t.Fatalf("expected %q valid, got %v", c.slug, err)
}
})
}
}
func TestValidateSlugAtMaxLength(t *testing.T) {
s := strings.Repeat("a", MaxSlugLength)
if err := validateSlug("test", s); err != nil {
t.Fatalf("MaxSlugLength=%d should pass, got %v", MaxSlugLength, err)
}
}
func FuzzValidateSlug(f *testing.F) {
f.Add("admin")
f.Add("posts:write")
f.Add("ads-manager")
f.Add("Admin")
f.Add("")
f.Add("a/b")
f.Add(strings.Repeat("a", 65))
f.Fuzz(func(t *testing.T, s string) {
err := validateSlug("fuzz", s)
if err == nil {
// Sanity-check the validation invariants. Anything that passes
// must be ASCII, length <= MaxSlugLength, lowercase-start.
if len(s) == 0 || len(s) > MaxSlugLength {
t.Fatalf("slug %q passed but length %d violates bounds", s, len(s))
}
c := s[0]
if !(c >= 'a' && c <= 'z') {
t.Fatalf("slug %q passed but starts with %q", s, c)
}
for _, r := range s {
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') &&
r != '_' && r != ':' && r != '-' {
t.Fatalf("slug %q passed but contains %q", s, r)
}
}
}
})
}

View file

@ -1,117 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"io/fs"
)
// Dialect describes how a particular SQL backend renders the queries authkit
// runs at runtime, bootstraps its session for migrations, and reports
// driver-specific errors. v1 ships the Postgres dialect; future MySQL or
// SQLite implementations satisfy the same interface without changes to the
// store code.
type Dialect interface {
// Name returns a stable short identifier ("postgres", "mysql", ...).
Name() string
// BuildQueries renders every templated SQL string from a validated
// Schema. Called once at New() and at Migrate() so identifiers and
// placeholder styles are baked in by the time stores execute queries.
BuildQueries(s Schema) Queries
// Bootstrap runs once per Migrate() before the migration lock is taken.
// It is the place for things that must happen outside any transaction
// (CREATE EXTENSION on Postgres, for example). Must be idempotent and
// may be a no-op.
Bootstrap(ctx context.Context, db *sql.DB) error
// AcquireMigrationLock takes a session-scoped lock on conn so concurrent
// migrators serialise. The returned release function never returns an
// error — implementations may log internally.
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
// Migrations returns the dialect's embedded .sql files, rooted such
// that fs.ReadDir(".") lists them lex-sorted by filename.
Migrations() fs.FS
// IsUniqueViolation maps a duplicate-key error from this driver to true
// so insert paths can return the matching authkit sentinel without
// depending on driver internals.
IsUniqueViolation(err error) bool
// Placeholder returns the placeholder for parameter index n (1-based)
// in this dialect — "$1" for Postgres, "?" for MySQL/SQLite.
Placeholder(n int) string
// PlaceholderList returns a comma-separated placeholder list for `count`
// parameters starting at position `start` (1-based), suitable for
// dynamic IN-clauses. The second return is the same length, prefilled
// with nil so the caller can append actual values.
PlaceholderList(start, count int) string
}
// Queries is the full set of statement templates authkit issues. Field names
// match the store method that consumes the query.
type Queries struct {
// users
CreateUser string
GetUserByID string
GetUserByEmail string
UpdateUser string
DeleteUser string
SetPassword string
SetEmailVerified string
BumpSessionVersion string
IncrementFailedLogins string
ResetFailedLogins string
// sessions
CreateSession string
GetSession string
TouchSession string
DeleteSession string
DeleteUserSessions string
DeleteExpiredSessions string
// tokens
CreateToken string
ConsumeToken string
GetToken string
DeleteByChain string
DeleteExpiredTokens string
// service keys
CreateServiceKey string
GetServiceKey string
ListServiceKeysByOwner string
TouchServiceKey string
RevokeServiceKey string
// roles
CreateRole string
GetRoleByID string
GetRoleByName string
ListRoles string
DeleteRole string
AssignRoleToUser string
RemoveRoleFromUser string
GetUserRoles string
// HasAnyRole is built at call time because the placeholder count varies.
// permissions
CreatePermission string
GetPermissionByID string
GetPermissionByName string
ListPermissions string
DeletePermission string
AssignPermissionToRole string
RemovePermissionFromRole string
GetRolePermissions string
GetUserPermissions string
// migrations
CreateMigrationsTable string
SelectAppliedVersions string
InsertAppliedVersion string
}

View file

@ -1,33 +0,0 @@
package postgres
import (
"errors"
"github.com/jackc/pgx/v5/pgconn"
)
// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib
// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError.
// lib/pq uses *pq.Error which has a Code field of the same value.
const pgUniqueViolation = "23505"
// isUniqueViolation inspects err for a Postgres unique-violation, regardless
// of which driver registered the connection. We match on either the pgx
// error type or any error implementing a Code() string method (lib/pq's
// pq.Error has SQLState and Code fields; we check via reflection-free
// duck-typing through an interface).
func isUniqueViolation(err error) bool {
if err == nil {
return false
}
var pgxErr *pgconn.PgError
if errors.As(err, &pgxErr) {
return pgxErr.Code == pgUniqueViolation
}
type sqlStater interface{ SQLState() string }
var s sqlStater
if errors.As(err, &s) {
return s.SQLState() == pgUniqueViolation
}
return false
}

View file

@ -1,99 +0,0 @@
-- 0001_init.sql
-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the
-- library can be embedded in an existing application database. Each
-- migration owns its own transaction and inserts its version row at the
-- bottom; the runner only orchestrates file discovery and concurrency.
BEGIN;
CREATE TABLE IF NOT EXISTS authkit_schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_users (
id UUID PRIMARY KEY,
email TEXT NOT NULL,
email_normalized TEXT NOT NULL,
email_verified_at TIMESTAMPTZ,
password_hash TEXT,
session_version INTEGER NOT NULL DEFAULT 0,
failed_logins INTEGER NOT NULL DEFAULT 0,
last_login_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq
ON authkit_users (email_normalized);
CREATE TABLE IF NOT EXISTS authkit_sessions (
id_hash BYTEA PRIMARY KEY,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
user_agent TEXT NOT NULL DEFAULT '',
ip TEXT,
created_at TIMESTAMPTZ NOT NULL,
last_seen_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id);
CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at);
CREATE TABLE IF NOT EXISTS authkit_tokens (
hash BYTEA NOT NULL,
kind TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
chain_id TEXT,
consumed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (kind, hash)
);
CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id);
CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at);
CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx
ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL;
CREATE TABLE IF NOT EXISTS authkit_api_keys (
id_hash BYTEA PRIMARY KEY,
owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
name TEXT NOT NULL,
abilities JSONB NOT NULL DEFAULT '[]'::jsonb,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id);
CREATE TABLE IF NOT EXISTS authkit_roles (
id UUID PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_permissions (
id UUID PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_role_permissions (
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
PRIMARY KEY (role_id, permission_id)
);
CREATE TABLE IF NOT EXISTS authkit_user_roles (
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
granted_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (user_id, role_id)
);
CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id);
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now())
ON CONFLICT (version) DO NOTHING;
COMMIT;

View file

@ -1,27 +0,0 @@
-- 0002_service_keys.sql
-- Adds owner-agnostic service tokens. Unlike authkit_api_keys, owner_id is
-- intentionally NOT FK-constrained: consumers manage their own cascades, and
-- authkit has no opinion on what "owner" means here (application id, tenant
-- id, etc.).
BEGIN;
CREATE TABLE IF NOT EXISTS authkit_service_keys (
id_hash BYTEA PRIMARY KEY,
owner_id UUID NOT NULL,
owner_kind TEXT NOT NULL,
name TEXT NOT NULL,
abilities JSONB NOT NULL DEFAULT '[]'::jsonb,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS authkit_service_keys_owner_idx
ON authkit_service_keys(owner_kind, owner_id);
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0002_service_keys', now())
ON CONFLICT (version) DO NOTHING;
COMMIT;

View file

@ -1,13 +0,0 @@
-- 0003_drop_api_keys.sql
-- Drops the user-owned API key table. After this migration only service
-- tokens carry abilities; user-owned credentials (sessions, JWTs,
-- magic-links) prove identity, with permissions resolved via RBAC.
BEGIN;
DROP TABLE IF EXISTS authkit_api_keys CASCADE;
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0003_drop_api_keys', now())
ON CONFLICT (version) DO NOTHING;
COMMIT;

View file

@ -1,270 +0,0 @@
// Package postgres is the Postgres dialect for authkit/sqlstore. Importing
// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"`
// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`.
package postgres
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
"strings"
"git.juancwu.dev/juancwu/authkit/sqlstore"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// Dialect implements sqlstore.Dialect for Postgres. The zero value is the
// only required form; New() returns a pointer for clarity at call sites.
type Dialect struct{}
// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate.
func New() *Dialect { return &Dialect{} }
func (Dialect) Name() string { return "postgres" }
// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable
// across rollouts and unlikely to clash with caller advisory locks.
const advisoryLockKey int64 = 0x617574686b6974
func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error {
// Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a
// hook for future migration prerequisites.
return nil
}
func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) {
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil {
return nil, err
}
release := func() {
if _, err := conn.ExecContext(context.Background(),
"SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil {
log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err)
}
}
return release, nil
}
func (Dialect) Migrations() fs.FS {
sub, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
// migrationsFS is statically populated; this can only fail if the
// embed directive is removed, which would be a build-time error.
panic(err)
}
return sub
}
func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) }
func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) }
// PlaceholderList renders a comma-separated `$start,$start+1,...` list of
// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion.
func (Dialect) PlaceholderList(start, count int) string {
if count <= 0 {
return ""
}
var b strings.Builder
for i := 0; i < count; i++ {
if i > 0 {
b.WriteByte(',')
}
fmt.Fprintf(&b, "$%d", start+i)
}
return b.String()
}
// BuildQueries renders every query authkit issues, with table identifiers
// taken from s and `?` placeholders rewritten to `$N`. Identifiers are
// already validated by Schema.Validate — this is interpolated with
// fmt.Sprintf, so the validation gate is load-bearing.
func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries {
t := s.Tables
q := sqlstore.Queries{
// users
CreateUser: `INSERT INTO ` + t.Users + `
(id, email, email_normalized, email_verified_at, password_hash,
session_version, failed_logins, last_login_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
GetUserByID: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, failed_logins, last_login_at,
created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`,
GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, failed_logins, last_login_at,
created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`,
UpdateUser: `UPDATE ` + t.Users + ` SET
email = ?, email_normalized = ?, email_verified_at = ?,
password_hash = ?, session_version = ?, failed_logins = ?,
last_login_at = ?, updated_at = ?
WHERE id = ?`,
DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`,
SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`,
SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`,
BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`,
IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`,
ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`,
// sessions
CreateSession: `INSERT INTO ` + t.Sessions + `
(id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at
FROM ` + t.Sessions + ` WHERE id_hash = ?`,
TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`,
DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`,
DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`,
DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`,
// tokens
CreateToken: `INSERT INTO ` + t.Tokens + `
(hash, kind, user_id, chain_id, consumed_at, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
ConsumeToken: `UPDATE ` + t.Tokens + `
SET consumed_at = ?
WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ?
RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`,
GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at
FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`,
DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`,
DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`,
// service keys
CreateServiceKey: `INSERT INTO ` + t.ServiceKeys + `
(id_hash, owner_id, owner_kind, name, abilities, last_used_at, created_at, expires_at, revoked_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
GetServiceKey: `SELECT id_hash, owner_id, owner_kind, name, abilities, last_used_at,
created_at, expires_at, revoked_at
FROM ` + t.ServiceKeys + ` WHERE id_hash = ?`,
ListServiceKeysByOwner: `SELECT id_hash, owner_id, owner_kind, name, abilities, last_used_at,
created_at, expires_at, revoked_at
FROM ` + t.ServiceKeys + ` WHERE owner_kind = ? AND owner_id = ? ORDER BY created_at DESC`,
TouchServiceKey: `UPDATE ` + t.ServiceKeys + ` SET last_used_at = ? WHERE id_hash = ?`,
RevokeServiceKey: `UPDATE ` + t.ServiceKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`,
// roles
CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`,
GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`,
ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`,
DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`,
AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`,
RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`,
GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at
FROM ` + t.Roles + ` r
JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id
WHERE ur.user_id = ? ORDER BY r.name`,
// permissions
CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`,
GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`,
ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`,
DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`,
AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?)
ON CONFLICT DO NOTHING`,
RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`,
GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at
FROM ` + t.Permissions + ` p
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
WHERE rp.role_id = ? ORDER BY p.name`,
GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at
FROM ` + t.Permissions + ` p
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id
WHERE ur.user_id = ? ORDER BY p.name`,
// migrations
CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
)`,
SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations,
InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`,
}
// Rewrite `?` placeholders to `$N`. Each query is independent; numbering
// resets per query.
q.CreateUser = rebind(q.CreateUser)
q.GetUserByID = rebind(q.GetUserByID)
q.GetUserByEmail = rebind(q.GetUserByEmail)
q.UpdateUser = rebind(q.UpdateUser)
q.DeleteUser = rebind(q.DeleteUser)
q.SetPassword = rebind(q.SetPassword)
q.SetEmailVerified = rebind(q.SetEmailVerified)
q.BumpSessionVersion = rebind(q.BumpSessionVersion)
q.IncrementFailedLogins = rebind(q.IncrementFailedLogins)
q.ResetFailedLogins = rebind(q.ResetFailedLogins)
q.CreateSession = rebind(q.CreateSession)
q.GetSession = rebind(q.GetSession)
q.TouchSession = rebind(q.TouchSession)
q.DeleteSession = rebind(q.DeleteSession)
q.DeleteUserSessions = rebind(q.DeleteUserSessions)
q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions)
q.CreateToken = rebind(q.CreateToken)
q.ConsumeToken = rebind(q.ConsumeToken)
q.GetToken = rebind(q.GetToken)
q.DeleteByChain = rebind(q.DeleteByChain)
q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens)
q.CreateServiceKey = rebind(q.CreateServiceKey)
q.GetServiceKey = rebind(q.GetServiceKey)
q.ListServiceKeysByOwner = rebind(q.ListServiceKeysByOwner)
q.TouchServiceKey = rebind(q.TouchServiceKey)
q.RevokeServiceKey = rebind(q.RevokeServiceKey)
q.CreateRole = rebind(q.CreateRole)
q.GetRoleByID = rebind(q.GetRoleByID)
q.GetRoleByName = rebind(q.GetRoleByName)
q.ListRoles = rebind(q.ListRoles)
q.DeleteRole = rebind(q.DeleteRole)
q.AssignRoleToUser = rebind(q.AssignRoleToUser)
q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser)
q.GetUserRoles = rebind(q.GetUserRoles)
q.CreatePermission = rebind(q.CreatePermission)
q.GetPermissionByID = rebind(q.GetPermissionByID)
q.GetPermissionByName = rebind(q.GetPermissionByName)
q.ListPermissions = rebind(q.ListPermissions)
q.DeletePermission = rebind(q.DeletePermission)
q.AssignPermissionToRole = rebind(q.AssignPermissionToRole)
q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole)
q.GetRolePermissions = rebind(q.GetRolePermissions)
q.GetUserPermissions = rebind(q.GetUserPermissions)
q.SelectAppliedVersions = rebind(q.SelectAppliedVersions)
q.InsertAppliedVersion = rebind(q.InsertAppliedVersion)
// CreateMigrationsTable contains no parameters.
return q
}
// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order.
// Our query strings contain no string literals that include `?` (verified
// by inspection of every query in BuildQueries); a literal-aware rewriter
// would be more defensive but is not needed for v1.
func rebind(s string) string {
var b strings.Builder
b.Grow(len(s) + 16)
n := 1
for i := 0; i < len(s); i++ {
c := s[i]
if c == '?' {
fmt.Fprintf(&b, "$%d", n)
n++
continue
}
b.WriteByte(c)
}
return b.String()
}
// Compile-time interface compliance check.
var _ sqlstore.Dialect = (*Dialect)(nil)

View file

@ -1,116 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"io/fs"
"sort"
"strings"
"time"
"git.juancwu.dev/juancwu/errx"
)
// Migrate applies every embedded migration the dialect ships that has not
// yet been recorded in the schema-migrations table. It is safe to call
// repeatedly and concurrently across processes — the dialect's session
// lock serialises rollouts.
//
// Each migration .sql file is responsible for owning its own
// BEGIN/COMMIT and inserting a row into the schema-migrations table on
// success. The runner only handles file discovery, version tracking, and
// concurrency.
func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error {
const op = "authkit.sqlstore.Migrate"
if db == nil {
return errx.New(op, "db is required")
}
if dialect == nil {
return errx.New(op, "dialect is required")
}
if err := schema.Validate(); err != nil {
return errx.Wrap(op, err)
}
if err := dialect.Bootstrap(ctx, db); err != nil {
return errx.Wrap(op, err)
}
conn, err := db.Conn(ctx)
if err != nil {
return errx.Wrap(op, err)
}
defer conn.Close()
release, err := dialect.AcquireMigrationLock(ctx, conn)
if err != nil {
return errx.Wrap(op, err)
}
defer release()
q := dialect.BuildQueries(schema)
if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil {
return errx.Wrap(op, err)
}
applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions)
if err != nil {
return errx.Wrap(op, err)
}
migs := dialect.Migrations()
files, err := fs.ReadDir(migs, ".")
if err != nil {
return errx.Wrap(op, err)
}
names := make([]string, 0, len(files))
for _, f := range files {
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
names = append(names, f.Name())
}
}
sort.Strings(names)
for _, name := range names {
version := strings.TrimSuffix(name, ".sql")
if _, ok := applied[version]; ok {
continue
}
body, err := fs.ReadFile(migs, name)
if err != nil {
return errx.Wrapf(op, err, "read %s", name)
}
if _, err := conn.ExecContext(ctx, string(body)); err != nil {
return errx.Wrapf(op, err, "apply %s", version)
}
}
return nil
}
// applyVersionRow is intentionally not exposed: migration files own their
// own version-row insert. We keep this helper around in case a dialect ever
// needs to backfill versions from outside a migration body — currently
// unused.
var _ = applyVersionRow
func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error {
_, err := conn.ExecContext(ctx, insertQ, version, at)
return err
}
func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) {
rows, err := conn.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer rows.Close()
out := make(map[string]struct{})
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
return nil, err
}
out[v] = struct{}{}
}
return out, rows.Err()
}

View file

@ -1,301 +0,0 @@
package sqlstore
import (
"context"
"fmt"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type roleStore struct{ storeBase }
type permissionStore struct{ storeBase }
// ----- roleStore ------------------------------------------------------------
func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error {
const op = "authkit.sqlstore.RoleStore.CreateRole"
if r.ID == uuid.Nil {
r.ID = uuid.New()
}
if r.CreatedAt.IsZero() {
r.CreatedAt = time.Now().UTC()
}
if _, err := s.db.ExecContext(ctx, s.q.CreateRole,
uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrapf(op, err, "role %q already exists", r.Name)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetRoleByID"
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
}
return r, nil
}
func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetRoleByName"
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
}
return r, nil
}
func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.ListRoles"
rows, err := s.db.QueryContext(ctx, s.q.ListRoles)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.DeleteRole"
tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrRoleNotFound)
}
return nil
}
func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.AssignRoleToUser"
if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser,
uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser"
if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser,
uuidArg(userID), uuidArg(roleID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetUserRoles"
rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
// HasAnyRole builds the IN-clause at call time because the placeholder count
// depends on len(names). Identifier substitution comes from the validated
// Schema; values are bound, never interpolated.
func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
const op = "authkit.sqlstore.RoleStore.HasAnyRole"
if len(names) == 0 {
return false, nil
}
// Placeholder $1 is the user_id; $2..$N+1 cover the names slice.
listSQL := s.d.PlaceholderList(2, len(names))
q := fmt.Sprintf(`SELECT EXISTS (
SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id
WHERE ur.user_id = %s AND r.name IN (%s)
)`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL)
args := make([]any, 0, 1+len(names))
args = append(args, uuidArg(userID))
for _, n := range names {
args = append(args, n)
}
var ok bool
if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// ----- permissionStore ------------------------------------------------------
func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error {
const op = "authkit.sqlstore.PermissionStore.CreatePermission"
if p.ID == uuid.Nil {
p.ID = uuid.New()
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now().UTC()
}
if _, err := s.db.ExecContext(ctx, s.q.CreatePermission,
uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrapf(op, err, "permission %q already exists", p.Name)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetPermissionByID"
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
}
return p, nil
}
func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetPermissionByName"
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
}
return p, nil
}
func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.ListPermissions"
rows, err := s.db.QueryContext(ctx, s.q.ListPermissions)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.DeletePermission"
tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrPermissionNotFound)
}
return nil
}
func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole"
if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole"
if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetRolePermissions"
rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetUserPermissions"
rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func scanRole(row rowScanner) (*authkit.Role, error) {
var (
r authkit.Role
idStr string
)
if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
r.ID = id
return &r, nil
}
func scanPermission(row rowScanner) (*authkit.Permission, error) {
var (
p authkit.Permission
idStr string
)
if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
p.ID = id
return &p, nil
}

View file

@ -1,86 +0,0 @@
package sqlstore
import (
"database/sql"
"net/netip"
"time"
"github.com/google/uuid"
)
// rowScanner is the lowest-common-denominator interface satisfied by both
// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops
// uniformly.
type rowScanner interface {
Scan(dest ...any) error
}
// nullableTime returns nil when t is the zero time, otherwise &t. Callers
// should usually accept *time.Time on the model side and bind via this.
func nullableTime(t *time.Time) any {
if t == nil {
return nil
}
return *t
}
// nullableString turns "" into nil so columns store NULL rather than an
// empty string. Used for password hashes and similar optional text columns.
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}
// nullableAddrString returns the string form of a netip.Addr when valid, or
// nil to bind as SQL NULL. Pairs with scanAddr.
func nullableAddrString(a netip.Addr) any {
if !a.IsValid() {
return nil
}
return a.String()
}
// scanAddr parses a *string column into a netip.Addr. The zero Addr value is
// returned when the column was NULL or empty.
func scanAddr(s *string) (netip.Addr, error) {
if s == nil || *s == "" {
return netip.Addr{}, nil
}
return netip.ParseAddr(*s)
}
// uuidArg returns the canonical string form of a UUID for binding. Every
// store uses this rather than passing uuid.UUID directly to keep behaviour
// identical across drivers (some accept driver.Valuer, some don't).
func uuidArg(id uuid.UUID) any { return id.String() }
// scanUUID reads a string column and parses it back to a uuid.UUID.
func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }
// chainArg returns either a *string or nil for binding the chain_id column.
func chainArg(c *string) any {
if c == nil {
return nil
}
return *c
}
// scanNullStringPtr converts sql.NullString to *string for the model.
func scanNullStringPtr(ns sql.NullString) *string {
if !ns.Valid {
return nil
}
v := ns.String
return &v
}
// scanNullTimePtr converts sql.NullTime to *time.Time for the model.
func scanNullTimePtr(nt sql.NullTime) *time.Time {
if !nt.Valid {
return nil
}
t := nt.Time
return &t
}

View file

@ -1,78 +0,0 @@
package sqlstore
import (
"regexp"
"git.juancwu.dev/juancwu/errx"
)
// Schema lets consumers map authkit storage to their own table names.
// Column overrides are intentionally not present in v1 — adding them later
// is purely additive.
type Schema struct {
Tables Tables
}
// Tables is the per-table identifier override set. Every field must be a
// valid unquoted SQL identifier (matching identifierRE). Validation runs at
// New() and Migrate() time so SQL injection through Schema is impossible
// past that gate.
type Tables struct {
Users string
Sessions string
Tokens string
ServiceKeys string
Roles string
Permissions string
UserRoles string
RolePermissions string
SchemaMigrations string
}
// DefaultSchema returns the stock authkit_* names used by the embedded
// migration files.
func DefaultSchema() Schema {
return Schema{Tables: Tables{
Users: "authkit_users",
Sessions: "authkit_sessions",
Tokens: "authkit_tokens",
ServiceKeys: "authkit_service_keys",
Roles: "authkit_roles",
Permissions: "authkit_permissions",
UserRoles: "authkit_user_roles",
RolePermissions: "authkit_role_permissions",
SchemaMigrations: "authkit_schema_migrations",
}}
}
// identifierRE matches the safe ASCII identifier subset shared by Postgres,
// MySQL and SQLite when not quoted. Anything outside this set is rejected
// rather than escaped — Schema is not the place to support exotic names.
var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// Validate ensures every Schema.Tables field is a non-empty, safe identifier.
func (s Schema) Validate() error {
const op = "authkit.sqlstore.Schema.Validate"
checks := []struct {
field, value string
}{
{"Users", s.Tables.Users},
{"Sessions", s.Tables.Sessions},
{"Tokens", s.Tables.Tokens},
{"ServiceKeys", s.Tables.ServiceKeys},
{"Roles", s.Tables.Roles},
{"Permissions", s.Tables.Permissions},
{"UserRoles", s.Tables.UserRoles},
{"RolePermissions", s.Tables.RolePermissions},
{"SchemaMigrations", s.Tables.SchemaMigrations},
}
for _, c := range checks {
if c.value == "" {
return errx.Newf(op, "Schema.Tables.%s is empty", c.field)
}
if !identifierRE.MatchString(c.value) {
return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value)
}
}
return nil
}

View file

@ -1,116 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"encoding/json"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type serviceKeyStore struct{ storeBase }
func (s *serviceKeyStore) CreateServiceKey(ctx context.Context, k *authkit.ServiceKey) error {
const op = "authkit.sqlstore.ServiceKeyStore.CreateServiceKey"
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now().UTC()
}
if k.Abilities == nil {
k.Abilities = []string{}
}
abilities, err := json.Marshal(k.Abilities)
if err != nil {
return errx.Wrap(op, err)
}
_, err = s.db.ExecContext(ctx, s.q.CreateServiceKey,
k.IDHash, uuidArg(k.OwnerID), k.OwnerKind, k.Name, abilities,
nullableTime(k.LastUsedAt), k.CreatedAt,
nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt))
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *serviceKeyStore) GetServiceKey(ctx context.Context, idHash []byte) (*authkit.ServiceKey, error) {
const op = "authkit.sqlstore.ServiceKeyStore.GetServiceKey"
k, err := scanServiceKey(s.db.QueryRowContext(ctx, s.q.GetServiceKey, idHash))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrServiceKeyInvalid))
}
return k, nil
}
func (s *serviceKeyStore) ListServiceKeysByOwner(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*authkit.ServiceKey, error) {
const op = "authkit.sqlstore.ServiceKeyStore.ListServiceKeysByOwner"
rows, err := s.db.QueryContext(ctx, s.q.ListServiceKeysByOwner, ownerKind, uuidArg(ownerID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.ServiceKey
for rows.Next() {
k, err := scanServiceKey(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, k)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *serviceKeyStore) TouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.sqlstore.ServiceKeyStore.TouchServiceKey"
if _, err := s.db.ExecContext(ctx, s.q.TouchServiceKey, at, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *serviceKeyStore) RevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.sqlstore.ServiceKeyStore.RevokeServiceKey"
tag, err := s.db.ExecContext(ctx, s.q.RevokeServiceKey, at, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrServiceKeyInvalid)
}
return nil
}
func scanServiceKey(row rowScanner) (*authkit.ServiceKey, error) {
var (
k authkit.ServiceKey
ownerIDStr string
abilitiesRaw []byte
lastUsed sql.NullTime
expires sql.NullTime
revoked sql.NullTime
)
if err := row.Scan(&k.IDHash, &ownerIDStr, &k.OwnerKind, &k.Name, &abilitiesRaw,
&lastUsed, &k.CreatedAt, &expires, &revoked); err != nil {
return nil, err
}
owner, err := scanUUID(ownerIDStr)
if err != nil {
return nil, err
}
k.OwnerID = owner
if len(abilitiesRaw) > 0 {
if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil {
return nil, err
}
}
if k.Abilities == nil {
k.Abilities = []string{}
}
k.LastUsedAt = scanNullTimePtr(lastUsed)
k.ExpiresAt = scanNullTimePtr(expires)
k.RevokedAt = scanNullTimePtr(revoked)
return &k, nil
}

View file

@ -1,98 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type sessionStore struct{ storeBase }
func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error {
const op = "authkit.sqlstore.SessionStore.CreateSession"
now := time.Now().UTC()
if ses.CreatedAt.IsZero() {
ses.CreatedAt = now
}
if ses.LastSeenAt.IsZero() {
ses.LastSeenAt = ses.CreatedAt
}
_, err := s.db.ExecContext(ctx, s.q.CreateSession,
ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP),
ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) {
const op = "authkit.sqlstore.SessionStore.GetSession"
var (
ses authkit.Session
uidStr string
ipStr sql.NullString
)
err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan(
&ses.IDHash, &uidStr, &ses.UserAgent, &ipStr,
&ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid))
}
uid, err := scanUUID(uidStr)
if err != nil {
return nil, errx.Wrap(op, err)
}
ses.UserID = uid
if ipStr.Valid {
addr, err := scanAddr(&ipStr.String)
if err != nil {
return nil, errx.Wrap(op, err)
}
ses.IP = addr
}
return &ses, nil
}
func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error {
const op = "authkit.sqlstore.SessionStore.TouchSession"
tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrSessionInvalid)
}
return nil
}
func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error {
const op = "authkit.sqlstore.SessionStore.DeleteSession"
if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.sqlstore.SessionStore.DeleteUserSessions"
if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.sqlstore.SessionStore.DeleteExpired"
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}

View file

@ -1,59 +0,0 @@
// Package sqlstore provides database/sql-backed implementations of every
// authkit store interface. It works with any driver (pgx-stdlib, lib/pq,
// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via
// Schema. Driver-specific behaviour lives in a Dialect; v1 ships
// dialect/postgres.
package sqlstore
import (
"database/sql"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
)
// Stores bundles every store implementation produced by New, ready to drop
// into authkit.Deps.
type Stores struct {
Users authkit.UserStore
Sessions authkit.SessionStore
Tokens authkit.TokenStore
ServiceKeys authkit.ServiceKeyStore
Roles authkit.RoleStore
Permissions authkit.PermissionStore
}
// New validates the Schema, builds the dialect's templated queries, and
// returns store implementations bound to db.
func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) {
const op = "authkit.sqlstore.New"
if db == nil {
return nil, errx.New(op, "db is required")
}
if dialect == nil {
return nil, errx.New(op, "dialect is required")
}
if err := schema.Validate(); err != nil {
return nil, errx.Wrap(op, err)
}
q := dialect.BuildQueries(schema)
base := storeBase{db: db, q: q, d: dialect, s: schema}
return &Stores{
Users: &userStore{storeBase: base},
Sessions: &sessionStore{storeBase: base},
Tokens: &tokenStore{storeBase: base},
ServiceKeys: &serviceKeyStore{storeBase: base},
Roles: &roleStore{storeBase: base},
Permissions: &permissionStore{storeBase: base},
}, nil
}
// storeBase carries the shared dependencies every store needs. Embedded into
// each concrete store struct.
type storeBase struct {
db *sql.DB
q Queries
d Dialect
s Schema
}

View file

@ -1,284 +0,0 @@
package sqlstore_test
// Integration tests against a real Postgres. Skipped unless
// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by
// running Migrate against a randomly-named set of tables — no cleanup
// fixture, no external Docker dependency.
import (
"context"
"database/sql"
"errors"
"fmt"
"net/netip"
"os"
"testing"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
"git.juancwu.dev/juancwu/authkit/sqlstore"
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
"github.com/google/uuid"
_ "github.com/jackc/pgx/v5/stdlib"
)
func envURL(t *testing.T) string {
t.Helper()
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
if url == "" {
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
}
return url
}
// freshDB opens a connection, runs Migrate against the default schema, and
// schedules a teardown that drops every authkit_* table. Tests must run
// sequentially against a single database (the package's tests do, by
// default — go test serialises within a package unless t.Parallel is
// called).
func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) {
t.Helper()
url := envURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("ping: %v", err)
}
schema := sqlstore.DefaultSchema()
// Drop first to start clean — previous failed tests may have left rows.
dropAuthkitTables(t, db, schema)
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
t.Fatalf("Migrate: %v", err)
}
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
stores, err := sqlstore.New(db, pgdialect.New(), schema)
if err != nil {
t.Fatalf("sqlstore.New: %v", err)
}
auth := authkit.New(authkit.Deps{
Users: stores.Users,
Sessions: stores.Sessions,
Tokens: stores.Tokens,
ServiceKeys: stores.ServiceKeys,
Roles: stores.Roles,
Permissions: stores.Permissions,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte("integration-secret-thirty-two!!!!"),
JWTIssuer: "authkit-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: 1 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
})
return auth, db, schema
}
func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) {
t.Helper()
tables := []string{
s.Tables.UserRoles, s.Tables.RolePermissions,
s.Tables.Roles, s.Tables.Permissions,
s.Tables.ServiceKeys, s.Tables.Tokens,
s.Tables.Sessions, s.Tables.Users,
s.Tables.SchemaMigrations,
}
for _, name := range tables {
_, _ = db.ExecContext(context.Background(),
fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
}
}
func TestIntegration_MigrateIdempotent(t *testing.T) {
url := envURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
schema := sqlstore.DefaultSchema()
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
for i := 0; i < 3; i++ {
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
t.Fatalf("Migrate iter %d: %v", i, err)
}
}
}
func TestIntegration_RegisterAndLogin(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("LoginPassword: %v", err)
}
if got.ID != u.ID {
t.Fatalf("login user id mismatch")
}
if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken, got %v", err)
}
if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestIntegration_SessionLifecycle(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "s@s.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
if sess.ExpiresAt.Before(time.Now()) {
t.Fatalf("session already expired at issue")
}
if _, err := auth.AuthenticateSession(ctx, plain); err != nil {
t.Fatalf("AuthenticateSession: %v", err)
}
if err := auth.RevokeSession(ctx, plain); err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) {
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
}
}
func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "j@j.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
_, refresh1, err := auth.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
_, refresh2, err := auth.RefreshJWT(ctx, refresh1)
if err != nil {
t.Fatalf("first RefreshJWT: %v", err)
}
if refresh1 == refresh2 {
t.Fatalf("refresh token did not rotate")
}
if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) {
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
}
if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err)
}
}
func TestIntegration_ServiceKeyFlow(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
appA := uuid.New()
appB := uuid.New()
plainA1, _, err := auth.IssueServiceKey(ctx, "application", appA, "events-ingest",
[]string{"events:write"}, nil)
if err != nil {
t.Fatalf("IssueServiceKey appA #1: %v", err)
}
if _, _, err := auth.IssueServiceKey(ctx, "application", appA, "events-ingest-2", nil, nil); err != nil {
t.Fatalf("IssueServiceKey appA #2: %v", err)
}
if _, _, err := auth.IssueServiceKey(ctx, "application", appB, "billing", nil, nil); err != nil {
t.Fatalf("IssueServiceKey appB: %v", err)
}
got, err := auth.AuthenticateServiceKey(ctx, plainA1)
if err != nil {
t.Fatalf("AuthenticateServiceKey: %v", err)
}
if got.OwnerKind != "application" || got.OwnerID != appA {
t.Fatalf("owner mismatch: kind=%q id=%v", got.OwnerKind, got.OwnerID)
}
if len(got.Abilities) != 1 || got.Abilities[0] != "events:write" {
t.Fatalf("abilities mismatch: %+v", got.Abilities)
}
listA, err := auth.ListServiceKeys(ctx, "application", appA)
if err != nil {
t.Fatalf("ListServiceKeys appA: %v", err)
}
if len(listA) != 2 {
t.Fatalf("ListServiceKeys appA = %d, want 2", len(listA))
}
listB, err := auth.ListServiceKeys(ctx, "application", appB)
if err != nil {
t.Fatalf("ListServiceKeys appB: %v", err)
}
if len(listB) != 1 {
t.Fatalf("ListServiceKeys appB = %d, want 1", len(listB))
}
if err := auth.RevokeServiceKey(ctx, plainA1); err != nil {
t.Fatalf("RevokeServiceKey: %v", err)
}
if _, err := auth.AuthenticateServiceKey(ctx, plainA1); !errors.Is(err, authkit.ErrServiceKeyInvalid) {
t.Fatalf("expected ErrServiceKeyInvalid post-revoke, got %v", err)
}
}
func TestIntegration_RBAC(t *testing.T) {
auth, db, schema := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "rb@b.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
r := &authkit.Role{Name: "editor"}
if err := stores.Roles.CreateRole(ctx, r); err != nil {
t.Fatalf("CreateRole: %v", err)
}
p := &authkit.Permission{Name: "posts:write"}
if err := stores.Permissions.CreatePermission(ctx, p); err != nil {
t.Fatalf("CreatePermission: %v", err)
}
if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil {
t.Fatalf("AssignPermissionToRole: %v", err)
}
if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
ok, err := auth.HasPermission(ctx, u.ID, "posts:write")
if err != nil || !ok {
t.Fatalf("HasPermission: %v %v", ok, err)
}
ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"})
if err != nil || !ok {
t.Fatalf("HasAnyRole: %v %v", ok, err)
}
if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("RemoveRole: %v", err)
}
ok, _ = auth.HasPermission(ctx, u.ID, "posts:write")
if ok {
t.Fatalf("HasPermission should be false after RemoveRole")
}
}

View file

@ -1,92 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
)
type tokenStore struct{ storeBase }
func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error {
const op = "authkit.sqlstore.TokenStore.CreateToken"
if t.CreatedAt.IsZero() {
t.CreatedAt = time.Now().UTC()
}
_, err := s.db.ExecContext(ctx, s.q.CreateToken,
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
// ConsumeToken does the find-and-mark-consumed in a single statement so two
// concurrent callers cannot both successfully consume the same token. The
// row is returned for inspection (e.g. ChainID for refresh rotation).
func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) {
const op = "authkit.sqlstore.TokenStore.ConsumeToken"
row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
}
return t, nil
}
func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) {
const op = "authkit.sqlstore.TokenStore.GetToken"
row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
}
return t, nil
}
func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) {
const op = "authkit.sqlstore.TokenStore.DeleteByChain"
tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.sqlstore.TokenStore.DeleteExpired"
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func scanToken(row rowScanner) (*authkit.Token, error) {
var (
t authkit.Token
kind string
userIDStr string
chainID sql.NullString
consumedAt sql.NullTime
)
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
&consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil {
return nil, err
}
t.Kind = authkit.TokenKind(kind)
uid, err := scanUUID(userIDStr)
if err != nil {
return nil, err
}
t.UserID = uid
t.ChainID = scanNullStringPtr(chainID)
t.ConsumedAt = scanNullTimePtr(consumedAt)
return &t, nil
}

View file

@ -1,186 +0,0 @@
package sqlstore
import (
"context"
"database/sql"
"errors"
"strings"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type userStore struct{ storeBase }
func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error {
const op = "authkit.sqlstore.UserStore.CreateUser"
if u.ID == uuid.Nil {
u.ID = uuid.New()
}
if u.EmailNormalized == "" {
u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email))
}
now := time.Now().UTC()
if u.CreatedAt.IsZero() {
u.CreatedAt = now
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = now
}
_, err := s.db.ExecContext(ctx, s.q.CreateUser,
uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt)
if err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrap(op, authkit.ErrEmailTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) {
const op = "authkit.sqlstore.UserStore.GetUserByID"
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return u, nil
}
func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) {
const op = "authkit.sqlstore.UserStore.GetUserByEmail"
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return u, nil
}
func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error {
const op = "authkit.sqlstore.UserStore.UpdateUser"
u.UpdatedAt = time.Now().UTC()
tag, err := s.db.ExecContext(ctx, s.q.UpdateUser,
u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.UserStore.DeleteUser"
tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error {
const op = "authkit.sqlstore.UserStore.SetPassword"
tag, err := s.db.ExecContext(ctx, s.q.SetPassword,
nullableString(encodedHash), time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error {
const op = "authkit.sqlstore.UserStore.SetEmailVerified"
tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) {
const op = "authkit.sqlstore.UserStore.BumpSessionVersion"
var v int
if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion,
time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return v, nil
}
func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) {
const op = "authkit.sqlstore.UserStore.IncrementFailedLogins"
var n int
if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins,
time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return n, nil
}
func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.sqlstore.UserStore.ResetFailedLogins"
tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func scanUser(row rowScanner) (*authkit.User, error) {
var (
u authkit.User
idStr string
passwordHash sql.NullString
emailVerified sql.NullTime
lastLogin sql.NullTime
)
if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified,
&passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin,
&u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
u.ID = id
if passwordHash.Valid {
u.PasswordHash = passwordHash.String
}
u.EmailVerifiedAt = scanNullTimePtr(emailVerified)
u.LastLoginAt = scanNullTimePtr(lastLogin)
return &u, nil
}
// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so
// callers get reliable errors.Is targets through errx wrapping.
func mapNotFound(err error, sentinel error) error {
if errors.Is(err, sql.ErrNoRows) {
return sentinel
}
return err
}

94
store_abilities.go Normal file
View file

@ -0,0 +1,94 @@
package authkit
import (
"context"
"database/sql"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreateAbility(ctx context.Context, ab *Ability) error {
const op = "authkit.storeCreateAbility"
if ab.ID == uuid.Nil {
ab.ID = uuid.New()
}
if ab.CreatedAt.IsZero() {
ab.CreatedAt = a.now()
}
if _, err := a.db.ExecContext(ctx, a.q.createAbility,
uuidArg(ab.ID), ab.Slug, nullableLabel(ab.Label), ab.CreatedAt); err != nil {
if isUniqueViolation(err) {
return errx.Wrap(op, ErrSlugTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetAbilityByID(ctx context.Context, id uuid.UUID) (*Ability, error) {
const op = "authkit.storeGetAbilityByID"
ab, err := scanAbility(a.db.QueryRowContext(ctx, a.q.getAbilityByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrAbilityNotFound))
}
return ab, nil
}
func (a *Auth) storeGetAbilityBySlug(ctx context.Context, slug string) (*Ability, error) {
const op = "authkit.storeGetAbilityBySlug"
ab, err := scanAbility(a.db.QueryRowContext(ctx, a.q.getAbilityBySlug, slug))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrAbilityNotFound))
}
return ab, nil
}
func (a *Auth) storeListAbilities(ctx context.Context) ([]*Ability, error) {
const op = "authkit.storeListAbilities"
rows, err := a.db.QueryContext(ctx, a.q.listAbilities)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Ability
for rows.Next() {
ab, err := scanAbility(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, ab)
}
return out, errx.Wrap(op, rows.Err())
}
func (a *Auth) storeDeleteAbility(ctx context.Context, id uuid.UUID) error {
const op = "authkit.storeDeleteAbility"
tag, err := a.db.ExecContext(ctx, a.q.deleteAbility, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrAbilityNotFound)
}
return nil
}
func scanAbility(row rowScanner) (*Ability, error) {
var (
ab Ability
idStr string
label sql.NullString
)
if err := row.Scan(&idStr, &ab.Slug, &label, &ab.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
ab.ID = id
ab.Label = scanNullString(label)
return &ab, nil
}

40
store_errors.go Normal file
View file

@ -0,0 +1,40 @@
package authkit
import (
"errors"
"github.com/jackc/pgx/v5/pgconn"
)
const (
pgUniqueViolation = "23505"
pgUndefinedTable = "42P01"
)
// isUniqueViolation reports whether err is a Postgres unique-violation,
// regardless of which driver is registered.
func isUniqueViolation(err error) bool {
return matchesSQLState(err, pgUniqueViolation)
}
// isMissingTable reports whether err is a Postgres undefined_table error.
// Used to distinguish "schema not yet bootstrapped" from real failures.
func isMissingTable(err error) bool {
return matchesSQLState(err, pgUndefinedTable)
}
func matchesSQLState(err error, code string) bool {
if err == nil {
return false
}
var pgxErr *pgconn.PgError
if errors.As(err, &pgxErr) {
return pgxErr.Code == code
}
type sqlStater interface{ SQLState() string }
var s sqlStater
if errors.As(err, &s) {
return s.SQLState() == code
}
return false
}

122
store_migrate.go Normal file
View file

@ -0,0 +1,122 @@
package authkit
import (
"context"
"database/sql"
"embed"
"io/fs"
"log"
"sort"
"strings"
"git.juancwu.dev/juancwu/errx"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64.
// Stable across rollouts and unlikely to clash with caller advisory locks.
const advisoryLockKey int64 = 0x617574686b6974
// Migrate applies every embedded migration not yet recorded in the
// schema-migrations table. Safe to call repeatedly and concurrently across
// processes; the advisory lock serialises rollouts. Each migration owns its
// own BEGIN/COMMIT.
//
// Embedded migrations hard-code the default authkit_* names. If the consumer
// has overridden any table name, Migrate is a no-op and the consumer is
// responsible for managing DDL out-of-band.
func Migrate(ctx context.Context, db *sql.DB, schema Schema) error {
const op = "authkit.Migrate"
if db == nil {
return errx.New(op, "db is required")
}
if err := schema.Validate(); err != nil {
return errx.Wrap(op, err)
}
if !schema.isDefault() {
// Custom-named schemas: consumer owns DDL. The verifier still runs
// against the configured names (with default-name fallback) to
// confirm the tables exist and match the expected layout.
return nil
}
conn, err := db.Conn(ctx)
if err != nil {
return errx.Wrap(op, err)
}
defer conn.Close()
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil {
return errx.Wrap(op, err)
}
defer func() {
if _, err := conn.ExecContext(context.Background(),
"SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil {
log.Printf("authkit: pg_advisory_unlock failed: %v", err)
}
}()
q := buildQueries(schema.Tables)
if _, err := conn.ExecContext(ctx, q.createMigrationsTable); err != nil {
return errx.Wrap(op, err)
}
applied, err := loadAppliedVersions(ctx, conn, q.selectAppliedVersions)
if err != nil {
return errx.Wrap(op, err)
}
migs, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
return errx.Wrap(op, err)
}
files, err := fs.ReadDir(migs, ".")
if err != nil {
return errx.Wrap(op, err)
}
names := make([]string, 0, len(files))
for _, f := range files {
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
names = append(names, f.Name())
}
}
sort.Strings(names)
for _, name := range names {
version := strings.TrimSuffix(name, ".sql")
if _, ok := applied[version]; ok {
continue
}
body, err := fs.ReadFile(migs, name)
if err != nil {
return errx.Wrapf(op, err, "read %s", name)
}
if _, err := conn.ExecContext(ctx, string(body)); err != nil {
return errx.Wrapf(op, err, "apply %s", version)
}
}
return nil
}
func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) {
rows, err := conn.QueryContext(ctx, q)
if err != nil {
if isMissingTable(err) {
return map[string]struct{}{}, nil
}
return nil, err
}
defer rows.Close()
out := make(map[string]struct{})
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
return nil, err
}
out[v] = struct{}{}
}
return out, rows.Err()
}

166
store_permissions.go Normal file
View file

@ -0,0 +1,166 @@
package authkit
import (
"context"
"database/sql"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreatePermission(ctx context.Context, p *Permission) error {
const op = "authkit.storeCreatePermission"
if p.ID == uuid.Nil {
p.ID = uuid.New()
}
if p.CreatedAt.IsZero() {
p.CreatedAt = a.now()
}
if _, err := a.db.ExecContext(ctx, a.q.createPermission,
uuidArg(p.ID), p.Slug, nullableLabel(p.Label), p.CreatedAt); err != nil {
if isUniqueViolation(err) {
return errx.Wrap(op, ErrSlugTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error) {
const op = "authkit.storeGetPermissionByID"
p, err := scanPermission(a.db.QueryRowContext(ctx, a.q.getPermissionByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrPermissionNotFound))
}
return p, nil
}
func (a *Auth) storeGetPermissionBySlug(ctx context.Context, slug string) (*Permission, error) {
const op = "authkit.storeGetPermissionBySlug"
p, err := scanPermission(a.db.QueryRowContext(ctx, a.q.getPermissionBySlug, slug))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrPermissionNotFound))
}
return p, nil
}
func (a *Auth) storeListPermissions(ctx context.Context) ([]*Permission, error) {
const op = "authkit.storeListPermissions"
rows, err := a.db.QueryContext(ctx, a.q.listPermissions)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (a *Auth) storeDeletePermission(ctx context.Context, id uuid.UUID) error {
const op = "authkit.storeDeletePermission"
tag, err := a.db.ExecContext(ctx, a.q.deletePermission, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrPermissionNotFound)
}
return nil
}
func (a *Auth) storeAssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.storeAssignPermissionToRole"
if _, err := a.db.ExecContext(ctx, a.q.assignPermissionToRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeRemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.storeRemovePermissionFromRole"
if _, err := a.db.ExecContext(ctx, a.q.removePermissionFromRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error) {
const op = "authkit.storeGetRolePermissions"
rows, err := a.db.QueryContext(ctx, a.q.getRolePermissions, uuidArg(roleID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (a *Auth) storeGetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error) {
const op = "authkit.storeGetUserPermissions"
rows, err := a.db.QueryContext(ctx, a.q.getUserPermissions, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (a *Auth) storeGrantPermissionToUser(ctx context.Context, userID, permID uuid.UUID) error {
const op = "authkit.storeGrantPermissionToUser"
if _, err := a.db.ExecContext(ctx, a.q.grantPermissionToUser,
uuidArg(userID), uuidArg(permID), a.now()); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeRevokePermissionFromUser(ctx context.Context, userID, permID uuid.UUID) error {
const op = "authkit.storeRevokePermissionFromUser"
if _, err := a.db.ExecContext(ctx, a.q.revokePermissionFromUser,
uuidArg(userID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func scanPermission(row rowScanner) (*Permission, error) {
var (
p Permission
idStr string
label sql.NullString
)
if err := row.Scan(&idStr, &p.Slug, &label, &p.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
p.ID = id
p.Label = scanNullString(label)
return &p, nil
}

249
store_queries.go Normal file
View file

@ -0,0 +1,249 @@
package authkit
import (
"fmt"
"strings"
)
// queries holds every SQL string the store issues, with table identifiers
// already substituted from a validated Schema. Built once at New() to avoid
// per-call concatenation. Identifiers are interpolated via concatenation,
// safe because Schema.Validate gated them through identifierRE.
type queries struct {
// users
createUser string
getUserByID string
getUserByEmail string
updateUser string
deleteUser string
setPassword string
setEmailVerified string
bumpSessionVersion string
// sessions
createSession string
getSession string
touchSession string
deleteSession string
deleteUserSessions string
deleteExpiredSessions string
// tokens
createToken string
consumeToken string
getToken string
getOTPForUser string
decrementOTPAttempt string
consumeOTPByID string
deleteByChain string
deleteExpiredTokens string
// service keys
createServiceKey string
getServiceKey string
listServiceKeys string
touchServiceKey string
revokeServiceKey string
getServiceKeyAbilities string
insertServiceKeyAbil string
// roles
createRole string
getRoleByID string
getRoleBySlug string
listRoles string
deleteRole string
assignRoleToUser string
removeRoleFromUser string
getUserRoles string
hasAnyRolePrefix string
// permissions
createPermission string
getPermissionByID string
getPermissionBySlug string
listPermissions string
deletePermission string
assignPermissionToRole string
removePermissionFromRole string
getRolePermissions string
getUserPermissions string
// direct user permissions
grantPermissionToUser string
revokePermissionFromUser string
// abilities
createAbility string
getAbilityByID string
getAbilityBySlug string
listAbilities string
deleteAbility string
// migrations
createMigrationsTable string
selectAppliedVersions string
}
func buildQueries(t Tables) queries {
return queries{
// users
createUser: `INSERT INTO ` + t.Users + `
(id, email, email_normalized, email_verified_at, password_hash,
session_version, last_login_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
getUserByID: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, last_login_at, created_at, updated_at
FROM ` + t.Users + ` WHERE id = $1`,
getUserByEmail: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, last_login_at, created_at, updated_at
FROM ` + t.Users + ` WHERE email_normalized = $1`,
updateUser: `UPDATE ` + t.Users + ` SET
email = $1, email_normalized = $2, email_verified_at = $3,
password_hash = $4, session_version = $5,
last_login_at = $6, updated_at = $7
WHERE id = $8`,
deleteUser: `DELETE FROM ` + t.Users + ` WHERE id = $1`,
setPassword: `UPDATE ` + t.Users + ` SET password_hash = $1, updated_at = $2 WHERE id = $3`,
setEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = $1, updated_at = $2 WHERE id = $3`,
bumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = $1 WHERE id = $2 RETURNING session_version`,
// sessions
createSession: `INSERT INTO ` + t.Sessions + `
(id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
getSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at
FROM ` + t.Sessions + ` WHERE id_hash = $1`,
touchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = $1, expires_at = $2 WHERE id_hash = $3`,
deleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = $1`,
deleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = $1`,
deleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= $1`,
// tokens
createToken: `INSERT INTO ` + t.Tokens + `
(hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
consumeToken: `UPDATE ` + t.Tokens + `
SET consumed_at = $1
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $4
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
getToken: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
FROM ` + t.Tokens + ` WHERE kind = $1 AND hash = $2`,
// getOTPForUser returns the most recent unconsumed, unexpired OTP for
// the user, used to verify a code by hash-comparing client input.
getOTPForUser: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
FROM ` + t.Tokens + `
WHERE kind = $1 AND user_id = $2 AND consumed_at IS NULL AND expires_at > $3
ORDER BY created_at DESC LIMIT 1`,
// decrementOTPAttempt drops attempts_remaining by 1 and consumes the
// row when it hits zero. Used after a wrong-code submission.
decrementOTPAttempt: `UPDATE ` + t.Tokens + `
SET attempts_remaining = GREATEST(COALESCE(attempts_remaining, 0) - 1, 0),
consumed_at = CASE WHEN COALESCE(attempts_remaining, 0) - 1 <= 0 THEN $1 ELSE consumed_at END
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1
RETURNING attempts_remaining`,
// consumeOTPByID is the success path: mark the matched OTP consumed.
consumeOTPByID: `UPDATE ` + t.Tokens + `
SET consumed_at = $1
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
deleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = $1`,
deleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= $1`,
// service keys
createServiceKey: `INSERT INTO ` + t.ServiceKeys + `
(id_hash, name, last_used_at, created_at, expires_at, revoked_at)
VALUES ($1, $2, $3, $4, $5, $6)`,
getServiceKey: `SELECT id_hash, name, last_used_at, created_at, expires_at, revoked_at
FROM ` + t.ServiceKeys + ` WHERE id_hash = $1`,
listServiceKeys: `SELECT id_hash, name, last_used_at, created_at, expires_at, revoked_at
FROM ` + t.ServiceKeys + ` ORDER BY created_at DESC`,
touchServiceKey: `UPDATE ` + t.ServiceKeys + ` SET last_used_at = $1 WHERE id_hash = $2`,
revokeServiceKey: `UPDATE ` + t.ServiceKeys + ` SET revoked_at = $1 WHERE id_hash = $2 AND revoked_at IS NULL`,
getServiceKeyAbilities: `SELECT a.slug FROM ` + t.Abilities + ` a
JOIN ` + t.ServiceKeyAbilities + ` ska ON ska.ability_id = a.id
WHERE ska.service_key_id_hash = $1 ORDER BY a.slug`,
insertServiceKeyAbil: `INSERT INTO ` + t.ServiceKeyAbilities + `
(service_key_id_hash, ability_id, granted_at) VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING`,
// roles
createRole: `INSERT INTO ` + t.Roles + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`,
getRoleByID: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` WHERE id = $1`,
getRoleBySlug: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` WHERE slug = $1`,
listRoles: `SELECT id, slug, label, created_at FROM ` + t.Roles + ` ORDER BY slug`,
deleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = $1`,
assignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at)
VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
removeRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = $1 AND role_id = $2`,
getUserRoles: `SELECT r.id, r.slug, r.label, r.created_at
FROM ` + t.Roles + ` r JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id
WHERE ur.user_id = $1 ORDER BY r.slug`,
hasAnyRolePrefix: `SELECT EXISTS (
SELECT 1 FROM ` + t.UserRoles + ` ur JOIN ` + t.Roles + ` r ON r.id = ur.role_id
WHERE ur.user_id = $1 AND r.slug IN (`,
// permissions
createPermission: `INSERT INTO ` + t.Permissions + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`,
getPermissionByID: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` WHERE id = $1`,
getPermissionBySlug: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` WHERE slug = $1`,
listPermissions: `SELECT id, slug, label, created_at FROM ` + t.Permissions + ` ORDER BY slug`,
deletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = $1`,
assignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id)
VALUES ($1, $2) ON CONFLICT DO NOTHING`,
removePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = $1 AND permission_id = $2`,
getRolePermissions: `SELECT p.id, p.slug, p.label, p.created_at
FROM ` + t.Permissions + ` p JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
WHERE rp.role_id = $1 ORDER BY p.slug`,
// UNION of role-derived and direct user permissions.
getUserPermissions: `SELECT DISTINCT p.id, p.slug, p.label, p.created_at FROM ` + t.Permissions + ` p
WHERE p.id IN (
SELECT rp.permission_id FROM ` + t.RolePermissions + ` rp
JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id
WHERE ur.user_id = $1
UNION
SELECT up.permission_id FROM ` + t.UserPermissions + ` up
WHERE up.user_id = $1
) ORDER BY p.slug`,
// direct user permissions
grantPermissionToUser: `INSERT INTO ` + t.UserPermissions + `
(user_id, permission_id, granted_at) VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING`,
revokePermissionFromUser: `DELETE FROM ` + t.UserPermissions + `
WHERE user_id = $1 AND permission_id = $2`,
// abilities
createAbility: `INSERT INTO ` + t.Abilities + ` (id, slug, label, created_at) VALUES ($1, $2, $3, $4)`,
getAbilityByID: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` WHERE id = $1`,
getAbilityBySlug: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` WHERE slug = $1`,
listAbilities: `SELECT id, slug, label, created_at FROM ` + t.Abilities + ` ORDER BY slug`,
deleteAbility: `DELETE FROM ` + t.Abilities + ` WHERE id = $1`,
// migrations
createMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
)`,
selectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations,
}
}
// hasAnyRoleSQL renders the dynamic IN-clause for HasAnyRole. Generated
// query is parameterized: $1 = user_id, $2..$N+1 = role slugs.
func (q queries) hasAnyRoleSQL(n int) string {
if n <= 0 {
return ""
}
var b strings.Builder
b.Grow(len(q.hasAnyRolePrefix) + 8*n + 4)
b.WriteString(q.hasAnyRolePrefix)
for i := 0; i < n; i++ {
if i > 0 {
b.WriteByte(',')
}
fmt.Fprintf(&b, "$%d", i+2)
}
b.WriteString("))")
return b.String()
}

151
store_roles.go Normal file
View file

@ -0,0 +1,151 @@
package authkit
import (
"context"
"database/sql"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreateRole(ctx context.Context, r *Role) error {
const op = "authkit.storeCreateRole"
if r.ID == uuid.Nil {
r.ID = uuid.New()
}
if r.CreatedAt.IsZero() {
r.CreatedAt = a.now()
}
if _, err := a.db.ExecContext(ctx, a.q.createRole,
uuidArg(r.ID), r.Slug, nullableLabel(r.Label), r.CreatedAt); err != nil {
if isUniqueViolation(err) {
return errx.Wrap(op, ErrSlugTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error) {
const op = "authkit.storeGetRoleByID"
r, err := scanRole(a.db.QueryRowContext(ctx, a.q.getRoleByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrRoleNotFound))
}
return r, nil
}
func (a *Auth) storeGetRoleBySlug(ctx context.Context, slug string) (*Role, error) {
const op = "authkit.storeGetRoleBySlug"
r, err := scanRole(a.db.QueryRowContext(ctx, a.q.getRoleBySlug, slug))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrRoleNotFound))
}
return r, nil
}
func (a *Auth) storeListRoles(ctx context.Context) ([]*Role, error) {
const op = "authkit.storeListRoles"
rows, err := a.db.QueryContext(ctx, a.q.listRoles)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
func (a *Auth) storeDeleteRole(ctx context.Context, id uuid.UUID) error {
const op = "authkit.storeDeleteRole"
tag, err := a.db.ExecContext(ctx, a.q.deleteRole, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrRoleNotFound)
}
return nil
}
func (a *Auth) storeAssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.storeAssignRoleToUser"
if _, err := a.db.ExecContext(ctx, a.q.assignRoleToUser,
uuidArg(userID), uuidArg(roleID), a.now()); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeRemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.storeRemoveRoleFromUser"
if _, err := a.db.ExecContext(ctx, a.q.removeRoleFromUser,
uuidArg(userID), uuidArg(roleID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error) {
const op = "authkit.storeGetUserRoles"
rows, err := a.db.QueryContext(ctx, a.q.getUserRoles, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
// storeHasAnyRole builds the IN-clause at call time because the placeholder
// count depends on len(slugs). Identifier substitution comes from the
// validated Schema; values are bound, never interpolated.
func (a *Auth) storeHasAnyRole(ctx context.Context, userID uuid.UUID, slugs []string) (bool, error) {
const op = "authkit.storeHasAnyRole"
if len(slugs) == 0 {
return false, nil
}
q := a.q.hasAnyRoleSQL(len(slugs))
args := make([]any, 0, 1+len(slugs))
args = append(args, uuidArg(userID))
for _, s := range slugs {
args = append(args, s)
}
var ok bool
if err := a.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
func scanRole(row rowScanner) (*Role, error) {
var (
r Role
idStr string
label sql.NullString
)
if err := row.Scan(&idStr, &r.Slug, &label, &r.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
r.ID = id
r.Label = scanNullString(label)
return &r, nil
}

101
store_scan.go Normal file
View file

@ -0,0 +1,101 @@
package authkit
import (
"database/sql"
"net/netip"
"time"
"github.com/google/uuid"
)
// rowScanner is satisfied by both *sql.Row and *sql.Rows so scan helpers
// serve QueryRow and Query loops uniformly.
type rowScanner interface {
Scan(dest ...any) error
}
func nullableTime(t *time.Time) any {
if t == nil {
return nil
}
return *t
}
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}
func nullableLabel(s string) any {
// Labels are user-facing display strings — empty string means "no label",
// which we store NULL for clarity.
if s == "" {
return nil
}
return s
}
func nullableAddrString(a netip.Addr) any {
if !a.IsValid() {
return nil
}
return a.String()
}
func scanAddr(s *string) (netip.Addr, error) {
if s == nil || *s == "" {
return netip.Addr{}, nil
}
return netip.ParseAddr(*s)
}
func uuidArg(id uuid.UUID) any { return id.String() }
func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }
func chainArg(c *string) any {
if c == nil {
return nil
}
return *c
}
func nullableInt(n *int) any {
if n == nil {
return nil
}
return *n
}
func scanNullStringPtr(ns sql.NullString) *string {
if !ns.Valid {
return nil
}
v := ns.String
return &v
}
func scanNullString(ns sql.NullString) string {
if !ns.Valid {
return ""
}
return ns.String
}
func scanNullTimePtr(nt sql.NullTime) *time.Time {
if !nt.Valid {
return nil
}
t := nt.Time
return &t
}
func scanNullIntPtr(ni sql.NullInt32) *int {
if !ni.Valid {
return nil
}
v := int(ni.Int32)
return &v
}

140
store_schema.go Normal file
View file

@ -0,0 +1,140 @@
package authkit
import (
"regexp"
"git.juancwu.dev/juancwu/errx"
)
// Schema lets consumers map authkit storage to their own table names. Column
// overrides are not exposed in v1 — the column set is fixed.
type Schema struct {
Tables Tables
}
// Tables holds per-table identifier overrides. Every field must be a valid
// unquoted SQL identifier (matching identifierRE). Validation runs at
// New()/Migrate() time so SQL injection through Schema is impossible past
// that gate.
type Tables struct {
Users string
Sessions string
Tokens string
ServiceKeys string
ServiceKeyAbilities string
Roles string
Permissions string
Abilities string
UserRoles string
UserPermissions string
RolePermissions string
SchemaMigrations string
}
// DefaultSchema returns the stock authkit_* names matching the embedded
// migration files.
func DefaultSchema() Schema {
return Schema{Tables: defaultTables()}
}
func defaultTables() Tables {
return Tables{
Users: "authkit_users",
Sessions: "authkit_sessions",
Tokens: "authkit_tokens",
ServiceKeys: "authkit_service_keys",
ServiceKeyAbilities: "authkit_service_key_abilities",
Roles: "authkit_roles",
Permissions: "authkit_permissions",
Abilities: "authkit_abilities",
UserRoles: "authkit_user_roles",
UserPermissions: "authkit_user_permissions",
RolePermissions: "authkit_role_permissions",
SchemaMigrations: "authkit_schema_migrations",
}
}
// identifierRE matches the safe ASCII identifier subset shared by Postgres
// when not quoted. Anything outside this set is rejected at validation time.
var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// Validate ensures every Schema.Tables field is a non-empty, safe identifier.
func (s Schema) Validate() error {
const op = "authkit.Schema.Validate"
checks := []struct {
field, value string
}{
{"Users", s.Tables.Users},
{"Sessions", s.Tables.Sessions},
{"Tokens", s.Tables.Tokens},
{"ServiceKeys", s.Tables.ServiceKeys},
{"ServiceKeyAbilities", s.Tables.ServiceKeyAbilities},
{"Roles", s.Tables.Roles},
{"Permissions", s.Tables.Permissions},
{"Abilities", s.Tables.Abilities},
{"UserRoles", s.Tables.UserRoles},
{"UserPermissions", s.Tables.UserPermissions},
{"RolePermissions", s.Tables.RolePermissions},
{"SchemaMigrations", s.Tables.SchemaMigrations},
}
for _, c := range checks {
if c.value == "" {
return errx.Newf(op, "Schema.Tables.%s is empty", c.field)
}
if !identifierRE.MatchString(c.value) {
return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value)
}
}
return nil
}
// isDefault reports whether every Schema.Tables entry matches the default
// names. Embedded migrations hard-code the defaults, so they only run
// unmodified against the default schema.
func (s Schema) isDefault() bool {
return s.Tables == defaultTables()
}
// mergeSchemaDefaults fills in any blank Tables field from the default set
// so callers can override one or two table names without having to copy
// the whole DefaultSchema structure.
func mergeSchemaDefaults(s Schema) Schema {
def := defaultTables()
if s.Tables.Users == "" {
s.Tables.Users = def.Users
}
if s.Tables.Sessions == "" {
s.Tables.Sessions = def.Sessions
}
if s.Tables.Tokens == "" {
s.Tables.Tokens = def.Tokens
}
if s.Tables.ServiceKeys == "" {
s.Tables.ServiceKeys = def.ServiceKeys
}
if s.Tables.ServiceKeyAbilities == "" {
s.Tables.ServiceKeyAbilities = def.ServiceKeyAbilities
}
if s.Tables.Roles == "" {
s.Tables.Roles = def.Roles
}
if s.Tables.Permissions == "" {
s.Tables.Permissions = def.Permissions
}
if s.Tables.Abilities == "" {
s.Tables.Abilities = def.Abilities
}
if s.Tables.UserRoles == "" {
s.Tables.UserRoles = def.UserRoles
}
if s.Tables.UserPermissions == "" {
s.Tables.UserPermissions = def.UserPermissions
}
if s.Tables.RolePermissions == "" {
s.Tables.RolePermissions = def.RolePermissions
}
if s.Tables.SchemaMigrations == "" {
s.Tables.SchemaMigrations = def.SchemaMigrations
}
return s
}

136
store_service_keys.go Normal file
View file

@ -0,0 +1,136 @@
package authkit
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreateServiceKey(ctx context.Context, k *ServiceKey, abilityIDs []uuid.UUID) error {
const op = "authkit.storeCreateServiceKey"
if k.CreatedAt.IsZero() {
k.CreatedAt = a.now()
}
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
return errx.Wrap(op, err)
}
defer func() { _ = tx.Rollback() }()
if _, err := tx.ExecContext(ctx, a.q.createServiceKey,
k.IDHash, k.Name, nullableTime(k.LastUsedAt), k.CreatedAt,
nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt)); err != nil {
return errx.Wrap(op, err)
}
now := a.now()
for _, id := range abilityIDs {
if _, err := tx.ExecContext(ctx, a.q.insertServiceKeyAbil,
k.IDHash, uuidArg(id), now); err != nil {
return errx.Wrap(op, err)
}
}
if err := tx.Commit(); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetServiceKey(ctx context.Context, idHash []byte) (*ServiceKey, error) {
const op = "authkit.storeGetServiceKey"
k, err := scanServiceKey(a.db.QueryRowContext(ctx, a.q.getServiceKey, idHash))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrServiceKeyInvalid))
}
abilities, err := a.storeServiceKeyAbilities(ctx, idHash)
if err != nil {
return nil, errx.Wrap(op, err)
}
k.Abilities = abilities
return k, nil
}
func (a *Auth) storeListServiceKeys(ctx context.Context) ([]*ServiceKey, error) {
const op = "authkit.storeListServiceKeys"
rows, err := a.db.QueryContext(ctx, a.q.listServiceKeys)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*ServiceKey
for rows.Next() {
k, err := scanServiceKey(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, k)
}
if err := rows.Err(); err != nil {
return nil, errx.Wrap(op, err)
}
for _, k := range out {
abilities, err := a.storeServiceKeyAbilities(ctx, k.IDHash)
if err != nil {
return nil, errx.Wrap(op, err)
}
k.Abilities = abilities
}
return out, nil
}
func (a *Auth) storeServiceKeyAbilities(ctx context.Context, idHash []byte) ([]string, error) {
rows, err := a.db.QueryContext(ctx, a.q.getServiceKeyAbilities, idHash)
if err != nil {
return nil, err
}
defer rows.Close()
out := []string{}
for rows.Next() {
var slug string
if err := rows.Scan(&slug); err != nil {
return nil, err
}
out = append(out, slug)
}
return out, rows.Err()
}
func (a *Auth) storeTouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.storeTouchServiceKey"
if _, err := a.db.ExecContext(ctx, a.q.touchServiceKey, at, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeRevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.storeRevokeServiceKey"
tag, err := a.db.ExecContext(ctx, a.q.revokeServiceKey, at, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrServiceKeyInvalid)
}
return nil
}
func scanServiceKey(row rowScanner) (*ServiceKey, error) {
var (
k ServiceKey
lastUsed sql.NullTime
expires sql.NullTime
revoked sql.NullTime
)
if err := row.Scan(&k.IDHash, &k.Name, &lastUsed, &k.CreatedAt,
&expires, &revoked); err != nil {
return nil, err
}
k.LastUsedAt = scanNullTimePtr(lastUsed)
k.ExpiresAt = scanNullTimePtr(expires)
k.RevokedAt = scanNullTimePtr(revoked)
return &k, nil
}

95
store_sessions.go Normal file
View file

@ -0,0 +1,95 @@
package authkit
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreateSession(ctx context.Context, s *Session) error {
const op = "authkit.storeCreateSession"
now := a.now()
if s.CreatedAt.IsZero() {
s.CreatedAt = now
}
if s.LastSeenAt.IsZero() {
s.LastSeenAt = s.CreatedAt
}
_, err := a.db.ExecContext(ctx, a.q.createSession,
s.IDHash, uuidArg(s.UserID), s.UserAgent, nullableAddrString(s.IP),
s.CreatedAt, s.LastSeenAt, s.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetSession(ctx context.Context, idHash []byte) (*Session, error) {
const op = "authkit.storeGetSession"
var (
s Session
uidStr string
ipStr sql.NullString
)
err := a.db.QueryRowContext(ctx, a.q.getSession, idHash).Scan(
&s.IDHash, &uidStr, &s.UserAgent, &ipStr,
&s.CreatedAt, &s.LastSeenAt, &s.ExpiresAt)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrSessionInvalid))
}
uid, err := scanUUID(uidStr)
if err != nil {
return nil, errx.Wrap(op, err)
}
s.UserID = uid
if ipStr.Valid {
addr, err := scanAddr(&ipStr.String)
if err != nil {
return nil, errx.Wrap(op, err)
}
s.IP = addr
}
return &s, nil
}
func (a *Auth) storeTouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error {
const op = "authkit.storeTouchSession"
tag, err := a.db.ExecContext(ctx, a.q.touchSession, lastSeenAt, newExpiresAt, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrSessionInvalid)
}
return nil
}
func (a *Auth) storeDeleteSession(ctx context.Context, idHash []byte) error {
const op = "authkit.storeDeleteSession"
if _, err := a.db.ExecContext(ctx, a.q.deleteSession, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeDeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.storeDeleteUserSessions"
if _, err := a.db.ExecContext(ctx, a.q.deleteUserSessions, uuidArg(userID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeDeleteExpiredSessions(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.storeDeleteExpiredSessions"
tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredSessions, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}

134
store_tokens.go Normal file
View file

@ -0,0 +1,134 @@
package authkit
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/errx"
)
func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error {
const op = "authkit.storeCreateToken"
if t.CreatedAt.IsZero() {
t.CreatedAt = a.now()
}
_, err := a.db.ExecContext(ctx, a.q.createToken,
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
nullableTime(t.ConsumedAt), nullableInt(t.AttemptsRemaining),
t.CreatedAt, t.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
// storeConsumeToken atomically marks the matching unexpired, unconsumed token
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
// Implementations MUST do this in one statement to prevent double-spend
// under concurrent callers.
func (a *Auth) storeConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
const op = "authkit.storeConsumeToken"
row := a.db.QueryRowContext(ctx, a.q.consumeToken, now, string(kind), hash, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
return t, nil
}
func (a *Auth) storeGetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) {
const op = "authkit.storeGetToken"
row := a.db.QueryRowContext(ctx, a.q.getToken, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
return t, nil
}
// storeGetActiveOTPForUser returns the most recent unconsumed, unexpired OTP
// row for a user. Used by ConsumeEmailOTP to verify a code by hash-comparing
// client input.
func (a *Auth) storeGetActiveOTPForUser(ctx context.Context, kind TokenKind, userID any, now time.Time) (*Token, error) {
const op = "authkit.storeGetActiveOTPForUser"
row := a.db.QueryRowContext(ctx, a.q.getOTPForUser, string(kind), userID, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
}
return t, nil
}
// storeDecrementOTPAttempt drops attempts_remaining by 1 on the matched
// (kind, hash) row, consuming it when zero. Returns the new
// attempts_remaining (0 = consumed). ErrTokenInvalid when no row matched.
func (a *Auth) storeDecrementOTPAttempt(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (int, error) {
const op = "authkit.storeDecrementOTPAttempt"
var remaining sql.NullInt32
if err := a.db.QueryRowContext(ctx, a.q.decrementOTPAttempt,
now, string(kind), hash).Scan(&remaining); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
if !remaining.Valid {
return 0, nil
}
return int(remaining.Int32), nil
}
// storeConsumeOTPByHash marks an OTP row consumed by direct hash match. Used
// on the success path of ConsumeEmailOTP.
func (a *Auth) storeConsumeOTPByHash(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
const op = "authkit.storeConsumeOTPByHash"
row := a.db.QueryRowContext(ctx, a.q.consumeOTPByID, now, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
}
return t, nil
}
func (a *Auth) storeDeleteByChain(ctx context.Context, chainID string) (int64, error) {
const op = "authkit.storeDeleteByChain"
tag, err := a.db.ExecContext(ctx, a.q.deleteByChain, chainID)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func (a *Auth) storeDeleteExpiredTokens(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.storeDeleteExpiredTokens"
tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredTokens, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func scanToken(row rowScanner) (*Token, error) {
var (
t Token
kind string
userIDStr string
chainID sql.NullString
consumedAt sql.NullTime
attempts sql.NullInt32
)
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
&consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil {
return nil, err
}
t.Kind = TokenKind(kind)
uid, err := scanUUID(userIDStr)
if err != nil {
return nil, err
}
t.UserID = uid
t.ChainID = scanNullStringPtr(chainID)
t.ConsumedAt = scanNullTimePtr(consumedAt)
t.AttemptsRemaining = scanNullIntPtr(attempts)
return &t, nil
}

158
store_users.go Normal file
View file

@ -0,0 +1,158 @@
package authkit
import (
"context"
"database/sql"
"errors"
"strings"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
func (a *Auth) storeCreateUser(ctx context.Context, u *User) error {
const op = "authkit.storeCreateUser"
if u.ID == uuid.Nil {
u.ID = uuid.New()
}
if u.EmailNormalized == "" {
u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email))
}
now := a.now()
if u.CreatedAt.IsZero() {
u.CreatedAt = now
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = now
}
_, err := a.db.ExecContext(ctx, a.q.createUser,
uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion,
nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
return errx.Wrap(op, ErrEmailTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (a *Auth) storeGetUserByID(ctx context.Context, id uuid.UUID) (*User, error) {
const op = "authkit.storeGetUserByID"
u, err := scanUser(a.db.QueryRowContext(ctx, a.q.getUserByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrUserNotFound))
}
return u, nil
}
func (a *Auth) storeGetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error) {
const op = "authkit.storeGetUserByEmail"
u, err := scanUser(a.db.QueryRowContext(ctx, a.q.getUserByEmail, normalizedEmail))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrUserNotFound))
}
return u, nil
}
func (a *Auth) storeUpdateUser(ctx context.Context, u *User) error {
const op = "authkit.storeUpdateUser"
u.UpdatedAt = a.now()
tag, err := a.db.ExecContext(ctx, a.q.updateUser,
u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion,
nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrUserNotFound)
}
return nil
}
func (a *Auth) storeDeleteUser(ctx context.Context, id uuid.UUID) error {
const op = "authkit.storeDeleteUser"
tag, err := a.db.ExecContext(ctx, a.q.deleteUser, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrUserNotFound)
}
return nil
}
func (a *Auth) storeSetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error {
const op = "authkit.storeSetPassword"
tag, err := a.db.ExecContext(ctx, a.q.setPassword,
nullableString(encodedHash), a.now(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrUserNotFound)
}
return nil
}
func (a *Auth) storeSetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error {
const op = "authkit.storeSetEmailVerified"
tag, err := a.db.ExecContext(ctx, a.q.setEmailVerified, at, a.now(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, ErrUserNotFound)
}
return nil
}
func (a *Auth) storeBumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) {
const op = "authkit.storeBumpSessionVersion"
var v int
if err := a.db.QueryRowContext(ctx, a.q.bumpSessionVersion,
a.now(), uuidArg(userID)).Scan(&v); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, ErrUserNotFound))
}
return v, nil
}
func scanUser(row rowScanner) (*User, error) {
var (
u User
idStr string
passwordHash sql.NullString
emailVerified sql.NullTime
lastLogin sql.NullTime
)
if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified,
&passwordHash, &u.SessionVersion, &lastLogin,
&u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
u.ID = id
u.PasswordHash = scanNullString(passwordHash)
u.EmailVerifiedAt = scanNullTimePtr(emailVerified)
u.LastLoginAt = scanNullTimePtr(lastLogin)
return &u, nil
}
// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so
// callers get reliable errors.Is targets through errx wrapping.
func mapNotFound(err error, sentinel error) error {
if errors.Is(err, sql.ErrNoRows) {
return sentinel
}
return err
}

278
store_verify.go Normal file
View file

@ -0,0 +1,278 @@
package authkit
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"git.juancwu.dev/juancwu/errx"
)
// columnSpec describes one expected column. dataType is matched against
// information_schema.columns.data_type (e.g. "uuid", "text", "bytea",
// "timestamp with time zone", "integer"). nullable is matched against
// is_nullable ("YES"/"NO"). Extra columns on the live table are allowed.
type columnSpec struct {
name string
dataType string
nullable bool
}
// tableSpec is the expected layout for one logical table.
type tableSpec struct {
logicalKey string // matches a Tables field name; used for error messages
configured string
defaultNm string
columns []columnSpec
}
// expectedSchema returns the full per-table column specification matching
// migrations/0001_init.sql. The verifier walks each tableSpec, looking up
// information_schema for the configured name first and falling back to the
// default name when the configured name has no rows.
func expectedSchema(s Schema) []tableSpec {
def := defaultTables()
t := s.Tables
return []tableSpec{
{
logicalKey: "Users",
configured: t.Users,
defaultNm: def.Users,
columns: []columnSpec{
{"id", "uuid", false},
{"email", "text", false},
{"email_normalized", "text", false},
{"email_verified_at", "timestamp with time zone", true},
{"password_hash", "text", true},
{"session_version", "integer", false},
{"last_login_at", "timestamp with time zone", true},
{"created_at", "timestamp with time zone", false},
{"updated_at", "timestamp with time zone", false},
},
},
{
logicalKey: "Sessions",
configured: t.Sessions,
defaultNm: def.Sessions,
columns: []columnSpec{
{"id_hash", "bytea", false},
{"user_id", "uuid", false},
{"user_agent", "text", false},
{"ip", "text", true},
{"created_at", "timestamp with time zone", false},
{"last_seen_at", "timestamp with time zone", false},
{"expires_at", "timestamp with time zone", false},
},
},
{
logicalKey: "Tokens",
configured: t.Tokens,
defaultNm: def.Tokens,
columns: []columnSpec{
{"hash", "bytea", false},
{"kind", "text", false},
{"user_id", "uuid", false},
{"chain_id", "text", true},
{"consumed_at", "timestamp with time zone", true},
{"attempts_remaining", "integer", true},
{"created_at", "timestamp with time zone", false},
{"expires_at", "timestamp with time zone", false},
},
},
{
logicalKey: "ServiceKeys",
configured: t.ServiceKeys,
defaultNm: def.ServiceKeys,
columns: []columnSpec{
{"id_hash", "bytea", false},
{"name", "text", false},
{"last_used_at", "timestamp with time zone", true},
{"created_at", "timestamp with time zone", false},
{"expires_at", "timestamp with time zone", true},
{"revoked_at", "timestamp with time zone", true},
},
},
{
logicalKey: "ServiceKeyAbilities",
configured: t.ServiceKeyAbilities,
defaultNm: def.ServiceKeyAbilities,
columns: []columnSpec{
{"service_key_id_hash", "bytea", false},
{"ability_id", "uuid", false},
{"granted_at", "timestamp with time zone", false},
},
},
{
logicalKey: "Roles",
configured: t.Roles,
defaultNm: def.Roles,
columns: []columnSpec{
{"id", "uuid", false},
{"slug", "text", false},
{"label", "text", true},
{"created_at", "timestamp with time zone", false},
},
},
{
logicalKey: "Permissions",
configured: t.Permissions,
defaultNm: def.Permissions,
columns: []columnSpec{
{"id", "uuid", false},
{"slug", "text", false},
{"label", "text", true},
{"created_at", "timestamp with time zone", false},
},
},
{
logicalKey: "Abilities",
configured: t.Abilities,
defaultNm: def.Abilities,
columns: []columnSpec{
{"id", "uuid", false},
{"slug", "text", false},
{"label", "text", true},
{"created_at", "timestamp with time zone", false},
},
},
{
logicalKey: "UserRoles",
configured: t.UserRoles,
defaultNm: def.UserRoles,
columns: []columnSpec{
{"user_id", "uuid", false},
{"role_id", "uuid", false},
{"granted_at", "timestamp with time zone", false},
},
},
{
logicalKey: "UserPermissions",
configured: t.UserPermissions,
defaultNm: def.UserPermissions,
columns: []columnSpec{
{"user_id", "uuid", false},
{"permission_id", "uuid", false},
{"granted_at", "timestamp with time zone", false},
},
},
{
logicalKey: "RolePermissions",
configured: t.RolePermissions,
defaultNm: def.RolePermissions,
columns: []columnSpec{
{"role_id", "uuid", false},
{"permission_id", "uuid", false},
},
},
}
}
// VerifySchema introspects the live database against the expected layout for
// the given schema. Returns a wrapped ErrSchemaDrift describing every
// missing/mismatched table or column. Extra columns on a table are allowed.
//
// For tables with non-default names, VerifySchema looks up the configured
// name first; if no rows are found, it falls back to the default name. This
// handles the case where a consumer migrated under custom names but later
// removed the overrides — drift is detected against whichever set of names
// actually exists.
func VerifySchema(ctx context.Context, db *sql.DB, schema Schema) error {
const op = "authkit.VerifySchema"
if db == nil {
return errx.New(op, "db is required")
}
if err := schema.Validate(); err != nil {
return errx.Wrap(op, err)
}
specs := expectedSchema(schema)
var problems []string
for _, spec := range specs {
live, foundUnder, err := loadTableColumns(ctx, db, spec.configured, spec.defaultNm)
if err != nil {
return errx.Wrap(op, err)
}
if foundUnder == "" {
problems = append(problems, fmt.Sprintf(
"table %q (%s): not found (also tried %q)",
spec.configured, spec.logicalKey, spec.defaultNm))
continue
}
for _, want := range spec.columns {
got, ok := live[want.name]
if !ok {
problems = append(problems, fmt.Sprintf(
"table %q column %q: missing", foundUnder, want.name))
continue
}
if got.dataType != want.dataType {
problems = append(problems, fmt.Sprintf(
"table %q column %q: data_type=%q, want %q",
foundUnder, want.name, got.dataType, want.dataType))
}
if got.nullable != want.nullable {
problems = append(problems, fmt.Sprintf(
"table %q column %q: nullable=%v, want %v",
foundUnder, want.name, got.nullable, want.nullable))
}
}
}
if len(problems) > 0 {
sort.Strings(problems)
return errx.Wrapf(op, ErrSchemaDrift,
"%d issue(s):\n - %s", len(problems), strings.Join(problems, "\n - "))
}
return nil
}
// loadTableColumns queries information_schema for a table's columns. If the
// configured name has no rows AND defaultName differs, it falls back to the
// default. Returns the columns map and the name actually used (empty string
// when neither exists).
func loadTableColumns(ctx context.Context, db *sql.DB, configured, defaultName string) (map[string]columnSpec, string, error) {
live, err := queryColumns(ctx, db, configured)
if err != nil {
return nil, "", err
}
if len(live) > 0 {
return live, configured, nil
}
if defaultName != "" && defaultName != configured {
live, err = queryColumns(ctx, db, defaultName)
if err != nil {
return nil, "", err
}
if len(live) > 0 {
return live, defaultName, nil
}
}
return nil, "", nil
}
func queryColumns(ctx context.Context, db *sql.DB, table string) (map[string]columnSpec, error) {
const q = `SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = $1`
rows, err := db.QueryContext(ctx, q, table)
if err != nil {
return nil, err
}
defer rows.Close()
out := make(map[string]columnSpec)
for rows.Next() {
var name, dataType, isNullable string
if err := rows.Scan(&name, &dataType, &isNullable); err != nil {
return nil, err
}
out[name] = columnSpec{
name: name,
dataType: dataType,
nullable: isNullable == "YES",
}
}
return out, rows.Err()
}

87
store_verify_test.go Normal file
View file

@ -0,0 +1,87 @@
package authkit
import (
"context"
"database/sql"
"errors"
"testing"
_ "github.com/jackc/pgx/v5/stdlib"
)
func TestIntegration_VerifySchemaPasses(t *testing.T) {
a := freshAuth(t)
if err := VerifySchema(context.Background(), a.DB(), DefaultSchema()); err != nil {
t.Fatalf("VerifySchema after Migrate should pass: %v", err)
}
}
func TestIntegration_VerifyAllowsExtraColumns(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
// Add a column the verifier doesn't expect — should still pass.
if _, err := a.DB().ExecContext(ctx,
"ALTER TABLE authkit_users ADD COLUMN consumer_extra TEXT"); err != nil {
t.Fatalf("ALTER ADD: %v", err)
}
if err := VerifySchema(ctx, a.DB(), DefaultSchema()); err != nil {
t.Fatalf("VerifySchema should tolerate extra columns: %v", err)
}
}
func TestIntegration_VerifyDetectsMissingColumn(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.DB().ExecContext(ctx,
"ALTER TABLE authkit_users DROP COLUMN last_login_at"); err != nil {
t.Fatalf("ALTER DROP: %v", err)
}
err := VerifySchema(ctx, a.DB(), DefaultSchema())
if !errors.Is(err, ErrSchemaDrift) {
t.Fatalf("expected ErrSchemaDrift on missing column, got %v", err)
}
}
func TestIntegration_VerifyDetectsMissingTable(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.DB().ExecContext(ctx,
"DROP TABLE authkit_user_permissions"); err != nil {
t.Fatalf("DROP: %v", err)
}
err := VerifySchema(ctx, a.DB(), DefaultSchema())
if !errors.Is(err, ErrSchemaDrift) {
t.Fatalf("expected ErrSchemaDrift on missing table, got %v", err)
}
}
func TestIntegration_VerifyFallbackToDefaultName(t *testing.T) {
// Migrate created tables under default names. Construct a custom schema
// pointing at non-existent table names — verifier should fall back to
// the defaults and pass.
a := freshAuth(t)
ctx := context.Background()
custom := DefaultSchema()
custom.Tables.Users = "renamed_users_does_not_exist"
if err := VerifySchema(ctx, a.DB(), custom); err != nil {
t.Fatalf("VerifySchema should fall back to default name when configured table is missing: %v", err)
}
}
func TestIntegration_MigrateIdempotent(t *testing.T) {
url := dbURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
t.Cleanup(func() { dropAllAuthkitTables(t, db, DefaultSchema()) })
dropAllAuthkitTables(t, db, DefaultSchema())
for i := 0; i < 3; i++ {
if err := Migrate(context.Background(), db, DefaultSchema()); err != nil {
t.Fatalf("Migrate iter %d: %v", i, err)
}
}
}

View file

@ -1,82 +0,0 @@
package authkit
import (
"context"
"time"
"github.com/google/uuid"
)
type UserStore interface {
CreateUser(ctx context.Context, u *User) error
GetUserByID(ctx context.Context, id uuid.UUID) (*User, error)
GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error)
UpdateUser(ctx context.Context, u *User) error
DeleteUser(ctx context.Context, id uuid.UUID) error
SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error
SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error
BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error)
IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error)
ResetFailedLogins(ctx context.Context, userID uuid.UUID) error
}
type SessionStore interface {
CreateSession(ctx context.Context, s *Session) error
GetSession(ctx context.Context, idHash []byte) (*Session, error)
TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error
DeleteSession(ctx context.Context, idHash []byte) error
DeleteUserSessions(ctx context.Context, userID uuid.UUID) error
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
}
type TokenStore interface {
CreateToken(ctx context.Context, t *Token) error
// ConsumeToken atomically marks the matching unexpired, unconsumed token
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
// Implementations MUST do this in one statement (UPDATE ... RETURNING)
// to prevent double-spend under concurrent callers.
ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error)
// GetToken returns a token without consuming it. Used for refresh-token
// reuse detection: a token that exists with consumed_at != nil is a
// replay signal.
GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error)
DeleteByChain(ctx context.Context, chainID string) (int64, error)
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
}
type ServiceKeyStore interface {
CreateServiceKey(ctx context.Context, k *ServiceKey) error
GetServiceKey(ctx context.Context, idHash []byte) (*ServiceKey, error)
ListServiceKeysByOwner(ctx context.Context, ownerKind string, ownerID uuid.UUID) ([]*ServiceKey, error)
TouchServiceKey(ctx context.Context, idHash []byte, at time.Time) error
RevokeServiceKey(ctx context.Context, idHash []byte, at time.Time) error
}
type RoleStore interface {
CreateRole(ctx context.Context, r *Role) error
GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error)
GetRoleByName(ctx context.Context, name string) (*Role, error)
ListRoles(ctx context.Context) ([]*Role, error)
DeleteRole(ctx context.Context, id uuid.UUID) error
AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error
RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error
GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error)
HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error)
}
type PermissionStore interface {
CreatePermission(ctx context.Context, p *Permission) error
GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error)
GetPermissionByName(ctx context.Context, name string) (*Permission, error)
ListPermissions(ctx context.Context) ([]*Permission, error)
DeletePermission(ctx context.Context, id uuid.UUID) error
AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error
RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error
GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error)
GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error)
}
type Hasher interface {
Hash(password string) (string, error)
Verify(password, encoded string) (ok bool, needsRehash bool, err error)
}

89
testdb_test.go Normal file
View file

@ -0,0 +1,89 @@
package authkit
// Integration test infrastructure. Skipped when AUTHKIT_TEST_DATABASE_URL is
// unset so the unit-test suite remains usable without a database.
import (
"context"
"database/sql"
"fmt"
"net/netip"
"os"
"testing"
"time"
"git.juancwu.dev/juancwu/authkit/hasher"
_ "github.com/jackc/pgx/v5/stdlib"
)
// noIP returns the zero-value netip.Addr — used by tests that don't care
// about the originating IP.
func noIP() netip.Addr { return netip.Addr{} }
func dbURL(t *testing.T) string {
t.Helper()
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
if url == "" {
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
}
return url
}
// freshAuth returns a fully-initialized *Auth bound to a clean database.
// All authkit_* tables are dropped before Migrate runs, so each test sees
// an empty schema.
func freshAuth(t *testing.T) *Auth {
t.Helper()
url := dbURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("ping: %v", err)
}
dropAllAuthkitTables(t, db, DefaultSchema())
t.Cleanup(func() { dropAllAuthkitTables(t, db, DefaultSchema()) })
a, err := New(context.Background(), Deps{
DB: db,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, Config{
JWTSecret: []byte("integration-secret-thirty-two!!!"),
JWTIssuer: "authkit-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
EmailOTPTTL: time.Minute,
EmailOTPMaxAttempts: 3,
})
if err != nil {
t.Fatalf("authkit.New: %v", err)
}
return a
}
func dropAllAuthkitTables(t *testing.T, db *sql.DB, s Schema) {
t.Helper()
tables := []string{
s.Tables.ServiceKeyAbilities, s.Tables.UserPermissions,
s.Tables.UserRoles, s.Tables.RolePermissions,
s.Tables.ServiceKeys, s.Tables.Abilities,
s.Tables.Roles, s.Tables.Permissions,
s.Tables.Tokens, s.Tables.Sessions, s.Tables.Users,
s.Tables.SchemaMigrations,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for _, name := range tables {
_, _ = db.ExecContext(ctx,
fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
}
}

View file

@ -3,7 +3,6 @@ package authkit
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"io"
"strings"
@ -26,9 +25,7 @@ const (
// MintOpaqueSecret generates a fresh opaque secret with the given prefix.
// Returns the plaintext (show once, never persist) and the SHA-256 lookup
// hash. A nil rng falls back to crypto/rand.Reader. Exposed so consumers
// building bespoke storage can produce secrets in the same shape authkit
// uses internally.
// hash. A nil rng falls back to crypto/rand.Reader.
func MintOpaqueSecret(rng io.Reader, prefix string) (plaintext string, hash []byte, err error) {
const op = "authkit.MintOpaqueSecret"
if rng == nil {
@ -53,7 +50,7 @@ func HashOpaqueSecret(plaintext string) []byte {
}
// ParseOpaqueSecret validates that a plaintext begins with the expected
// prefix and returns the lookup hash. Returns ok=false on prefix mismatch.
// prefix and returns the lookup hash.
func ParseOpaqueSecret(prefix, plaintext string) (hash []byte, ok bool) {
want := prefix + "_"
if !strings.HasPrefix(plaintext, want) {
@ -61,22 +58,3 @@ func ParseOpaqueSecret(prefix, plaintext string) (hash []byte, ok bool) {
}
return HashOpaqueSecret(plaintext), true
}
// mintSecret is the internal entry point; existing callers pass prefix
// first to match call-site readability ("mint a token of kind X").
func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) {
return MintOpaqueSecret(rng, prefix)
}
func hashSecret(plaintext string) []byte {
return HashOpaqueSecret(plaintext)
}
func parseSecret(prefix, plaintext string) (hash []byte, ok bool) {
return ParseOpaqueSecret(prefix, plaintext)
}
// constantTimeEqual is a thin wrapper for readability at call sites.
func constantTimeEqual(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}

View file

@ -7,46 +7,46 @@ import (
"testing"
)
func TestMintSecretRoundtrip(t *testing.T) {
plaintext, hash, err := mintSecret(prefixSession, nil)
func TestMintOpaqueSecretRoundtrip(t *testing.T) {
plaintext, hash, err := MintOpaqueSecret(nil, prefixSession)
if err != nil {
t.Fatalf("mintSecret: %v", err)
t.Fatalf("MintOpaqueSecret: %v", err)
}
if !strings.HasPrefix(plaintext, prefixSession+"_") {
t.Fatalf("missing prefix: %q", plaintext)
}
parsed, ok := parseSecret(prefixSession, plaintext)
parsed, ok := ParseOpaqueSecret(prefixSession, plaintext)
if !ok {
t.Fatalf("parseSecret rejected our own mint")
t.Fatalf("ParseOpaqueSecret rejected our own mint")
}
if !bytes.Equal(hash, parsed) {
t.Fatalf("hash mismatch")
}
want := sha256.Sum256([]byte(plaintext))
if !bytes.Equal(hash, want[:]) {
t.Fatalf("hashSecret != sha256(plaintext)")
t.Fatalf("HashOpaqueSecret != sha256(plaintext)")
}
}
func TestParseSecretWrongPrefix(t *testing.T) {
plaintext, _, err := mintSecret(prefixSession, nil)
func TestParseOpaqueSecretWrongPrefix(t *testing.T) {
plaintext, _, err := MintOpaqueSecret(nil, prefixSession)
if err != nil {
t.Fatalf("mintSecret: %v", err)
t.Fatalf("MintOpaqueSecret: %v", err)
}
if _, ok := parseSecret(prefixServiceKey, plaintext); ok {
t.Fatalf("parseSecret should reject mismatched prefix")
if _, ok := ParseOpaqueSecret(prefixServiceKey, plaintext); ok {
t.Fatalf("ParseOpaqueSecret should reject mismatched prefix")
}
if _, ok := parseSecret(prefixSession, "sessXXXX"); ok {
t.Fatalf("parseSecret should require trailing underscore")
if _, ok := ParseOpaqueSecret(prefixSession, "sessXXXX"); ok {
t.Fatalf("ParseOpaqueSecret should require trailing underscore")
}
}
func TestMintSecretUniqueness(t *testing.T) {
func TestMintOpaqueSecretUniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
p, _, err := mintSecret(prefixServiceKey, nil)
p, _, err := MintOpaqueSecret(nil, prefixServiceKey)
if err != nil {
t.Fatalf("mintSecret: %v", err)
t.Fatalf("MintOpaqueSecret: %v", err)
}
if _, dup := seen[p]; dup {
t.Fatalf("duplicate mint: %s", p)

105
userctx.go Normal file
View file

@ -0,0 +1,105 @@
package authkit
import (
"context"
"sync"
"github.com/google/uuid"
)
// userCtxKey is an unexported context key. The empty struct shape guarantees
// no collision with caller-defined keys.
type userCtxKey struct{}
type serviceKeyCtxKey struct{}
// userBox holds the per-request lazy-loaded user. The box pointer is what's
// stored on the context, so RefreshUserInCtx can mutate the cache visible
// to every UserFromCtx call within the same request.
type userBox struct {
mu sync.Mutex
auth *Auth
userID uuid.UUID
cached *User
}
func (b *userBox) get(ctx context.Context) (*User, error) {
b.mu.Lock()
if b.cached != nil {
u := b.cached
b.mu.Unlock()
return u, nil
}
b.mu.Unlock()
// Don't hold the lock across the DB call.
u, err := b.auth.storeGetUserByID(ctx, b.userID)
if err != nil {
return nil, err
}
b.mu.Lock()
if b.cached == nil {
b.cached = u
}
out := b.cached
b.mu.Unlock()
return out, nil
}
func (b *userBox) refresh(ctx context.Context) (*User, error) {
b.mu.Lock()
b.cached = nil
b.mu.Unlock()
return b.get(ctx)
}
// WithUserContext attaches a lazy user-context to ctx. Middleware uses this
// to record an authenticated user_id without paying for a DB read until a
// handler actually calls UserFromCtx. Custom middleware authors can use
// this directly to integrate hand-rolled auth flows.
func WithUserContext(ctx context.Context, a *Auth, userID uuid.UUID) context.Context {
return context.WithValue(ctx, userCtxKey{}, &userBox{auth: a, userID: userID})
}
// WithServiceKey attaches a *ServiceKey to ctx. Used by service-key middleware.
func WithServiceKey(ctx context.Context, k *ServiceKey) context.Context {
return context.WithValue(ctx, serviceKeyCtxKey{}, k)
}
// UserIDFromCtx returns the authenticated user_id placed by middleware via
// WithUserContext. The boolean is false when no user-bound auth ran for
// this request (e.g. a service-key request).
func UserIDFromCtx(ctx context.Context) (uuid.UUID, bool) {
b, ok := ctx.Value(userCtxKey{}).(*userBox)
if !ok {
return uuid.Nil, false
}
return b.userID, true
}
// UserFromCtx returns the authenticated *User, lazy-loading from the
// database on first call within this request and caching the result for
// subsequent calls. Returns ErrNoUserContext if no user-bound auth ran.
func UserFromCtx(ctx context.Context) (*User, error) {
b, ok := ctx.Value(userCtxKey{}).(*userBox)
if !ok {
return nil, ErrNoUserContext
}
return b.get(ctx)
}
// RefreshUserInCtx invalidates the cached user and refetches. Use after an
// admin-side update that should be visible to the rest of the request.
func RefreshUserInCtx(ctx context.Context) (*User, error) {
b, ok := ctx.Value(userCtxKey{}).(*userBox)
if !ok {
return nil, ErrNoUserContext
}
return b.refresh(ctx)
}
// ServiceKeyFromCtx returns the authenticated *ServiceKey placed by
// service-key middleware. The boolean is false when no service-key
// authentication ran for this request.
func ServiceKeyFromCtx(ctx context.Context) (*ServiceKey, bool) {
k, ok := ctx.Value(serviceKeyCtxKey{}).(*ServiceKey)
return k, ok
}

64
userctx_test.go Normal file
View file

@ -0,0 +1,64 @@
package authkit
import (
"context"
"testing"
)
func TestIntegration_UserCtxLazyAndRefresh(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "ctx@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
rctx := WithUserContext(ctx, a, u.ID)
// UserIDFromCtx is non-loading.
id, ok := UserIDFromCtx(rctx)
if !ok || id != u.ID {
t.Fatalf("UserIDFromCtx mismatch")
}
first, err := UserFromCtx(rctx)
if err != nil {
t.Fatalf("UserFromCtx (lazy load): %v", err)
}
// Mutate the underlying user out-of-band so we can prove refresh sees
// the change.
if err := a.SetPassword(ctx, u.ID, "new-secret"); err != nil {
t.Fatalf("SetPassword: %v", err)
}
// Without RefreshUserInCtx, UserFromCtx returns the cached value (which
// has the empty password hash from the initial load).
second, err := UserFromCtx(rctx)
if err != nil {
t.Fatalf("UserFromCtx (cached): %v", err)
}
if second != first {
t.Fatalf("expected cached pointer identity, got distinct pointers")
}
// Refresh: cache busts, next read sees the password hash.
refreshed, err := RefreshUserInCtx(rctx)
if err != nil {
t.Fatalf("RefreshUserInCtx: %v", err)
}
if refreshed.PasswordHash == "" {
t.Fatalf("refresh should observe the SetPassword side-effect")
}
}
func TestUserCtxNoUser(t *testing.T) {
if _, ok := UserIDFromCtx(context.Background()); ok {
t.Fatalf("UserIDFromCtx should be false on a bare context")
}
if _, err := UserFromCtx(context.Background()); err != ErrNoUserContext {
t.Fatalf("UserFromCtx should return ErrNoUserContext, got %v", err)
}
if _, err := RefreshUserInCtx(context.Background()); err != ErrNoUserContext {
t.Fatalf("RefreshUserInCtx should return ErrNoUserContext, got %v", err)
}
}