authkit initial

This commit is contained in:
juancwu 2026-04-26 01:36:53 +00:00
commit 134393fbca
43 changed files with 5188 additions and 1 deletions

344
README.md
View file

@ -1,3 +1,345 @@
# authkit # authkit
Just a concoction of auth stuff in one place. A pragmatic authentication and authorization toolkit for Go web services.
`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and
permissions, plus default `database/sql` Postgres implementations and
framework-neutral HTTP middleware. It supports both opaque server-side
sessions and JWT access tokens with rotating refresh tokens, hashes passwords
with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux)
or any `net/http` stack.
## Install
```
go get git.juancwu.dev/juancwu/authkit
```
`authkit` itself depends only on `database/sql` and the Go standard library
plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring
your own driver: `pgx`, `lib/pq`, or anything else that registers a
`database/sql` driver.
```go
import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
```
PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and
`pgcrypto` so no extensions are required.
## What's included
**Authentication flows**
- Email + password registration and login (Argon2id PHC-encoded hashes)
- Opaque server-side sessions with sliding TTL bounded by an absolute cap
- JWT access tokens (HS256) with rotating refresh tokens and reuse detection
- Email verification, password reset, and magic-link passwordless login
**Authorization**
- Roles and permissions with many-to-many wiring
- API keys with custom abilities for per-endpoint scoping
- A unified `Principal` type so middleware works the same regardless of which
authentication method ran
**Storage**
- Interfaces for every store so callers can plug in their own backends
- Default Postgres implementation built on `*sql.DB` (`sqlstore` package)
- Override table names via `Schema` without forking — useful when authkit
lives alongside an existing application schema
- A `Dialect` abstraction so future MySQL / SQLite implementations slot in
without changes to store code
- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)`
helper that takes a session-scoped advisory lock
**HTTP**
- `middleware.RequireSession`, `RequireJWT`, `RequireAPIKey`, `RequireAny`
- `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`,
`RequireAbility`
- `middleware.PrincipalFrom(ctx)` to read the authenticated principal in
handlers
**Errors**
- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`,
`ErrTokenReused`, `ErrSessionInvalid`, `ErrAPIKeyInvalid`,
`ErrPermissionDenied`, ...) compatible with `errors.Is`
- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx)
for op tags
## Out of scope (v1)
MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching,
pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite
dialects (architecture supports them; only Postgres ships in v1), and
column-name overrides in `Schema` (table-name overrides only).
## Quick start
### 1. Open a database and run migrations
```go
import (
"database/sql"
"git.juancwu.dev/juancwu/authkit/sqlstore"
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
_ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
)
db, err := sql.Open("pgx", os.Getenv("DATABASE_URL"))
if err != nil { /* ... */ }
defer db.Close()
if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil {
log.Fatal(err)
}
```
`Migrate` is idempotent and safe to call from multiple processes — it takes
a session-scoped `pg_advisory_lock` to serialise rollouts.
`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same
calls — the library only cares about `*sql.DB`.
### 2. Wire the service
```go
import (
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
)
stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema())
if err != nil { /* ... */ }
auth := authkit.New(authkit.Deps{
Users: stores.Users,
Sessions: stores.Sessions,
Tokens: stores.Tokens,
APIKeys: stores.APIKeys,
Roles: stores.Roles,
Permissions: stores.Permissions,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
JWTIssuer: "myapp",
SessionCookieSecure: true,
SessionCookieHTTPOnly: true,
})
```
`Config` zero values fall back to sensible defaults (24h idle / 30d absolute
session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h
password-reset, 15m magic-link). `JWTSecret` and the seven `Deps` fields are
required; `New` panics on a misconfiguration.
### 3. Use the service
```go
// Registration + password login
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2")
// Opaque session (cookie-friendly)
plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP)
http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt))
// JWT + rotating refresh
access, refresh, err := auth.IssueJWT(ctx, u.ID)
access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed
// API key with abilities
plaintext, key, err := auth.IssueAPIKey(ctx, u.ID, "ci",
[]string{"billing:read", "users:list"}, nil)
// Email verification + password reset + magic link
tok, err := auth.RequestEmailVerification(ctx, u.ID)
_, err = auth.ConfirmEmail(ctx, tok)
tok, err = auth.RequestPasswordReset(ctx, "alice@example.com")
err = auth.ConfirmPasswordReset(ctx, tok, "new-password")
tok, err = auth.RequestMagicLink(ctx, "alice@example.com")
u, err = auth.ConsumeMagicLink(ctx, tok)
```
The plaintext returned by `IssueSession`, `IssueJWT`, `IssueAPIKey`, and the
token-minting flows is **show-once** — only its SHA-256 hash is stored. Show
it to the user immediately; you cannot recover it later.
### 4. Wire middleware
`authkit/middleware` returns standard `func(http.Handler) http.Handler`
values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any
`net/http` mux that accepts the same shape.
```go
import (
authkitmw "git.juancwu.dev/juancwu/authkit/middleware"
"git.juancwu.dev/juancwu/lightmux"
)
mux := lightmux.New()
cookieAuth := authkitmw.RequireSession(authkitmw.Options{
Auth: auth,
Extractor: authkit.ChainExtractors(
authkit.CookieExtractor("authkit_session"),
authkit.BearerExtractor(),
),
})
me := mux.Group("/me", cookieAuth)
me.Get("", func(w http.ResponseWriter, r *http.Request) {
p := authkitmw.MustPrincipal(r)
json.NewEncoder(w).Encode(p)
})
// RBAC: stack authz on top of any auth method
admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin"))
// API-key-only route with a per-endpoint ability check
api := mux.Group("/api/v1", authkitmw.RequireAPIKey(authkitmw.Options{Auth: auth}))
api.Get("/billing", billingHandler, authkitmw.RequireAbility("billing:read"))
```
`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or
chain extractors) when reading session cookies. `Options.OnUnauth` and
`Options.OnForbidden` default to a JSON `401` / `403`; override them to match
your error envelope.
## Custom table names
Pass a non-default `Schema` to use your own table names. Identifiers must
match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and
`Migrate()` time, so SQL injection through the schema is impossible.
```go
schema := sqlstore.DefaultSchema()
schema.Tables.Users = "accounts"
schema.Tables.APIKeys = "api_credentials"
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
```
The bundled migration files use the default `authkit_*` names. If you
override, you're responsible for matching DDL (most consumers with custom
naming already have their own DDL pipeline).
Column-name overrides are not exposed in v1 — the column set is fixed for
each table. Adding column overrides later is purely additive.
## How things work
### Secret token format
Sessions, refresh tokens, API keys, email-verify tokens, password-reset tokens,
and magic-link tokens all share one format:
```
plaintext = "<prefix>_" + base64url(32 random bytes, no padding)
lookup = sha256(plaintext)
```
Plaintext is returned to the caller exactly once and never persisted; the
SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or
`Config.Random` for tests).
### JWT revocation
Access tokens carry `sv` (`session_version`) in their claims. When you call
`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version`
column increments and every outstanding access token fails the next
`AuthenticateJWT`. This is the only way to invalidate a JWT before its
`exp`.
### Refresh token rotation
Each `RefreshJWT` consumes the presented refresh token and issues a new one
on the same chain (the `chain_id` column on `authkit_tokens`). If a
*consumed* refresh token is ever presented again — a strong replay signal —
the entire chain is deleted via `TokenStore.DeleteByChain` and the call
returns `ErrTokenReused`.
### Sliding session TTL
Each authenticated request via `AuthenticateSession` slides `expires_at` to
`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`.
Long-lived idle sessions still hit the absolute boundary.
### Schema and migrations
`sqlstore.Migrate` applies every embedded `.sql` file under
`sqlstore/dialect/postgres/migrations/` whose version (filename without
`.sql`) is not in `authkit_schema_migrations`. The dialect's
`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises
concurrent migrators. Each migration owns its own transaction so future
migrations can use statements like `CREATE INDEX CONCURRENTLY`.
Every default table is prefixed `authkit_` so the schema can live alongside
your application's own tables in a shared database.
### Driver and dialect architecture
The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour
lives behind a small `Dialect` interface:
```go
type Dialect interface {
Name() string
BuildQueries(s Schema) Queries
Bootstrap(ctx context.Context, db *sql.DB) error
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
Migrations() fs.FS
IsUniqueViolation(err error) bool
Placeholder(n int) string
PlaceholderList(start, count int) string
}
```
v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new
implementation; no changes to store code.
## Configuration reference
| Field | Default | Notes |
|---|---|---|
| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request |
| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this |
| `SessionCookieName` | `authkit_session` | |
| `SessionCookieSameSite` | `Lax` | |
| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production |
| `JWTSecret` | — (required) | HS256 key |
| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them |
| `AccessTokenTTL` | 15m | |
| `RefreshTokenTTL` | 30d | |
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests |
| `Random` | `crypto/rand.Reader` | Override for deterministic tests |
| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit |
## Implementing your own store
Every store is a small interface with explicit semantics — see `stores.go`.
The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token
consumed and return it in a single statement (`UPDATE ... RETURNING` on
Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed.
## Testing
```
go test ./... # unit tests, no DB
AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests
```
Unit tests cover token mint/parse, Argon2id encode/verify (including
`needsRehash` on parameter change), JWT issue/parse (incl. expired,
`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email
verification, password reset cascading session invalidation, magic-link
self-verification, API keys with abilities, and RBAC role-permission
resolution. Integration tests run the full `sqlstore` contract against a
real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set.
## License
MIT. See `LICENSE`.

53
Taskfile.yml Normal file
View file

@ -0,0 +1,53 @@
version: '3'
tasks:
default:
desc: Run vet and unit tests
cmds:
- task: check
install:tools:
desc: Install development tools (tparse for prettier test output)
cmds:
- go install github.com/mfridman/tparse@latest
build:
desc: Build all packages
cmds:
- go build ./...
test:
desc: Run unit tests with prettier output via tparse
cmds:
- set -o pipefail && go test ./... -json -cover | tparse -all
test:race:
desc: Run unit tests under the race detector
cmds:
- set -o pipefail && go test ./... -race -json | tparse -all
test:integration:
desc: Run sqlstore integration tests against a real Postgres (requires AUTHKIT_TEST_DATABASE_URL)
cmds:
- set -o pipefail && go test ./sqlstore/... -run Integration -count=1 -json | tparse -all
vet:
desc: Run go vet
cmds:
- go vet ./...
fmt:
desc: Format Go source files
cmds:
- gofmt -w .
tidy:
desc: Tidy go.mod
cmds:
- go mod tidy
check:
desc: Run vet and unit tests
cmds:
- task: vet
- task: test

121
authkit.go Normal file
View file

@ -0,0 +1,121 @@
package authkit
import (
"context"
"crypto/rand"
"io"
"net/http"
"time"
"git.juancwu.dev/juancwu/errx"
)
// Deps bundles every backing store and the password hasher the Auth service
// depends on. All fields are required; New panics on a nil dep so misuse is
// caught at boot rather than under load.
type Deps struct {
Users UserStore
Sessions SessionStore
Tokens TokenStore
APIKeys APIKeyStore
Roles RoleStore
Permissions PermissionStore
Hasher Hasher
}
// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material,
// and optional hooks. Any zero-valued duration is replaced with a sane
// default in New; required fields (notably JWTSecret) cause New to panic.
type Config struct {
// Session (opaque) cookies + DB-backed lifetime
SessionIdleTTL time.Duration
SessionAbsoluteTTL time.Duration
SessionCookieName string
SessionCookieDomain string
SessionCookiePath string
SessionCookieSecure bool
SessionCookieHTTPOnly bool
SessionCookieSameSite http.SameSite
// JWT (HS256)
JWTSecret []byte
JWTIssuer string
JWTAudience string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
// Single-use tokens
EmailVerifyTTL time.Duration
PasswordResetTTL time.Duration
MagicLinkTTL time.Duration
// Hooks (optional)
Clock func() time.Time
Random io.Reader
LoginHook func(ctx context.Context, email string, success bool) error
}
// Auth is the high-level service that composes the stores and hasher into the
// flows callers use: registration, login, sessions, JWTs, magic links, API
// keys, and authz checks. It is safe for concurrent use; method receivers
// never mutate Auth state after construction.
type Auth struct {
deps Deps
cfg Config
}
// New validates Deps and Config, fills in defaults, and returns a ready
// service. It panics on missing deps or missing JWT secret rather than
// returning an error — these are programmer errors, not runtime ones.
func New(deps Deps, cfg Config) *Auth {
if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil ||
deps.APIKeys == nil || deps.Roles == nil || deps.Permissions == nil ||
deps.Hasher == nil {
panic(errx.New("authkit.New", "all Deps fields are required"))
}
if len(cfg.JWTSecret) == 0 {
panic(errx.New("authkit.New", "Config.JWTSecret is required"))
}
if cfg.SessionIdleTTL == 0 {
cfg.SessionIdleTTL = 24 * time.Hour
}
if cfg.SessionAbsoluteTTL == 0 {
cfg.SessionAbsoluteTTL = 30 * 24 * time.Hour
}
if cfg.SessionCookieName == "" {
cfg.SessionCookieName = "authkit_session"
}
if cfg.SessionCookiePath == "" {
cfg.SessionCookiePath = "/"
}
if cfg.SessionCookieSameSite == 0 {
cfg.SessionCookieSameSite = http.SameSiteLaxMode
}
if cfg.AccessTokenTTL == 0 {
cfg.AccessTokenTTL = 15 * time.Minute
}
if cfg.RefreshTokenTTL == 0 {
cfg.RefreshTokenTTL = 30 * 24 * time.Hour
}
if cfg.EmailVerifyTTL == 0 {
cfg.EmailVerifyTTL = 48 * time.Hour
}
if cfg.PasswordResetTTL == 0 {
cfg.PasswordResetTTL = time.Hour
}
if cfg.MagicLinkTTL == 0 {
cfg.MagicLinkTTL = 15 * time.Minute
}
if cfg.Clock == nil {
cfg.Clock = func() time.Time { return time.Now().UTC() }
}
if cfg.Random == nil {
cfg.Random = rand.Reader
}
return &Auth{deps: deps, cfg: cfg}
}
// now returns the configured wall clock, defaulting to time.Now in UTC.
func (a *Auth) now() time.Time { return a.cfg.Clock() }

13
doc.go Normal file
View file

@ -0,0 +1,13 @@
// Package authkit is an authentication and authorization toolkit for Go web
// services. It defines storage interfaces (UserStore, SessionStore, TokenStore,
// APIKeyStore, RoleStore, PermissionStore) and a high-level Auth service that
// composes them to support registration, password login, opaque server-side
// sessions, JWT access plus rotating refresh tokens, email verification,
// password resets, magic-link passwordless login, role-based access control,
// and API keys with custom abilities.
//
// Default Postgres implementations of every store live in the pgstore
// subpackage. Argon2id password hashing lives in hasher. Framework-neutral
// HTTP middleware (compatible with lightmux and any net/http stack) lives in
// middleware.
package authkit

18
errors.go Normal file
View file

@ -0,0 +1,18 @@
package authkit
import "errors"
var (
ErrEmailTaken = errors.New("authkit: email already registered")
ErrUserNotFound = errors.New("authkit: user not found")
ErrInvalidCredentials = errors.New("authkit: invalid credentials")
ErrEmailNotVerified = errors.New("authkit: email not verified")
ErrTokenInvalid = errors.New("authkit: invalid or expired token")
ErrTokenReused = errors.New("authkit: token reuse detected")
ErrSessionInvalid = errors.New("authkit: invalid or expired session")
ErrAPIKeyInvalid = errors.New("authkit: invalid or expired api key")
ErrPermissionDenied = errors.New("authkit: permission denied")
ErrRoleNotFound = errors.New("authkit: role not found")
ErrPermissionNotFound = errors.New("authkit: permission not found")
ErrConfigInvalid = errors.New("authkit: invalid configuration")
)

64
extractor.go Normal file
View file

@ -0,0 +1,64 @@
package authkit
import (
"net/http"
"strings"
)
// Extractor pulls a credential string out of an HTTP request. It returns
// (value, true) when a value was found, otherwise ("", false).
type Extractor func(r *http.Request) (string, bool)
// BearerExtractor reads the value following "Bearer " in the Authorization
// header. Comparison is case-insensitive on the scheme.
func BearerExtractor() Extractor {
return func(r *http.Request) (string, bool) {
h := r.Header.Get("Authorization")
if h == "" {
return "", false
}
const prefix = "bearer "
if len(h) <= len(prefix) || !strings.EqualFold(h[:len(prefix)], prefix) {
return "", false
}
v := strings.TrimSpace(h[len(prefix):])
if v == "" {
return "", false
}
return v, true
}
}
// CookieExtractor reads the named cookie's value.
func CookieExtractor(name string) Extractor {
return func(r *http.Request) (string, bool) {
c, err := r.Cookie(name)
if err != nil || c.Value == "" {
return "", false
}
return c.Value, true
}
}
// HeaderExtractor reads a custom header verbatim.
func HeaderExtractor(name string) Extractor {
return func(r *http.Request) (string, bool) {
v := strings.TrimSpace(r.Header.Get(name))
if v == "" {
return "", false
}
return v, true
}
}
// ChainExtractors tries each extractor in order, returning the first hit.
func ChainExtractors(es ...Extractor) Extractor {
return func(r *http.Request) (string, bool) {
for _, e := range es {
if v, ok := e(r); ok {
return v, true
}
}
return "", false
}
}

20
go.mod Normal file
View file

@ -0,0 +1,20 @@
module git.juancwu.dev/juancwu/authkit
go 1.26.2
require (
git.juancwu.dev/juancwu/errx v0.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.9.2
golang.org/x/crypto v0.50.0
)
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
)

36
go.sum Normal file
View file

@ -0,0 +1,36 @@
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

136
hasher/argon2id.go Normal file
View file

@ -0,0 +1,136 @@
// Package hasher provides authkit.Hasher implementations. The default
// implementation, Argon2id, encodes hashes in the standard PHC string format
// (https://github.com/P-H-C/phc-string-format) so callers can introspect
// parameters and migrate.
package hasher
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"io"
"strings"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"golang.org/x/crypto/argon2"
)
// Argon2idParams configures the Argon2id KDF.
type Argon2idParams struct {
Memory uint32 // KB; OWASP 2024 baseline: 19 MiB minimum, 64 MiB recommended
Iterations uint32 // time cost
Parallelism uint8 // lanes
SaltLen uint32 // bytes
KeyLen uint32 // bytes
}
// DefaultArgon2idParams returns sensible defaults: 64 MiB memory, 3
// iterations, 2 lanes, 16-byte salt, 32-byte key. Tune up Memory/Iterations
// over time and rely on Verify's needsRehash signal to migrate stored hashes.
func DefaultArgon2idParams() Argon2idParams {
return Argon2idParams{
Memory: 64 * 1024,
Iterations: 3,
Parallelism: 2,
SaltLen: 16,
KeyLen: 32,
}
}
type argon2idHasher struct {
params Argon2idParams
rng io.Reader
}
// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the
// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand.
func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher {
if params == (Argon2idParams{}) {
params = DefaultArgon2idParams()
}
if rng == nil {
rng = rand.Reader
}
return &argon2idHasher{params: params, rng: rng}
}
func (h *argon2idHasher) Hash(password string) (string, error) {
const op = "authkit.hasher.Argon2id.Hash"
if password == "" {
return "", errx.New(op, "password is empty")
}
salt := make([]byte, h.params.SaltLen)
if _, err := io.ReadFull(h.rng, salt); err != nil {
return "", errx.Wrap(op, err)
}
key := argon2.IDKey([]byte(password), salt,
h.params.Iterations, h.params.Memory, h.params.Parallelism, h.params.KeyLen)
return encodePHC(h.params, salt, key), nil
}
func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) {
const op = "authkit.hasher.Argon2id.Verify"
got, salt, key, err := decodePHC(encoded)
if err != nil {
return false, false, errx.Wrap(op, err)
}
want := argon2.IDKey([]byte(password), salt,
got.Iterations, got.Memory, got.Parallelism, uint32(len(key)))
if subtle.ConstantTimeCompare(want, key) != 1 {
return false, false, nil
}
needsRehash := got.Memory != h.params.Memory ||
got.Iterations != h.params.Iterations ||
got.Parallelism != h.params.Parallelism ||
uint32(len(salt)) != h.params.SaltLen ||
uint32(len(key)) != h.params.KeyLen
return true, needsRehash, nil
}
// PHC string format:
//
// $argon2id$v=19$m=<mem>,t=<iters>,p=<lanes>$<salt b64>$<key b64>
//
// base64 here is the unpadded standard alphabet, per the spec.
func encodePHC(p Argon2idParams, salt, key []byte) string {
enc := base64.RawStdEncoding
return fmt.Sprintf(
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, p.Memory, p.Iterations, p.Parallelism,
enc.EncodeToString(salt), enc.EncodeToString(key),
)
}
func decodePHC(s string) (Argon2idParams, []byte, []byte, error) {
parts := strings.Split(s, "$")
// Expect: ["", "argon2id", "v=19", "m=...,t=...,p=...", "<salt>", "<key>"]
if len(parts) != 6 || parts[0] != "" || parts[1] != "argon2id" {
return Argon2idParams{}, nil, nil, errors.New("hasher: not an argon2id phc string")
}
var version int
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad version segment: %w", err)
}
if version != argon2.Version {
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: unsupported argon2 version %d", version)
}
var p Argon2idParams
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism); err != nil {
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad params segment: %w", err)
}
enc := base64.RawStdEncoding
salt, err := enc.DecodeString(parts[4])
if err != nil {
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad salt: %w", err)
}
key, err := enc.DecodeString(parts[5])
if err != nil {
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad key: %w", err)
}
p.SaltLen = uint32(len(salt))
p.KeyLen = uint32(len(key))
return p, salt, key, nil
}

78
hasher/argon2id_test.go Normal file
View file

@ -0,0 +1,78 @@
package hasher
import (
"strings"
"testing"
)
func TestArgon2idHashVerifyRoundtrip(t *testing.T) {
h := NewArgon2id(DefaultArgon2idParams(), nil)
encoded, err := h.Hash("hunter2hunter2")
if err != nil {
t.Fatalf("Hash: %v", err)
}
if !strings.HasPrefix(encoded, "$argon2id$") {
t.Fatalf("encoded hash not in PHC form: %s", encoded)
}
ok, needsRehash, err := h.Verify("hunter2hunter2", encoded)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if !ok {
t.Fatalf("Verify rejected the original password")
}
if needsRehash {
t.Fatalf("Verify with default params should not signal rehash")
}
}
func TestArgon2idVerifyWrongPassword(t *testing.T) {
h := NewArgon2id(DefaultArgon2idParams(), nil)
encoded, err := h.Hash("correct horse battery staple")
if err != nil {
t.Fatalf("Hash: %v", err)
}
ok, _, err := h.Verify("nope", encoded)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if ok {
t.Fatalf("Verify should reject wrong password")
}
}
func TestArgon2idNeedsRehashOnParamChange(t *testing.T) {
// Hash with light params...
light := Argon2idParams{Memory: 8 * 1024, Iterations: 1, Parallelism: 1, SaltLen: 16, KeyLen: 32}
encoded, err := NewArgon2id(light, nil).Hash("hello world")
if err != nil {
t.Fatalf("Hash: %v", err)
}
// ...verify with stronger params should still match but flag rehash.
heavier := DefaultArgon2idParams()
ok, needsRehash, err := NewArgon2id(heavier, nil).Verify("hello world", encoded)
if err != nil {
t.Fatalf("Verify: %v", err)
}
if !ok {
t.Fatalf("Verify rejected legitimate password across params")
}
if !needsRehash {
t.Fatalf("Verify should flag rehash when stored params differ from current")
}
}
func TestArgon2idRejectsMalformed(t *testing.T) {
h := NewArgon2id(DefaultArgon2idParams(), nil)
cases := []string{
"",
"not-a-phc",
"$argon2i$v=19$m=64,t=1,p=1$abc$def",
"$argon2id$v=99$m=64,t=1,p=1$YWJj$ZGVm",
}
for _, c := range cases {
if _, _, err := h.Verify("x", c); err == nil {
t.Fatalf("Verify should reject malformed encoding: %q", c)
}
}
}

69
jwt.go Normal file
View file

@ -0,0 +1,69 @@
package authkit
import (
"git.juancwu.dev/juancwu/errx"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// accessClaims is the JWT shape issued by IssueJWT. The session_version
// field carries the User.SessionVersion at issue time so AuthenticateJWT
// can detect global revocations (logout-everywhere, password change).
type accessClaims struct {
jwt.RegisteredClaims
SessionVersion int `json:"sv"`
Method string `json:"m"`
}
func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, error) {
const op = "authkit.signAccessToken"
now := a.now()
claims := accessClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: userID.String(),
Issuer: a.cfg.JWTIssuer,
Audience: jwt.ClaimStrings{a.cfg.JWTAudience},
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(a.cfg.AccessTokenTTL)),
ID: uuid.NewString(),
},
SessionVersion: sessionVersion,
Method: string(AuthMethodJWT),
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := tok.SignedString(a.cfg.JWTSecret)
if err != nil {
return "", errx.Wrap(op, err)
}
return signed, nil
}
// parseAccessToken validates the signature and returns the parsed claims.
func (a *Auth) parseAccessToken(token string) (*accessClaims, error) {
const op = "authkit.parseAccessToken"
opts := []jwt.ParserOption{
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithExpirationRequired(),
jwt.WithIssuedAt(),
jwt.WithTimeFunc(a.cfg.Clock),
}
if a.cfg.JWTIssuer != "" {
opts = append(opts, jwt.WithIssuer(a.cfg.JWTIssuer))
}
if a.cfg.JWTAudience != "" {
opts = append(opts, jwt.WithAudience(a.cfg.JWTAudience))
}
parser := jwt.NewParser(opts...)
parsed, err := parser.ParseWithClaims(token, &accessClaims{}, func(t *jwt.Token) (any, error) {
return a.cfg.JWTSecret, nil
})
if err != nil {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
claims, ok := parsed.Claims.(*accessClaims)
if !ok || !parsed.Valid {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
return claims, nil
}

99
jwt_test.go Normal file
View file

@ -0,0 +1,99 @@
package authkit
import (
"context"
"errors"
"testing"
"time"
)
func TestJWTIssueAndAuthenticate(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "alice@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, refresh, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if access == "" || refresh == "" {
t.Fatalf("empty tokens")
}
p, err := a.AuthenticateJWT(context.Background(), access)
if err != nil {
t.Fatalf("AuthenticateJWT: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("principal user id mismatch")
}
if p.Method != AuthMethodJWT {
t.Fatalf("principal method = %s, want jwt", p.Method)
}
}
func TestJWTSessionVersionMismatchRejected(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "bob@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, _, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil {
t.Fatalf("RevokeAllUserSessions: %v", err)
}
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err)
}
}
func TestJWTRefreshRotationAndReuseDetection(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "carol@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
_, refresh1, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
_, refresh2, err := a.RefreshJWT(context.Background(), refresh1)
if err != nil {
t.Fatalf("first RefreshJWT: %v", err)
}
if refresh1 == refresh2 {
t.Fatalf("refresh token did not rotate")
}
// Replaying refresh1 must surface ErrTokenReused and revoke the chain.
if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) {
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
}
// After chain revocation, even refresh2 (the legitimate next one) must
// be rejected.
if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err)
}
}
func TestJWTExpiredTokenRejected(t *testing.T) {
a := newTestAuth(t)
now := time.Now().UTC()
a.cfg.Clock = func() time.Time { return now }
u, err := a.Register(context.Background(), "dan@example.com", "hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
access, _, err := a.IssueJWT(context.Background(), u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
// Advance clock past TTL.
a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) }
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err)
}
}

619
memstore_test.go Normal file
View file

@ -0,0 +1,619 @@
package authkit
// In-memory store fakes used by service-level tests. Kept in a _test.go file
// so they don't ship in the public API. Each fake is intentionally minimal —
// it satisfies the interface and supports the flows the tests exercise; not
// every method is wired up (a few panic to flag accidental use).
import (
"bytes"
"context"
"errors"
"slices"
"sync"
"time"
"github.com/google/uuid"
)
type memUserStore struct {
mu sync.Mutex
byID map[uuid.UUID]*User
byEml map[string]uuid.UUID
}
func newMemUserStore() *memUserStore {
return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}}
}
func (s *memUserStore) CreateUser(_ context.Context, u *User) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.byEml[u.EmailNormalized]; exists {
return ErrEmailTaken
}
if u.ID == uuid.Nil {
u.ID = uuid.New()
}
cp := *u
s.byID[u.ID] = &cp
s.byEml[u.EmailNormalized] = u.ID
return nil
}
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return nil, ErrUserNotFound
}
cp := *u
return &cp, nil
}
func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.byEml[ne]
if !ok {
return nil, ErrUserNotFound
}
u := *s.byID[id]
return &u, nil
}
func (s *memUserStore) UpdateUser(_ context.Context, u *User) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.byID[u.ID]; !ok {
return ErrUserNotFound
}
cp := *u
s.byID[u.ID] = &cp
return nil
}
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
delete(s.byID, id)
delete(s.byEml, u.EmailNormalized)
return nil
}
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.PasswordHash = h
return nil
}
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.EmailVerifiedAt = &at
return nil
}
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return 0, ErrUserNotFound
}
u.SessionVersion++
return u.SessionVersion, nil
}
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return 0, ErrUserNotFound
}
u.FailedLogins++
return u.FailedLogins, nil
}
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
u, ok := s.byID[id]
if !ok {
return ErrUserNotFound
}
u.FailedLogins = 0
return nil
}
type memSessionStore struct {
mu sync.Mutex
m map[string]*Session
}
func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} }
func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *ses
s.m[string(ses.IDHash)] = &cp
return nil
}
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.m[string(h)]
if !ok {
return nil, ErrSessionInvalid
}
cp := *v
return &cp, nil
}
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.m[string(h)]
if !ok {
return ErrSessionInvalid
}
v.LastSeenAt = lastSeen
v.ExpiresAt = exp
return nil
}
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.m, string(h))
return nil
}
func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.m {
if v.UserID == uid {
delete(s.m, k)
}
}
return nil
}
func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, v := range s.m {
if !v.ExpiresAt.After(now) {
delete(s.m, k)
n++
}
}
return n, nil
}
type memTokenStore struct {
mu sync.Mutex
// Keyed by kind+hex(hash) so identical hashes across kinds don't collide.
m map[string]*Token
}
func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} }
func tokKey(kind TokenKind, h []byte) string {
return string(kind) + ":" + string(h)
}
func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *t
s.m[tokKey(t.Kind, t.Hash)] = &cp
return nil
}
func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.m[tokKey(kind, h)]
if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) {
return nil, ErrTokenInvalid
}
t.ConsumedAt = &now
cp := *t
return &cp, nil
}
func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.m[tokKey(kind, h)]
if !ok {
return nil, ErrTokenInvalid
}
cp := *t
return &cp, nil
}
func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, t := range s.m {
if t.ChainID != nil && *t.ChainID == chainID {
delete(s.m, k)
n++
}
}
return n, nil
}
func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
var n int64
for k, t := range s.m {
if !t.ExpiresAt.After(now) {
delete(s.m, k)
n++
}
}
return n, nil
}
type memAPIKeyStore struct {
mu sync.Mutex
m map[string]*APIKey
}
func newMemAPIKeyStore() *memAPIKeyStore { return &memAPIKeyStore{m: map[string]*APIKey{}} }
func (s *memAPIKeyStore) CreateAPIKey(_ context.Context, k *APIKey) error {
s.mu.Lock()
defer s.mu.Unlock()
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
s.m[string(k.IDHash)] = &cp
return nil
}
func (s *memAPIKeyStore) GetAPIKey(_ context.Context, h []byte) (*APIKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return nil, ErrAPIKeyInvalid
}
cp := *k
cp.Abilities = append([]string(nil), k.Abilities...)
return &cp, nil
}
func (s *memAPIKeyStore) ListAPIKeysByOwner(_ context.Context, owner uuid.UUID) ([]*APIKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
var out []*APIKey
for _, k := range s.m {
if k.OwnerID == owner {
cp := *k
out = append(out, &cp)
}
}
return out, nil
}
func (s *memAPIKeyStore) TouchAPIKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
if k, ok := s.m[string(h)]; ok {
k.LastUsedAt = &at
}
return nil
}
func (s *memAPIKeyStore) RevokeAPIKey(_ context.Context, h []byte, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.m[string(h)]
if !ok {
return ErrAPIKeyInvalid
}
if k.RevokedAt != nil {
return ErrAPIKeyInvalid
}
k.RevokedAt = &at
return nil
}
func (s *memAPIKeyStore) RevokeAPIKeysByOwner(_ context.Context, owner uuid.UUID, at time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, k := range s.m {
if k.OwnerID == owner && k.RevokedAt == nil {
k.RevokedAt = &at
}
}
return nil
}
type memRoleStore struct {
mu sync.Mutex
roles map[uuid.UUID]*Role
rolesByNm map[string]uuid.UUID
userRoles map[uuid.UUID]map[uuid.UUID]struct{}
}
func newMemRoleStore() *memRoleStore {
return &memRoleStore{
roles: map[uuid.UUID]*Role{},
rolesByNm: map[string]uuid.UUID{},
userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{},
}
}
func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.rolesByNm[r.Name]; dup {
return errors.New("role exists")
}
if r.ID == uuid.Nil {
r.ID = uuid.New()
}
cp := *r
s.roles[r.ID] = &cp
s.rolesByNm[r.Name] = r.ID
return nil
}
func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.roles[id]
if !ok {
return nil, ErrRoleNotFound
}
cp := *r
return &cp, nil
}
func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.rolesByNm[n]
if !ok {
return nil, ErrRoleNotFound
}
r := *s.roles[id]
return &r, nil
}
func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*Role, 0, len(s.roles))
for _, r := range s.roles {
cp := *r
out = append(out, &cp)
}
return out, nil
}
func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.roles[id]
if !ok {
return ErrRoleNotFound
}
delete(s.roles, id)
delete(s.rolesByNm, r.Name)
for _, m := range s.userRoles {
delete(m, id)
}
return nil
}
func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
m, ok := s.userRoles[uid]
if !ok {
m = map[uuid.UUID]struct{}{}
s.userRoles[uid] = m
}
m[rid] = struct{}{}
return nil
}
func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
if m, ok := s.userRoles[uid]; ok {
delete(m, rid)
}
return nil
}
func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := []*Role{}
for rid := range s.userRoles[uid] {
if r, ok := s.roles[rid]; ok {
cp := *r
out = append(out, &cp)
}
}
return out, nil
}
func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
for rid := range s.userRoles[uid] {
r := s.roles[rid]
if slices.Contains(names, r.Name) {
return true, nil
}
}
return false, nil
}
type memPermStore struct {
mu sync.Mutex
perms map[uuid.UUID]*Permission
permsByNm map[string]uuid.UUID
rolePerms map[uuid.UUID]map[uuid.UUID]struct{}
roles *memRoleStore
}
func newMemPermStore(rs *memRoleStore) *memPermStore {
return &memPermStore{
perms: map[uuid.UUID]*Permission{},
permsByNm: map[string]uuid.UUID{},
rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{},
roles: rs,
}
}
func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.permsByNm[p.Name]; dup {
return errors.New("perm exists")
}
if p.ID == uuid.Nil {
p.ID = uuid.New()
}
cp := *p
s.perms[p.ID] = &cp
s.permsByNm[p.Name] = p.ID
return nil
}
func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.perms[id]
if !ok {
return nil, ErrPermissionNotFound
}
cp := *p
return &cp, nil
}
func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.permsByNm[n]
if !ok {
return nil, ErrPermissionNotFound
}
p := *s.perms[id]
return &p, nil
}
func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*Permission, 0, len(s.perms))
for _, p := range s.perms {
cp := *p
out = append(out, &cp)
}
return out, nil
}
func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
p, ok := s.perms[id]
if !ok {
return ErrPermissionNotFound
}
delete(s.perms, id)
delete(s.permsByNm, p.Name)
for _, m := range s.rolePerms {
delete(m, id)
}
return nil
}
func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
m, ok := s.rolePerms[rid]
if !ok {
m = map[uuid.UUID]struct{}{}
s.rolePerms[rid] = m
}
m[pid] = struct{}{}
return nil
}
func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error {
s.mu.Lock()
defer s.mu.Unlock()
if m, ok := s.rolePerms[rid]; ok {
delete(m, pid)
}
return nil
}
func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := []*Permission{}
for pid := range s.rolePerms[rid] {
if p, ok := s.perms[pid]; ok {
cp := *p
out = append(out, &cp)
}
}
return out, nil
}
func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) {
s.roles.mu.Lock()
roleIDs := make([]uuid.UUID, 0)
for rid := range s.roles.userRoles[uid] {
roleIDs = append(roleIDs, rid)
}
s.roles.mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
seen := map[uuid.UUID]struct{}{}
out := []*Permission{}
for _, rid := range roleIDs {
for pid := range s.rolePerms[rid] {
if _, dup := seen[pid]; dup {
continue
}
seen[pid] = struct{}{}
if p, ok := s.perms[pid]; ok {
cp := *p
out = append(out, &cp)
}
}
}
return out, nil
}
// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in
// the hasher package itself; we just need *something* callable here.
type stubHasher struct{}
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
want := "stub:" + p
if !bytes.Equal([]byte(want), []byte(encoded)) {
return false, false, nil
}
return true, false, nil
}
// newTestAuth wires the fakes into Auth with deterministic config.
func newTestAuth(t interface{ Helper() }) *Auth {
if h, ok := t.(interface{ Helper() }); ok {
h.Helper()
}
roles := newMemRoleStore()
return New(Deps{
Users: newMemUserStore(),
Sessions: newMemSessionStore(),
Tokens: newMemTokenStore(),
APIKeys: newMemAPIKeyStore(),
Roles: roles,
Permissions: newMemPermStore(roles),
Hasher: stubHasher{},
}, Config{
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
JWTIssuer: "authkit-test",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: 1 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
})
}

71
middleware/authz.go Normal file
View file

@ -0,0 +1,71 @@
package middleware
import (
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// authzGuard wraps the common pattern of "look up the Principal, run a
// predicate, succeed or 403". onForbidden defaults to JSON 403.
func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler {
if onForbidden == nil {
onForbidden = defaultJSONError(http.StatusForbidden)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p, ok := PrincipalFrom(r.Context())
if !ok {
// No auth middleware ran upstream; treat as forbidden
// rather than crashing — composition is the caller's
// responsibility but a 403 is the safer default.
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
if !pred(p) {
onForbidden(w, r, authkit.ErrPermissionDenied)
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireRole permits requests whose Principal holds the named role.
func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasRole(name)
})
}
// RequireAnyRole permits requests whose Principal holds at least one of the
// named roles.
func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasAnyRole(names...)
})
}
// RequirePermission permits requests whose Principal holds the named
// permission (resolved via roles at auth time).
func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasPermission(name)
})
}
// RequireAbility permits requests whose Principal carries the named ability.
// Abilities are populated only for API-key authentication; this middleware
// will reject session/JWT-authenticated requests by design.
func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
return p.HasAbility(name)
})
}
func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) {
if len(s) == 0 {
return nil
}
return s[0]
}

39
middleware/context.go Normal file
View file

@ -0,0 +1,39 @@
// Package middleware provides framework-neutral HTTP middleware for authkit.
// Every middleware function returns the standard func(http.Handler)
// http.Handler type, so it composes with lightmux's Use/Group/Handle as well
// as any net/http stack that uses the same signature.
package middleware
import (
"context"
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// principalKey is an unexported context key. Using a distinct empty struct
// type guarantees no collision with caller-defined keys.
type principalKey struct{}
// withPrincipal stashes p on the request context for downstream handlers.
func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context {
return context.WithValue(ctx, principalKey{}, p)
}
// PrincipalFrom retrieves the authenticated Principal placed by RequireSession,
// RequireJWT, or RequireAPIKey. The boolean is false if no auth middleware
// ran for this request.
func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) {
p, ok := ctx.Value(principalKey{}).(*authkit.Principal)
return p, ok
}
// MustPrincipal panics if no Principal is on the context. Use only on
// handlers known to be behind a Require* middleware.
func MustPrincipal(r *http.Request) *authkit.Principal {
p, ok := PrincipalFrom(r.Context())
if !ok {
panic("authkit/middleware: no principal on request context")
}
return p
}

138
middleware/middleware.go Normal file
View file

@ -0,0 +1,138 @@
package middleware
import (
"encoding/json"
"net/http"
"git.juancwu.dev/juancwu/authkit"
)
// Options configures auth middleware. Auth is required; the rest fall back
// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403
// on authz failure.
type Options struct {
Auth *authkit.Auth
Extractor authkit.Extractor
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
}
func (o Options) extractor() authkit.Extractor {
if o.Extractor != nil {
return o.Extractor
}
return authkit.BearerExtractor()
}
func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) {
if o.OnUnauth != nil {
return o.OnUnauth
}
return defaultJSONError(http.StatusUnauthorized)
}
func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) {
if o.OnForbidden != nil {
return o.OnForbidden
}
return defaultJSONError(http.StatusForbidden)
}
func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) {
return func(w http.ResponseWriter, _ *http.Request, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(map[string]string{
"error": http.StatusText(status),
})
}
}
// RequireSession authenticates the request via an opaque session string. The
// extractor is consulted first; if no extractor is set the default Bearer
// extractor is used. For cookie-based session lookup, set
// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName).
func RequireSession(opts Options) func(http.Handler) http.Handler {
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
return opts.Auth.AuthenticateSession(r.Context(), raw)
})
}
// RequireJWT authenticates the request via an HS256 JWT.
func RequireJWT(opts Options) func(http.Handler) http.Handler {
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
return opts.Auth.AuthenticateJWT(r.Context(), raw)
})
}
// RequireAPIKey authenticates the request via an opaque API secret.
func RequireAPIKey(opts Options) func(http.Handler) http.Handler {
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
return opts.Auth.AuthenticateAPIKey(r.Context(), raw)
})
}
// RequireAny tries each method in order until one succeeds. Useful for routes
// that accept either a session cookie or an API key.
func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
if len(methods) == 0 {
methods = []authkit.AuthMethod{
authkit.AuthMethodSession,
authkit.AuthMethodJWT,
authkit.AuthMethodAPIKey,
}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := opts.extractor()(r)
if !ok || raw == "" {
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
return
}
var (
p *authkit.Principal
lastErr error
)
for _, m := range methods {
switch m {
case authkit.AuthMethodSession:
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
case authkit.AuthMethodJWT:
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
case authkit.AuthMethodAPIKey:
p, lastErr = opts.Auth.AuthenticateAPIKey(r.Context(), raw)
}
if lastErr == nil && p != nil {
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
return
}
}
opts.onUnauth()(w, r, lastErr)
})
}
}
// requireWith is the shared scaffolding for the single-method Require*
// middlewares.
func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler {
if opts.Auth == nil {
panic("authkit/middleware: Options.Auth is required")
}
extractor := opts.extractor()
onUnauth := opts.onUnauth()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, ok := extractor(r)
if !ok || raw == "" {
onUnauth(w, r, authkit.ErrSessionInvalid)
return
}
p, err := authn(r, raw)
if err != nil {
onUnauth(w, r, err)
return
}
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
})
}
}

75
models.go Normal file
View file

@ -0,0 +1,75 @@
package authkit
import (
"net/netip"
"time"
"github.com/google/uuid"
)
type User struct {
ID uuid.UUID
Email string
EmailNormalized string
EmailVerifiedAt *time.Time
PasswordHash string
SessionVersion int
FailedLogins int
LastLoginAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct {
IDHash []byte
UserID uuid.UUID
UserAgent string
IP netip.Addr
CreatedAt time.Time
LastSeenAt time.Time
ExpiresAt time.Time
}
type TokenKind string
const (
TokenEmailVerify TokenKind = "email_verify"
TokenPasswordReset TokenKind = "password_reset"
TokenMagicLink TokenKind = "magic_link"
TokenRefresh TokenKind = "refresh"
)
type Token struct {
Hash []byte
Kind TokenKind
UserID uuid.UUID
ChainID *string
ConsumedAt *time.Time
CreatedAt time.Time
ExpiresAt time.Time
}
type APIKey struct {
IDHash []byte
OwnerID uuid.UUID
Name string
Abilities []string
LastUsedAt *time.Time
CreatedAt time.Time
ExpiresAt *time.Time
RevokedAt *time.Time
}
type Role struct {
ID uuid.UUID
Name string
Description string
CreatedAt time.Time
}
type Permission struct {
ID uuid.UUID
Name string
Description string
CreatedAt time.Time
}

63
principal.go Normal file
View file

@ -0,0 +1,63 @@
package authkit
import (
"time"
"github.com/google/uuid"
)
type AuthMethod string
const (
AuthMethodSession AuthMethod = "session"
AuthMethodJWT AuthMethod = "jwt"
AuthMethodAPIKey AuthMethod = "api_key"
)
type Principal struct {
UserID uuid.UUID
Method AuthMethod
SessionID []byte
APIKeyID []byte
Roles []string
Permissions []string
Abilities []string
IssuedAt time.Time
ExpiresAt time.Time
}
func (p *Principal) HasRole(name string) bool {
for _, r := range p.Roles {
if r == name {
return true
}
}
return false
}
func (p *Principal) HasAnyRole(names ...string) bool {
for _, n := range names {
if p.HasRole(n) {
return true
}
}
return false
}
func (p *Principal) HasPermission(name string) bool {
for _, perm := range p.Permissions {
if perm == name {
return true
}
}
return false
}
func (p *Principal) HasAbility(name string) bool {
for _, a := range p.Abilities {
if a == name {
return true
}
}
return false
}

92
service_apikey.go Normal file
View file

@ -0,0 +1,92 @@
package authkit
import (
"context"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// IssueAPIKey mints a fresh API secret with the given abilities and an
// optional TTL. The plaintext is returned to the caller (show-once) and the
// SHA-256 lookup hash is stored. Pass ttl=nil for a non-expiring key.
func (a *Auth) IssueAPIKey(ctx context.Context, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *APIKey, error) {
const op = "authkit.Auth.IssueAPIKey"
plaintext, hash, err := mintSecret(prefixAPIKey, a.cfg.Random)
if err != nil {
return "", nil, errx.Wrap(op, err)
}
now := a.now()
k := &APIKey{
IDHash: hash,
OwnerID: ownerID,
Name: name,
Abilities: append([]string(nil), abilities...),
CreatedAt: now,
}
if ttl != nil {
exp := now.Add(*ttl)
k.ExpiresAt = &exp
}
if err := a.deps.APIKeys.CreateAPIKey(ctx, k); err != nil {
return "", nil, errx.Wrap(op, err)
}
return plaintext, k, nil
}
// AuthenticateAPIKey validates an API secret string, touches last_used_at
// (best-effort), and returns a Principal carrying the key's abilities. The
// owning user's roles+permissions are also resolved so the same Principal
// can satisfy RequireRole / RequirePermission middleware.
func (a *Auth) AuthenticateAPIKey(ctx context.Context, plaintext string) (*Principal, error) {
const op = "authkit.Auth.AuthenticateAPIKey"
hash, ok := parseSecret(prefixAPIKey, plaintext)
if !ok {
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
}
k, err := a.deps.APIKeys.GetAPIKey(ctx, hash)
if err != nil {
return nil, errx.Wrap(op, err)
}
now := a.now()
if k.RevokedAt != nil {
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
}
if k.ExpiresAt != nil && !k.ExpiresAt.After(now) {
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
}
_ = a.deps.APIKeys.TouchAPIKey(ctx, hash, now)
roles, perms, err := a.resolveRolesAndPermissions(ctx, k.OwnerID)
if err != nil {
return nil, errx.Wrap(op, err)
}
expires := now
if k.ExpiresAt != nil {
expires = *k.ExpiresAt
}
return &Principal{
UserID: k.OwnerID,
Method: AuthMethodAPIKey,
APIKeyID: hash,
Roles: roles,
Permissions: perms,
Abilities: append([]string(nil), k.Abilities...),
IssuedAt: k.CreatedAt,
ExpiresAt: expires,
}, nil
}
// RevokeAPIKey marks a key revoked. Idempotent on already-revoked keys.
func (a *Auth) RevokeAPIKey(ctx context.Context, plaintext string) error {
const op = "authkit.Auth.RevokeAPIKey"
hash, ok := parseSecret(prefixAPIKey, plaintext)
if !ok {
return errx.Wrap(op, ErrAPIKeyInvalid)
}
if err := a.deps.APIKeys.RevokeAPIKey(ctx, hash, a.now()); err != nil {
return errx.Wrap(op, err)
}
return nil
}

85
service_authz.go Normal file
View file

@ -0,0 +1,85 @@
package authkit
import (
"context"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// UserPermissions returns the union of permission names a user holds via
// their assigned roles. Resolved at call time; v1 does not cache.
func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
const op = "authkit.Auth.UserPermissions"
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
if err != nil {
return nil, errx.Wrap(op, err)
}
out := make([]string, len(perms))
for i, p := range perms {
out[i] = p.Name
}
return out, nil
}
// HasPermission checks whether a user holds the named permission via any
// assigned role.
func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
const op = "authkit.Auth.HasPermission"
perms, err := a.UserPermissions(ctx, userID)
if err != nil {
return false, errx.Wrap(op, err)
}
for _, p := range perms {
if p == name {
return true, nil
}
}
return false, nil
}
// HasRole checks whether a user is assigned the named role.
func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
const op = "authkit.Auth.HasRole"
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name})
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// HasAnyRole checks whether a user holds at least one of the named roles.
func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
const op = "authkit.Auth.HasAnyRole"
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names)
if err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// AssignRole is a convenience that looks up a role by name and assigns it.
func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error {
const op = "authkit.Auth.AssignRole"
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RemoveRole is the symmetric helper for AssignRole.
func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error {
const op = "authkit.Auth.RemoveRole"
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil {
return errx.Wrap(op, err)
}
return nil
}

147
service_jwt.go Normal file
View file

@ -0,0 +1,147 @@
package authkit
import (
"context"
"errors"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// IssueJWT issues a fresh access JWT and a rotating opaque refresh token.
// The refresh token is bound to a chain via Token.ChainID; rotation
// preserves that chain so reuse-detection can revoke the whole family.
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
const op = "authkit.Auth.IssueJWT"
u, err := a.deps.Users.GetUserByID(ctx, userID)
if err != nil {
return "", "", errx.Wrap(op, err)
}
access, err = a.signAccessToken(u.ID, u.SessionVersion)
if err != nil {
return "", "", errx.Wrap(op, err)
}
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString())
if err != nil {
return "", "", errx.Wrap(op, err)
}
return access, refresh, nil
}
// AuthenticateJWT validates the access JWT, cross-checks the user's
// session_version (instant revocation), and resolves a Principal.
func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, error) {
const op = "authkit.Auth.AuthenticateJWT"
claims, err := a.parseAccessToken(access)
if err != nil {
return nil, errx.Wrap(op, err)
}
uid, err := uuid.Parse(claims.Subject)
if err != nil {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
u, err := a.deps.Users.GetUserByID(ctx, uid)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
return nil, errx.Wrap(op, err)
}
if u.SessionVersion != claims.SessionVersion {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
roles, perms, err := a.resolveRolesAndPermissions(ctx, u.ID)
if err != nil {
return nil, errx.Wrap(op, err)
}
return &Principal{
UserID: u.ID,
Method: AuthMethodJWT,
Roles: roles,
Permissions: perms,
IssuedAt: claims.IssuedAt.Time,
ExpiresAt: claims.ExpiresAt.Time,
}, nil
}
// RefreshJWT consumes the presented refresh token and mints a new access +
// refresh pair. Reuse of an already-consumed refresh token deletes the
// entire chain (logout-everywhere on that device family) and returns
// ErrTokenReused.
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
const op = "authkit.Auth.RefreshJWT"
hash, ok := parseSecret(prefixRefresh, plaintextRefresh)
if !ok {
return "", "", errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now)
if err != nil {
// Differentiate plain-invalid (never existed / expired) from
// reuse (existed, already consumed). The presence-check below is
// the reuse signal.
if errors.Is(err, ErrTokenInvalid) {
if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil {
if existing.ChainID != nil && *existing.ChainID != "" {
_, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID)
}
return "", "", errx.Wrap(op, ErrTokenReused)
}
}
return "", "", errx.Wrap(op, err)
}
var chainID string
if consumed.ChainID != nil {
chainID = *consumed.ChainID
}
if chainID == "" {
// Defensive: every refresh token should be chain-bound. Fall back
// to a fresh chain so we never throw on missing metadata.
chainID = uuid.NewString()
}
access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID))
if err != nil {
return "", "", errx.Wrap(op, err)
}
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID)
if err != nil {
return "", "", errx.Wrap(op, err)
}
return access, refresh, nil
}
// mintRefreshToken stores a fresh refresh token bound to chainID and returns
// the plaintext.
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
const op = "authkit.Auth.mintRefreshToken"
plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random)
if err != nil {
return "", errx.Wrap(op, err)
}
now := a.now()
t := &Token{
Hash: hash,
Kind: TokenRefresh,
UserID: userID,
ChainID: &chainID,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// userSessionVersion fetches the current session_version. Errors collapse to
// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly
// — but we still need a value to embed in the freshly-minted access token.
func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int {
if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil {
return u.SessionVersion
}
return 0
}

62
service_magic.go Normal file
View file

@ -0,0 +1,62 @@
package authkit
import (
"context"
"git.juancwu.dev/juancwu/errx"
)
// RequestMagicLink mints a single-use magic-link token for the email and
// returns the plaintext for delivery. ErrUserNotFound is returned for
// unregistered emails.
func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) {
const op = "authkit.Auth.RequestMagicLink"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
return "", errx.Wrap(op, err)
}
plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random)
if err != nil {
return "", errx.Wrap(op, err)
}
now := a.now()
t := &Token{
Hash: hash,
Kind: TokenMagicLink,
UserID: u.ID,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.MagicLinkTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConsumeMagicLink consumes the magic-link token and returns the
// authenticated user. Callers typically follow this with IssueSession or
// IssueJWT to actually log the user in.
func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) {
const op = "authkit.Auth.ConsumeMagicLink"
hash, ok := parseSecret(prefixMagicLink, plaintextToken)
if !ok {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now)
if err != nil {
return nil, errx.Wrap(op, err)
}
u, err := a.deps.Users.GetUserByID(ctx, t.UserID)
if err != nil {
return nil, errx.Wrap(op, err)
}
// A successful magic-link login also implicitly verifies the email
// (the user demonstrably controls the inbox).
if u.EmailVerifiedAt == nil {
if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil {
u.EmailVerifiedAt = &now
}
}
return u, nil
}

69
service_reset.go Normal file
View file

@ -0,0 +1,69 @@
package authkit
import (
"context"
"errors"
"git.juancwu.dev/juancwu/errx"
)
// RequestPasswordReset mints a single-use password-reset token for the user
// behind email and returns the plaintext for the caller to deliver via email.
// Returns ErrUserNotFound when the email isn't registered (per project
// policy of distinct errors over anti-enumeration).
func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) {
const op = "authkit.Auth.RequestPasswordReset"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
return "", errx.Wrap(op, err)
}
plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random)
if err != nil {
return "", errx.Wrap(op, err)
}
now := a.now()
t := &Token{
Hash: hash,
Kind: TokenPasswordReset,
UserID: u.ID,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.PasswordResetTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConfirmPasswordReset consumes the reset token, sets the new password,
// bumps the user's session_version, and revokes outstanding sessions so the
// reset constitutes a global logout.
func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error {
const op = "authkit.Auth.ConfirmPasswordReset"
hash, ok := parseSecret(prefixPasswordRset, plaintextToken)
if !ok {
return errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now)
if err != nil {
return errx.Wrap(op, err)
}
newHash, err := a.deps.Hasher.Hash(newPassword)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil {
if errors.Is(err, ErrUserNotFound) {
return errx.Wrap(op, ErrUserNotFound)
}
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil {
return errx.Wrap(op, err)
}
return nil
}

154
service_session.go Normal file
View file

@ -0,0 +1,154 @@
package authkit
import (
"context"
"net/http"
"net/netip"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// IssueSession mints an opaque session ID, persists the session record, and
// returns the plaintext (for the cookie) plus the stored Session.
func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) {
const op = "authkit.Auth.IssueSession"
plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random)
if err != nil {
return "", nil, errx.Wrap(op, err)
}
now := a.now()
expires := now.Add(a.cfg.SessionIdleTTL)
if cap := now.Add(a.cfg.SessionAbsoluteTTL); expires.After(cap) {
expires = cap
}
s := &Session{
IDHash: hash,
UserID: userID,
UserAgent: userAgent,
IP: ip,
CreatedAt: now,
LastSeenAt: now,
ExpiresAt: expires,
}
if err := a.deps.Sessions.CreateSession(ctx, s); err != nil {
return "", nil, errx.Wrap(op, err)
}
return plaintext, s, nil
}
// AuthenticateSession validates an opaque session string, slides the TTL,
// resolves the user's roles+permissions, and returns a Principal. Expired or
// unknown sessions return ErrSessionInvalid.
func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) {
const op = "authkit.Auth.AuthenticateSession"
hash, ok := parseSecret(prefixSession, plaintext)
if !ok {
return nil, errx.Wrap(op, ErrSessionInvalid)
}
s, err := a.deps.Sessions.GetSession(ctx, hash)
if err != nil {
return nil, errx.Wrap(op, err)
}
now := a.now()
if !s.ExpiresAt.After(now) {
_ = a.deps.Sessions.DeleteSession(ctx, hash)
return nil, errx.Wrap(op, ErrSessionInvalid)
}
// Slide the idle TTL, capped at created_at + AbsoluteTTL so an active
// session still expires at the absolute boundary.
newExpires := now.Add(a.cfg.SessionIdleTTL)
if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) {
newExpires = cap
}
if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil {
return nil, errx.Wrap(op, err)
}
roles, perms, err := a.resolveRolesAndPermissions(ctx, s.UserID)
if err != nil {
return nil, errx.Wrap(op, err)
}
return &Principal{
UserID: s.UserID,
Method: AuthMethodSession,
SessionID: hash,
Roles: roles,
Permissions: perms,
IssuedAt: s.CreatedAt,
ExpiresAt: newExpires,
}, nil
}
// RevokeSession deletes a single session by its plaintext id. Idempotent:
// missing sessions are not an error (logout twice should not 500).
func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error {
const op = "authkit.Auth.RevokeSession"
hash, ok := parseSecret(prefixSession, plaintext)
if !ok {
return nil
}
if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RevokeAllUserSessions kills every active session for the user and bumps
// the user's session_version (invalidating outstanding JWT access tokens).
func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.Auth.RevokeAllUserSessions"
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the
// plaintext returned by IssueSession; pass the matching ExpiresAt from the
// returned *Session as `expires`. To clear a cookie at logout, pass an empty
// plaintext and a past expiry.
func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie {
c := &http.Cookie{
Name: a.cfg.SessionCookieName,
Value: plaintext,
Path: a.cfg.SessionCookiePath,
Domain: a.cfg.SessionCookieDomain,
Secure: a.cfg.SessionCookieSecure,
HttpOnly: a.cfg.SessionCookieHTTPOnly,
SameSite: a.cfg.SessionCookieSameSite,
Expires: expires,
}
if plaintext == "" {
c.MaxAge = -1
}
return c
}
// resolveRolesAndPermissions fetches the user's role names and the union of
// their permission names. Both are returned as flat string slices for cheap
// containment checks on the Principal.
func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) {
roles, err := a.deps.Roles.GetUserRoles(ctx, userID)
if err != nil {
return nil, nil, err
}
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
if err != nil {
return nil, nil, err
}
rNames := make([]string, len(roles))
for i, r := range roles {
rNames[i] = r.Name
}
pNames := make([]string, len(perms))
for i, p := range perms {
pNames[i] = p.Name
}
return rNames, pNames, nil
}

220
service_test.go Normal file
View file

@ -0,0 +1,220 @@
package authkit
import (
"context"
"errors"
"net/netip"
"testing"
)
func TestRegisterAndLogin(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
if u.EmailNormalized != "alice@example.com" {
t.Fatalf("email_normalized = %q", u.EmailNormalized)
}
got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("LoginPassword: %v", err)
}
if got.ID != u.ID {
t.Fatalf("login user mismatch")
}
}
func TestRegisterDuplicateEmail(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil {
t.Fatalf("Register: %v", err)
}
_, err := a.Register(context.Background(), "X@Y.COM", "abc")
if !errors.Is(err, ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken, got %v", err)
}
}
func TestLoginWrongPassword(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil {
t.Fatalf("Register: %v", err)
}
if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestSessionIssueAuthenticateRevoke(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "s@s.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
if sess == nil || plain == "" {
t.Fatalf("missing session or plaintext")
}
p, err := a.AuthenticateSession(context.Background(), plain)
if err != nil {
t.Fatalf("AuthenticateSession: %v", err)
}
if p.UserID != u.ID {
t.Fatalf("principal user id mismatch")
}
if err := a.RevokeSession(context.Background(), plain); err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
}
}
func TestEmailVerificationFlow(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "ev@e.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
tok, err := a.RequestEmailVerification(context.Background(), u.ID)
if err != nil {
t.Fatalf("RequestEmailVerification: %v", err)
}
confirmed, err := a.ConfirmEmail(context.Background(), tok)
if err != nil {
t.Fatalf("ConfirmEmail: %v", err)
}
if confirmed.EmailVerifiedAt == nil {
t.Fatalf("email_verified_at not set")
}
// Re-using the token must fail.
if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err)
}
}
func TestPasswordResetFlow(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "r@r.com", "old")
if err != nil {
t.Fatalf("Register: %v", err)
}
// Issue a session that should be invalidated by the reset.
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
tok, err := a.RequestPasswordReset(context.Background(), "r@r.com")
if err != nil {
t.Fatalf("RequestPasswordReset: %v", err)
}
if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil {
t.Fatalf("ConfirmPasswordReset: %v", err)
}
if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("old password should fail post-reset, got %v", err)
}
if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil {
t.Fatalf("new password should work post-reset, got %v", err)
}
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
t.Fatalf("session should be invalidated by reset, got %v", err)
}
}
func TestMagicLinkFlow(t *testing.T) {
a := newTestAuth(t)
if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil {
t.Fatalf("Register: %v", err)
}
tok, err := a.RequestMagicLink(context.Background(), "m@m.com")
if err != nil {
t.Fatalf("RequestMagicLink: %v", err)
}
u, err := a.ConsumeMagicLink(context.Background(), tok)
if err != nil {
t.Fatalf("ConsumeMagicLink: %v", err)
}
if u.EmailVerifiedAt == nil {
t.Fatalf("magic link should imply email verification")
}
if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err)
}
}
func TestAPIKeyFlowWithAbilities(t *testing.T) {
a := newTestAuth(t)
u, err := a.Register(context.Background(), "k@k.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plaintext, k, err := a.IssueAPIKey(context.Background(), u.ID, "ci", []string{"billing:read", "users:list"}, nil)
if err != nil {
t.Fatalf("IssueAPIKey: %v", err)
}
if k == nil || plaintext == "" {
t.Fatalf("missing api key")
}
p, err := a.AuthenticateAPIKey(context.Background(), plaintext)
if err != nil {
t.Fatalf("AuthenticateAPIKey: %v", err)
}
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
t.Fatalf("abilities missing on principal: %+v", p.Abilities)
}
if p.HasAbility("admin:nuke") {
t.Fatalf("unexpected ability granted")
}
if err := a.RevokeAPIKey(context.Background(), plaintext); err != nil {
t.Fatalf("RevokeAPIKey: %v", err)
}
if _, err := a.AuthenticateAPIKey(context.Background(), plaintext); !errors.Is(err, ErrAPIKeyInvalid) {
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
}
}
func TestRBACRolesAndPermissions(t *testing.T) {
ctx := context.Background()
a := newTestAuth(t)
u, err := a.Register(ctx, "rb@a.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
// Create role + permission, hook them up.
role := &Role{Name: "editor"}
if err := a.deps.Roles.CreateRole(ctx, role); err != nil {
t.Fatalf("CreateRole: %v", err)
}
perm := &Permission{Name: "posts:write"}
if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil {
t.Fatalf("CreatePermission: %v", err)
}
if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil {
t.Fatalf("AssignPermissionToRole: %v", err)
}
if err := a.AssignRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
ok, err := a.HasPermission(ctx, u.ID, "posts:write")
if err != nil || !ok {
t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err)
}
ok, err = a.HasRole(ctx, u.ID, "editor")
if err != nil || !ok {
t.Fatalf("HasRole editor should be true, got %v %v", ok, err)
}
if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("RemoveRole: %v", err)
}
ok, _ = a.HasPermission(ctx, u.ID, "posts:write")
if ok {
t.Fatalf("HasPermission should be false after RemoveRole")
}
}

178
service_user.go Normal file
View file

@ -0,0 +1,178 @@
package authkit
import (
"context"
"errors"
"strings"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail
// and the email_normalized column. Trim + lowercase is intentional; we do
// not collapse Gmail-style "+" addressing or strip dots — that's a policy
// decision callers can layer on top.
func normalizeEmail(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
// Register creates a new user with an Argon2id-hashed password. Returns
// ErrEmailTaken if the normalized email is already registered.
func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) {
const op = "authkit.Auth.Register"
if email == "" || password == "" {
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
hash, err := a.deps.Hasher.Hash(password)
if err != nil {
return nil, errx.Wrap(op, err)
}
now := a.now()
u := &User{
ID: uuid.New(),
Email: email,
EmailNormalized: normalizeEmail(email),
PasswordHash: hash,
CreatedAt: now,
UpdatedAt: now,
}
if err := a.deps.Users.CreateUser(ctx, u); err != nil {
return nil, errx.Wrap(op, err)
}
return u, nil
}
// LoginPassword verifies the password and returns the authenticated user.
// Failure increments failed_logins; success resets it and stamps last_login_at.
// LoginHook (if configured) is invoked with the success outcome — use this to
// hook in rate limiting or audit logging.
func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) {
const op = "authkit.Auth.LoginPassword"
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
if err != nil {
_ = a.fireLoginHook(ctx, email, false)
if errors.Is(err, ErrUserNotFound) {
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
return nil, errx.Wrap(op, err)
}
if u.PasswordHash == "" {
_ = a.fireLoginHook(ctx, email, false)
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash)
if err != nil {
return nil, errx.Wrap(op, err)
}
if !ok {
_, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID)
_ = a.fireLoginHook(ctx, email, false)
return nil, errx.Wrap(op, ErrInvalidCredentials)
}
now := a.now()
u.LastLoginAt = &now
u.FailedLogins = 0
if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil {
return nil, errx.Wrap(op, err)
}
if err := a.deps.Users.UpdateUser(ctx, u); err != nil {
return nil, errx.Wrap(op, err)
}
if needsRehash {
if newHash, herr := a.deps.Hasher.Hash(password); herr == nil {
_ = a.deps.Users.SetPassword(ctx, u.ID, newHash)
u.PasswordHash = newHash
}
}
_ = a.fireLoginHook(ctx, email, true)
return u, nil
}
// ChangePassword verifies the current password, sets the new one, and bumps
// the user's session_version so all outstanding JWT access tokens are
// instantly invalidated. Outstanding opaque sessions are also revoked.
func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error {
const op = "authkit.Auth.ChangePassword"
u, err := a.deps.Users.GetUserByID(ctx, userID)
if err != nil {
return errx.Wrap(op, err)
}
if u.PasswordHash == "" {
return errx.Wrap(op, ErrInvalidCredentials)
}
ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash)
if err != nil {
return errx.Wrap(op, err)
}
if !ok {
return errx.Wrap(op, ErrInvalidCredentials)
}
newHash, err := a.deps.Hasher.Hash(newPassword)
if err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil {
return errx.Wrap(op, err)
}
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
return errx.Wrap(op, err)
}
return nil
}
// RequestEmailVerification mints a single-use email-verify token for the
// user. Return the plaintext to the caller so they can put it in an email
// link; the lookup hash is stored in TokenStore.
func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) {
const op = "authkit.Auth.RequestEmailVerification"
plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random)
if err != nil {
return "", errx.Wrap(op, err)
}
now := a.now()
t := &Token{
Hash: hash,
Kind: TokenEmailVerify,
UserID: userID,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.EmailVerifyTTL),
}
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
return "", errx.Wrap(op, err)
}
return plaintext, nil
}
// ConfirmEmail consumes the verification token and marks the user's email
// verified. Returns ErrTokenInvalid if the token is missing/expired/used.
func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) {
const op = "authkit.Auth.ConfirmEmail"
hash, ok := parseSecret(prefixEmailVerify, plaintextToken)
if !ok {
return nil, errx.Wrap(op, ErrTokenInvalid)
}
now := a.now()
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now)
if err != nil {
return nil, errx.Wrap(op, err)
}
if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil {
return nil, errx.Wrap(op, err)
}
return a.deps.Users.GetUserByID(ctx, t.UserID)
}
// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied
// hooks; we never want a misbehaving telemetry hook to break login.
func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error {
if a.cfg.LoginHook == nil {
return nil
}
return a.cfg.LoginHook(ctx, email, success)
}

124
sqlstore/apikeys.go Normal file
View file

@ -0,0 +1,124 @@
package sqlstore
import (
"context"
"database/sql"
"encoding/json"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type apiKeyStore struct{ storeBase }
func (s *apiKeyStore) CreateAPIKey(ctx context.Context, k *authkit.APIKey) error {
const op = "authkit.sqlstore.APIKeyStore.CreateAPIKey"
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now().UTC()
}
if k.Abilities == nil {
k.Abilities = []string{}
}
abilities, err := json.Marshal(k.Abilities)
if err != nil {
return errx.Wrap(op, err)
}
_, err = s.db.ExecContext(ctx, s.q.CreateAPIKey,
k.IDHash, uuidArg(k.OwnerID), k.Name, abilities,
nullableTime(k.LastUsedAt), k.CreatedAt,
nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt))
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *apiKeyStore) GetAPIKey(ctx context.Context, idHash []byte) (*authkit.APIKey, error) {
const op = "authkit.sqlstore.APIKeyStore.GetAPIKey"
k, err := scanAPIKey(s.db.QueryRowContext(ctx, s.q.GetAPIKey, idHash))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrAPIKeyInvalid))
}
return k, nil
}
func (s *apiKeyStore) ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*authkit.APIKey, error) {
const op = "authkit.sqlstore.APIKeyStore.ListAPIKeysByOwner"
rows, err := s.db.QueryContext(ctx, s.q.ListAPIKeysByOwner, uuidArg(ownerID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.APIKey
for rows.Next() {
k, err := scanAPIKey(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, k)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *apiKeyStore) TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.sqlstore.APIKeyStore.TouchAPIKey"
if _, err := s.db.ExecContext(ctx, s.q.TouchAPIKey, at, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *apiKeyStore) RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKey"
tag, err := s.db.ExecContext(ctx, s.q.RevokeAPIKey, at, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrAPIKeyInvalid)
}
return nil
}
func (s *apiKeyStore) RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error {
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKeysByOwner"
if _, err := s.db.ExecContext(ctx, s.q.RevokeAPIKeysByOwner, at, uuidArg(ownerID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func scanAPIKey(row rowScanner) (*authkit.APIKey, error) {
var (
k authkit.APIKey
ownerIDStr string
abilitiesRaw []byte
lastUsed sql.NullTime
expires sql.NullTime
revoked sql.NullTime
)
if err := row.Scan(&k.IDHash, &ownerIDStr, &k.Name, &abilitiesRaw,
&lastUsed, &k.CreatedAt, &expires, &revoked); err != nil {
return nil, err
}
owner, err := scanUUID(ownerIDStr)
if err != nil {
return nil, err
}
k.OwnerID = owner
if len(abilitiesRaw) > 0 {
if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil {
return nil, err
}
}
if k.Abilities == nil {
k.Abilities = []string{}
}
k.LastUsedAt = scanNullTimePtr(lastUsed)
k.ExpiresAt = scanNullTimePtr(expires)
k.RevokedAt = scanNullTimePtr(revoked)
return &k, nil
}

118
sqlstore/dialect.go Normal file
View file

@ -0,0 +1,118 @@
package sqlstore
import (
"context"
"database/sql"
"io/fs"
)
// Dialect describes how a particular SQL backend renders the queries authkit
// runs at runtime, bootstraps its session for migrations, and reports
// driver-specific errors. v1 ships the Postgres dialect; future MySQL or
// SQLite implementations satisfy the same interface without changes to the
// store code.
type Dialect interface {
// Name returns a stable short identifier ("postgres", "mysql", ...).
Name() string
// BuildQueries renders every templated SQL string from a validated
// Schema. Called once at New() and at Migrate() so identifiers and
// placeholder styles are baked in by the time stores execute queries.
BuildQueries(s Schema) Queries
// Bootstrap runs once per Migrate() before the migration lock is taken.
// It is the place for things that must happen outside any transaction
// (CREATE EXTENSION on Postgres, for example). Must be idempotent and
// may be a no-op.
Bootstrap(ctx context.Context, db *sql.DB) error
// AcquireMigrationLock takes a session-scoped lock on conn so concurrent
// migrators serialise. The returned release function never returns an
// error — implementations may log internally.
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
// Migrations returns the dialect's embedded .sql files, rooted such
// that fs.ReadDir(".") lists them lex-sorted by filename.
Migrations() fs.FS
// IsUniqueViolation maps a duplicate-key error from this driver to true
// so insert paths can return the matching authkit sentinel without
// depending on driver internals.
IsUniqueViolation(err error) bool
// Placeholder returns the placeholder for parameter index n (1-based)
// in this dialect — "$1" for Postgres, "?" for MySQL/SQLite.
Placeholder(n int) string
// PlaceholderList returns a comma-separated placeholder list for `count`
// parameters starting at position `start` (1-based), suitable for
// dynamic IN-clauses. The second return is the same length, prefilled
// with nil so the caller can append actual values.
PlaceholderList(start, count int) string
}
// Queries is the full set of statement templates authkit issues. Field names
// match the store method that consumes the query.
type Queries struct {
// users
CreateUser string
GetUserByID string
GetUserByEmail string
UpdateUser string
DeleteUser string
SetPassword string
SetEmailVerified string
BumpSessionVersion string
IncrementFailedLogins string
ResetFailedLogins string
// sessions
CreateSession string
GetSession string
TouchSession string
DeleteSession string
DeleteUserSessions string
DeleteExpiredSessions string
// tokens
CreateToken string
ConsumeToken string
GetToken string
DeleteByChain string
DeleteExpiredTokens string
// api keys
CreateAPIKey string
GetAPIKey string
ListAPIKeysByOwner string
TouchAPIKey string
RevokeAPIKey string
RevokeAPIKeysByOwner string
// roles
CreateRole string
GetRoleByID string
GetRoleByName string
ListRoles string
DeleteRole string
AssignRoleToUser string
RemoveRoleFromUser string
GetUserRoles string
// HasAnyRole is built at call time because the placeholder count varies.
// permissions
CreatePermission string
GetPermissionByID string
GetPermissionByName string
ListPermissions string
DeletePermission string
AssignPermissionToRole string
RemovePermissionFromRole string
GetRolePermissions string
GetUserPermissions string
// migrations
CreateMigrationsTable string
SelectAppliedVersions string
InsertAppliedVersion string
}

View file

@ -0,0 +1,33 @@
package postgres
import (
"errors"
"github.com/jackc/pgx/v5/pgconn"
)
// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib
// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError.
// lib/pq uses *pq.Error which has a Code field of the same value.
const pgUniqueViolation = "23505"
// isUniqueViolation inspects err for a Postgres unique-violation, regardless
// of which driver registered the connection. We match on either the pgx
// error type or any error implementing a Code() string method (lib/pq's
// pq.Error has SQLState and Code fields; we check via reflection-free
// duck-typing through an interface).
func isUniqueViolation(err error) bool {
if err == nil {
return false
}
var pgxErr *pgconn.PgError
if errors.As(err, &pgxErr) {
return pgxErr.Code == pgUniqueViolation
}
type sqlStater interface{ SQLState() string }
var s sqlStater
if errors.As(err, &s) {
return s.SQLState() == pgUniqueViolation
}
return false
}

View file

@ -0,0 +1,99 @@
-- 0001_init.sql
-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the
-- library can be embedded in an existing application database. Each
-- migration owns its own transaction and inserts its version row at the
-- bottom; the runner only orchestrates file discovery and concurrency.
BEGIN;
CREATE TABLE IF NOT EXISTS authkit_schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_users (
id UUID PRIMARY KEY,
email TEXT NOT NULL,
email_normalized TEXT NOT NULL,
email_verified_at TIMESTAMPTZ,
password_hash TEXT,
session_version INTEGER NOT NULL DEFAULT 0,
failed_logins INTEGER NOT NULL DEFAULT 0,
last_login_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq
ON authkit_users (email_normalized);
CREATE TABLE IF NOT EXISTS authkit_sessions (
id_hash BYTEA PRIMARY KEY,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
user_agent TEXT NOT NULL DEFAULT '',
ip TEXT,
created_at TIMESTAMPTZ NOT NULL,
last_seen_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id);
CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at);
CREATE TABLE IF NOT EXISTS authkit_tokens (
hash BYTEA NOT NULL,
kind TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
chain_id TEXT,
consumed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (kind, hash)
);
CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id);
CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at);
CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx
ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL;
CREATE TABLE IF NOT EXISTS authkit_api_keys (
id_hash BYTEA PRIMARY KEY,
owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
name TEXT NOT NULL,
abilities JSONB NOT NULL DEFAULT '[]'::jsonb,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL,
expires_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id);
CREATE TABLE IF NOT EXISTS authkit_roles (
id UUID PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_permissions (
id UUID PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS authkit_role_permissions (
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
PRIMARY KEY (role_id, permission_id)
);
CREATE TABLE IF NOT EXISTS authkit_user_roles (
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
granted_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (user_id, role_id)
);
CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id);
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now())
ON CONFLICT (version) DO NOTHING;
COMMIT;

View file

@ -0,0 +1,272 @@
// Package postgres is the Postgres dialect for authkit/sqlstore. Importing
// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"`
// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`.
package postgres
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
"strings"
"git.juancwu.dev/juancwu/authkit/sqlstore"
)
//go:embed migrations/*.sql
var migrationsFS embed.FS
// Dialect implements sqlstore.Dialect for Postgres. The zero value is the
// only required form; New() returns a pointer for clarity at call sites.
type Dialect struct{}
// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate.
func New() *Dialect { return &Dialect{} }
func (Dialect) Name() string { return "postgres" }
// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable
// across rollouts and unlikely to clash with caller advisory locks.
const advisoryLockKey int64 = 0x617574686b6974
func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error {
// Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a
// hook for future migration prerequisites.
return nil
}
func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) {
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil {
return nil, err
}
release := func() {
if _, err := conn.ExecContext(context.Background(),
"SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil {
log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err)
}
}
return release, nil
}
func (Dialect) Migrations() fs.FS {
sub, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
// migrationsFS is statically populated; this can only fail if the
// embed directive is removed, which would be a build-time error.
panic(err)
}
return sub
}
func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) }
func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) }
// PlaceholderList renders a comma-separated `$start,$start+1,...` list of
// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion.
func (Dialect) PlaceholderList(start, count int) string {
if count <= 0 {
return ""
}
var b strings.Builder
for i := 0; i < count; i++ {
if i > 0 {
b.WriteByte(',')
}
fmt.Fprintf(&b, "$%d", start+i)
}
return b.String()
}
// BuildQueries renders every query authkit issues, with table identifiers
// taken from s and `?` placeholders rewritten to `$N`. Identifiers are
// already validated by Schema.Validate — this is interpolated with
// fmt.Sprintf, so the validation gate is load-bearing.
func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries {
t := s.Tables
q := sqlstore.Queries{
// users
CreateUser: `INSERT INTO ` + t.Users + `
(id, email, email_normalized, email_verified_at, password_hash,
session_version, failed_logins, last_login_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
GetUserByID: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, failed_logins, last_login_at,
created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`,
GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at,
password_hash, session_version, failed_logins, last_login_at,
created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`,
UpdateUser: `UPDATE ` + t.Users + ` SET
email = ?, email_normalized = ?, email_verified_at = ?,
password_hash = ?, session_version = ?, failed_logins = ?,
last_login_at = ?, updated_at = ?
WHERE id = ?`,
DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`,
SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`,
SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`,
BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`,
IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`,
ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`,
// sessions
CreateSession: `INSERT INTO ` + t.Sessions + `
(id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at
FROM ` + t.Sessions + ` WHERE id_hash = ?`,
TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`,
DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`,
DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`,
DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`,
// tokens
CreateToken: `INSERT INTO ` + t.Tokens + `
(hash, kind, user_id, chain_id, consumed_at, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
ConsumeToken: `UPDATE ` + t.Tokens + `
SET consumed_at = ?
WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ?
RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`,
GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at
FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`,
DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`,
DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`,
// api keys
CreateAPIKey: `INSERT INTO ` + t.APIKeys + `
(id_hash, owner_id, name, abilities, last_used_at, created_at, expires_at, revoked_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
GetAPIKey: `SELECT id_hash, owner_id, name, abilities, last_used_at,
created_at, expires_at, revoked_at
FROM ` + t.APIKeys + ` WHERE id_hash = ?`,
ListAPIKeysByOwner: `SELECT id_hash, owner_id, name, abilities, last_used_at,
created_at, expires_at, revoked_at
FROM ` + t.APIKeys + ` WHERE owner_id = ? ORDER BY created_at DESC`,
TouchAPIKey: `UPDATE ` + t.APIKeys + ` SET last_used_at = ? WHERE id_hash = ?`,
RevokeAPIKey: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`,
RevokeAPIKeysByOwner: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE owner_id = ? AND revoked_at IS NULL`,
// roles
CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`,
GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`,
ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`,
DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`,
AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`,
RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`,
GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at
FROM ` + t.Roles + ` r
JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id
WHERE ur.user_id = ? ORDER BY r.name`,
// permissions
CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`,
GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`,
ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`,
DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`,
AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?)
ON CONFLICT DO NOTHING`,
RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`,
GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at
FROM ` + t.Permissions + ` p
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
WHERE rp.role_id = ? ORDER BY p.name`,
GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at
FROM ` + t.Permissions + ` p
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id
WHERE ur.user_id = ? ORDER BY p.name`,
// migrations
CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL
)`,
SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations,
InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`,
}
// Rewrite `?` placeholders to `$N`. Each query is independent; numbering
// resets per query.
q.CreateUser = rebind(q.CreateUser)
q.GetUserByID = rebind(q.GetUserByID)
q.GetUserByEmail = rebind(q.GetUserByEmail)
q.UpdateUser = rebind(q.UpdateUser)
q.DeleteUser = rebind(q.DeleteUser)
q.SetPassword = rebind(q.SetPassword)
q.SetEmailVerified = rebind(q.SetEmailVerified)
q.BumpSessionVersion = rebind(q.BumpSessionVersion)
q.IncrementFailedLogins = rebind(q.IncrementFailedLogins)
q.ResetFailedLogins = rebind(q.ResetFailedLogins)
q.CreateSession = rebind(q.CreateSession)
q.GetSession = rebind(q.GetSession)
q.TouchSession = rebind(q.TouchSession)
q.DeleteSession = rebind(q.DeleteSession)
q.DeleteUserSessions = rebind(q.DeleteUserSessions)
q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions)
q.CreateToken = rebind(q.CreateToken)
q.ConsumeToken = rebind(q.ConsumeToken)
q.GetToken = rebind(q.GetToken)
q.DeleteByChain = rebind(q.DeleteByChain)
q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens)
q.CreateAPIKey = rebind(q.CreateAPIKey)
q.GetAPIKey = rebind(q.GetAPIKey)
q.ListAPIKeysByOwner = rebind(q.ListAPIKeysByOwner)
q.TouchAPIKey = rebind(q.TouchAPIKey)
q.RevokeAPIKey = rebind(q.RevokeAPIKey)
q.RevokeAPIKeysByOwner = rebind(q.RevokeAPIKeysByOwner)
q.CreateRole = rebind(q.CreateRole)
q.GetRoleByID = rebind(q.GetRoleByID)
q.GetRoleByName = rebind(q.GetRoleByName)
q.ListRoles = rebind(q.ListRoles)
q.DeleteRole = rebind(q.DeleteRole)
q.AssignRoleToUser = rebind(q.AssignRoleToUser)
q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser)
q.GetUserRoles = rebind(q.GetUserRoles)
q.CreatePermission = rebind(q.CreatePermission)
q.GetPermissionByID = rebind(q.GetPermissionByID)
q.GetPermissionByName = rebind(q.GetPermissionByName)
q.ListPermissions = rebind(q.ListPermissions)
q.DeletePermission = rebind(q.DeletePermission)
q.AssignPermissionToRole = rebind(q.AssignPermissionToRole)
q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole)
q.GetRolePermissions = rebind(q.GetRolePermissions)
q.GetUserPermissions = rebind(q.GetUserPermissions)
q.SelectAppliedVersions = rebind(q.SelectAppliedVersions)
q.InsertAppliedVersion = rebind(q.InsertAppliedVersion)
// CreateMigrationsTable contains no parameters.
return q
}
// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order.
// Our query strings contain no string literals that include `?` (verified
// by inspection of every query in BuildQueries); a literal-aware rewriter
// would be more defensive but is not needed for v1.
func rebind(s string) string {
var b strings.Builder
b.Grow(len(s) + 16)
n := 1
for i := 0; i < len(s); i++ {
c := s[i]
if c == '?' {
fmt.Fprintf(&b, "$%d", n)
n++
continue
}
b.WriteByte(c)
}
return b.String()
}
// Compile-time interface compliance check.
var _ sqlstore.Dialect = (*Dialect)(nil)

116
sqlstore/migrate.go Normal file
View file

@ -0,0 +1,116 @@
package sqlstore
import (
"context"
"database/sql"
"io/fs"
"sort"
"strings"
"time"
"git.juancwu.dev/juancwu/errx"
)
// Migrate applies every embedded migration the dialect ships that has not
// yet been recorded in the schema-migrations table. It is safe to call
// repeatedly and concurrently across processes — the dialect's session
// lock serialises rollouts.
//
// Each migration .sql file is responsible for owning its own
// BEGIN/COMMIT and inserting a row into the schema-migrations table on
// success. The runner only handles file discovery, version tracking, and
// concurrency.
func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error {
const op = "authkit.sqlstore.Migrate"
if db == nil {
return errx.New(op, "db is required")
}
if dialect == nil {
return errx.New(op, "dialect is required")
}
if err := schema.Validate(); err != nil {
return errx.Wrap(op, err)
}
if err := dialect.Bootstrap(ctx, db); err != nil {
return errx.Wrap(op, err)
}
conn, err := db.Conn(ctx)
if err != nil {
return errx.Wrap(op, err)
}
defer conn.Close()
release, err := dialect.AcquireMigrationLock(ctx, conn)
if err != nil {
return errx.Wrap(op, err)
}
defer release()
q := dialect.BuildQueries(schema)
if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil {
return errx.Wrap(op, err)
}
applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions)
if err != nil {
return errx.Wrap(op, err)
}
migs := dialect.Migrations()
files, err := fs.ReadDir(migs, ".")
if err != nil {
return errx.Wrap(op, err)
}
names := make([]string, 0, len(files))
for _, f := range files {
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
names = append(names, f.Name())
}
}
sort.Strings(names)
for _, name := range names {
version := strings.TrimSuffix(name, ".sql")
if _, ok := applied[version]; ok {
continue
}
body, err := fs.ReadFile(migs, name)
if err != nil {
return errx.Wrapf(op, err, "read %s", name)
}
if _, err := conn.ExecContext(ctx, string(body)); err != nil {
return errx.Wrapf(op, err, "apply %s", version)
}
}
return nil
}
// applyVersionRow is intentionally not exposed: migration files own their
// own version-row insert. We keep this helper around in case a dialect ever
// needs to backfill versions from outside a migration body — currently
// unused.
var _ = applyVersionRow
func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error {
_, err := conn.ExecContext(ctx, insertQ, version, at)
return err
}
func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) {
rows, err := conn.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer rows.Close()
out := make(map[string]struct{})
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
return nil, err
}
out[v] = struct{}{}
}
return out, rows.Err()
}

301
sqlstore/rbac.go Normal file
View file

@ -0,0 +1,301 @@
package sqlstore
import (
"context"
"fmt"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type roleStore struct{ storeBase }
type permissionStore struct{ storeBase }
// ----- roleStore ------------------------------------------------------------
func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error {
const op = "authkit.sqlstore.RoleStore.CreateRole"
if r.ID == uuid.Nil {
r.ID = uuid.New()
}
if r.CreatedAt.IsZero() {
r.CreatedAt = time.Now().UTC()
}
if _, err := s.db.ExecContext(ctx, s.q.CreateRole,
uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrapf(op, err, "role %q already exists", r.Name)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetRoleByID"
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
}
return r, nil
}
func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetRoleByName"
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
}
return r, nil
}
func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.ListRoles"
rows, err := s.db.QueryContext(ctx, s.q.ListRoles)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.DeleteRole"
tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrRoleNotFound)
}
return nil
}
func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.AssignRoleToUser"
if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser,
uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error {
const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser"
if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser,
uuidArg(userID), uuidArg(roleID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) {
const op = "authkit.sqlstore.RoleStore.GetUserRoles"
rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Role
for rows.Next() {
r, err := scanRole(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, r)
}
return out, errx.Wrap(op, rows.Err())
}
// HasAnyRole builds the IN-clause at call time because the placeholder count
// depends on len(names). Identifier substitution comes from the validated
// Schema; values are bound, never interpolated.
func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
const op = "authkit.sqlstore.RoleStore.HasAnyRole"
if len(names) == 0 {
return false, nil
}
// Placeholder $1 is the user_id; $2..$N+1 cover the names slice.
listSQL := s.d.PlaceholderList(2, len(names))
q := fmt.Sprintf(`SELECT EXISTS (
SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id
WHERE ur.user_id = %s AND r.name IN (%s)
)`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL)
args := make([]any, 0, 1+len(names))
args = append(args, uuidArg(userID))
for _, n := range names {
args = append(args, n)
}
var ok bool
if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil {
return false, errx.Wrap(op, err)
}
return ok, nil
}
// ----- permissionStore ------------------------------------------------------
func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error {
const op = "authkit.sqlstore.PermissionStore.CreatePermission"
if p.ID == uuid.Nil {
p.ID = uuid.New()
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now().UTC()
}
if _, err := s.db.ExecContext(ctx, s.q.CreatePermission,
uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrapf(op, err, "permission %q already exists", p.Name)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetPermissionByID"
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
}
return p, nil
}
func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetPermissionByName"
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
}
return p, nil
}
func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.ListPermissions"
rows, err := s.db.QueryContext(ctx, s.q.ListPermissions)
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.DeletePermission"
tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrPermissionNotFound)
}
return nil
}
func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole"
if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error {
const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole"
if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole,
uuidArg(roleID), uuidArg(permID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetRolePermissions"
rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) {
const op = "authkit.sqlstore.PermissionStore.GetUserPermissions"
rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID))
if err != nil {
return nil, errx.Wrap(op, err)
}
defer rows.Close()
var out []*authkit.Permission
for rows.Next() {
p, err := scanPermission(rows)
if err != nil {
return nil, errx.Wrap(op, err)
}
out = append(out, p)
}
return out, errx.Wrap(op, rows.Err())
}
func scanRole(row rowScanner) (*authkit.Role, error) {
var (
r authkit.Role
idStr string
)
if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
r.ID = id
return &r, nil
}
func scanPermission(row rowScanner) (*authkit.Permission, error) {
var (
p authkit.Permission
idStr string
)
if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
p.ID = id
return &p, nil
}

86
sqlstore/scan.go Normal file
View file

@ -0,0 +1,86 @@
package sqlstore
import (
"database/sql"
"net/netip"
"time"
"github.com/google/uuid"
)
// rowScanner is the lowest-common-denominator interface satisfied by both
// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops
// uniformly.
type rowScanner interface {
Scan(dest ...any) error
}
// nullableTime returns nil when t is the zero time, otherwise &t. Callers
// should usually accept *time.Time on the model side and bind via this.
func nullableTime(t *time.Time) any {
if t == nil {
return nil
}
return *t
}
// nullableString turns "" into nil so columns store NULL rather than an
// empty string. Used for password hashes and similar optional text columns.
func nullableString(s string) any {
if s == "" {
return nil
}
return s
}
// nullableAddrString returns the string form of a netip.Addr when valid, or
// nil to bind as SQL NULL. Pairs with scanAddr.
func nullableAddrString(a netip.Addr) any {
if !a.IsValid() {
return nil
}
return a.String()
}
// scanAddr parses a *string column into a netip.Addr. The zero Addr value is
// returned when the column was NULL or empty.
func scanAddr(s *string) (netip.Addr, error) {
if s == nil || *s == "" {
return netip.Addr{}, nil
}
return netip.ParseAddr(*s)
}
// uuidArg returns the canonical string form of a UUID for binding. Every
// store uses this rather than passing uuid.UUID directly to keep behaviour
// identical across drivers (some accept driver.Valuer, some don't).
func uuidArg(id uuid.UUID) any { return id.String() }
// scanUUID reads a string column and parses it back to a uuid.UUID.
func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }
// chainArg returns either a *string or nil for binding the chain_id column.
func chainArg(c *string) any {
if c == nil {
return nil
}
return *c
}
// scanNullStringPtr converts sql.NullString to *string for the model.
func scanNullStringPtr(ns sql.NullString) *string {
if !ns.Valid {
return nil
}
v := ns.String
return &v
}
// scanNullTimePtr converts sql.NullTime to *time.Time for the model.
func scanNullTimePtr(nt sql.NullTime) *time.Time {
if !nt.Valid {
return nil
}
t := nt.Time
return &t
}

78
sqlstore/schema.go Normal file
View file

@ -0,0 +1,78 @@
package sqlstore
import (
"regexp"
"git.juancwu.dev/juancwu/errx"
)
// Schema lets consumers map authkit storage to their own table names.
// Column overrides are intentionally not present in v1 — adding them later
// is purely additive.
type Schema struct {
Tables Tables
}
// Tables is the per-table identifier override set. Every field must be a
// valid unquoted SQL identifier (matching identifierRE). Validation runs at
// New() and Migrate() time so SQL injection through Schema is impossible
// past that gate.
type Tables struct {
Users string
Sessions string
Tokens string
APIKeys string
Roles string
Permissions string
UserRoles string
RolePermissions string
SchemaMigrations string
}
// DefaultSchema returns the stock authkit_* names used by the embedded
// migration files.
func DefaultSchema() Schema {
return Schema{Tables: Tables{
Users: "authkit_users",
Sessions: "authkit_sessions",
Tokens: "authkit_tokens",
APIKeys: "authkit_api_keys",
Roles: "authkit_roles",
Permissions: "authkit_permissions",
UserRoles: "authkit_user_roles",
RolePermissions: "authkit_role_permissions",
SchemaMigrations: "authkit_schema_migrations",
}}
}
// identifierRE matches the safe ASCII identifier subset shared by Postgres,
// MySQL and SQLite when not quoted. Anything outside this set is rejected
// rather than escaped — Schema is not the place to support exotic names.
var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// Validate ensures every Schema.Tables field is a non-empty, safe identifier.
func (s Schema) Validate() error {
const op = "authkit.sqlstore.Schema.Validate"
checks := []struct {
field, value string
}{
{"Users", s.Tables.Users},
{"Sessions", s.Tables.Sessions},
{"Tokens", s.Tables.Tokens},
{"APIKeys", s.Tables.APIKeys},
{"Roles", s.Tables.Roles},
{"Permissions", s.Tables.Permissions},
{"UserRoles", s.Tables.UserRoles},
{"RolePermissions", s.Tables.RolePermissions},
{"SchemaMigrations", s.Tables.SchemaMigrations},
}
for _, c := range checks {
if c.value == "" {
return errx.Newf(op, "Schema.Tables.%s is empty", c.field)
}
if !identifierRE.MatchString(c.value) {
return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value)
}
}
return nil
}

98
sqlstore/sessions.go Normal file
View file

@ -0,0 +1,98 @@
package sqlstore
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type sessionStore struct{ storeBase }
func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error {
const op = "authkit.sqlstore.SessionStore.CreateSession"
now := time.Now().UTC()
if ses.CreatedAt.IsZero() {
ses.CreatedAt = now
}
if ses.LastSeenAt.IsZero() {
ses.LastSeenAt = ses.CreatedAt
}
_, err := s.db.ExecContext(ctx, s.q.CreateSession,
ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP),
ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) {
const op = "authkit.sqlstore.SessionStore.GetSession"
var (
ses authkit.Session
uidStr string
ipStr sql.NullString
)
err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan(
&ses.IDHash, &uidStr, &ses.UserAgent, &ipStr,
&ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid))
}
uid, err := scanUUID(uidStr)
if err != nil {
return nil, errx.Wrap(op, err)
}
ses.UserID = uid
if ipStr.Valid {
addr, err := scanAddr(&ipStr.String)
if err != nil {
return nil, errx.Wrap(op, err)
}
ses.IP = addr
}
return &ses, nil
}
func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error {
const op = "authkit.sqlstore.SessionStore.TouchSession"
tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash)
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrSessionInvalid)
}
return nil
}
func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error {
const op = "authkit.sqlstore.SessionStore.DeleteSession"
if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.sqlstore.SessionStore.DeleteUserSessions"
if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil {
return errx.Wrap(op, err)
}
return nil
}
func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.sqlstore.SessionStore.DeleteExpired"
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}

59
sqlstore/sqlstore.go Normal file
View file

@ -0,0 +1,59 @@
// Package sqlstore provides database/sql-backed implementations of every
// authkit store interface. It works with any driver (pgx-stdlib, lib/pq,
// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via
// Schema. Driver-specific behaviour lives in a Dialect; v1 ships
// dialect/postgres.
package sqlstore
import (
"database/sql"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
)
// Stores bundles every store implementation produced by New, ready to drop
// into authkit.Deps.
type Stores struct {
Users authkit.UserStore
Sessions authkit.SessionStore
Tokens authkit.TokenStore
APIKeys authkit.APIKeyStore
Roles authkit.RoleStore
Permissions authkit.PermissionStore
}
// New validates the Schema, builds the dialect's templated queries, and
// returns store implementations bound to db.
func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) {
const op = "authkit.sqlstore.New"
if db == nil {
return nil, errx.New(op, "db is required")
}
if dialect == nil {
return nil, errx.New(op, "dialect is required")
}
if err := schema.Validate(); err != nil {
return nil, errx.Wrap(op, err)
}
q := dialect.BuildQueries(schema)
base := storeBase{db: db, q: q, d: dialect, s: schema}
return &Stores{
Users: &userStore{storeBase: base},
Sessions: &sessionStore{storeBase: base},
Tokens: &tokenStore{storeBase: base},
APIKeys: &apiKeyStore{storeBase: base},
Roles: &roleStore{storeBase: base},
Permissions: &permissionStore{storeBase: base},
}, nil
}
// storeBase carries the shared dependencies every store needs. Embedded into
// each concrete store struct.
type storeBase struct {
db *sql.DB
q Queries
d Dialect
s Schema
}

258
sqlstore/sqlstore_test.go Normal file
View file

@ -0,0 +1,258 @@
package sqlstore_test
// Integration tests against a real Postgres. Skipped unless
// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by
// running Migrate against a randomly-named set of tables — no cleanup
// fixture, no external Docker dependency.
import (
"context"
"database/sql"
"errors"
"fmt"
"net/netip"
"os"
"testing"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
"git.juancwu.dev/juancwu/authkit/sqlstore"
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
_ "github.com/jackc/pgx/v5/stdlib"
)
func envURL(t *testing.T) string {
t.Helper()
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
if url == "" {
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
}
return url
}
// freshDB opens a connection, runs Migrate against the default schema, and
// schedules a teardown that drops every authkit_* table. Tests must run
// sequentially against a single database (the package's tests do, by
// default — go test serialises within a package unless t.Parallel is
// called).
func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) {
t.Helper()
url := envURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("ping: %v", err)
}
schema := sqlstore.DefaultSchema()
// Drop first to start clean — previous failed tests may have left rows.
dropAuthkitTables(t, db, schema)
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
t.Fatalf("Migrate: %v", err)
}
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
stores, err := sqlstore.New(db, pgdialect.New(), schema)
if err != nil {
t.Fatalf("sqlstore.New: %v", err)
}
auth := authkit.New(authkit.Deps{
Users: stores.Users,
Sessions: stores.Sessions,
Tokens: stores.Tokens,
APIKeys: stores.APIKeys,
Roles: stores.Roles,
Permissions: stores.Permissions,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte("integration-secret-thirty-two!!!!"),
JWTIssuer: "authkit-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: 1 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
})
return auth, db, schema
}
func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) {
t.Helper()
tables := []string{
s.Tables.UserRoles, s.Tables.RolePermissions,
s.Tables.Roles, s.Tables.Permissions,
s.Tables.APIKeys, s.Tables.Tokens,
s.Tables.Sessions, s.Tables.Users,
s.Tables.SchemaMigrations,
}
for _, name := range tables {
_, _ = db.ExecContext(context.Background(),
fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
}
}
func TestIntegration_MigrateIdempotent(t *testing.T) {
url := envURL(t)
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
schema := sqlstore.DefaultSchema()
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
for i := 0; i < 3; i++ {
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
t.Fatalf("Migrate iter %d: %v", i, err)
}
}
}
func TestIntegration_RegisterAndLogin(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("Register: %v", err)
}
got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2")
if err != nil {
t.Fatalf("LoginPassword: %v", err)
}
if got.ID != u.ID {
t.Fatalf("login user id mismatch")
}
if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) {
t.Fatalf("expected ErrEmailTaken, got %v", err)
}
if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) {
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
}
}
func TestIntegration_SessionLifecycle(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "s@s.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
if sess.ExpiresAt.Before(time.Now()) {
t.Fatalf("session already expired at issue")
}
if _, err := auth.AuthenticateSession(ctx, plain); err != nil {
t.Fatalf("AuthenticateSession: %v", err)
}
if err := auth.RevokeSession(ctx, plain); err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) {
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
}
}
func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "j@j.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
_, refresh1, err := auth.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
_, refresh2, err := auth.RefreshJWT(ctx, refresh1)
if err != nil {
t.Fatalf("first RefreshJWT: %v", err)
}
if refresh1 == refresh2 {
t.Fatalf("refresh token did not rotate")
}
if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) {
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
}
if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err)
}
}
func TestIntegration_APIKeyWithAbilities(t *testing.T) {
auth, _, _ := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "k@k.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
plain, _, err := auth.IssueAPIKey(ctx, u.ID, "ci",
[]string{"billing:read", "users:list"}, nil)
if err != nil {
t.Fatalf("IssueAPIKey: %v", err)
}
p, err := auth.AuthenticateAPIKey(ctx, plain)
if err != nil {
t.Fatalf("AuthenticateAPIKey: %v", err)
}
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
t.Fatalf("abilities missing: %+v", p.Abilities)
}
if err := auth.RevokeAPIKey(ctx, plain); err != nil {
t.Fatalf("RevokeAPIKey: %v", err)
}
if _, err := auth.AuthenticateAPIKey(ctx, plain); !errors.Is(err, authkit.ErrAPIKeyInvalid) {
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
}
}
func TestIntegration_RBAC(t *testing.T) {
auth, db, schema := freshDB(t)
ctx := context.Background()
u, err := auth.Register(ctx, "rb@b.com", "pw")
if err != nil {
t.Fatalf("Register: %v", err)
}
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
r := &authkit.Role{Name: "editor"}
if err := stores.Roles.CreateRole(ctx, r); err != nil {
t.Fatalf("CreateRole: %v", err)
}
p := &authkit.Permission{Name: "posts:write"}
if err := stores.Permissions.CreatePermission(ctx, p); err != nil {
t.Fatalf("CreatePermission: %v", err)
}
if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil {
t.Fatalf("AssignPermissionToRole: %v", err)
}
if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
ok, err := auth.HasPermission(ctx, u.ID, "posts:write")
if err != nil || !ok {
t.Fatalf("HasPermission: %v %v", ok, err)
}
ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"})
if err != nil || !ok {
t.Fatalf("HasAnyRole: %v %v", ok, err)
}
if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil {
t.Fatalf("RemoveRole: %v", err)
}
ok, _ = auth.HasPermission(ctx, u.ID, "posts:write")
if ok {
t.Fatalf("HasPermission should be false after RemoveRole")
}
}

92
sqlstore/tokens.go Normal file
View file

@ -0,0 +1,92 @@
package sqlstore
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
)
type tokenStore struct{ storeBase }
func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error {
const op = "authkit.sqlstore.TokenStore.CreateToken"
if t.CreatedAt.IsZero() {
t.CreatedAt = time.Now().UTC()
}
_, err := s.db.ExecContext(ctx, s.q.CreateToken,
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
// ConsumeToken does the find-and-mark-consumed in a single statement so two
// concurrent callers cannot both successfully consume the same token. The
// row is returned for inspection (e.g. ChainID for refresh rotation).
func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) {
const op = "authkit.sqlstore.TokenStore.ConsumeToken"
row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
}
return t, nil
}
func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) {
const op = "authkit.sqlstore.TokenStore.GetToken"
row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
}
return t, nil
}
func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) {
const op = "authkit.sqlstore.TokenStore.DeleteByChain"
tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.sqlstore.TokenStore.DeleteExpired"
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func scanToken(row rowScanner) (*authkit.Token, error) {
var (
t authkit.Token
kind string
userIDStr string
chainID sql.NullString
consumedAt sql.NullTime
)
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
&consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil {
return nil, err
}
t.Kind = authkit.TokenKind(kind)
uid, err := scanUUID(userIDStr)
if err != nil {
return nil, err
}
t.UserID = uid
t.ChainID = scanNullStringPtr(chainID)
t.ConsumedAt = scanNullTimePtr(consumedAt)
return &t, nil
}

186
sqlstore/users.go Normal file
View file

@ -0,0 +1,186 @@
package sqlstore
import (
"context"
"database/sql"
"errors"
"strings"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
)
type userStore struct{ storeBase }
func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error {
const op = "authkit.sqlstore.UserStore.CreateUser"
if u.ID == uuid.Nil {
u.ID = uuid.New()
}
if u.EmailNormalized == "" {
u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email))
}
now := time.Now().UTC()
if u.CreatedAt.IsZero() {
u.CreatedAt = now
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = now
}
_, err := s.db.ExecContext(ctx, s.q.CreateUser,
uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt)
if err != nil {
if s.d.IsUniqueViolation(err) {
return errx.Wrap(op, authkit.ErrEmailTaken)
}
return errx.Wrap(op, err)
}
return nil
}
func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) {
const op = "authkit.sqlstore.UserStore.GetUserByID"
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id)))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return u, nil
}
func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) {
const op = "authkit.sqlstore.UserStore.GetUserByEmail"
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail))
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return u, nil
}
func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error {
const op = "authkit.sqlstore.UserStore.UpdateUser"
u.UpdatedAt = time.Now().UTC()
tag, err := s.db.ExecContext(ctx, s.q.UpdateUser,
u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
const op = "authkit.sqlstore.UserStore.DeleteUser"
tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error {
const op = "authkit.sqlstore.UserStore.SetPassword"
tag, err := s.db.ExecContext(ctx, s.q.SetPassword,
nullableString(encodedHash), time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error {
const op = "authkit.sqlstore.UserStore.SetEmailVerified"
tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) {
const op = "authkit.sqlstore.UserStore.BumpSessionVersion"
var v int
if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion,
time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return v, nil
}
func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) {
const op = "authkit.sqlstore.UserStore.IncrementFailedLogins"
var n int
if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins,
time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
}
return n, nil
}
func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error {
const op = "authkit.sqlstore.UserStore.ResetFailedLogins"
tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID))
if err != nil {
return errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
if n == 0 {
return errx.Wrap(op, authkit.ErrUserNotFound)
}
return nil
}
func scanUser(row rowScanner) (*authkit.User, error) {
var (
u authkit.User
idStr string
passwordHash sql.NullString
emailVerified sql.NullTime
lastLogin sql.NullTime
)
if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified,
&passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin,
&u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, err
}
id, err := scanUUID(idStr)
if err != nil {
return nil, err
}
u.ID = id
if passwordHash.Valid {
u.PasswordHash = passwordHash.String
}
u.EmailVerifiedAt = scanNullTimePtr(emailVerified)
u.LastLoginAt = scanNullTimePtr(lastLogin)
return &u, nil
}
// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so
// callers get reliable errors.Is targets through errx wrapping.
func mapNotFound(err error, sentinel error) error {
if errors.Is(err, sql.ErrNoRows) {
return sentinel
}
return err
}

83
stores.go Normal file
View file

@ -0,0 +1,83 @@
package authkit
import (
"context"
"time"
"github.com/google/uuid"
)
type UserStore interface {
CreateUser(ctx context.Context, u *User) error
GetUserByID(ctx context.Context, id uuid.UUID) (*User, error)
GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error)
UpdateUser(ctx context.Context, u *User) error
DeleteUser(ctx context.Context, id uuid.UUID) error
SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error
SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error
BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error)
IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error)
ResetFailedLogins(ctx context.Context, userID uuid.UUID) error
}
type SessionStore interface {
CreateSession(ctx context.Context, s *Session) error
GetSession(ctx context.Context, idHash []byte) (*Session, error)
TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error
DeleteSession(ctx context.Context, idHash []byte) error
DeleteUserSessions(ctx context.Context, userID uuid.UUID) error
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
}
type TokenStore interface {
CreateToken(ctx context.Context, t *Token) error
// ConsumeToken atomically marks the matching unexpired, unconsumed token
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
// Implementations MUST do this in one statement (UPDATE ... RETURNING)
// to prevent double-spend under concurrent callers.
ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error)
// GetToken returns a token without consuming it. Used for refresh-token
// reuse detection: a token that exists with consumed_at != nil is a
// replay signal.
GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error)
DeleteByChain(ctx context.Context, chainID string) (int64, error)
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
}
type APIKeyStore interface {
CreateAPIKey(ctx context.Context, k *APIKey) error
GetAPIKey(ctx context.Context, idHash []byte) (*APIKey, error)
ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*APIKey, error)
TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error
RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error
RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error
}
type RoleStore interface {
CreateRole(ctx context.Context, r *Role) error
GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error)
GetRoleByName(ctx context.Context, name string) (*Role, error)
ListRoles(ctx context.Context) ([]*Role, error)
DeleteRole(ctx context.Context, id uuid.UUID) error
AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error
RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error
GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error)
HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error)
}
type PermissionStore interface {
CreatePermission(ctx context.Context, p *Permission) error
GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error)
GetPermissionByName(ctx context.Context, name string) (*Permission, error)
ListPermissions(ctx context.Context) ([]*Permission, error)
DeletePermission(ctx context.Context, id uuid.UUID) error
AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error
RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error
GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error)
GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error)
}
type Hasher interface {
Hash(password string) (string, error)
Verify(password, encoded string) (ok bool, needsRehash bool, err error)
}

67
tokens.go Normal file
View file

@ -0,0 +1,67 @@
package authkit
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"io"
"strings"
"git.juancwu.dev/juancwu/errx"
)
const secretRandomBytes = 32
// Secret prefixes. The leading token namespaces the secret so callers can
// route them to the right verifier without trial-and-error.
const (
prefixSession = "sess"
prefixRefresh = "rfr"
prefixAPIKey = "ak"
prefixEmailVerify = "evr"
prefixPasswordRset = "pwr"
prefixMagicLink = "mlnk"
)
// mintSecret generates a new opaque secret of the given prefix kind and
// returns the plaintext (to be returned to the user, never stored) and the
// SHA-256 lookup hash (to be stored).
func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) {
const op = "authkit.mintSecret"
if rng == nil {
rng = rand.Reader
}
buf := make([]byte, secretRandomBytes)
if _, err := io.ReadFull(rng, buf); err != nil {
return "", nil, errx.Wrap(op, err)
}
body := base64.RawURLEncoding.EncodeToString(buf)
plaintext = prefix + "_" + body
hash = hashSecret(plaintext)
return plaintext, hash, nil
}
// hashSecret returns sha256(plaintext) — the lookup key for opaque secrets.
// Plaintexts have full entropy from a CSPRNG so a plain hash is sufficient
// (no per-record salt needed; the random body is the salt).
func hashSecret(plaintext string) []byte {
sum := sha256.Sum256([]byte(plaintext))
return sum[:]
}
// parseSecret validates that a plaintext starts with the expected prefix
// and returns the lookup hash. The prefix check is constant-time relative
// to a fixed-length comparison.
func parseSecret(prefix, plaintext string) (hash []byte, ok bool) {
want := prefix + "_"
if !strings.HasPrefix(plaintext, want) {
return nil, false
}
return hashSecret(plaintext), true
}
// constantTimeEqual is a thin wrapper for readability at call sites.
func constantTimeEqual(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}

56
tokens_test.go Normal file
View file

@ -0,0 +1,56 @@
package authkit
import (
"bytes"
"crypto/sha256"
"strings"
"testing"
)
func TestMintSecretRoundtrip(t *testing.T) {
plaintext, hash, err := mintSecret(prefixSession, nil)
if err != nil {
t.Fatalf("mintSecret: %v", err)
}
if !strings.HasPrefix(plaintext, prefixSession+"_") {
t.Fatalf("missing prefix: %q", plaintext)
}
parsed, ok := parseSecret(prefixSession, plaintext)
if !ok {
t.Fatalf("parseSecret rejected our own mint")
}
if !bytes.Equal(hash, parsed) {
t.Fatalf("hash mismatch")
}
want := sha256.Sum256([]byte(plaintext))
if !bytes.Equal(hash, want[:]) {
t.Fatalf("hashSecret != sha256(plaintext)")
}
}
func TestParseSecretWrongPrefix(t *testing.T) {
plaintext, _, err := mintSecret(prefixSession, nil)
if err != nil {
t.Fatalf("mintSecret: %v", err)
}
if _, ok := parseSecret(prefixAPIKey, plaintext); ok {
t.Fatalf("parseSecret should reject mismatched prefix")
}
if _, ok := parseSecret(prefixSession, "sessXXXX"); ok {
t.Fatalf("parseSecret should require trailing underscore")
}
}
func TestMintSecretUniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
p, _, err := mintSecret(prefixAPIKey, nil)
if err != nil {
t.Fatalf("mintSecret: %v", err)
}
if _, dup := seen[p]; dup {
t.Fatalf("duplicate mint: %s", p)
}
seen[p] = struct{}{}
}
}