authkit initial
This commit is contained in:
parent
5173b0a43d
commit
134393fbca
43 changed files with 5188 additions and 1 deletions
344
README.md
344
README.md
|
|
@ -1,3 +1,345 @@
|
||||||
# authkit
|
# authkit
|
||||||
|
|
||||||
Just a concoction of auth stuff in one place.
|
A pragmatic authentication and authorization toolkit for Go web services.
|
||||||
|
|
||||||
|
`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and
|
||||||
|
permissions, plus default `database/sql` Postgres implementations and
|
||||||
|
framework-neutral HTTP middleware. It supports both opaque server-side
|
||||||
|
sessions and JWT access tokens with rotating refresh tokens, hashes passwords
|
||||||
|
with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux)
|
||||||
|
or any `net/http` stack.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```
|
||||||
|
go get git.juancwu.dev/juancwu/authkit
|
||||||
|
```
|
||||||
|
|
||||||
|
`authkit` itself depends only on `database/sql` and the Go standard library
|
||||||
|
plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring
|
||||||
|
your own driver: `pgx`, `lib/pq`, or anything else that registers a
|
||||||
|
`database/sql` driver.
|
||||||
|
|
||||||
|
```go
|
||||||
|
import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
|
||||||
|
```
|
||||||
|
|
||||||
|
PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and
|
||||||
|
`pgcrypto` so no extensions are required.
|
||||||
|
|
||||||
|
## What's included
|
||||||
|
|
||||||
|
**Authentication flows**
|
||||||
|
- Email + password registration and login (Argon2id PHC-encoded hashes)
|
||||||
|
- Opaque server-side sessions with sliding TTL bounded by an absolute cap
|
||||||
|
- JWT access tokens (HS256) with rotating refresh tokens and reuse detection
|
||||||
|
- Email verification, password reset, and magic-link passwordless login
|
||||||
|
|
||||||
|
**Authorization**
|
||||||
|
- Roles and permissions with many-to-many wiring
|
||||||
|
- API keys with custom abilities for per-endpoint scoping
|
||||||
|
- A unified `Principal` type so middleware works the same regardless of which
|
||||||
|
authentication method ran
|
||||||
|
|
||||||
|
**Storage**
|
||||||
|
- Interfaces for every store so callers can plug in their own backends
|
||||||
|
- Default Postgres implementation built on `*sql.DB` (`sqlstore` package)
|
||||||
|
- Override table names via `Schema` without forking — useful when authkit
|
||||||
|
lives alongside an existing application schema
|
||||||
|
- A `Dialect` abstraction so future MySQL / SQLite implementations slot in
|
||||||
|
without changes to store code
|
||||||
|
- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)`
|
||||||
|
helper that takes a session-scoped advisory lock
|
||||||
|
|
||||||
|
**HTTP**
|
||||||
|
- `middleware.RequireSession`, `RequireJWT`, `RequireAPIKey`, `RequireAny`
|
||||||
|
- `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`,
|
||||||
|
`RequireAbility`
|
||||||
|
- `middleware.PrincipalFrom(ctx)` to read the authenticated principal in
|
||||||
|
handlers
|
||||||
|
|
||||||
|
**Errors**
|
||||||
|
- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`,
|
||||||
|
`ErrTokenReused`, `ErrSessionInvalid`, `ErrAPIKeyInvalid`,
|
||||||
|
`ErrPermissionDenied`, ...) compatible with `errors.Is`
|
||||||
|
- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx)
|
||||||
|
for op tags
|
||||||
|
|
||||||
|
## Out of scope (v1)
|
||||||
|
|
||||||
|
MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching,
|
||||||
|
pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite
|
||||||
|
dialects (architecture supports them; only Postgres ships in v1), and
|
||||||
|
column-name overrides in `Schema` (table-name overrides only).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
### 1. Open a database and run migrations
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||||
|
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
db, err := sql.Open("pgx", os.Getenv("DATABASE_URL"))
|
||||||
|
if err != nil { /* ... */ }
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`Migrate` is idempotent and safe to call from multiple processes — it takes
|
||||||
|
a session-scoped `pg_advisory_lock` to serialise rollouts.
|
||||||
|
|
||||||
|
`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same
|
||||||
|
calls — the library only cares about `*sql.DB`.
|
||||||
|
|
||||||
|
### 2. Wire the service
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/authkit/hasher"
|
||||||
|
)
|
||||||
|
|
||||||
|
stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema())
|
||||||
|
if err != nil { /* ... */ }
|
||||||
|
|
||||||
|
auth := authkit.New(authkit.Deps{
|
||||||
|
Users: stores.Users,
|
||||||
|
Sessions: stores.Sessions,
|
||||||
|
Tokens: stores.Tokens,
|
||||||
|
APIKeys: stores.APIKeys,
|
||||||
|
Roles: stores.Roles,
|
||||||
|
Permissions: stores.Permissions,
|
||||||
|
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
|
||||||
|
}, authkit.Config{
|
||||||
|
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
|
||||||
|
JWTIssuer: "myapp",
|
||||||
|
SessionCookieSecure: true,
|
||||||
|
SessionCookieHTTPOnly: true,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
`Config` zero values fall back to sensible defaults (24h idle / 30d absolute
|
||||||
|
session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h
|
||||||
|
password-reset, 15m magic-link). `JWTSecret` and the seven `Deps` fields are
|
||||||
|
required; `New` panics on a misconfiguration.
|
||||||
|
|
||||||
|
### 3. Use the service
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Registration + password login
|
||||||
|
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
|
||||||
|
u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2")
|
||||||
|
|
||||||
|
// Opaque session (cookie-friendly)
|
||||||
|
plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP)
|
||||||
|
http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt))
|
||||||
|
|
||||||
|
// JWT + rotating refresh
|
||||||
|
access, refresh, err := auth.IssueJWT(ctx, u.ID)
|
||||||
|
access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed
|
||||||
|
|
||||||
|
// API key with abilities
|
||||||
|
plaintext, key, err := auth.IssueAPIKey(ctx, u.ID, "ci",
|
||||||
|
[]string{"billing:read", "users:list"}, nil)
|
||||||
|
|
||||||
|
// Email verification + password reset + magic link
|
||||||
|
tok, err := auth.RequestEmailVerification(ctx, u.ID)
|
||||||
|
_, err = auth.ConfirmEmail(ctx, tok)
|
||||||
|
|
||||||
|
tok, err = auth.RequestPasswordReset(ctx, "alice@example.com")
|
||||||
|
err = auth.ConfirmPasswordReset(ctx, tok, "new-password")
|
||||||
|
|
||||||
|
tok, err = auth.RequestMagicLink(ctx, "alice@example.com")
|
||||||
|
u, err = auth.ConsumeMagicLink(ctx, tok)
|
||||||
|
```
|
||||||
|
|
||||||
|
The plaintext returned by `IssueSession`, `IssueJWT`, `IssueAPIKey`, and the
|
||||||
|
token-minting flows is **show-once** — only its SHA-256 hash is stored. Show
|
||||||
|
it to the user immediately; you cannot recover it later.
|
||||||
|
|
||||||
|
### 4. Wire middleware
|
||||||
|
|
||||||
|
`authkit/middleware` returns standard `func(http.Handler) http.Handler`
|
||||||
|
values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any
|
||||||
|
`net/http` mux that accepts the same shape.
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
authkitmw "git.juancwu.dev/juancwu/authkit/middleware"
|
||||||
|
"git.juancwu.dev/juancwu/lightmux"
|
||||||
|
)
|
||||||
|
|
||||||
|
mux := lightmux.New()
|
||||||
|
|
||||||
|
cookieAuth := authkitmw.RequireSession(authkitmw.Options{
|
||||||
|
Auth: auth,
|
||||||
|
Extractor: authkit.ChainExtractors(
|
||||||
|
authkit.CookieExtractor("authkit_session"),
|
||||||
|
authkit.BearerExtractor(),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
me := mux.Group("/me", cookieAuth)
|
||||||
|
me.Get("", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p := authkitmw.MustPrincipal(r)
|
||||||
|
json.NewEncoder(w).Encode(p)
|
||||||
|
})
|
||||||
|
|
||||||
|
// RBAC: stack authz on top of any auth method
|
||||||
|
admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin"))
|
||||||
|
|
||||||
|
// API-key-only route with a per-endpoint ability check
|
||||||
|
api := mux.Group("/api/v1", authkitmw.RequireAPIKey(authkitmw.Options{Auth: auth}))
|
||||||
|
api.Get("/billing", billingHandler, authkitmw.RequireAbility("billing:read"))
|
||||||
|
```
|
||||||
|
|
||||||
|
`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or
|
||||||
|
chain extractors) when reading session cookies. `Options.OnUnauth` and
|
||||||
|
`Options.OnForbidden` default to a JSON `401` / `403`; override them to match
|
||||||
|
your error envelope.
|
||||||
|
|
||||||
|
## Custom table names
|
||||||
|
|
||||||
|
Pass a non-default `Schema` to use your own table names. Identifiers must
|
||||||
|
match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and
|
||||||
|
`Migrate()` time, so SQL injection through the schema is impossible.
|
||||||
|
|
||||||
|
```go
|
||||||
|
schema := sqlstore.DefaultSchema()
|
||||||
|
schema.Tables.Users = "accounts"
|
||||||
|
schema.Tables.APIKeys = "api_credentials"
|
||||||
|
|
||||||
|
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
|
||||||
|
```
|
||||||
|
|
||||||
|
The bundled migration files use the default `authkit_*` names. If you
|
||||||
|
override, you're responsible for matching DDL (most consumers with custom
|
||||||
|
naming already have their own DDL pipeline).
|
||||||
|
|
||||||
|
Column-name overrides are not exposed in v1 — the column set is fixed for
|
||||||
|
each table. Adding column overrides later is purely additive.
|
||||||
|
|
||||||
|
## How things work
|
||||||
|
|
||||||
|
### Secret token format
|
||||||
|
|
||||||
|
Sessions, refresh tokens, API keys, email-verify tokens, password-reset tokens,
|
||||||
|
and magic-link tokens all share one format:
|
||||||
|
|
||||||
|
```
|
||||||
|
plaintext = "<prefix>_" + base64url(32 random bytes, no padding)
|
||||||
|
lookup = sha256(plaintext)
|
||||||
|
```
|
||||||
|
|
||||||
|
Plaintext is returned to the caller exactly once and never persisted; the
|
||||||
|
SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or
|
||||||
|
`Config.Random` for tests).
|
||||||
|
|
||||||
|
### JWT revocation
|
||||||
|
|
||||||
|
Access tokens carry `sv` (`session_version`) in their claims. When you call
|
||||||
|
`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version`
|
||||||
|
column increments and every outstanding access token fails the next
|
||||||
|
`AuthenticateJWT`. This is the only way to invalidate a JWT before its
|
||||||
|
`exp`.
|
||||||
|
|
||||||
|
### Refresh token rotation
|
||||||
|
|
||||||
|
Each `RefreshJWT` consumes the presented refresh token and issues a new one
|
||||||
|
on the same chain (the `chain_id` column on `authkit_tokens`). If a
|
||||||
|
*consumed* refresh token is ever presented again — a strong replay signal —
|
||||||
|
the entire chain is deleted via `TokenStore.DeleteByChain` and the call
|
||||||
|
returns `ErrTokenReused`.
|
||||||
|
|
||||||
|
### Sliding session TTL
|
||||||
|
|
||||||
|
Each authenticated request via `AuthenticateSession` slides `expires_at` to
|
||||||
|
`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`.
|
||||||
|
Long-lived idle sessions still hit the absolute boundary.
|
||||||
|
|
||||||
|
### Schema and migrations
|
||||||
|
|
||||||
|
`sqlstore.Migrate` applies every embedded `.sql` file under
|
||||||
|
`sqlstore/dialect/postgres/migrations/` whose version (filename without
|
||||||
|
`.sql`) is not in `authkit_schema_migrations`. The dialect's
|
||||||
|
`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises
|
||||||
|
concurrent migrators. Each migration owns its own transaction so future
|
||||||
|
migrations can use statements like `CREATE INDEX CONCURRENTLY`.
|
||||||
|
|
||||||
|
Every default table is prefixed `authkit_` so the schema can live alongside
|
||||||
|
your application's own tables in a shared database.
|
||||||
|
|
||||||
|
### Driver and dialect architecture
|
||||||
|
|
||||||
|
The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour
|
||||||
|
lives behind a small `Dialect` interface:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Dialect interface {
|
||||||
|
Name() string
|
||||||
|
BuildQueries(s Schema) Queries
|
||||||
|
Bootstrap(ctx context.Context, db *sql.DB) error
|
||||||
|
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
|
||||||
|
Migrations() fs.FS
|
||||||
|
IsUniqueViolation(err error) bool
|
||||||
|
Placeholder(n int) string
|
||||||
|
PlaceholderList(start, count int) string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new
|
||||||
|
implementation; no changes to store code.
|
||||||
|
|
||||||
|
## Configuration reference
|
||||||
|
|
||||||
|
| Field | Default | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request |
|
||||||
|
| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this |
|
||||||
|
| `SessionCookieName` | `authkit_session` | |
|
||||||
|
| `SessionCookieSameSite` | `Lax` | |
|
||||||
|
| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production |
|
||||||
|
| `JWTSecret` | — (required) | HS256 key |
|
||||||
|
| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them |
|
||||||
|
| `AccessTokenTTL` | 15m | |
|
||||||
|
| `RefreshTokenTTL` | 30d | |
|
||||||
|
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
|
||||||
|
| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests |
|
||||||
|
| `Random` | `crypto/rand.Reader` | Override for deterministic tests |
|
||||||
|
| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit |
|
||||||
|
|
||||||
|
## Implementing your own store
|
||||||
|
|
||||||
|
Every store is a small interface with explicit semantics — see `stores.go`.
|
||||||
|
The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token
|
||||||
|
consumed and return it in a single statement (`UPDATE ... RETURNING` on
|
||||||
|
Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```
|
||||||
|
go test ./... # unit tests, no DB
|
||||||
|
AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests
|
||||||
|
```
|
||||||
|
|
||||||
|
Unit tests cover token mint/parse, Argon2id encode/verify (including
|
||||||
|
`needsRehash` on parameter change), JWT issue/parse (incl. expired,
|
||||||
|
`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email
|
||||||
|
verification, password reset cascading session invalidation, magic-link
|
||||||
|
self-verification, API keys with abilities, and RBAC role-permission
|
||||||
|
resolution. Integration tests run the full `sqlstore` contract against a
|
||||||
|
real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT. See `LICENSE`.
|
||||||
|
|
|
||||||
53
Taskfile.yml
Normal file
53
Taskfile.yml
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
version: '3'
|
||||||
|
|
||||||
|
tasks:
|
||||||
|
default:
|
||||||
|
desc: Run vet and unit tests
|
||||||
|
cmds:
|
||||||
|
- task: check
|
||||||
|
|
||||||
|
install:tools:
|
||||||
|
desc: Install development tools (tparse for prettier test output)
|
||||||
|
cmds:
|
||||||
|
- go install github.com/mfridman/tparse@latest
|
||||||
|
|
||||||
|
build:
|
||||||
|
desc: Build all packages
|
||||||
|
cmds:
|
||||||
|
- go build ./...
|
||||||
|
|
||||||
|
test:
|
||||||
|
desc: Run unit tests with prettier output via tparse
|
||||||
|
cmds:
|
||||||
|
- set -o pipefail && go test ./... -json -cover | tparse -all
|
||||||
|
|
||||||
|
test:race:
|
||||||
|
desc: Run unit tests under the race detector
|
||||||
|
cmds:
|
||||||
|
- set -o pipefail && go test ./... -race -json | tparse -all
|
||||||
|
|
||||||
|
test:integration:
|
||||||
|
desc: Run sqlstore integration tests against a real Postgres (requires AUTHKIT_TEST_DATABASE_URL)
|
||||||
|
cmds:
|
||||||
|
- set -o pipefail && go test ./sqlstore/... -run Integration -count=1 -json | tparse -all
|
||||||
|
|
||||||
|
vet:
|
||||||
|
desc: Run go vet
|
||||||
|
cmds:
|
||||||
|
- go vet ./...
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
desc: Format Go source files
|
||||||
|
cmds:
|
||||||
|
- gofmt -w .
|
||||||
|
|
||||||
|
tidy:
|
||||||
|
desc: Tidy go.mod
|
||||||
|
cmds:
|
||||||
|
- go mod tidy
|
||||||
|
|
||||||
|
check:
|
||||||
|
desc: Run vet and unit tests
|
||||||
|
cmds:
|
||||||
|
- task: vet
|
||||||
|
- task: test
|
||||||
121
authkit.go
Normal file
121
authkit.go
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deps bundles every backing store and the password hasher the Auth service
|
||||||
|
// depends on. All fields are required; New panics on a nil dep so misuse is
|
||||||
|
// caught at boot rather than under load.
|
||||||
|
type Deps struct {
|
||||||
|
Users UserStore
|
||||||
|
Sessions SessionStore
|
||||||
|
Tokens TokenStore
|
||||||
|
APIKeys APIKeyStore
|
||||||
|
Roles RoleStore
|
||||||
|
Permissions PermissionStore
|
||||||
|
Hasher Hasher
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material,
|
||||||
|
// and optional hooks. Any zero-valued duration is replaced with a sane
|
||||||
|
// default in New; required fields (notably JWTSecret) cause New to panic.
|
||||||
|
type Config struct {
|
||||||
|
// Session (opaque) cookies + DB-backed lifetime
|
||||||
|
SessionIdleTTL time.Duration
|
||||||
|
SessionAbsoluteTTL time.Duration
|
||||||
|
SessionCookieName string
|
||||||
|
SessionCookieDomain string
|
||||||
|
SessionCookiePath string
|
||||||
|
SessionCookieSecure bool
|
||||||
|
SessionCookieHTTPOnly bool
|
||||||
|
SessionCookieSameSite http.SameSite
|
||||||
|
|
||||||
|
// JWT (HS256)
|
||||||
|
JWTSecret []byte
|
||||||
|
JWTIssuer string
|
||||||
|
JWTAudience string
|
||||||
|
AccessTokenTTL time.Duration
|
||||||
|
RefreshTokenTTL time.Duration
|
||||||
|
|
||||||
|
// Single-use tokens
|
||||||
|
EmailVerifyTTL time.Duration
|
||||||
|
PasswordResetTTL time.Duration
|
||||||
|
MagicLinkTTL time.Duration
|
||||||
|
|
||||||
|
// Hooks (optional)
|
||||||
|
Clock func() time.Time
|
||||||
|
Random io.Reader
|
||||||
|
LoginHook func(ctx context.Context, email string, success bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth is the high-level service that composes the stores and hasher into the
|
||||||
|
// flows callers use: registration, login, sessions, JWTs, magic links, API
|
||||||
|
// keys, and authz checks. It is safe for concurrent use; method receivers
|
||||||
|
// never mutate Auth state after construction.
|
||||||
|
type Auth struct {
|
||||||
|
deps Deps
|
||||||
|
cfg Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// New validates Deps and Config, fills in defaults, and returns a ready
|
||||||
|
// service. It panics on missing deps or missing JWT secret rather than
|
||||||
|
// returning an error — these are programmer errors, not runtime ones.
|
||||||
|
func New(deps Deps, cfg Config) *Auth {
|
||||||
|
if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil ||
|
||||||
|
deps.APIKeys == nil || deps.Roles == nil || deps.Permissions == nil ||
|
||||||
|
deps.Hasher == nil {
|
||||||
|
panic(errx.New("authkit.New", "all Deps fields are required"))
|
||||||
|
}
|
||||||
|
if len(cfg.JWTSecret) == 0 {
|
||||||
|
panic(errx.New("authkit.New", "Config.JWTSecret is required"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.SessionIdleTTL == 0 {
|
||||||
|
cfg.SessionIdleTTL = 24 * time.Hour
|
||||||
|
}
|
||||||
|
if cfg.SessionAbsoluteTTL == 0 {
|
||||||
|
cfg.SessionAbsoluteTTL = 30 * 24 * time.Hour
|
||||||
|
}
|
||||||
|
if cfg.SessionCookieName == "" {
|
||||||
|
cfg.SessionCookieName = "authkit_session"
|
||||||
|
}
|
||||||
|
if cfg.SessionCookiePath == "" {
|
||||||
|
cfg.SessionCookiePath = "/"
|
||||||
|
}
|
||||||
|
if cfg.SessionCookieSameSite == 0 {
|
||||||
|
cfg.SessionCookieSameSite = http.SameSiteLaxMode
|
||||||
|
}
|
||||||
|
if cfg.AccessTokenTTL == 0 {
|
||||||
|
cfg.AccessTokenTTL = 15 * time.Minute
|
||||||
|
}
|
||||||
|
if cfg.RefreshTokenTTL == 0 {
|
||||||
|
cfg.RefreshTokenTTL = 30 * 24 * time.Hour
|
||||||
|
}
|
||||||
|
if cfg.EmailVerifyTTL == 0 {
|
||||||
|
cfg.EmailVerifyTTL = 48 * time.Hour
|
||||||
|
}
|
||||||
|
if cfg.PasswordResetTTL == 0 {
|
||||||
|
cfg.PasswordResetTTL = time.Hour
|
||||||
|
}
|
||||||
|
if cfg.MagicLinkTTL == 0 {
|
||||||
|
cfg.MagicLinkTTL = 15 * time.Minute
|
||||||
|
}
|
||||||
|
if cfg.Clock == nil {
|
||||||
|
cfg.Clock = func() time.Time { return time.Now().UTC() }
|
||||||
|
}
|
||||||
|
if cfg.Random == nil {
|
||||||
|
cfg.Random = rand.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Auth{deps: deps, cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// now returns the configured wall clock, defaulting to time.Now in UTC.
|
||||||
|
func (a *Auth) now() time.Time { return a.cfg.Clock() }
|
||||||
13
doc.go
Normal file
13
doc.go
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
// Package authkit is an authentication and authorization toolkit for Go web
|
||||||
|
// services. It defines storage interfaces (UserStore, SessionStore, TokenStore,
|
||||||
|
// APIKeyStore, RoleStore, PermissionStore) and a high-level Auth service that
|
||||||
|
// composes them to support registration, password login, opaque server-side
|
||||||
|
// sessions, JWT access plus rotating refresh tokens, email verification,
|
||||||
|
// password resets, magic-link passwordless login, role-based access control,
|
||||||
|
// and API keys with custom abilities.
|
||||||
|
//
|
||||||
|
// Default Postgres implementations of every store live in the pgstore
|
||||||
|
// subpackage. Argon2id password hashing lives in hasher. Framework-neutral
|
||||||
|
// HTTP middleware (compatible with lightmux and any net/http stack) lives in
|
||||||
|
// middleware.
|
||||||
|
package authkit
|
||||||
18
errors.go
Normal file
18
errors.go
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrEmailTaken = errors.New("authkit: email already registered")
|
||||||
|
ErrUserNotFound = errors.New("authkit: user not found")
|
||||||
|
ErrInvalidCredentials = errors.New("authkit: invalid credentials")
|
||||||
|
ErrEmailNotVerified = errors.New("authkit: email not verified")
|
||||||
|
ErrTokenInvalid = errors.New("authkit: invalid or expired token")
|
||||||
|
ErrTokenReused = errors.New("authkit: token reuse detected")
|
||||||
|
ErrSessionInvalid = errors.New("authkit: invalid or expired session")
|
||||||
|
ErrAPIKeyInvalid = errors.New("authkit: invalid or expired api key")
|
||||||
|
ErrPermissionDenied = errors.New("authkit: permission denied")
|
||||||
|
ErrRoleNotFound = errors.New("authkit: role not found")
|
||||||
|
ErrPermissionNotFound = errors.New("authkit: permission not found")
|
||||||
|
ErrConfigInvalid = errors.New("authkit: invalid configuration")
|
||||||
|
)
|
||||||
64
extractor.go
Normal file
64
extractor.go
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extractor pulls a credential string out of an HTTP request. It returns
|
||||||
|
// (value, true) when a value was found, otherwise ("", false).
|
||||||
|
type Extractor func(r *http.Request) (string, bool)
|
||||||
|
|
||||||
|
// BearerExtractor reads the value following "Bearer " in the Authorization
|
||||||
|
// header. Comparison is case-insensitive on the scheme.
|
||||||
|
func BearerExtractor() Extractor {
|
||||||
|
return func(r *http.Request) (string, bool) {
|
||||||
|
h := r.Header.Get("Authorization")
|
||||||
|
if h == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
const prefix = "bearer "
|
||||||
|
if len(h) <= len(prefix) || !strings.EqualFold(h[:len(prefix)], prefix) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
v := strings.TrimSpace(h[len(prefix):])
|
||||||
|
if v == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieExtractor reads the named cookie's value.
|
||||||
|
func CookieExtractor(name string) Extractor {
|
||||||
|
return func(r *http.Request) (string, bool) {
|
||||||
|
c, err := r.Cookie(name)
|
||||||
|
if err != nil || c.Value == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return c.Value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderExtractor reads a custom header verbatim.
|
||||||
|
func HeaderExtractor(name string) Extractor {
|
||||||
|
return func(r *http.Request) (string, bool) {
|
||||||
|
v := strings.TrimSpace(r.Header.Get(name))
|
||||||
|
if v == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChainExtractors tries each extractor in order, returning the first hit.
|
||||||
|
func ChainExtractors(es ...Extractor) Extractor {
|
||||||
|
return func(r *http.Request) (string, bool) {
|
||||||
|
for _, e := range es {
|
||||||
|
if v, ok := e(r); ok {
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
20
go.mod
Normal file
20
go.mod
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
module git.juancwu.dev/juancwu/authkit
|
||||||
|
|
||||||
|
go 1.26.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
git.juancwu.dev/juancwu/errx v0.1.0
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2
|
||||||
|
golang.org/x/crypto v0.50.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
|
golang.org/x/text v0.36.0 // indirect
|
||||||
|
)
|
||||||
36
go.sum
Normal file
36
go.sum
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
||||||
|
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
136
hasher/argon2id.go
Normal file
136
hasher/argon2id.go
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
// Package hasher provides authkit.Hasher implementations. The default
|
||||||
|
// implementation, Argon2id, encodes hashes in the standard PHC string format
|
||||||
|
// (https://github.com/P-H-C/phc-string-format) so callers can introspect
|
||||||
|
// parameters and migrate.
|
||||||
|
package hasher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Argon2idParams configures the Argon2id KDF.
|
||||||
|
type Argon2idParams struct {
|
||||||
|
Memory uint32 // KB; OWASP 2024 baseline: 19 MiB minimum, 64 MiB recommended
|
||||||
|
Iterations uint32 // time cost
|
||||||
|
Parallelism uint8 // lanes
|
||||||
|
SaltLen uint32 // bytes
|
||||||
|
KeyLen uint32 // bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultArgon2idParams returns sensible defaults: 64 MiB memory, 3
|
||||||
|
// iterations, 2 lanes, 16-byte salt, 32-byte key. Tune up Memory/Iterations
|
||||||
|
// over time and rely on Verify's needsRehash signal to migrate stored hashes.
|
||||||
|
func DefaultArgon2idParams() Argon2idParams {
|
||||||
|
return Argon2idParams{
|
||||||
|
Memory: 64 * 1024,
|
||||||
|
Iterations: 3,
|
||||||
|
Parallelism: 2,
|
||||||
|
SaltLen: 16,
|
||||||
|
KeyLen: 32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type argon2idHasher struct {
|
||||||
|
params Argon2idParams
|
||||||
|
rng io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the
|
||||||
|
// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand.
|
||||||
|
func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher {
|
||||||
|
if params == (Argon2idParams{}) {
|
||||||
|
params = DefaultArgon2idParams()
|
||||||
|
}
|
||||||
|
if rng == nil {
|
||||||
|
rng = rand.Reader
|
||||||
|
}
|
||||||
|
return &argon2idHasher{params: params, rng: rng}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *argon2idHasher) Hash(password string) (string, error) {
|
||||||
|
const op = "authkit.hasher.Argon2id.Hash"
|
||||||
|
if password == "" {
|
||||||
|
return "", errx.New(op, "password is empty")
|
||||||
|
}
|
||||||
|
salt := make([]byte, h.params.SaltLen)
|
||||||
|
if _, err := io.ReadFull(h.rng, salt); err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
key := argon2.IDKey([]byte(password), salt,
|
||||||
|
h.params.Iterations, h.params.Memory, h.params.Parallelism, h.params.KeyLen)
|
||||||
|
return encodePHC(h.params, salt, key), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) {
|
||||||
|
const op = "authkit.hasher.Argon2id.Verify"
|
||||||
|
got, salt, key, err := decodePHC(encoded)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
want := argon2.IDKey([]byte(password), salt,
|
||||||
|
got.Iterations, got.Memory, got.Parallelism, uint32(len(key)))
|
||||||
|
if subtle.ConstantTimeCompare(want, key) != 1 {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
needsRehash := got.Memory != h.params.Memory ||
|
||||||
|
got.Iterations != h.params.Iterations ||
|
||||||
|
got.Parallelism != h.params.Parallelism ||
|
||||||
|
uint32(len(salt)) != h.params.SaltLen ||
|
||||||
|
uint32(len(key)) != h.params.KeyLen
|
||||||
|
return true, needsRehash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PHC string format:
|
||||||
|
//
|
||||||
|
// $argon2id$v=19$m=<mem>,t=<iters>,p=<lanes>$<salt b64>$<key b64>
|
||||||
|
//
|
||||||
|
// base64 here is the unpadded standard alphabet, per the spec.
|
||||||
|
func encodePHC(p Argon2idParams, salt, key []byte) string {
|
||||||
|
enc := base64.RawStdEncoding
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
argon2.Version, p.Memory, p.Iterations, p.Parallelism,
|
||||||
|
enc.EncodeToString(salt), enc.EncodeToString(key),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodePHC(s string) (Argon2idParams, []byte, []byte, error) {
|
||||||
|
parts := strings.Split(s, "$")
|
||||||
|
// Expect: ["", "argon2id", "v=19", "m=...,t=...,p=...", "<salt>", "<key>"]
|
||||||
|
if len(parts) != 6 || parts[0] != "" || parts[1] != "argon2id" {
|
||||||
|
return Argon2idParams{}, nil, nil, errors.New("hasher: not an argon2id phc string")
|
||||||
|
}
|
||||||
|
var version int
|
||||||
|
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||||
|
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad version segment: %w", err)
|
||||||
|
}
|
||||||
|
if version != argon2.Version {
|
||||||
|
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: unsupported argon2 version %d", version)
|
||||||
|
}
|
||||||
|
var p Argon2idParams
|
||||||
|
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism); err != nil {
|
||||||
|
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad params segment: %w", err)
|
||||||
|
}
|
||||||
|
enc := base64.RawStdEncoding
|
||||||
|
salt, err := enc.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad salt: %w", err)
|
||||||
|
}
|
||||||
|
key, err := enc.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad key: %w", err)
|
||||||
|
}
|
||||||
|
p.SaltLen = uint32(len(salt))
|
||||||
|
p.KeyLen = uint32(len(key))
|
||||||
|
return p, salt, key, nil
|
||||||
|
}
|
||||||
78
hasher/argon2id_test.go
Normal file
78
hasher/argon2id_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
package hasher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestArgon2idHashVerifyRoundtrip(t *testing.T) {
|
||||||
|
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||||
|
encoded, err := h.Hash("hunter2hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Hash: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(encoded, "$argon2id$") {
|
||||||
|
t.Fatalf("encoded hash not in PHC form: %s", encoded)
|
||||||
|
}
|
||||||
|
ok, needsRehash, err := h.Verify("hunter2hunter2", encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Verify: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Verify rejected the original password")
|
||||||
|
}
|
||||||
|
if needsRehash {
|
||||||
|
t.Fatalf("Verify with default params should not signal rehash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArgon2idVerifyWrongPassword(t *testing.T) {
|
||||||
|
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||||
|
encoded, err := h.Hash("correct horse battery staple")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Hash: %v", err)
|
||||||
|
}
|
||||||
|
ok, _, err := h.Verify("nope", encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Verify: %v", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("Verify should reject wrong password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArgon2idNeedsRehashOnParamChange(t *testing.T) {
|
||||||
|
// Hash with light params...
|
||||||
|
light := Argon2idParams{Memory: 8 * 1024, Iterations: 1, Parallelism: 1, SaltLen: 16, KeyLen: 32}
|
||||||
|
encoded, err := NewArgon2id(light, nil).Hash("hello world")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Hash: %v", err)
|
||||||
|
}
|
||||||
|
// ...verify with stronger params should still match but flag rehash.
|
||||||
|
heavier := DefaultArgon2idParams()
|
||||||
|
ok, needsRehash, err := NewArgon2id(heavier, nil).Verify("hello world", encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Verify: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Verify rejected legitimate password across params")
|
||||||
|
}
|
||||||
|
if !needsRehash {
|
||||||
|
t.Fatalf("Verify should flag rehash when stored params differ from current")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArgon2idRejectsMalformed(t *testing.T) {
|
||||||
|
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||||
|
cases := []string{
|
||||||
|
"",
|
||||||
|
"not-a-phc",
|
||||||
|
"$argon2i$v=19$m=64,t=1,p=1$abc$def",
|
||||||
|
"$argon2id$v=99$m=64,t=1,p=1$YWJj$ZGVm",
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
if _, _, err := h.Verify("x", c); err == nil {
|
||||||
|
t.Fatalf("Verify should reject malformed encoding: %q", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
69
jwt.go
Normal file
69
jwt.go
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// accessClaims is the JWT shape issued by IssueJWT. The session_version
|
||||||
|
// field carries the User.SessionVersion at issue time so AuthenticateJWT
|
||||||
|
// can detect global revocations (logout-everywhere, password change).
|
||||||
|
type accessClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
SessionVersion int `json:"sv"`
|
||||||
|
Method string `json:"m"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, error) {
|
||||||
|
const op = "authkit.signAccessToken"
|
||||||
|
now := a.now()
|
||||||
|
claims := accessClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: userID.String(),
|
||||||
|
Issuer: a.cfg.JWTIssuer,
|
||||||
|
Audience: jwt.ClaimStrings{a.cfg.JWTAudience},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(a.cfg.AccessTokenTTL)),
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
},
|
||||||
|
SessionVersion: sessionVersion,
|
||||||
|
Method: string(AuthMethodJWT),
|
||||||
|
}
|
||||||
|
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
signed, err := tok.SignedString(a.cfg.JWTSecret)
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return signed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAccessToken validates the signature and returns the parsed claims.
|
||||||
|
func (a *Auth) parseAccessToken(token string) (*accessClaims, error) {
|
||||||
|
const op = "authkit.parseAccessToken"
|
||||||
|
opts := []jwt.ParserOption{
|
||||||
|
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
|
jwt.WithIssuedAt(),
|
||||||
|
jwt.WithTimeFunc(a.cfg.Clock),
|
||||||
|
}
|
||||||
|
if a.cfg.JWTIssuer != "" {
|
||||||
|
opts = append(opts, jwt.WithIssuer(a.cfg.JWTIssuer))
|
||||||
|
}
|
||||||
|
if a.cfg.JWTAudience != "" {
|
||||||
|
opts = append(opts, jwt.WithAudience(a.cfg.JWTAudience))
|
||||||
|
}
|
||||||
|
parser := jwt.NewParser(opts...)
|
||||||
|
|
||||||
|
parsed, err := parser.ParseWithClaims(token, &accessClaims{}, func(t *jwt.Token) (any, error) {
|
||||||
|
return a.cfg.JWTSecret, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
claims, ok := parsed.Claims.(*accessClaims)
|
||||||
|
if !ok || !parsed.Valid {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
99
jwt_test.go
Normal file
99
jwt_test.go
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWTIssueAndAuthenticate(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "alice@example.com", "hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
access, refresh, err := a.IssueJWT(context.Background(), u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueJWT: %v", err)
|
||||||
|
}
|
||||||
|
if access == "" || refresh == "" {
|
||||||
|
t.Fatalf("empty tokens")
|
||||||
|
}
|
||||||
|
p, err := a.AuthenticateJWT(context.Background(), access)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthenticateJWT: %v", err)
|
||||||
|
}
|
||||||
|
if p.UserID != u.ID {
|
||||||
|
t.Fatalf("principal user id mismatch")
|
||||||
|
}
|
||||||
|
if p.Method != AuthMethodJWT {
|
||||||
|
t.Fatalf("principal method = %s, want jwt", p.Method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTSessionVersionMismatchRejected(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "bob@example.com", "hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
access, _, err := a.IssueJWT(context.Background(), u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueJWT: %v", err)
|
||||||
|
}
|
||||||
|
if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil {
|
||||||
|
t.Fatalf("RevokeAllUserSessions: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTRefreshRotationAndReuseDetection(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "carol@example.com", "hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
_, refresh1, err := a.IssueJWT(context.Background(), u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueJWT: %v", err)
|
||||||
|
}
|
||||||
|
_, refresh2, err := a.RefreshJWT(context.Background(), refresh1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first RefreshJWT: %v", err)
|
||||||
|
}
|
||||||
|
if refresh1 == refresh2 {
|
||||||
|
t.Fatalf("refresh token did not rotate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replaying refresh1 must surface ErrTokenReused and revoke the chain.
|
||||||
|
if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) {
|
||||||
|
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
|
||||||
|
}
|
||||||
|
// After chain revocation, even refresh2 (the legitimate next one) must
|
||||||
|
// be rejected.
|
||||||
|
if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTExpiredTokenRejected(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
a.cfg.Clock = func() time.Time { return now }
|
||||||
|
u, err := a.Register(context.Background(), "dan@example.com", "hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
access, _, err := a.IssueJWT(context.Background(), u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueJWT: %v", err)
|
||||||
|
}
|
||||||
|
// Advance clock past TTL.
|
||||||
|
a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) }
|
||||||
|
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
619
memstore_test.go
Normal file
619
memstore_test.go
Normal file
|
|
@ -0,0 +1,619 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
// In-memory store fakes used by service-level tests. Kept in a _test.go file
|
||||||
|
// so they don't ship in the public API. Each fake is intentionally minimal —
|
||||||
|
// it satisfies the interface and supports the flows the tests exercise; not
|
||||||
|
// every method is wired up (a few panic to flag accidental use).
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memUserStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
byID map[uuid.UUID]*User
|
||||||
|
byEml map[string]uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemUserStore() *memUserStore {
|
||||||
|
return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memUserStore) CreateUser(_ context.Context, u *User) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if _, exists := s.byEml[u.EmailNormalized]; exists {
|
||||||
|
return ErrEmailTaken
|
||||||
|
}
|
||||||
|
if u.ID == uuid.Nil {
|
||||||
|
u.ID = uuid.New()
|
||||||
|
}
|
||||||
|
cp := *u
|
||||||
|
s.byID[u.ID] = &cp
|
||||||
|
s.byEml[u.EmailNormalized] = u.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
cp := *u
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
id, ok := s.byEml[ne]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
u := *s.byID[id]
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memUserStore) UpdateUser(_ context.Context, u *User) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if _, ok := s.byID[u.ID]; !ok {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
cp := *u
|
||||||
|
s.byID[u.ID] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
delete(s.byID, id)
|
||||||
|
delete(s.byEml, u.EmailNormalized)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
u.PasswordHash = h
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
u.EmailVerifiedAt = &at
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return 0, ErrUserNotFound
|
||||||
|
}
|
||||||
|
u.SessionVersion++
|
||||||
|
return u.SessionVersion, nil
|
||||||
|
}
|
||||||
|
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return 0, ErrUserNotFound
|
||||||
|
}
|
||||||
|
u.FailedLogins++
|
||||||
|
return u.FailedLogins, nil
|
||||||
|
}
|
||||||
|
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u, ok := s.byID[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
u.FailedLogins = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type memSessionStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
m map[string]*Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} }
|
||||||
|
func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
cp := *ses
|
||||||
|
s.m[string(ses.IDHash)] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
v, ok := s.m[string(h)]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrSessionInvalid
|
||||||
|
}
|
||||||
|
cp := *v
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
v, ok := s.m[string(h)]
|
||||||
|
if !ok {
|
||||||
|
return ErrSessionInvalid
|
||||||
|
}
|
||||||
|
v.LastSeenAt = lastSeen
|
||||||
|
v.ExpiresAt = exp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.m, string(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, v := range s.m {
|
||||||
|
if v.UserID == uid {
|
||||||
|
delete(s.m, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var n int64
|
||||||
|
for k, v := range s.m {
|
||||||
|
if !v.ExpiresAt.After(now) {
|
||||||
|
delete(s.m, k)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type memTokenStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
// Keyed by kind+hex(hash) so identical hashes across kinds don't collide.
|
||||||
|
m map[string]*Token
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} }
|
||||||
|
func tokKey(kind TokenKind, h []byte) string {
|
||||||
|
return string(kind) + ":" + string(h)
|
||||||
|
}
|
||||||
|
func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
cp := *t
|
||||||
|
s.m[tokKey(t.Kind, t.Hash)] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
t, ok := s.m[tokKey(kind, h)]
|
||||||
|
if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) {
|
||||||
|
return nil, ErrTokenInvalid
|
||||||
|
}
|
||||||
|
t.ConsumedAt = &now
|
||||||
|
cp := *t
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
t, ok := s.m[tokKey(kind, h)]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrTokenInvalid
|
||||||
|
}
|
||||||
|
cp := *t
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var n int64
|
||||||
|
for k, t := range s.m {
|
||||||
|
if t.ChainID != nil && *t.ChainID == chainID {
|
||||||
|
delete(s.m, k)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var n int64
|
||||||
|
for k, t := range s.m {
|
||||||
|
if !t.ExpiresAt.After(now) {
|
||||||
|
delete(s.m, k)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type memAPIKeyStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
m map[string]*APIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemAPIKeyStore() *memAPIKeyStore { return &memAPIKeyStore{m: map[string]*APIKey{}} }
|
||||||
|
func (s *memAPIKeyStore) CreateAPIKey(_ context.Context, k *APIKey) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
cp := *k
|
||||||
|
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||||
|
s.m[string(k.IDHash)] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memAPIKeyStore) GetAPIKey(_ context.Context, h []byte) (*APIKey, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
k, ok := s.m[string(h)]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrAPIKeyInvalid
|
||||||
|
}
|
||||||
|
cp := *k
|
||||||
|
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memAPIKeyStore) ListAPIKeysByOwner(_ context.Context, owner uuid.UUID) ([]*APIKey, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var out []*APIKey
|
||||||
|
for _, k := range s.m {
|
||||||
|
if k.OwnerID == owner {
|
||||||
|
cp := *k
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
func (s *memAPIKeyStore) TouchAPIKey(_ context.Context, h []byte, at time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if k, ok := s.m[string(h)]; ok {
|
||||||
|
k.LastUsedAt = &at
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memAPIKeyStore) RevokeAPIKey(_ context.Context, h []byte, at time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
k, ok := s.m[string(h)]
|
||||||
|
if !ok {
|
||||||
|
return ErrAPIKeyInvalid
|
||||||
|
}
|
||||||
|
if k.RevokedAt != nil {
|
||||||
|
return ErrAPIKeyInvalid
|
||||||
|
}
|
||||||
|
k.RevokedAt = &at
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memAPIKeyStore) RevokeAPIKeysByOwner(_ context.Context, owner uuid.UUID, at time.Time) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for _, k := range s.m {
|
||||||
|
if k.OwnerID == owner && k.RevokedAt == nil {
|
||||||
|
k.RevokedAt = &at
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type memRoleStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
roles map[uuid.UUID]*Role
|
||||||
|
rolesByNm map[string]uuid.UUID
|
||||||
|
userRoles map[uuid.UUID]map[uuid.UUID]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemRoleStore() *memRoleStore {
|
||||||
|
return &memRoleStore{
|
||||||
|
roles: map[uuid.UUID]*Role{},
|
||||||
|
rolesByNm: map[string]uuid.UUID{},
|
||||||
|
userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if _, dup := s.rolesByNm[r.Name]; dup {
|
||||||
|
return errors.New("role exists")
|
||||||
|
}
|
||||||
|
if r.ID == uuid.Nil {
|
||||||
|
r.ID = uuid.New()
|
||||||
|
}
|
||||||
|
cp := *r
|
||||||
|
s.roles[r.ID] = &cp
|
||||||
|
s.rolesByNm[r.Name] = r.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
r, ok := s.roles[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrRoleNotFound
|
||||||
|
}
|
||||||
|
cp := *r
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
id, ok := s.rolesByNm[n]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrRoleNotFound
|
||||||
|
}
|
||||||
|
r := *s.roles[id]
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
out := make([]*Role, 0, len(s.roles))
|
||||||
|
for _, r := range s.roles {
|
||||||
|
cp := *r
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
r, ok := s.roles[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrRoleNotFound
|
||||||
|
}
|
||||||
|
delete(s.roles, id)
|
||||||
|
delete(s.rolesByNm, r.Name)
|
||||||
|
for _, m := range s.userRoles {
|
||||||
|
delete(m, id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
m, ok := s.userRoles[uid]
|
||||||
|
if !ok {
|
||||||
|
m = map[uuid.UUID]struct{}{}
|
||||||
|
s.userRoles[uid] = m
|
||||||
|
}
|
||||||
|
m[rid] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if m, ok := s.userRoles[uid]; ok {
|
||||||
|
delete(m, rid)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
out := []*Role{}
|
||||||
|
for rid := range s.userRoles[uid] {
|
||||||
|
if r, ok := s.roles[rid]; ok {
|
||||||
|
cp := *r
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for rid := range s.userRoles[uid] {
|
||||||
|
r := s.roles[rid]
|
||||||
|
if slices.Contains(names, r.Name) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type memPermStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
perms map[uuid.UUID]*Permission
|
||||||
|
permsByNm map[string]uuid.UUID
|
||||||
|
rolePerms map[uuid.UUID]map[uuid.UUID]struct{}
|
||||||
|
roles *memRoleStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemPermStore(rs *memRoleStore) *memPermStore {
|
||||||
|
return &memPermStore{
|
||||||
|
perms: map[uuid.UUID]*Permission{},
|
||||||
|
permsByNm: map[string]uuid.UUID{},
|
||||||
|
rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||||
|
roles: rs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if _, dup := s.permsByNm[p.Name]; dup {
|
||||||
|
return errors.New("perm exists")
|
||||||
|
}
|
||||||
|
if p.ID == uuid.Nil {
|
||||||
|
p.ID = uuid.New()
|
||||||
|
}
|
||||||
|
cp := *p
|
||||||
|
s.perms[p.ID] = &cp
|
||||||
|
s.permsByNm[p.Name] = p.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
p, ok := s.perms[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrPermissionNotFound
|
||||||
|
}
|
||||||
|
cp := *p
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
id, ok := s.permsByNm[n]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrPermissionNotFound
|
||||||
|
}
|
||||||
|
p := *s.perms[id]
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
out := make([]*Permission, 0, len(s.perms))
|
||||||
|
for _, p := range s.perms {
|
||||||
|
cp := *p
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
p, ok := s.perms[id]
|
||||||
|
if !ok {
|
||||||
|
return ErrPermissionNotFound
|
||||||
|
}
|
||||||
|
delete(s.perms, id)
|
||||||
|
delete(s.permsByNm, p.Name)
|
||||||
|
for _, m := range s.rolePerms {
|
||||||
|
delete(m, id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
m, ok := s.rolePerms[rid]
|
||||||
|
if !ok {
|
||||||
|
m = map[uuid.UUID]struct{}{}
|
||||||
|
s.rolePerms[rid] = m
|
||||||
|
}
|
||||||
|
m[pid] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if m, ok := s.rolePerms[rid]; ok {
|
||||||
|
delete(m, pid)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
out := []*Permission{}
|
||||||
|
for pid := range s.rolePerms[rid] {
|
||||||
|
if p, ok := s.perms[pid]; ok {
|
||||||
|
cp := *p
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) {
|
||||||
|
s.roles.mu.Lock()
|
||||||
|
roleIDs := make([]uuid.UUID, 0)
|
||||||
|
for rid := range s.roles.userRoles[uid] {
|
||||||
|
roleIDs = append(roleIDs, rid)
|
||||||
|
}
|
||||||
|
s.roles.mu.Unlock()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
seen := map[uuid.UUID]struct{}{}
|
||||||
|
out := []*Permission{}
|
||||||
|
for _, rid := range roleIDs {
|
||||||
|
for pid := range s.rolePerms[rid] {
|
||||||
|
if _, dup := seen[pid]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[pid] = struct{}{}
|
||||||
|
if p, ok := s.perms[pid]; ok {
|
||||||
|
cp := *p
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in
|
||||||
|
// the hasher package itself; we just need *something* callable here.
|
||||||
|
type stubHasher struct{}
|
||||||
|
|
||||||
|
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
|
||||||
|
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
|
||||||
|
want := "stub:" + p
|
||||||
|
if !bytes.Equal([]byte(want), []byte(encoded)) {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestAuth wires the fakes into Auth with deterministic config.
|
||||||
|
func newTestAuth(t interface{ Helper() }) *Auth {
|
||||||
|
if h, ok := t.(interface{ Helper() }); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
roles := newMemRoleStore()
|
||||||
|
return New(Deps{
|
||||||
|
Users: newMemUserStore(),
|
||||||
|
Sessions: newMemSessionStore(),
|
||||||
|
Tokens: newMemTokenStore(),
|
||||||
|
APIKeys: newMemAPIKeyStore(),
|
||||||
|
Roles: roles,
|
||||||
|
Permissions: newMemPermStore(roles),
|
||||||
|
Hasher: stubHasher{},
|
||||||
|
}, Config{
|
||||||
|
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
|
||||||
|
JWTIssuer: "authkit-test",
|
||||||
|
AccessTokenTTL: 2 * time.Minute,
|
||||||
|
RefreshTokenTTL: 1 * time.Hour,
|
||||||
|
SessionIdleTTL: time.Hour,
|
||||||
|
SessionAbsoluteTTL: 24 * time.Hour,
|
||||||
|
EmailVerifyTTL: time.Hour,
|
||||||
|
PasswordResetTTL: time.Hour,
|
||||||
|
MagicLinkTTL: time.Minute,
|
||||||
|
})
|
||||||
|
}
|
||||||
71
middleware/authz.go
Normal file
71
middleware/authz.go
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
)
|
||||||
|
|
||||||
|
// authzGuard wraps the common pattern of "look up the Principal, run a
|
||||||
|
// predicate, succeed or 403". onForbidden defaults to JSON 403.
|
||||||
|
func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler {
|
||||||
|
if onForbidden == nil {
|
||||||
|
onForbidden = defaultJSONError(http.StatusForbidden)
|
||||||
|
}
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p, ok := PrincipalFrom(r.Context())
|
||||||
|
if !ok {
|
||||||
|
// No auth middleware ran upstream; treat as forbidden
|
||||||
|
// rather than crashing — composition is the caller's
|
||||||
|
// responsibility but a 403 is the safer default.
|
||||||
|
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !pred(p) {
|
||||||
|
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireRole permits requests whose Principal holds the named role.
|
||||||
|
func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||||
|
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||||
|
return p.HasRole(name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAnyRole permits requests whose Principal holds at least one of the
|
||||||
|
// named roles.
|
||||||
|
func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||||
|
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||||
|
return p.HasAnyRole(names...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequirePermission permits requests whose Principal holds the named
|
||||||
|
// permission (resolved via roles at auth time).
|
||||||
|
func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||||
|
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||||
|
return p.HasPermission(name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAbility permits requests whose Principal carries the named ability.
|
||||||
|
// Abilities are populated only for API-key authentication; this middleware
|
||||||
|
// will reject session/JWT-authenticated requests by design.
|
||||||
|
func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||||
|
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||||
|
return p.HasAbility(name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s[0]
|
||||||
|
}
|
||||||
39
middleware/context.go
Normal file
39
middleware/context.go
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Package middleware provides framework-neutral HTTP middleware for authkit.
|
||||||
|
// Every middleware function returns the standard func(http.Handler)
|
||||||
|
// http.Handler type, so it composes with lightmux's Use/Group/Handle as well
|
||||||
|
// as any net/http stack that uses the same signature.
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
)
|
||||||
|
|
||||||
|
// principalKey is an unexported context key. Using a distinct empty struct
|
||||||
|
// type guarantees no collision with caller-defined keys.
|
||||||
|
type principalKey struct{}
|
||||||
|
|
||||||
|
// withPrincipal stashes p on the request context for downstream handlers.
|
||||||
|
func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context {
|
||||||
|
return context.WithValue(ctx, principalKey{}, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrincipalFrom retrieves the authenticated Principal placed by RequireSession,
|
||||||
|
// RequireJWT, or RequireAPIKey. The boolean is false if no auth middleware
|
||||||
|
// ran for this request.
|
||||||
|
func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) {
|
||||||
|
p, ok := ctx.Value(principalKey{}).(*authkit.Principal)
|
||||||
|
return p, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustPrincipal panics if no Principal is on the context. Use only on
|
||||||
|
// handlers known to be behind a Require* middleware.
|
||||||
|
func MustPrincipal(r *http.Request) *authkit.Principal {
|
||||||
|
p, ok := PrincipalFrom(r.Context())
|
||||||
|
if !ok {
|
||||||
|
panic("authkit/middleware: no principal on request context")
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
138
middleware/middleware.go
Normal file
138
middleware/middleware.go
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options configures auth middleware. Auth is required; the rest fall back
|
||||||
|
// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403
|
||||||
|
// on authz failure.
|
||||||
|
type Options struct {
|
||||||
|
Auth *authkit.Auth
|
||||||
|
Extractor authkit.Extractor
|
||||||
|
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
|
||||||
|
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) extractor() authkit.Extractor {
|
||||||
|
if o.Extractor != nil {
|
||||||
|
return o.Extractor
|
||||||
|
}
|
||||||
|
return authkit.BearerExtractor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
if o.OnUnauth != nil {
|
||||||
|
return o.OnUnauth
|
||||||
|
}
|
||||||
|
return defaultJSONError(http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
if o.OnForbidden != nil {
|
||||||
|
return o.OnForbidden
|
||||||
|
}
|
||||||
|
return defaultJSONError(http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request, err error) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": http.StatusText(status),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireSession authenticates the request via an opaque session string. The
|
||||||
|
// extractor is consulted first; if no extractor is set the default Bearer
|
||||||
|
// extractor is used. For cookie-based session lookup, set
|
||||||
|
// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName).
|
||||||
|
func RequireSession(opts Options) func(http.Handler) http.Handler {
|
||||||
|
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||||
|
return opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireJWT authenticates the request via an HS256 JWT.
|
||||||
|
func RequireJWT(opts Options) func(http.Handler) http.Handler {
|
||||||
|
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||||
|
return opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAPIKey authenticates the request via an opaque API secret.
|
||||||
|
func RequireAPIKey(opts Options) func(http.Handler) http.Handler {
|
||||||
|
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||||
|
return opts.Auth.AuthenticateAPIKey(r.Context(), raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAny tries each method in order until one succeeds. Useful for routes
|
||||||
|
// that accept either a session cookie or an API key.
|
||||||
|
func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
|
||||||
|
if len(methods) == 0 {
|
||||||
|
methods = []authkit.AuthMethod{
|
||||||
|
authkit.AuthMethodSession,
|
||||||
|
authkit.AuthMethodJWT,
|
||||||
|
authkit.AuthMethodAPIKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
raw, ok := opts.extractor()(r)
|
||||||
|
if !ok || raw == "" {
|
||||||
|
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
p *authkit.Principal
|
||||||
|
lastErr error
|
||||||
|
)
|
||||||
|
for _, m := range methods {
|
||||||
|
switch m {
|
||||||
|
case authkit.AuthMethodSession:
|
||||||
|
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||||
|
case authkit.AuthMethodJWT:
|
||||||
|
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||||
|
case authkit.AuthMethodAPIKey:
|
||||||
|
p, lastErr = opts.Auth.AuthenticateAPIKey(r.Context(), raw)
|
||||||
|
}
|
||||||
|
if lastErr == nil && p != nil {
|
||||||
|
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opts.onUnauth()(w, r, lastErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requireWith is the shared scaffolding for the single-method Require*
|
||||||
|
// middlewares.
|
||||||
|
func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler {
|
||||||
|
if opts.Auth == nil {
|
||||||
|
panic("authkit/middleware: Options.Auth is required")
|
||||||
|
}
|
||||||
|
extractor := opts.extractor()
|
||||||
|
onUnauth := opts.onUnauth()
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
raw, ok := extractor(r)
|
||||||
|
if !ok || raw == "" {
|
||||||
|
onUnauth(w, r, authkit.ErrSessionInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p, err := authn(r, raw)
|
||||||
|
if err != nil {
|
||||||
|
onUnauth(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
75
models.go
Normal file
75
models.go
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Email string
|
||||||
|
EmailNormalized string
|
||||||
|
EmailVerifiedAt *time.Time
|
||||||
|
PasswordHash string
|
||||||
|
SessionVersion int
|
||||||
|
FailedLogins int
|
||||||
|
LastLoginAt *time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
IDHash []byte
|
||||||
|
UserID uuid.UUID
|
||||||
|
UserAgent string
|
||||||
|
IP netip.Addr
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastSeenAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenEmailVerify TokenKind = "email_verify"
|
||||||
|
TokenPasswordReset TokenKind = "password_reset"
|
||||||
|
TokenMagicLink TokenKind = "magic_link"
|
||||||
|
TokenRefresh TokenKind = "refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
Hash []byte
|
||||||
|
Kind TokenKind
|
||||||
|
UserID uuid.UUID
|
||||||
|
ChainID *string
|
||||||
|
ConsumedAt *time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIKey struct {
|
||||||
|
IDHash []byte
|
||||||
|
OwnerID uuid.UUID
|
||||||
|
Name string
|
||||||
|
Abilities []string
|
||||||
|
LastUsedAt *time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt *time.Time
|
||||||
|
RevokedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Permission struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
63
principal.go
Normal file
63
principal.go
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthMethod string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthMethodSession AuthMethod = "session"
|
||||||
|
AuthMethodJWT AuthMethod = "jwt"
|
||||||
|
AuthMethodAPIKey AuthMethod = "api_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Principal struct {
|
||||||
|
UserID uuid.UUID
|
||||||
|
Method AuthMethod
|
||||||
|
SessionID []byte
|
||||||
|
APIKeyID []byte
|
||||||
|
Roles []string
|
||||||
|
Permissions []string
|
||||||
|
Abilities []string
|
||||||
|
IssuedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Principal) HasRole(name string) bool {
|
||||||
|
for _, r := range p.Roles {
|
||||||
|
if r == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Principal) HasAnyRole(names ...string) bool {
|
||||||
|
for _, n := range names {
|
||||||
|
if p.HasRole(n) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Principal) HasPermission(name string) bool {
|
||||||
|
for _, perm := range p.Permissions {
|
||||||
|
if perm == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Principal) HasAbility(name string) bool {
|
||||||
|
for _, a := range p.Abilities {
|
||||||
|
if a == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
92
service_apikey.go
Normal file
92
service_apikey.go
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssueAPIKey mints a fresh API secret with the given abilities and an
|
||||||
|
// optional TTL. The plaintext is returned to the caller (show-once) and the
|
||||||
|
// SHA-256 lookup hash is stored. Pass ttl=nil for a non-expiring key.
|
||||||
|
func (a *Auth) IssueAPIKey(ctx context.Context, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *APIKey, error) {
|
||||||
|
const op = "authkit.Auth.IssueAPIKey"
|
||||||
|
plaintext, hash, err := mintSecret(prefixAPIKey, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
k := &APIKey{
|
||||||
|
IDHash: hash,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
Name: name,
|
||||||
|
Abilities: append([]string(nil), abilities...),
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
if ttl != nil {
|
||||||
|
exp := now.Add(*ttl)
|
||||||
|
k.ExpiresAt = &exp
|
||||||
|
}
|
||||||
|
if err := a.deps.APIKeys.CreateAPIKey(ctx, k); err != nil {
|
||||||
|
return "", nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateAPIKey validates an API secret string, touches last_used_at
|
||||||
|
// (best-effort), and returns a Principal carrying the key's abilities. The
|
||||||
|
// owning user's roles+permissions are also resolved so the same Principal
|
||||||
|
// can satisfy RequireRole / RequirePermission middleware.
|
||||||
|
func (a *Auth) AuthenticateAPIKey(ctx context.Context, plaintext string) (*Principal, error) {
|
||||||
|
const op = "authkit.Auth.AuthenticateAPIKey"
|
||||||
|
hash, ok := parseSecret(prefixAPIKey, plaintext)
|
||||||
|
if !ok {
|
||||||
|
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||||
|
}
|
||||||
|
k, err := a.deps.APIKeys.GetAPIKey(ctx, hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
if k.RevokedAt != nil {
|
||||||
|
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil && !k.ExpiresAt.After(now) {
|
||||||
|
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||||
|
}
|
||||||
|
_ = a.deps.APIKeys.TouchAPIKey(ctx, hash, now)
|
||||||
|
|
||||||
|
roles, perms, err := a.resolveRolesAndPermissions(ctx, k.OwnerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
expires := now
|
||||||
|
if k.ExpiresAt != nil {
|
||||||
|
expires = *k.ExpiresAt
|
||||||
|
}
|
||||||
|
return &Principal{
|
||||||
|
UserID: k.OwnerID,
|
||||||
|
Method: AuthMethodAPIKey,
|
||||||
|
APIKeyID: hash,
|
||||||
|
Roles: roles,
|
||||||
|
Permissions: perms,
|
||||||
|
Abilities: append([]string(nil), k.Abilities...),
|
||||||
|
IssuedAt: k.CreatedAt,
|
||||||
|
ExpiresAt: expires,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeAPIKey marks a key revoked. Idempotent on already-revoked keys.
|
||||||
|
func (a *Auth) RevokeAPIKey(ctx context.Context, plaintext string) error {
|
||||||
|
const op = "authkit.Auth.RevokeAPIKey"
|
||||||
|
hash, ok := parseSecret(prefixAPIKey, plaintext)
|
||||||
|
if !ok {
|
||||||
|
return errx.Wrap(op, ErrAPIKeyInvalid)
|
||||||
|
}
|
||||||
|
if err := a.deps.APIKeys.RevokeAPIKey(ctx, hash, a.now()); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
85
service_authz.go
Normal file
85
service_authz.go
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserPermissions returns the union of permission names a user holds via
|
||||||
|
// their assigned roles. Resolved at call time; v1 does not cache.
|
||||||
|
func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
|
||||||
|
const op = "authkit.Auth.UserPermissions"
|
||||||
|
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out := make([]string, len(perms))
|
||||||
|
for i, p := range perms {
|
||||||
|
out[i] = p.Name
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasPermission checks whether a user holds the named permission via any
|
||||||
|
// assigned role.
|
||||||
|
func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
|
||||||
|
const op = "authkit.Auth.HasPermission"
|
||||||
|
perms, err := a.UserPermissions(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return false, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
for _, p := range perms {
|
||||||
|
if p == name {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRole checks whether a user is assigned the named role.
|
||||||
|
func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
|
||||||
|
const op = "authkit.Auth.HasRole"
|
||||||
|
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name})
|
||||||
|
if err != nil {
|
||||||
|
return false, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasAnyRole checks whether a user holds at least one of the named roles.
|
||||||
|
func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
|
||||||
|
const op = "authkit.Auth.HasAnyRole"
|
||||||
|
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names)
|
||||||
|
if err != nil {
|
||||||
|
return false, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssignRole is a convenience that looks up a role by name and assigns it.
|
||||||
|
func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error {
|
||||||
|
const op = "authkit.Auth.AssignRole"
|
||||||
|
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveRole is the symmetric helper for AssignRole.
|
||||||
|
func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error {
|
||||||
|
const op = "authkit.Auth.RemoveRole"
|
||||||
|
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
147
service_jwt.go
Normal file
147
service_jwt.go
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssueJWT issues a fresh access JWT and a rotating opaque refresh token.
|
||||||
|
// The refresh token is bound to a chain via Token.ChainID; rotation
|
||||||
|
// preserves that chain so reuse-detection can revoke the whole family.
|
||||||
|
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
|
||||||
|
const op = "authkit.Auth.IssueJWT"
|
||||||
|
u, err := a.deps.Users.GetUserByID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
access, err = a.signAccessToken(u.ID, u.SessionVersion)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString())
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return access, refresh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateJWT validates the access JWT, cross-checks the user's
|
||||||
|
// session_version (instant revocation), and resolves a Principal.
|
||||||
|
func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, error) {
|
||||||
|
const op = "authkit.Auth.AuthenticateJWT"
|
||||||
|
claims, err := a.parseAccessToken(access)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
uid, err := uuid.Parse(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
u, err := a.deps.Users.GetUserByID(ctx, uid)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if u.SessionVersion != claims.SessionVersion {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
roles, perms, err := a.resolveRolesAndPermissions(ctx, u.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return &Principal{
|
||||||
|
UserID: u.ID,
|
||||||
|
Method: AuthMethodJWT,
|
||||||
|
Roles: roles,
|
||||||
|
Permissions: perms,
|
||||||
|
IssuedAt: claims.IssuedAt.Time,
|
||||||
|
ExpiresAt: claims.ExpiresAt.Time,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshJWT consumes the presented refresh token and mints a new access +
|
||||||
|
// refresh pair. Reuse of an already-consumed refresh token deletes the
|
||||||
|
// entire chain (logout-everywhere on that device family) and returns
|
||||||
|
// ErrTokenReused.
|
||||||
|
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
|
||||||
|
const op = "authkit.Auth.RefreshJWT"
|
||||||
|
hash, ok := parseSecret(prefixRefresh, plaintextRefresh)
|
||||||
|
if !ok {
|
||||||
|
return "", "", errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
|
||||||
|
consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now)
|
||||||
|
if err != nil {
|
||||||
|
// Differentiate plain-invalid (never existed / expired) from
|
||||||
|
// reuse (existed, already consumed). The presence-check below is
|
||||||
|
// the reuse signal.
|
||||||
|
if errors.Is(err, ErrTokenInvalid) {
|
||||||
|
if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil {
|
||||||
|
if existing.ChainID != nil && *existing.ChainID != "" {
|
||||||
|
_, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID)
|
||||||
|
}
|
||||||
|
return "", "", errx.Wrap(op, ErrTokenReused)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chainID string
|
||||||
|
if consumed.ChainID != nil {
|
||||||
|
chainID = *consumed.ChainID
|
||||||
|
}
|
||||||
|
if chainID == "" {
|
||||||
|
// Defensive: every refresh token should be chain-bound. Fall back
|
||||||
|
// to a fresh chain so we never throw on missing metadata.
|
||||||
|
chainID = uuid.NewString()
|
||||||
|
}
|
||||||
|
|
||||||
|
access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return access, refresh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mintRefreshToken stores a fresh refresh token bound to chainID and returns
|
||||||
|
// the plaintext.
|
||||||
|
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
|
||||||
|
const op = "authkit.Auth.mintRefreshToken"
|
||||||
|
plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t := &Token{
|
||||||
|
Hash: hash,
|
||||||
|
Kind: TokenRefresh,
|
||||||
|
UserID: userID,
|
||||||
|
ChainID: &chainID,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
|
||||||
|
}
|
||||||
|
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// userSessionVersion fetches the current session_version. Errors collapse to
|
||||||
|
// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly
|
||||||
|
// — but we still need a value to embed in the freshly-minted access token.
|
||||||
|
func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int {
|
||||||
|
if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil {
|
||||||
|
return u.SessionVersion
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
62
service_magic.go
Normal file
62
service_magic.go
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestMagicLink mints a single-use magic-link token for the email and
|
||||||
|
// returns the plaintext for delivery. ErrUserNotFound is returned for
|
||||||
|
// unregistered emails.
|
||||||
|
func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) {
|
||||||
|
const op = "authkit.Auth.RequestMagicLink"
|
||||||
|
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t := &Token{
|
||||||
|
Hash: hash,
|
||||||
|
Kind: TokenMagicLink,
|
||||||
|
UserID: u.ID,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(a.cfg.MagicLinkTTL),
|
||||||
|
}
|
||||||
|
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumeMagicLink consumes the magic-link token and returns the
|
||||||
|
// authenticated user. Callers typically follow this with IssueSession or
|
||||||
|
// IssueJWT to actually log the user in.
|
||||||
|
func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) {
|
||||||
|
const op = "authkit.Auth.ConsumeMagicLink"
|
||||||
|
hash, ok := parseSecret(prefixMagicLink, plaintextToken)
|
||||||
|
if !ok {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
u, err := a.deps.Users.GetUserByID(ctx, t.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
// A successful magic-link login also implicitly verifies the email
|
||||||
|
// (the user demonstrably controls the inbox).
|
||||||
|
if u.EmailVerifiedAt == nil {
|
||||||
|
if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil {
|
||||||
|
u.EmailVerifiedAt = &now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
69
service_reset.go
Normal file
69
service_reset.go
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestPasswordReset mints a single-use password-reset token for the user
|
||||||
|
// behind email and returns the plaintext for the caller to deliver via email.
|
||||||
|
// Returns ErrUserNotFound when the email isn't registered (per project
|
||||||
|
// policy of distinct errors over anti-enumeration).
|
||||||
|
func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) {
|
||||||
|
const op = "authkit.Auth.RequestPasswordReset"
|
||||||
|
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t := &Token{
|
||||||
|
Hash: hash,
|
||||||
|
Kind: TokenPasswordReset,
|
||||||
|
UserID: u.ID,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(a.cfg.PasswordResetTTL),
|
||||||
|
}
|
||||||
|
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmPasswordReset consumes the reset token, sets the new password,
|
||||||
|
// bumps the user's session_version, and revokes outstanding sessions so the
|
||||||
|
// reset constitutes a global logout.
|
||||||
|
func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error {
|
||||||
|
const op = "authkit.Auth.ConfirmPasswordReset"
|
||||||
|
hash, ok := parseSecret(prefixPasswordRset, plaintextToken)
|
||||||
|
if !ok {
|
||||||
|
return errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
newHash, err := a.deps.Hasher.Hash(newPassword)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
return errx.Wrap(op, ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
154
service_session.go
Normal file
154
service_session.go
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IssueSession mints an opaque session ID, persists the session record, and
|
||||||
|
// returns the plaintext (for the cookie) plus the stored Session.
|
||||||
|
func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) {
|
||||||
|
const op = "authkit.Auth.IssueSession"
|
||||||
|
plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
expires := now.Add(a.cfg.SessionIdleTTL)
|
||||||
|
if cap := now.Add(a.cfg.SessionAbsoluteTTL); expires.After(cap) {
|
||||||
|
expires = cap
|
||||||
|
}
|
||||||
|
s := &Session{
|
||||||
|
IDHash: hash,
|
||||||
|
UserID: userID,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IP: ip,
|
||||||
|
CreatedAt: now,
|
||||||
|
LastSeenAt: now,
|
||||||
|
ExpiresAt: expires,
|
||||||
|
}
|
||||||
|
if err := a.deps.Sessions.CreateSession(ctx, s); err != nil {
|
||||||
|
return "", nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateSession validates an opaque session string, slides the TTL,
|
||||||
|
// resolves the user's roles+permissions, and returns a Principal. Expired or
|
||||||
|
// unknown sessions return ErrSessionInvalid.
|
||||||
|
func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) {
|
||||||
|
const op = "authkit.Auth.AuthenticateSession"
|
||||||
|
hash, ok := parseSecret(prefixSession, plaintext)
|
||||||
|
if !ok {
|
||||||
|
return nil, errx.Wrap(op, ErrSessionInvalid)
|
||||||
|
}
|
||||||
|
s, err := a.deps.Sessions.GetSession(ctx, hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
if !s.ExpiresAt.After(now) {
|
||||||
|
_ = a.deps.Sessions.DeleteSession(ctx, hash)
|
||||||
|
return nil, errx.Wrap(op, ErrSessionInvalid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slide the idle TTL, capped at created_at + AbsoluteTTL so an active
|
||||||
|
// session still expires at the absolute boundary.
|
||||||
|
newExpires := now.Add(a.cfg.SessionIdleTTL)
|
||||||
|
if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) {
|
||||||
|
newExpires = cap
|
||||||
|
}
|
||||||
|
if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
roles, perms, err := a.resolveRolesAndPermissions(ctx, s.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return &Principal{
|
||||||
|
UserID: s.UserID,
|
||||||
|
Method: AuthMethodSession,
|
||||||
|
SessionID: hash,
|
||||||
|
Roles: roles,
|
||||||
|
Permissions: perms,
|
||||||
|
IssuedAt: s.CreatedAt,
|
||||||
|
ExpiresAt: newExpires,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeSession deletes a single session by its plaintext id. Idempotent:
|
||||||
|
// missing sessions are not an error (logout twice should not 500).
|
||||||
|
func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error {
|
||||||
|
const op = "authkit.Auth.RevokeSession"
|
||||||
|
hash, ok := parseSecret(prefixSession, plaintext)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeAllUserSessions kills every active session for the user and bumps
|
||||||
|
// the user's session_version (invalidating outstanding JWT access tokens).
|
||||||
|
func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error {
|
||||||
|
const op = "authkit.Auth.RevokeAllUserSessions"
|
||||||
|
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the
|
||||||
|
// plaintext returned by IssueSession; pass the matching ExpiresAt from the
|
||||||
|
// returned *Session as `expires`. To clear a cookie at logout, pass an empty
|
||||||
|
// plaintext and a past expiry.
|
||||||
|
func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie {
|
||||||
|
c := &http.Cookie{
|
||||||
|
Name: a.cfg.SessionCookieName,
|
||||||
|
Value: plaintext,
|
||||||
|
Path: a.cfg.SessionCookiePath,
|
||||||
|
Domain: a.cfg.SessionCookieDomain,
|
||||||
|
Secure: a.cfg.SessionCookieSecure,
|
||||||
|
HttpOnly: a.cfg.SessionCookieHTTPOnly,
|
||||||
|
SameSite: a.cfg.SessionCookieSameSite,
|
||||||
|
Expires: expires,
|
||||||
|
}
|
||||||
|
if plaintext == "" {
|
||||||
|
c.MaxAge = -1
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRolesAndPermissions fetches the user's role names and the union of
|
||||||
|
// their permission names. Both are returned as flat string slices for cheap
|
||||||
|
// containment checks on the Principal.
|
||||||
|
func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) {
|
||||||
|
roles, err := a.deps.Roles.GetUserRoles(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
rNames := make([]string, len(roles))
|
||||||
|
for i, r := range roles {
|
||||||
|
rNames[i] = r.Name
|
||||||
|
}
|
||||||
|
pNames := make([]string, len(perms))
|
||||||
|
for i, p := range perms {
|
||||||
|
pNames[i] = p.Name
|
||||||
|
}
|
||||||
|
return rNames, pNames, nil
|
||||||
|
}
|
||||||
220
service_test.go
Normal file
220
service_test.go
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegisterAndLogin(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
if u.EmailNormalized != "alice@example.com" {
|
||||||
|
t.Fatalf("email_normalized = %q", u.EmailNormalized)
|
||||||
|
}
|
||||||
|
got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoginPassword: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != u.ID {
|
||||||
|
t.Fatalf("login user mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterDuplicateEmail(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
_, err := a.Register(context.Background(), "X@Y.COM", "abc")
|
||||||
|
if !errors.Is(err, ErrEmailTaken) {
|
||||||
|
t.Fatalf("expected ErrEmailTaken, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginWrongPassword(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) {
|
||||||
|
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionIssueAuthenticateRevoke(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "s@s.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueSession: %v", err)
|
||||||
|
}
|
||||||
|
if sess == nil || plain == "" {
|
||||||
|
t.Fatalf("missing session or plaintext")
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := a.AuthenticateSession(context.Background(), plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthenticateSession: %v", err)
|
||||||
|
}
|
||||||
|
if p.UserID != u.ID {
|
||||||
|
t.Fatalf("principal user id mismatch")
|
||||||
|
}
|
||||||
|
if err := a.RevokeSession(context.Background(), plain); err != nil {
|
||||||
|
t.Fatalf("RevokeSession: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
|
||||||
|
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailVerificationFlow(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "ev@e.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
tok, err := a.RequestEmailVerification(context.Background(), u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestEmailVerification: %v", err)
|
||||||
|
}
|
||||||
|
confirmed, err := a.ConfirmEmail(context.Background(), tok)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConfirmEmail: %v", err)
|
||||||
|
}
|
||||||
|
if confirmed.EmailVerifiedAt == nil {
|
||||||
|
t.Fatalf("email_verified_at not set")
|
||||||
|
}
|
||||||
|
// Re-using the token must fail.
|
||||||
|
if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasswordResetFlow(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "r@r.com", "old")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
// Issue a session that should be invalidated by the reset.
|
||||||
|
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueSession: %v", err)
|
||||||
|
}
|
||||||
|
tok, err := a.RequestPasswordReset(context.Background(), "r@r.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestPasswordReset: %v", err)
|
||||||
|
}
|
||||||
|
if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil {
|
||||||
|
t.Fatalf("ConfirmPasswordReset: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) {
|
||||||
|
t.Fatalf("old password should fail post-reset, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil {
|
||||||
|
t.Fatalf("new password should work post-reset, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
|
||||||
|
t.Fatalf("session should be invalidated by reset, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMagicLinkFlow(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
tok, err := a.RequestMagicLink(context.Background(), "m@m.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RequestMagicLink: %v", err)
|
||||||
|
}
|
||||||
|
u, err := a.ConsumeMagicLink(context.Background(), tok)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConsumeMagicLink: %v", err)
|
||||||
|
}
|
||||||
|
if u.EmailVerifiedAt == nil {
|
||||||
|
t.Fatalf("magic link should imply email verification")
|
||||||
|
}
|
||||||
|
if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyFlowWithAbilities(t *testing.T) {
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(context.Background(), "k@k.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
plaintext, k, err := a.IssueAPIKey(context.Background(), u.ID, "ci", []string{"billing:read", "users:list"}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
if k == nil || plaintext == "" {
|
||||||
|
t.Fatalf("missing api key")
|
||||||
|
}
|
||||||
|
p, err := a.AuthenticateAPIKey(context.Background(), plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthenticateAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
|
||||||
|
t.Fatalf("abilities missing on principal: %+v", p.Abilities)
|
||||||
|
}
|
||||||
|
if p.HasAbility("admin:nuke") {
|
||||||
|
t.Fatalf("unexpected ability granted")
|
||||||
|
}
|
||||||
|
if err := a.RevokeAPIKey(context.Background(), plaintext); err != nil {
|
||||||
|
t.Fatalf("RevokeAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := a.AuthenticateAPIKey(context.Background(), plaintext); !errors.Is(err, ErrAPIKeyInvalid) {
|
||||||
|
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRBACRolesAndPermissions(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
a := newTestAuth(t)
|
||||||
|
u, err := a.Register(ctx, "rb@a.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create role + permission, hook them up.
|
||||||
|
role := &Role{Name: "editor"}
|
||||||
|
if err := a.deps.Roles.CreateRole(ctx, role); err != nil {
|
||||||
|
t.Fatalf("CreateRole: %v", err)
|
||||||
|
}
|
||||||
|
perm := &Permission{Name: "posts:write"}
|
||||||
|
if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil {
|
||||||
|
t.Fatalf("CreatePermission: %v", err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil {
|
||||||
|
t.Fatalf("AssignPermissionToRole: %v", err)
|
||||||
|
}
|
||||||
|
if err := a.AssignRole(ctx, u.ID, "editor"); err != nil {
|
||||||
|
t.Fatalf("AssignRole: %v", err)
|
||||||
|
}
|
||||||
|
ok, err := a.HasPermission(ctx, u.ID, "posts:write")
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err)
|
||||||
|
}
|
||||||
|
ok, err = a.HasRole(ctx, u.ID, "editor")
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("HasRole editor should be true, got %v %v", ok, err)
|
||||||
|
}
|
||||||
|
if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil {
|
||||||
|
t.Fatalf("RemoveRole: %v", err)
|
||||||
|
}
|
||||||
|
ok, _ = a.HasPermission(ctx, u.ID, "posts:write")
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("HasPermission should be false after RemoveRole")
|
||||||
|
}
|
||||||
|
}
|
||||||
178
service_user.go
Normal file
178
service_user.go
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail
|
||||||
|
// and the email_normalized column. Trim + lowercase is intentional; we do
|
||||||
|
// not collapse Gmail-style "+" addressing or strip dots — that's a policy
|
||||||
|
// decision callers can layer on top.
|
||||||
|
func normalizeEmail(s string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register creates a new user with an Argon2id-hashed password. Returns
|
||||||
|
// ErrEmailTaken if the normalized email is already registered.
|
||||||
|
func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) {
|
||||||
|
const op = "authkit.Auth.Register"
|
||||||
|
if email == "" || password == "" {
|
||||||
|
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
hash, err := a.deps.Hasher.Hash(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
u := &User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Email: email,
|
||||||
|
EmailNormalized: normalizeEmail(email),
|
||||||
|
PasswordHash: hash,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
if err := a.deps.Users.CreateUser(ctx, u); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginPassword verifies the password and returns the authenticated user.
|
||||||
|
// Failure increments failed_logins; success resets it and stamps last_login_at.
|
||||||
|
// LoginHook (if configured) is invoked with the success outcome — use this to
|
||||||
|
// hook in rate limiting or audit logging.
|
||||||
|
func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) {
|
||||||
|
const op = "authkit.Auth.LoginPassword"
|
||||||
|
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||||
|
if err != nil {
|
||||||
|
_ = a.fireLoginHook(ctx, email, false)
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if u.PasswordHash == "" {
|
||||||
|
_ = a.fireLoginHook(ctx, email, false)
|
||||||
|
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
_, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID)
|
||||||
|
_ = a.fireLoginHook(ctx, email, false)
|
||||||
|
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := a.now()
|
||||||
|
u.LastLoginAt = &now
|
||||||
|
u.FailedLogins = 0
|
||||||
|
if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Users.UpdateUser(ctx, u); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsRehash {
|
||||||
|
if newHash, herr := a.deps.Hasher.Hash(password); herr == nil {
|
||||||
|
_ = a.deps.Users.SetPassword(ctx, u.ID, newHash)
|
||||||
|
u.PasswordHash = newHash
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = a.fireLoginHook(ctx, email, true)
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangePassword verifies the current password, sets the new one, and bumps
|
||||||
|
// the user's session_version so all outstanding JWT access tokens are
|
||||||
|
// instantly invalidated. Outstanding opaque sessions are also revoked.
|
||||||
|
func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error {
|
||||||
|
const op = "authkit.Auth.ChangePassword"
|
||||||
|
u, err := a.deps.Users.GetUserByID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if u.PasswordHash == "" {
|
||||||
|
return errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return errx.Wrap(op, ErrInvalidCredentials)
|
||||||
|
}
|
||||||
|
newHash, err := a.deps.Hasher.Hash(newPassword)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestEmailVerification mints a single-use email-verify token for the
|
||||||
|
// user. Return the plaintext to the caller so they can put it in an email
|
||||||
|
// link; the lookup hash is stored in TokenStore.
|
||||||
|
func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||||
|
const op = "authkit.Auth.RequestEmailVerification"
|
||||||
|
plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random)
|
||||||
|
if err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t := &Token{
|
||||||
|
Hash: hash,
|
||||||
|
Kind: TokenEmailVerify,
|
||||||
|
UserID: userID,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(a.cfg.EmailVerifyTTL),
|
||||||
|
}
|
||||||
|
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||||
|
return "", errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfirmEmail consumes the verification token and marks the user's email
|
||||||
|
// verified. Returns ErrTokenInvalid if the token is missing/expired/used.
|
||||||
|
func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) {
|
||||||
|
const op = "authkit.Auth.ConfirmEmail"
|
||||||
|
hash, ok := parseSecret(prefixEmailVerify, plaintextToken)
|
||||||
|
if !ok {
|
||||||
|
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||||
|
}
|
||||||
|
now := a.now()
|
||||||
|
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return a.deps.Users.GetUserByID(ctx, t.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied
|
||||||
|
// hooks; we never want a misbehaving telemetry hook to break login.
|
||||||
|
func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error {
|
||||||
|
if a.cfg.LoginHook == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.cfg.LoginHook(ctx, email, success)
|
||||||
|
}
|
||||||
124
sqlstore/apikeys.go
Normal file
124
sqlstore/apikeys.go
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyStore struct{ storeBase }
|
||||||
|
|
||||||
|
func (s *apiKeyStore) CreateAPIKey(ctx context.Context, k *authkit.APIKey) error {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.CreateAPIKey"
|
||||||
|
if k.CreatedAt.IsZero() {
|
||||||
|
k.CreatedAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
if k.Abilities == nil {
|
||||||
|
k.Abilities = []string{}
|
||||||
|
}
|
||||||
|
abilities, err := json.Marshal(k.Abilities)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
_, err = s.db.ExecContext(ctx, s.q.CreateAPIKey,
|
||||||
|
k.IDHash, uuidArg(k.OwnerID), k.Name, abilities,
|
||||||
|
nullableTime(k.LastUsedAt), k.CreatedAt,
|
||||||
|
nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyStore) GetAPIKey(ctx context.Context, idHash []byte) (*authkit.APIKey, error) {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.GetAPIKey"
|
||||||
|
k, err := scanAPIKey(s.db.QueryRowContext(ctx, s.q.GetAPIKey, idHash))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrAPIKeyInvalid))
|
||||||
|
}
|
||||||
|
return k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyStore) ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*authkit.APIKey, error) {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.ListAPIKeysByOwner"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.ListAPIKeysByOwner, uuidArg(ownerID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.APIKey
|
||||||
|
for rows.Next() {
|
||||||
|
k, err := scanAPIKey(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, k)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyStore) TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.TouchAPIKey"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.TouchAPIKey, at, idHash); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyStore) RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKey"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.RevokeAPIKey, at, idHash)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrAPIKeyInvalid)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyStore) RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error {
|
||||||
|
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKeysByOwner"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.RevokeAPIKeysByOwner, at, uuidArg(ownerID)); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanAPIKey(row rowScanner) (*authkit.APIKey, error) {
|
||||||
|
var (
|
||||||
|
k authkit.APIKey
|
||||||
|
ownerIDStr string
|
||||||
|
abilitiesRaw []byte
|
||||||
|
lastUsed sql.NullTime
|
||||||
|
expires sql.NullTime
|
||||||
|
revoked sql.NullTime
|
||||||
|
)
|
||||||
|
if err := row.Scan(&k.IDHash, &ownerIDStr, &k.Name, &abilitiesRaw,
|
||||||
|
&lastUsed, &k.CreatedAt, &expires, &revoked); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
owner, err := scanUUID(ownerIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
k.OwnerID = owner
|
||||||
|
if len(abilitiesRaw) > 0 {
|
||||||
|
if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if k.Abilities == nil {
|
||||||
|
k.Abilities = []string{}
|
||||||
|
}
|
||||||
|
k.LastUsedAt = scanNullTimePtr(lastUsed)
|
||||||
|
k.ExpiresAt = scanNullTimePtr(expires)
|
||||||
|
k.RevokedAt = scanNullTimePtr(revoked)
|
||||||
|
return &k, nil
|
||||||
|
}
|
||||||
118
sqlstore/dialect.go
Normal file
118
sqlstore/dialect.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"io/fs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialect describes how a particular SQL backend renders the queries authkit
|
||||||
|
// runs at runtime, bootstraps its session for migrations, and reports
|
||||||
|
// driver-specific errors. v1 ships the Postgres dialect; future MySQL or
|
||||||
|
// SQLite implementations satisfy the same interface without changes to the
|
||||||
|
// store code.
|
||||||
|
type Dialect interface {
|
||||||
|
// Name returns a stable short identifier ("postgres", "mysql", ...).
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// BuildQueries renders every templated SQL string from a validated
|
||||||
|
// Schema. Called once at New() and at Migrate() so identifiers and
|
||||||
|
// placeholder styles are baked in by the time stores execute queries.
|
||||||
|
BuildQueries(s Schema) Queries
|
||||||
|
|
||||||
|
// Bootstrap runs once per Migrate() before the migration lock is taken.
|
||||||
|
// It is the place for things that must happen outside any transaction
|
||||||
|
// (CREATE EXTENSION on Postgres, for example). Must be idempotent and
|
||||||
|
// may be a no-op.
|
||||||
|
Bootstrap(ctx context.Context, db *sql.DB) error
|
||||||
|
|
||||||
|
// AcquireMigrationLock takes a session-scoped lock on conn so concurrent
|
||||||
|
// migrators serialise. The returned release function never returns an
|
||||||
|
// error — implementations may log internally.
|
||||||
|
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
|
||||||
|
|
||||||
|
// Migrations returns the dialect's embedded .sql files, rooted such
|
||||||
|
// that fs.ReadDir(".") lists them lex-sorted by filename.
|
||||||
|
Migrations() fs.FS
|
||||||
|
|
||||||
|
// IsUniqueViolation maps a duplicate-key error from this driver to true
|
||||||
|
// so insert paths can return the matching authkit sentinel without
|
||||||
|
// depending on driver internals.
|
||||||
|
IsUniqueViolation(err error) bool
|
||||||
|
|
||||||
|
// Placeholder returns the placeholder for parameter index n (1-based)
|
||||||
|
// in this dialect — "$1" for Postgres, "?" for MySQL/SQLite.
|
||||||
|
Placeholder(n int) string
|
||||||
|
|
||||||
|
// PlaceholderList returns a comma-separated placeholder list for `count`
|
||||||
|
// parameters starting at position `start` (1-based), suitable for
|
||||||
|
// dynamic IN-clauses. The second return is the same length, prefilled
|
||||||
|
// with nil so the caller can append actual values.
|
||||||
|
PlaceholderList(start, count int) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queries is the full set of statement templates authkit issues. Field names
|
||||||
|
// match the store method that consumes the query.
|
||||||
|
type Queries struct {
|
||||||
|
// users
|
||||||
|
CreateUser string
|
||||||
|
GetUserByID string
|
||||||
|
GetUserByEmail string
|
||||||
|
UpdateUser string
|
||||||
|
DeleteUser string
|
||||||
|
SetPassword string
|
||||||
|
SetEmailVerified string
|
||||||
|
BumpSessionVersion string
|
||||||
|
IncrementFailedLogins string
|
||||||
|
ResetFailedLogins string
|
||||||
|
|
||||||
|
// sessions
|
||||||
|
CreateSession string
|
||||||
|
GetSession string
|
||||||
|
TouchSession string
|
||||||
|
DeleteSession string
|
||||||
|
DeleteUserSessions string
|
||||||
|
DeleteExpiredSessions string
|
||||||
|
|
||||||
|
// tokens
|
||||||
|
CreateToken string
|
||||||
|
ConsumeToken string
|
||||||
|
GetToken string
|
||||||
|
DeleteByChain string
|
||||||
|
DeleteExpiredTokens string
|
||||||
|
|
||||||
|
// api keys
|
||||||
|
CreateAPIKey string
|
||||||
|
GetAPIKey string
|
||||||
|
ListAPIKeysByOwner string
|
||||||
|
TouchAPIKey string
|
||||||
|
RevokeAPIKey string
|
||||||
|
RevokeAPIKeysByOwner string
|
||||||
|
|
||||||
|
// roles
|
||||||
|
CreateRole string
|
||||||
|
GetRoleByID string
|
||||||
|
GetRoleByName string
|
||||||
|
ListRoles string
|
||||||
|
DeleteRole string
|
||||||
|
AssignRoleToUser string
|
||||||
|
RemoveRoleFromUser string
|
||||||
|
GetUserRoles string
|
||||||
|
// HasAnyRole is built at call time because the placeholder count varies.
|
||||||
|
|
||||||
|
// permissions
|
||||||
|
CreatePermission string
|
||||||
|
GetPermissionByID string
|
||||||
|
GetPermissionByName string
|
||||||
|
ListPermissions string
|
||||||
|
DeletePermission string
|
||||||
|
AssignPermissionToRole string
|
||||||
|
RemovePermissionFromRole string
|
||||||
|
GetRolePermissions string
|
||||||
|
GetUserPermissions string
|
||||||
|
|
||||||
|
// migrations
|
||||||
|
CreateMigrationsTable string
|
||||||
|
SelectAppliedVersions string
|
||||||
|
InsertAppliedVersion string
|
||||||
|
}
|
||||||
33
sqlstore/dialect/postgres/errors.go
Normal file
33
sqlstore/dialect/postgres/errors.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib
|
||||||
|
// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError.
|
||||||
|
// lib/pq uses *pq.Error which has a Code field of the same value.
|
||||||
|
const pgUniqueViolation = "23505"
|
||||||
|
|
||||||
|
// isUniqueViolation inspects err for a Postgres unique-violation, regardless
|
||||||
|
// of which driver registered the connection. We match on either the pgx
|
||||||
|
// error type or any error implementing a Code() string method (lib/pq's
|
||||||
|
// pq.Error has SQLState and Code fields; we check via reflection-free
|
||||||
|
// duck-typing through an interface).
|
||||||
|
func isUniqueViolation(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var pgxErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgxErr) {
|
||||||
|
return pgxErr.Code == pgUniqueViolation
|
||||||
|
}
|
||||||
|
type sqlStater interface{ SQLState() string }
|
||||||
|
var s sqlStater
|
||||||
|
if errors.As(err, &s) {
|
||||||
|
return s.SQLState() == pgUniqueViolation
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
99
sqlstore/dialect/postgres/migrations/0001_init.sql
Normal file
99
sqlstore/dialect/postgres/migrations/0001_init.sql
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
-- 0001_init.sql
|
||||||
|
-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the
|
||||||
|
-- library can be embedded in an existing application database. Each
|
||||||
|
-- migration owns its own transaction and inserts its version row at the
|
||||||
|
-- bottom; the runner only orchestrates file discovery and concurrency.
|
||||||
|
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_schema_migrations (
|
||||||
|
version TEXT PRIMARY KEY,
|
||||||
|
applied_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_users (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
email TEXT NOT NULL,
|
||||||
|
email_normalized TEXT NOT NULL,
|
||||||
|
email_verified_at TIMESTAMPTZ,
|
||||||
|
password_hash TEXT,
|
||||||
|
session_version INTEGER NOT NULL DEFAULT 0,
|
||||||
|
failed_logins INTEGER NOT NULL DEFAULT 0,
|
||||||
|
last_login_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq
|
||||||
|
ON authkit_users (email_normalized);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_sessions (
|
||||||
|
id_hash BYTEA PRIMARY KEY,
|
||||||
|
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||||
|
user_agent TEXT NOT NULL DEFAULT '',
|
||||||
|
ip TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL,
|
||||||
|
last_seen_at TIMESTAMPTZ NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_tokens (
|
||||||
|
hash BYTEA NOT NULL,
|
||||||
|
kind TEXT NOT NULL,
|
||||||
|
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||||
|
chain_id TEXT,
|
||||||
|
consumed_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
PRIMARY KEY (kind, hash)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx
|
||||||
|
ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_api_keys (
|
||||||
|
id_hash BYTEA PRIMARY KEY,
|
||||||
|
owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
abilities JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
last_used_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ,
|
||||||
|
revoked_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_roles (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
description TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_permissions (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
description TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_role_permissions (
|
||||||
|
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
|
||||||
|
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
|
||||||
|
PRIMARY KEY (role_id, permission_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS authkit_user_roles (
|
||||||
|
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||||
|
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
|
||||||
|
granted_at TIMESTAMPTZ NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, role_id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id);
|
||||||
|
|
||||||
|
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now())
|
||||||
|
ON CONFLICT (version) DO NOTHING;
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
272
sqlstore/dialect/postgres/postgres.go
Normal file
272
sqlstore/dialect/postgres/postgres.go
Normal file
|
|
@ -0,0 +1,272 @@
|
||||||
|
// Package postgres is the Postgres dialect for authkit/sqlstore. Importing
|
||||||
|
// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"`
|
||||||
|
// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`.
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed migrations/*.sql
|
||||||
|
var migrationsFS embed.FS
|
||||||
|
|
||||||
|
// Dialect implements sqlstore.Dialect for Postgres. The zero value is the
|
||||||
|
// only required form; New() returns a pointer for clarity at call sites.
|
||||||
|
type Dialect struct{}
|
||||||
|
|
||||||
|
// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate.
|
||||||
|
func New() *Dialect { return &Dialect{} }
|
||||||
|
|
||||||
|
func (Dialect) Name() string { return "postgres" }
|
||||||
|
|
||||||
|
// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable
|
||||||
|
// across rollouts and unlikely to clash with caller advisory locks.
|
||||||
|
const advisoryLockKey int64 = 0x617574686b6974
|
||||||
|
|
||||||
|
func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error {
|
||||||
|
// Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a
|
||||||
|
// hook for future migration prerequisites.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) {
|
||||||
|
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
release := func() {
|
||||||
|
if _, err := conn.ExecContext(context.Background(),
|
||||||
|
"SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil {
|
||||||
|
log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Dialect) Migrations() fs.FS {
|
||||||
|
sub, err := fs.Sub(migrationsFS, "migrations")
|
||||||
|
if err != nil {
|
||||||
|
// migrationsFS is statically populated; this can only fail if the
|
||||||
|
// embed directive is removed, which would be a build-time error.
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) }
|
||||||
|
|
||||||
|
func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) }
|
||||||
|
|
||||||
|
// PlaceholderList renders a comma-separated `$start,$start+1,...` list of
|
||||||
|
// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion.
|
||||||
|
func (Dialect) PlaceholderList(start, count int) string {
|
||||||
|
if count <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
b.WriteByte(',')
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "$%d", start+i)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildQueries renders every query authkit issues, with table identifiers
|
||||||
|
// taken from s and `?` placeholders rewritten to `$N`. Identifiers are
|
||||||
|
// already validated by Schema.Validate — this is interpolated with
|
||||||
|
// fmt.Sprintf, so the validation gate is load-bearing.
|
||||||
|
func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries {
|
||||||
|
t := s.Tables
|
||||||
|
q := sqlstore.Queries{
|
||||||
|
// users
|
||||||
|
CreateUser: `INSERT INTO ` + t.Users + `
|
||||||
|
(id, email, email_normalized, email_verified_at, password_hash,
|
||||||
|
session_version, failed_logins, last_login_at, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
GetUserByID: `SELECT id, email, email_normalized, email_verified_at,
|
||||||
|
password_hash, session_version, failed_logins, last_login_at,
|
||||||
|
created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`,
|
||||||
|
GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at,
|
||||||
|
password_hash, session_version, failed_logins, last_login_at,
|
||||||
|
created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`,
|
||||||
|
UpdateUser: `UPDATE ` + t.Users + ` SET
|
||||||
|
email = ?, email_normalized = ?, email_verified_at = ?,
|
||||||
|
password_hash = ?, session_version = ?, failed_logins = ?,
|
||||||
|
last_login_at = ?, updated_at = ?
|
||||||
|
WHERE id = ?`,
|
||||||
|
DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`,
|
||||||
|
SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`,
|
||||||
|
SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`,
|
||||||
|
BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`,
|
||||||
|
IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`,
|
||||||
|
ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`,
|
||||||
|
|
||||||
|
// sessions
|
||||||
|
CreateSession: `INSERT INTO ` + t.Sessions + `
|
||||||
|
(id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at
|
||||||
|
FROM ` + t.Sessions + ` WHERE id_hash = ?`,
|
||||||
|
TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`,
|
||||||
|
DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`,
|
||||||
|
DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`,
|
||||||
|
DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`,
|
||||||
|
|
||||||
|
// tokens
|
||||||
|
CreateToken: `INSERT INTO ` + t.Tokens + `
|
||||||
|
(hash, kind, user_id, chain_id, consumed_at, created_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
ConsumeToken: `UPDATE ` + t.Tokens + `
|
||||||
|
SET consumed_at = ?
|
||||||
|
WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ?
|
||||||
|
RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`,
|
||||||
|
GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at
|
||||||
|
FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`,
|
||||||
|
DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`,
|
||||||
|
DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`,
|
||||||
|
|
||||||
|
// api keys
|
||||||
|
CreateAPIKey: `INSERT INTO ` + t.APIKeys + `
|
||||||
|
(id_hash, owner_id, name, abilities, last_used_at, created_at, expires_at, revoked_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
GetAPIKey: `SELECT id_hash, owner_id, name, abilities, last_used_at,
|
||||||
|
created_at, expires_at, revoked_at
|
||||||
|
FROM ` + t.APIKeys + ` WHERE id_hash = ?`,
|
||||||
|
ListAPIKeysByOwner: `SELECT id_hash, owner_id, name, abilities, last_used_at,
|
||||||
|
created_at, expires_at, revoked_at
|
||||||
|
FROM ` + t.APIKeys + ` WHERE owner_id = ? ORDER BY created_at DESC`,
|
||||||
|
TouchAPIKey: `UPDATE ` + t.APIKeys + ` SET last_used_at = ? WHERE id_hash = ?`,
|
||||||
|
RevokeAPIKey: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`,
|
||||||
|
RevokeAPIKeysByOwner: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE owner_id = ? AND revoked_at IS NULL`,
|
||||||
|
|
||||||
|
// roles
|
||||||
|
CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
|
||||||
|
GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`,
|
||||||
|
GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`,
|
||||||
|
ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`,
|
||||||
|
DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`,
|
||||||
|
AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`,
|
||||||
|
RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`,
|
||||||
|
GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at
|
||||||
|
FROM ` + t.Roles + ` r
|
||||||
|
JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id
|
||||||
|
WHERE ur.user_id = ? ORDER BY r.name`,
|
||||||
|
|
||||||
|
// permissions
|
||||||
|
CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
|
||||||
|
GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`,
|
||||||
|
GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`,
|
||||||
|
ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`,
|
||||||
|
DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`,
|
||||||
|
AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?)
|
||||||
|
ON CONFLICT DO NOTHING`,
|
||||||
|
RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`,
|
||||||
|
GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at
|
||||||
|
FROM ` + t.Permissions + ` p
|
||||||
|
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
|
||||||
|
WHERE rp.role_id = ? ORDER BY p.name`,
|
||||||
|
GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at
|
||||||
|
FROM ` + t.Permissions + ` p
|
||||||
|
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
|
||||||
|
JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id
|
||||||
|
WHERE ur.user_id = ? ORDER BY p.name`,
|
||||||
|
|
||||||
|
// migrations
|
||||||
|
CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` (
|
||||||
|
version TEXT PRIMARY KEY,
|
||||||
|
applied_at TIMESTAMPTZ NOT NULL
|
||||||
|
)`,
|
||||||
|
SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations,
|
||||||
|
InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite `?` placeholders to `$N`. Each query is independent; numbering
|
||||||
|
// resets per query.
|
||||||
|
q.CreateUser = rebind(q.CreateUser)
|
||||||
|
q.GetUserByID = rebind(q.GetUserByID)
|
||||||
|
q.GetUserByEmail = rebind(q.GetUserByEmail)
|
||||||
|
q.UpdateUser = rebind(q.UpdateUser)
|
||||||
|
q.DeleteUser = rebind(q.DeleteUser)
|
||||||
|
q.SetPassword = rebind(q.SetPassword)
|
||||||
|
q.SetEmailVerified = rebind(q.SetEmailVerified)
|
||||||
|
q.BumpSessionVersion = rebind(q.BumpSessionVersion)
|
||||||
|
q.IncrementFailedLogins = rebind(q.IncrementFailedLogins)
|
||||||
|
q.ResetFailedLogins = rebind(q.ResetFailedLogins)
|
||||||
|
|
||||||
|
q.CreateSession = rebind(q.CreateSession)
|
||||||
|
q.GetSession = rebind(q.GetSession)
|
||||||
|
q.TouchSession = rebind(q.TouchSession)
|
||||||
|
q.DeleteSession = rebind(q.DeleteSession)
|
||||||
|
q.DeleteUserSessions = rebind(q.DeleteUserSessions)
|
||||||
|
q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions)
|
||||||
|
|
||||||
|
q.CreateToken = rebind(q.CreateToken)
|
||||||
|
q.ConsumeToken = rebind(q.ConsumeToken)
|
||||||
|
q.GetToken = rebind(q.GetToken)
|
||||||
|
q.DeleteByChain = rebind(q.DeleteByChain)
|
||||||
|
q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens)
|
||||||
|
|
||||||
|
q.CreateAPIKey = rebind(q.CreateAPIKey)
|
||||||
|
q.GetAPIKey = rebind(q.GetAPIKey)
|
||||||
|
q.ListAPIKeysByOwner = rebind(q.ListAPIKeysByOwner)
|
||||||
|
q.TouchAPIKey = rebind(q.TouchAPIKey)
|
||||||
|
q.RevokeAPIKey = rebind(q.RevokeAPIKey)
|
||||||
|
q.RevokeAPIKeysByOwner = rebind(q.RevokeAPIKeysByOwner)
|
||||||
|
|
||||||
|
q.CreateRole = rebind(q.CreateRole)
|
||||||
|
q.GetRoleByID = rebind(q.GetRoleByID)
|
||||||
|
q.GetRoleByName = rebind(q.GetRoleByName)
|
||||||
|
q.ListRoles = rebind(q.ListRoles)
|
||||||
|
q.DeleteRole = rebind(q.DeleteRole)
|
||||||
|
q.AssignRoleToUser = rebind(q.AssignRoleToUser)
|
||||||
|
q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser)
|
||||||
|
q.GetUserRoles = rebind(q.GetUserRoles)
|
||||||
|
|
||||||
|
q.CreatePermission = rebind(q.CreatePermission)
|
||||||
|
q.GetPermissionByID = rebind(q.GetPermissionByID)
|
||||||
|
q.GetPermissionByName = rebind(q.GetPermissionByName)
|
||||||
|
q.ListPermissions = rebind(q.ListPermissions)
|
||||||
|
q.DeletePermission = rebind(q.DeletePermission)
|
||||||
|
q.AssignPermissionToRole = rebind(q.AssignPermissionToRole)
|
||||||
|
q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole)
|
||||||
|
q.GetRolePermissions = rebind(q.GetRolePermissions)
|
||||||
|
q.GetUserPermissions = rebind(q.GetUserPermissions)
|
||||||
|
|
||||||
|
q.SelectAppliedVersions = rebind(q.SelectAppliedVersions)
|
||||||
|
q.InsertAppliedVersion = rebind(q.InsertAppliedVersion)
|
||||||
|
// CreateMigrationsTable contains no parameters.
|
||||||
|
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order.
|
||||||
|
// Our query strings contain no string literals that include `?` (verified
|
||||||
|
// by inspection of every query in BuildQueries); a literal-aware rewriter
|
||||||
|
// would be more defensive but is not needed for v1.
|
||||||
|
func rebind(s string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(s) + 16)
|
||||||
|
n := 1
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
if c == '?' {
|
||||||
|
fmt.Fprintf(&b, "$%d", n)
|
||||||
|
n++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time interface compliance check.
|
||||||
|
var _ sqlstore.Dialect = (*Dialect)(nil)
|
||||||
116
sqlstore/migrate.go
Normal file
116
sqlstore/migrate.go
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"io/fs"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Migrate applies every embedded migration the dialect ships that has not
|
||||||
|
// yet been recorded in the schema-migrations table. It is safe to call
|
||||||
|
// repeatedly and concurrently across processes — the dialect's session
|
||||||
|
// lock serialises rollouts.
|
||||||
|
//
|
||||||
|
// Each migration .sql file is responsible for owning its own
|
||||||
|
// BEGIN/COMMIT and inserting a row into the schema-migrations table on
|
||||||
|
// success. The runner only handles file discovery, version tracking, and
|
||||||
|
// concurrency.
|
||||||
|
func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error {
|
||||||
|
const op = "authkit.sqlstore.Migrate"
|
||||||
|
if db == nil {
|
||||||
|
return errx.New(op, "db is required")
|
||||||
|
}
|
||||||
|
if dialect == nil {
|
||||||
|
return errx.New(op, "dialect is required")
|
||||||
|
}
|
||||||
|
if err := schema.Validate(); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dialect.Bootstrap(ctx, db); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := db.Conn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
release, err := dialect.AcquireMigrationLock(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer release()
|
||||||
|
|
||||||
|
q := dialect.BuildQueries(schema)
|
||||||
|
if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migs := dialect.Migrations()
|
||||||
|
files, err := fs.ReadDir(migs, ".")
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
names := make([]string, 0, len(files))
|
||||||
|
for _, f := range files {
|
||||||
|
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
|
||||||
|
names = append(names, f.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
version := strings.TrimSuffix(name, ".sql")
|
||||||
|
if _, ok := applied[version]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body, err := fs.ReadFile(migs, name)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrapf(op, err, "read %s", name)
|
||||||
|
}
|
||||||
|
if _, err := conn.ExecContext(ctx, string(body)); err != nil {
|
||||||
|
return errx.Wrapf(op, err, "apply %s", version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyVersionRow is intentionally not exposed: migration files own their
|
||||||
|
// own version-row insert. We keep this helper around in case a dialect ever
|
||||||
|
// needs to backfill versions from outside a migration body — currently
|
||||||
|
// unused.
|
||||||
|
var _ = applyVersionRow
|
||||||
|
|
||||||
|
func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error {
|
||||||
|
_, err := conn.ExecContext(ctx, insertQ, version, at)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) {
|
||||||
|
rows, err := conn.QueryContext(ctx, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
out := make(map[string]struct{})
|
||||||
|
for rows.Next() {
|
||||||
|
var v string
|
||||||
|
if err := rows.Scan(&v); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out[v] = struct{}{}
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
301
sqlstore/rbac.go
Normal file
301
sqlstore/rbac.go
Normal file
|
|
@ -0,0 +1,301 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roleStore struct{ storeBase }
|
||||||
|
type permissionStore struct{ storeBase }
|
||||||
|
|
||||||
|
// ----- roleStore ------------------------------------------------------------
|
||||||
|
|
||||||
|
func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.CreateRole"
|
||||||
|
if r.ID == uuid.Nil {
|
||||||
|
r.ID = uuid.New()
|
||||||
|
}
|
||||||
|
if r.CreatedAt.IsZero() {
|
||||||
|
r.CreatedAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.CreateRole,
|
||||||
|
uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil {
|
||||||
|
if s.d.IsUniqueViolation(err) {
|
||||||
|
return errx.Wrapf(op, err, "role %q already exists", r.Name)
|
||||||
|
}
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.GetRoleByID"
|
||||||
|
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.GetRoleByName"
|
||||||
|
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.ListRoles"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.ListRoles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.Role
|
||||||
|
for rows.Next() {
|
||||||
|
r, err := scanRole(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, r)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.DeleteRole"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrRoleNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.AssignRoleToUser"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser,
|
||||||
|
uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser,
|
||||||
|
uuidArg(userID), uuidArg(roleID)); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.GetUserRoles"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.Role
|
||||||
|
for rows.Next() {
|
||||||
|
r, err := scanRole(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, r)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasAnyRole builds the IN-clause at call time because the placeholder count
|
||||||
|
// depends on len(names). Identifier substitution comes from the validated
|
||||||
|
// Schema; values are bound, never interpolated.
|
||||||
|
func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
|
||||||
|
const op = "authkit.sqlstore.RoleStore.HasAnyRole"
|
||||||
|
if len(names) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
// Placeholder $1 is the user_id; $2..$N+1 cover the names slice.
|
||||||
|
listSQL := s.d.PlaceholderList(2, len(names))
|
||||||
|
q := fmt.Sprintf(`SELECT EXISTS (
|
||||||
|
SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id
|
||||||
|
WHERE ur.user_id = %s AND r.name IN (%s)
|
||||||
|
)`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL)
|
||||||
|
|
||||||
|
args := make([]any, 0, 1+len(names))
|
||||||
|
args = append(args, uuidArg(userID))
|
||||||
|
for _, n := range names {
|
||||||
|
args = append(args, n)
|
||||||
|
}
|
||||||
|
var ok bool
|
||||||
|
if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil {
|
||||||
|
return false, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- permissionStore ------------------------------------------------------
|
||||||
|
|
||||||
|
func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.CreatePermission"
|
||||||
|
if p.ID == uuid.Nil {
|
||||||
|
p.ID = uuid.New()
|
||||||
|
}
|
||||||
|
if p.CreatedAt.IsZero() {
|
||||||
|
p.CreatedAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.CreatePermission,
|
||||||
|
uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil {
|
||||||
|
if s.d.IsUniqueViolation(err) {
|
||||||
|
return errx.Wrapf(op, err, "permission %q already exists", p.Name)
|
||||||
|
}
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.GetPermissionByID"
|
||||||
|
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.GetPermissionByName"
|
||||||
|
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.ListPermissions"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.ListPermissions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.Permission
|
||||||
|
for rows.Next() {
|
||||||
|
p, err := scanPermission(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.DeletePermission"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrPermissionNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole,
|
||||||
|
uuidArg(roleID), uuidArg(permID)); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole,
|
||||||
|
uuidArg(roleID), uuidArg(permID)); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.GetRolePermissions"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.Permission
|
||||||
|
for rows.Next() {
|
||||||
|
p, err := scanPermission(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) {
|
||||||
|
const op = "authkit.sqlstore.PermissionStore.GetUserPermissions"
|
||||||
|
rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var out []*authkit.Permission
|
||||||
|
for rows.Next() {
|
||||||
|
p, err := scanPermission(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
return out, errx.Wrap(op, rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanRole(row rowScanner) (*authkit.Role, error) {
|
||||||
|
var (
|
||||||
|
r authkit.Role
|
||||||
|
idStr string
|
||||||
|
)
|
||||||
|
if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id, err := scanUUID(idStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.ID = id
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanPermission(row rowScanner) (*authkit.Permission, error) {
|
||||||
|
var (
|
||||||
|
p authkit.Permission
|
||||||
|
idStr string
|
||||||
|
)
|
||||||
|
if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id, err := scanUUID(idStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.ID = id
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
86
sqlstore/scan.go
Normal file
86
sqlstore/scan.go
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rowScanner is the lowest-common-denominator interface satisfied by both
|
||||||
|
// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops
|
||||||
|
// uniformly.
|
||||||
|
type rowScanner interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullableTime returns nil when t is the zero time, otherwise &t. Callers
|
||||||
|
// should usually accept *time.Time on the model side and bind via this.
|
||||||
|
func nullableTime(t *time.Time) any {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *t
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullableString turns "" into nil so columns store NULL rather than an
|
||||||
|
// empty string. Used for password hashes and similar optional text columns.
|
||||||
|
func nullableString(s string) any {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullableAddrString returns the string form of a netip.Addr when valid, or
|
||||||
|
// nil to bind as SQL NULL. Pairs with scanAddr.
|
||||||
|
func nullableAddrString(a netip.Addr) any {
|
||||||
|
if !a.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanAddr parses a *string column into a netip.Addr. The zero Addr value is
|
||||||
|
// returned when the column was NULL or empty.
|
||||||
|
func scanAddr(s *string) (netip.Addr, error) {
|
||||||
|
if s == nil || *s == "" {
|
||||||
|
return netip.Addr{}, nil
|
||||||
|
}
|
||||||
|
return netip.ParseAddr(*s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// uuidArg returns the canonical string form of a UUID for binding. Every
|
||||||
|
// store uses this rather than passing uuid.UUID directly to keep behaviour
|
||||||
|
// identical across drivers (some accept driver.Valuer, some don't).
|
||||||
|
func uuidArg(id uuid.UUID) any { return id.String() }
|
||||||
|
|
||||||
|
// scanUUID reads a string column and parses it back to a uuid.UUID.
|
||||||
|
func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }
|
||||||
|
|
||||||
|
// chainArg returns either a *string or nil for binding the chain_id column.
|
||||||
|
func chainArg(c *string) any {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *c
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanNullStringPtr converts sql.NullString to *string for the model.
|
||||||
|
func scanNullStringPtr(ns sql.NullString) *string {
|
||||||
|
if !ns.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v := ns.String
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanNullTimePtr converts sql.NullTime to *time.Time for the model.
|
||||||
|
func scanNullTimePtr(nt sql.NullTime) *time.Time {
|
||||||
|
if !nt.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t := nt.Time
|
||||||
|
return &t
|
||||||
|
}
|
||||||
78
sqlstore/schema.go
Normal file
78
sqlstore/schema.go
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema lets consumers map authkit storage to their own table names.
|
||||||
|
// Column overrides are intentionally not present in v1 — adding them later
|
||||||
|
// is purely additive.
|
||||||
|
type Schema struct {
|
||||||
|
Tables Tables
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tables is the per-table identifier override set. Every field must be a
|
||||||
|
// valid unquoted SQL identifier (matching identifierRE). Validation runs at
|
||||||
|
// New() and Migrate() time so SQL injection through Schema is impossible
|
||||||
|
// past that gate.
|
||||||
|
type Tables struct {
|
||||||
|
Users string
|
||||||
|
Sessions string
|
||||||
|
Tokens string
|
||||||
|
APIKeys string
|
||||||
|
Roles string
|
||||||
|
Permissions string
|
||||||
|
UserRoles string
|
||||||
|
RolePermissions string
|
||||||
|
SchemaMigrations string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSchema returns the stock authkit_* names used by the embedded
|
||||||
|
// migration files.
|
||||||
|
func DefaultSchema() Schema {
|
||||||
|
return Schema{Tables: Tables{
|
||||||
|
Users: "authkit_users",
|
||||||
|
Sessions: "authkit_sessions",
|
||||||
|
Tokens: "authkit_tokens",
|
||||||
|
APIKeys: "authkit_api_keys",
|
||||||
|
Roles: "authkit_roles",
|
||||||
|
Permissions: "authkit_permissions",
|
||||||
|
UserRoles: "authkit_user_roles",
|
||||||
|
RolePermissions: "authkit_role_permissions",
|
||||||
|
SchemaMigrations: "authkit_schema_migrations",
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// identifierRE matches the safe ASCII identifier subset shared by Postgres,
|
||||||
|
// MySQL and SQLite when not quoted. Anything outside this set is rejected
|
||||||
|
// rather than escaped — Schema is not the place to support exotic names.
|
||||||
|
var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||||
|
|
||||||
|
// Validate ensures every Schema.Tables field is a non-empty, safe identifier.
|
||||||
|
func (s Schema) Validate() error {
|
||||||
|
const op = "authkit.sqlstore.Schema.Validate"
|
||||||
|
checks := []struct {
|
||||||
|
field, value string
|
||||||
|
}{
|
||||||
|
{"Users", s.Tables.Users},
|
||||||
|
{"Sessions", s.Tables.Sessions},
|
||||||
|
{"Tokens", s.Tables.Tokens},
|
||||||
|
{"APIKeys", s.Tables.APIKeys},
|
||||||
|
{"Roles", s.Tables.Roles},
|
||||||
|
{"Permissions", s.Tables.Permissions},
|
||||||
|
{"UserRoles", s.Tables.UserRoles},
|
||||||
|
{"RolePermissions", s.Tables.RolePermissions},
|
||||||
|
{"SchemaMigrations", s.Tables.SchemaMigrations},
|
||||||
|
}
|
||||||
|
for _, c := range checks {
|
||||||
|
if c.value == "" {
|
||||||
|
return errx.Newf(op, "Schema.Tables.%s is empty", c.field)
|
||||||
|
}
|
||||||
|
if !identifierRE.MatchString(c.value) {
|
||||||
|
return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
98
sqlstore/sessions.go
Normal file
98
sqlstore/sessions.go
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sessionStore struct{ storeBase }
|
||||||
|
|
||||||
|
func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.CreateSession"
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if ses.CreatedAt.IsZero() {
|
||||||
|
ses.CreatedAt = now
|
||||||
|
}
|
||||||
|
if ses.LastSeenAt.IsZero() {
|
||||||
|
ses.LastSeenAt = ses.CreatedAt
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx, s.q.CreateSession,
|
||||||
|
ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP),
|
||||||
|
ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.GetSession"
|
||||||
|
var (
|
||||||
|
ses authkit.Session
|
||||||
|
uidStr string
|
||||||
|
ipStr sql.NullString
|
||||||
|
)
|
||||||
|
err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan(
|
||||||
|
&ses.IDHash, &uidStr, &ses.UserAgent, &ipStr,
|
||||||
|
&ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid))
|
||||||
|
}
|
||||||
|
uid, err := scanUUID(uidStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
ses.UserID = uid
|
||||||
|
if ipStr.Valid {
|
||||||
|
addr, err := scanAddr(&ipStr.String)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
ses.IP = addr
|
||||||
|
}
|
||||||
|
return &ses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.TouchSession"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrSessionInvalid)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.DeleteSession"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.DeleteUserSessions"
|
||||||
|
if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
const op = "authkit.sqlstore.SessionStore.DeleteExpired"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
59
sqlstore/sqlstore.go
Normal file
59
sqlstore/sqlstore.go
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
// Package sqlstore provides database/sql-backed implementations of every
|
||||||
|
// authkit store interface. It works with any driver (pgx-stdlib, lib/pq,
|
||||||
|
// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via
|
||||||
|
// Schema. Driver-specific behaviour lives in a Dialect; v1 ships
|
||||||
|
// dialect/postgres.
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Stores bundles every store implementation produced by New, ready to drop
|
||||||
|
// into authkit.Deps.
|
||||||
|
type Stores struct {
|
||||||
|
Users authkit.UserStore
|
||||||
|
Sessions authkit.SessionStore
|
||||||
|
Tokens authkit.TokenStore
|
||||||
|
APIKeys authkit.APIKeyStore
|
||||||
|
Roles authkit.RoleStore
|
||||||
|
Permissions authkit.PermissionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// New validates the Schema, builds the dialect's templated queries, and
|
||||||
|
// returns store implementations bound to db.
|
||||||
|
func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) {
|
||||||
|
const op = "authkit.sqlstore.New"
|
||||||
|
if db == nil {
|
||||||
|
return nil, errx.New(op, "db is required")
|
||||||
|
}
|
||||||
|
if dialect == nil {
|
||||||
|
return nil, errx.New(op, "dialect is required")
|
||||||
|
}
|
||||||
|
if err := schema.Validate(); err != nil {
|
||||||
|
return nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
q := dialect.BuildQueries(schema)
|
||||||
|
|
||||||
|
base := storeBase{db: db, q: q, d: dialect, s: schema}
|
||||||
|
return &Stores{
|
||||||
|
Users: &userStore{storeBase: base},
|
||||||
|
Sessions: &sessionStore{storeBase: base},
|
||||||
|
Tokens: &tokenStore{storeBase: base},
|
||||||
|
APIKeys: &apiKeyStore{storeBase: base},
|
||||||
|
Roles: &roleStore{storeBase: base},
|
||||||
|
Permissions: &permissionStore{storeBase: base},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeBase carries the shared dependencies every store needs. Embedded into
|
||||||
|
// each concrete store struct.
|
||||||
|
type storeBase struct {
|
||||||
|
db *sql.DB
|
||||||
|
q Queries
|
||||||
|
d Dialect
|
||||||
|
s Schema
|
||||||
|
}
|
||||||
258
sqlstore/sqlstore_test.go
Normal file
258
sqlstore/sqlstore_test.go
Normal file
|
|
@ -0,0 +1,258 @@
|
||||||
|
package sqlstore_test
|
||||||
|
|
||||||
|
// Integration tests against a real Postgres. Skipped unless
|
||||||
|
// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by
|
||||||
|
// running Migrate against a randomly-named set of tables — no cleanup
|
||||||
|
// fixture, no external Docker dependency.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/authkit/hasher"
|
||||||
|
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||||
|
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
func envURL(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
|
||||||
|
if url == "" {
|
||||||
|
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
|
||||||
|
}
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
// freshDB opens a connection, runs Migrate against the default schema, and
|
||||||
|
// schedules a teardown that drops every authkit_* table. Tests must run
|
||||||
|
// sequentially against a single database (the package's tests do, by
|
||||||
|
// default — go test serialises within a package unless t.Parallel is
|
||||||
|
// called).
|
||||||
|
func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) {
|
||||||
|
t.Helper()
|
||||||
|
url := envURL(t)
|
||||||
|
db, err := sql.Open("pgx", url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
if err := db.PingContext(context.Background()); err != nil {
|
||||||
|
t.Fatalf("ping: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := sqlstore.DefaultSchema()
|
||||||
|
// Drop first to start clean — previous failed tests may have left rows.
|
||||||
|
dropAuthkitTables(t, db, schema)
|
||||||
|
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
|
||||||
|
|
||||||
|
stores, err := sqlstore.New(db, pgdialect.New(), schema)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sqlstore.New: %v", err)
|
||||||
|
}
|
||||||
|
auth := authkit.New(authkit.Deps{
|
||||||
|
Users: stores.Users,
|
||||||
|
Sessions: stores.Sessions,
|
||||||
|
Tokens: stores.Tokens,
|
||||||
|
APIKeys: stores.APIKeys,
|
||||||
|
Roles: stores.Roles,
|
||||||
|
Permissions: stores.Permissions,
|
||||||
|
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
|
||||||
|
}, authkit.Config{
|
||||||
|
JWTSecret: []byte("integration-secret-thirty-two!!!!"),
|
||||||
|
JWTIssuer: "authkit-int",
|
||||||
|
AccessTokenTTL: 2 * time.Minute,
|
||||||
|
RefreshTokenTTL: 1 * time.Hour,
|
||||||
|
SessionIdleTTL: time.Hour,
|
||||||
|
SessionAbsoluteTTL: 24 * time.Hour,
|
||||||
|
EmailVerifyTTL: time.Hour,
|
||||||
|
PasswordResetTTL: time.Hour,
|
||||||
|
MagicLinkTTL: time.Minute,
|
||||||
|
})
|
||||||
|
return auth, db, schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) {
|
||||||
|
t.Helper()
|
||||||
|
tables := []string{
|
||||||
|
s.Tables.UserRoles, s.Tables.RolePermissions,
|
||||||
|
s.Tables.Roles, s.Tables.Permissions,
|
||||||
|
s.Tables.APIKeys, s.Tables.Tokens,
|
||||||
|
s.Tables.Sessions, s.Tables.Users,
|
||||||
|
s.Tables.SchemaMigrations,
|
||||||
|
}
|
||||||
|
for _, name := range tables {
|
||||||
|
_, _ = db.ExecContext(context.Background(),
|
||||||
|
fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_MigrateIdempotent(t *testing.T) {
|
||||||
|
url := envURL(t)
|
||||||
|
db, err := sql.Open("pgx", url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
schema := sqlstore.DefaultSchema()
|
||||||
|
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
|
||||||
|
t.Fatalf("Migrate iter %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_RegisterAndLogin(t *testing.T) {
|
||||||
|
auth, _, _ := freshDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoginPassword: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != u.ID {
|
||||||
|
t.Fatalf("login user id mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) {
|
||||||
|
t.Fatalf("expected ErrEmailTaken, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) {
|
||||||
|
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_SessionLifecycle(t *testing.T) {
|
||||||
|
auth, _, _ := freshDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
u, err := auth.Register(ctx, "s@s.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueSession: %v", err)
|
||||||
|
}
|
||||||
|
if sess.ExpiresAt.Before(time.Now()) {
|
||||||
|
t.Fatalf("session already expired at issue")
|
||||||
|
}
|
||||||
|
if _, err := auth.AuthenticateSession(ctx, plain); err != nil {
|
||||||
|
t.Fatalf("AuthenticateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := auth.RevokeSession(ctx, plain); err != nil {
|
||||||
|
t.Fatalf("RevokeSession: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) {
|
||||||
|
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) {
|
||||||
|
auth, _, _ := freshDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
u, err := auth.Register(ctx, "j@j.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
_, refresh1, err := auth.IssueJWT(ctx, u.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueJWT: %v", err)
|
||||||
|
}
|
||||||
|
_, refresh2, err := auth.RefreshJWT(ctx, refresh1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first RefreshJWT: %v", err)
|
||||||
|
}
|
||||||
|
if refresh1 == refresh2 {
|
||||||
|
t.Fatalf("refresh token did not rotate")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) {
|
||||||
|
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
|
||||||
|
}
|
||||||
|
if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) {
|
||||||
|
t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_APIKeyWithAbilities(t *testing.T) {
|
||||||
|
auth, _, _ := freshDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
u, err := auth.Register(ctx, "k@k.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
plain, _, err := auth.IssueAPIKey(ctx, u.ID, "ci",
|
||||||
|
[]string{"billing:read", "users:list"}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
p, err := auth.AuthenticateAPIKey(ctx, plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthenticateAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
|
||||||
|
t.Fatalf("abilities missing: %+v", p.Abilities)
|
||||||
|
}
|
||||||
|
if err := auth.RevokeAPIKey(ctx, plain); err != nil {
|
||||||
|
t.Fatalf("RevokeAPIKey: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := auth.AuthenticateAPIKey(ctx, plain); !errors.Is(err, authkit.ErrAPIKeyInvalid) {
|
||||||
|
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_RBAC(t *testing.T) {
|
||||||
|
auth, db, schema := freshDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
u, err := auth.Register(ctx, "rb@b.com", "pw")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
|
||||||
|
r := &authkit.Role{Name: "editor"}
|
||||||
|
if err := stores.Roles.CreateRole(ctx, r); err != nil {
|
||||||
|
t.Fatalf("CreateRole: %v", err)
|
||||||
|
}
|
||||||
|
p := &authkit.Permission{Name: "posts:write"}
|
||||||
|
if err := stores.Permissions.CreatePermission(ctx, p); err != nil {
|
||||||
|
t.Fatalf("CreatePermission: %v", err)
|
||||||
|
}
|
||||||
|
if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil {
|
||||||
|
t.Fatalf("AssignPermissionToRole: %v", err)
|
||||||
|
}
|
||||||
|
if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil {
|
||||||
|
t.Fatalf("AssignRole: %v", err)
|
||||||
|
}
|
||||||
|
ok, err := auth.HasPermission(ctx, u.ID, "posts:write")
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("HasPermission: %v %v", ok, err)
|
||||||
|
}
|
||||||
|
ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"})
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("HasAnyRole: %v %v", ok, err)
|
||||||
|
}
|
||||||
|
if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil {
|
||||||
|
t.Fatalf("RemoveRole: %v", err)
|
||||||
|
}
|
||||||
|
ok, _ = auth.HasPermission(ctx, u.ID, "posts:write")
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("HasPermission should be false after RemoveRole")
|
||||||
|
}
|
||||||
|
}
|
||||||
92
sqlstore/tokens.go
Normal file
92
sqlstore/tokens.go
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenStore struct{ storeBase }
|
||||||
|
|
||||||
|
func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error {
|
||||||
|
const op = "authkit.sqlstore.TokenStore.CreateToken"
|
||||||
|
if t.CreatedAt.IsZero() {
|
||||||
|
t.CreatedAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx, s.q.CreateToken,
|
||||||
|
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
|
||||||
|
nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumeToken does the find-and-mark-consumed in a single statement so two
|
||||||
|
// concurrent callers cannot both successfully consume the same token. The
|
||||||
|
// row is returned for inspection (e.g. ChainID for refresh rotation).
|
||||||
|
func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) {
|
||||||
|
const op = "authkit.sqlstore.TokenStore.ConsumeToken"
|
||||||
|
row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now)
|
||||||
|
t, err := scanToken(row)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) {
|
||||||
|
const op = "authkit.sqlstore.TokenStore.GetToken"
|
||||||
|
row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash)
|
||||||
|
t, err := scanToken(row)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) {
|
||||||
|
const op = "authkit.sqlstore.TokenStore.DeleteByChain"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
const op = "authkit.sqlstore.TokenStore.DeleteExpired"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanToken(row rowScanner) (*authkit.Token, error) {
|
||||||
|
var (
|
||||||
|
t authkit.Token
|
||||||
|
kind string
|
||||||
|
userIDStr string
|
||||||
|
chainID sql.NullString
|
||||||
|
consumedAt sql.NullTime
|
||||||
|
)
|
||||||
|
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
|
||||||
|
&consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.Kind = authkit.TokenKind(kind)
|
||||||
|
uid, err := scanUUID(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.UserID = uid
|
||||||
|
t.ChainID = scanNullStringPtr(chainID)
|
||||||
|
t.ConsumedAt = scanNullTimePtr(consumedAt)
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
186
sqlstore/users.go
Normal file
186
sqlstore/users.go
Normal file
|
|
@ -0,0 +1,186 @@
|
||||||
|
package sqlstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/authkit"
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userStore struct{ storeBase }
|
||||||
|
|
||||||
|
func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.CreateUser"
|
||||||
|
if u.ID == uuid.Nil {
|
||||||
|
u.ID = uuid.New()
|
||||||
|
}
|
||||||
|
if u.EmailNormalized == "" {
|
||||||
|
u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email))
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if u.CreatedAt.IsZero() {
|
||||||
|
u.CreatedAt = now
|
||||||
|
}
|
||||||
|
if u.UpdatedAt.IsZero() {
|
||||||
|
u.UpdatedAt = now
|
||||||
|
}
|
||||||
|
_, err := s.db.ExecContext(ctx, s.q.CreateUser,
|
||||||
|
uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
|
||||||
|
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
|
||||||
|
nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
if s.d.IsUniqueViolation(err) {
|
||||||
|
return errx.Wrap(op, authkit.ErrEmailTaken)
|
||||||
|
}
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) {
|
||||||
|
const op = "authkit.sqlstore.UserStore.GetUserByID"
|
||||||
|
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) {
|
||||||
|
const op = "authkit.sqlstore.UserStore.GetUserByEmail"
|
||||||
|
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.UpdateUser"
|
||||||
|
u.UpdatedAt = time.Now().UTC()
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.UpdateUser,
|
||||||
|
u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
|
||||||
|
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
|
||||||
|
nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.DeleteUser"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.SetPassword"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.SetPassword,
|
||||||
|
nullableString(encodedHash), time.Now().UTC(), uuidArg(userID))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.SetEmailVerified"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||||
|
const op = "authkit.sqlstore.UserStore.BumpSessionVersion"
|
||||||
|
var v int
|
||||||
|
if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion,
|
||||||
|
time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil {
|
||||||
|
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||||
|
const op = "authkit.sqlstore.UserStore.IncrementFailedLogins"
|
||||||
|
var n int
|
||||||
|
if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins,
|
||||||
|
time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil {
|
||||||
|
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error {
|
||||||
|
const op = "authkit.sqlstore.UserStore.ResetFailedLogins"
|
||||||
|
tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID))
|
||||||
|
if err != nil {
|
||||||
|
return errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
n, _ := tag.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanUser(row rowScanner) (*authkit.User, error) {
|
||||||
|
var (
|
||||||
|
u authkit.User
|
||||||
|
idStr string
|
||||||
|
passwordHash sql.NullString
|
||||||
|
emailVerified sql.NullTime
|
||||||
|
lastLogin sql.NullTime
|
||||||
|
)
|
||||||
|
if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified,
|
||||||
|
&passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin,
|
||||||
|
&u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id, err := scanUUID(idStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
u.ID = id
|
||||||
|
if passwordHash.Valid {
|
||||||
|
u.PasswordHash = passwordHash.String
|
||||||
|
}
|
||||||
|
u.EmailVerifiedAt = scanNullTimePtr(emailVerified)
|
||||||
|
u.LastLoginAt = scanNullTimePtr(lastLogin)
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so
|
||||||
|
// callers get reliable errors.Is targets through errx wrapping.
|
||||||
|
func mapNotFound(err error, sentinel error) error {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return sentinel
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
83
stores.go
Normal file
83
stores.go
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserStore interface {
|
||||||
|
CreateUser(ctx context.Context, u *User) error
|
||||||
|
GetUserByID(ctx context.Context, id uuid.UUID) (*User, error)
|
||||||
|
GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error)
|
||||||
|
UpdateUser(ctx context.Context, u *User) error
|
||||||
|
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||||
|
SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error
|
||||||
|
SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error
|
||||||
|
BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error)
|
||||||
|
IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error)
|
||||||
|
ResetFailedLogins(ctx context.Context, userID uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionStore interface {
|
||||||
|
CreateSession(ctx context.Context, s *Session) error
|
||||||
|
GetSession(ctx context.Context, idHash []byte) (*Session, error)
|
||||||
|
TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error
|
||||||
|
DeleteSession(ctx context.Context, idHash []byte) error
|
||||||
|
DeleteUserSessions(ctx context.Context, userID uuid.UUID) error
|
||||||
|
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenStore interface {
|
||||||
|
CreateToken(ctx context.Context, t *Token) error
|
||||||
|
// ConsumeToken atomically marks the matching unexpired, unconsumed token
|
||||||
|
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
|
||||||
|
// Implementations MUST do this in one statement (UPDATE ... RETURNING)
|
||||||
|
// to prevent double-spend under concurrent callers.
|
||||||
|
ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error)
|
||||||
|
// GetToken returns a token without consuming it. Used for refresh-token
|
||||||
|
// reuse detection: a token that exists with consumed_at != nil is a
|
||||||
|
// replay signal.
|
||||||
|
GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error)
|
||||||
|
DeleteByChain(ctx context.Context, chainID string) (int64, error)
|
||||||
|
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIKeyStore interface {
|
||||||
|
CreateAPIKey(ctx context.Context, k *APIKey) error
|
||||||
|
GetAPIKey(ctx context.Context, idHash []byte) (*APIKey, error)
|
||||||
|
ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*APIKey, error)
|
||||||
|
TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error
|
||||||
|
RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error
|
||||||
|
RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoleStore interface {
|
||||||
|
CreateRole(ctx context.Context, r *Role) error
|
||||||
|
GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error)
|
||||||
|
GetRoleByName(ctx context.Context, name string) (*Role, error)
|
||||||
|
ListRoles(ctx context.Context) ([]*Role, error)
|
||||||
|
DeleteRole(ctx context.Context, id uuid.UUID) error
|
||||||
|
AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error
|
||||||
|
RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error
|
||||||
|
GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error)
|
||||||
|
HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PermissionStore interface {
|
||||||
|
CreatePermission(ctx context.Context, p *Permission) error
|
||||||
|
GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error)
|
||||||
|
GetPermissionByName(ctx context.Context, name string) (*Permission, error)
|
||||||
|
ListPermissions(ctx context.Context) ([]*Permission, error)
|
||||||
|
DeletePermission(ctx context.Context, id uuid.UUID) error
|
||||||
|
AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error
|
||||||
|
RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error
|
||||||
|
GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error)
|
||||||
|
GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hasher interface {
|
||||||
|
Hash(password string) (string, error)
|
||||||
|
Verify(password, encoded string) (ok bool, needsRehash bool, err error)
|
||||||
|
}
|
||||||
67
tokens.go
Normal file
67
tokens.go
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/errx"
|
||||||
|
)
|
||||||
|
|
||||||
|
const secretRandomBytes = 32
|
||||||
|
|
||||||
|
// Secret prefixes. The leading token namespaces the secret so callers can
|
||||||
|
// route them to the right verifier without trial-and-error.
|
||||||
|
const (
|
||||||
|
prefixSession = "sess"
|
||||||
|
prefixRefresh = "rfr"
|
||||||
|
prefixAPIKey = "ak"
|
||||||
|
prefixEmailVerify = "evr"
|
||||||
|
prefixPasswordRset = "pwr"
|
||||||
|
prefixMagicLink = "mlnk"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mintSecret generates a new opaque secret of the given prefix kind and
|
||||||
|
// returns the plaintext (to be returned to the user, never stored) and the
|
||||||
|
// SHA-256 lookup hash (to be stored).
|
||||||
|
func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) {
|
||||||
|
const op = "authkit.mintSecret"
|
||||||
|
if rng == nil {
|
||||||
|
rng = rand.Reader
|
||||||
|
}
|
||||||
|
buf := make([]byte, secretRandomBytes)
|
||||||
|
if _, err := io.ReadFull(rng, buf); err != nil {
|
||||||
|
return "", nil, errx.Wrap(op, err)
|
||||||
|
}
|
||||||
|
body := base64.RawURLEncoding.EncodeToString(buf)
|
||||||
|
plaintext = prefix + "_" + body
|
||||||
|
hash = hashSecret(plaintext)
|
||||||
|
return plaintext, hash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashSecret returns sha256(plaintext) — the lookup key for opaque secrets.
|
||||||
|
// Plaintexts have full entropy from a CSPRNG so a plain hash is sufficient
|
||||||
|
// (no per-record salt needed; the random body is the salt).
|
||||||
|
func hashSecret(plaintext string) []byte {
|
||||||
|
sum := sha256.Sum256([]byte(plaintext))
|
||||||
|
return sum[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSecret validates that a plaintext starts with the expected prefix
|
||||||
|
// and returns the lookup hash. The prefix check is constant-time relative
|
||||||
|
// to a fixed-length comparison.
|
||||||
|
func parseSecret(prefix, plaintext string) (hash []byte, ok bool) {
|
||||||
|
want := prefix + "_"
|
||||||
|
if !strings.HasPrefix(plaintext, want) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return hashSecret(plaintext), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// constantTimeEqual is a thin wrapper for readability at call sites.
|
||||||
|
func constantTimeEqual(a, b []byte) bool {
|
||||||
|
return subtle.ConstantTimeCompare(a, b) == 1
|
||||||
|
}
|
||||||
56
tokens_test.go
Normal file
56
tokens_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
package authkit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMintSecretRoundtrip(t *testing.T) {
|
||||||
|
plaintext, hash, err := mintSecret(prefixSession, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mintSecret: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(plaintext, prefixSession+"_") {
|
||||||
|
t.Fatalf("missing prefix: %q", plaintext)
|
||||||
|
}
|
||||||
|
parsed, ok := parseSecret(prefixSession, plaintext)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("parseSecret rejected our own mint")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(hash, parsed) {
|
||||||
|
t.Fatalf("hash mismatch")
|
||||||
|
}
|
||||||
|
want := sha256.Sum256([]byte(plaintext))
|
||||||
|
if !bytes.Equal(hash, want[:]) {
|
||||||
|
t.Fatalf("hashSecret != sha256(plaintext)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSecretWrongPrefix(t *testing.T) {
|
||||||
|
plaintext, _, err := mintSecret(prefixSession, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mintSecret: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := parseSecret(prefixAPIKey, plaintext); ok {
|
||||||
|
t.Fatalf("parseSecret should reject mismatched prefix")
|
||||||
|
}
|
||||||
|
if _, ok := parseSecret(prefixSession, "sessXXXX"); ok {
|
||||||
|
t.Fatalf("parseSecret should require trailing underscore")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMintSecretUniqueness(t *testing.T) {
|
||||||
|
seen := make(map[string]struct{}, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
p, _, err := mintSecret(prefixAPIKey, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mintSecret: %v", err)
|
||||||
|
}
|
||||||
|
if _, dup := seen[p]; dup {
|
||||||
|
t.Fatalf("duplicate mint: %s", p)
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue