authkit initial
This commit is contained in:
parent
5173b0a43d
commit
134393fbca
43 changed files with 5188 additions and 1 deletions
344
README.md
344
README.md
|
|
@ -1,3 +1,345 @@
|
|||
# authkit
|
||||
|
||||
Just a concoction of auth stuff in one place.
|
||||
A pragmatic authentication and authorization toolkit for Go web services.
|
||||
|
||||
`authkit` ships interfaces for users, sessions, tokens, API keys, roles, and
|
||||
permissions, plus default `database/sql` Postgres implementations and
|
||||
framework-neutral HTTP middleware. It supports both opaque server-side
|
||||
sessions and JWT access tokens with rotating refresh tokens, hashes passwords
|
||||
with Argon2id, and pairs naturally with [`lightmux`](https://git.juancwu.dev/juancwu/lightmux)
|
||||
or any `net/http` stack.
|
||||
|
||||
## Install
|
||||
|
||||
```
|
||||
go get git.juancwu.dev/juancwu/authkit
|
||||
```
|
||||
|
||||
`authkit` itself depends only on `database/sql` and the Go standard library
|
||||
plus `golang-jwt`, `google/uuid`, `golang.org/x/crypto`, and `errx`. Bring
|
||||
your own driver: `pgx`, `lib/pq`, or anything else that registers a
|
||||
`database/sql` driver.
|
||||
|
||||
```go
|
||||
import _ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
|
||||
```
|
||||
|
||||
PostgreSQL 12+ is sufficient — the schema avoids `gen_random_uuid()` and
|
||||
`pgcrypto` so no extensions are required.
|
||||
|
||||
## What's included
|
||||
|
||||
**Authentication flows**
|
||||
- Email + password registration and login (Argon2id PHC-encoded hashes)
|
||||
- Opaque server-side sessions with sliding TTL bounded by an absolute cap
|
||||
- JWT access tokens (HS256) with rotating refresh tokens and reuse detection
|
||||
- Email verification, password reset, and magic-link passwordless login
|
||||
|
||||
**Authorization**
|
||||
- Roles and permissions with many-to-many wiring
|
||||
- API keys with custom abilities for per-endpoint scoping
|
||||
- A unified `Principal` type so middleware works the same regardless of which
|
||||
authentication method ran
|
||||
|
||||
**Storage**
|
||||
- Interfaces for every store so callers can plug in their own backends
|
||||
- Default Postgres implementation built on `*sql.DB` (`sqlstore` package)
|
||||
- Override table names via `Schema` without forking — useful when authkit
|
||||
lives alongside an existing application schema
|
||||
- A `Dialect` abstraction so future MySQL / SQLite implementations slot in
|
||||
without changes to store code
|
||||
- Embedded versioned migrations applied by a `Migrate(ctx, db, dialect, schema)`
|
||||
helper that takes a session-scoped advisory lock
|
||||
|
||||
**HTTP**
|
||||
- `middleware.RequireSession`, `RequireJWT`, `RequireAPIKey`, `RequireAny`
|
||||
- `middleware.RequireRole`, `RequireAnyRole`, `RequirePermission`,
|
||||
`RequireAbility`
|
||||
- `middleware.PrincipalFrom(ctx)` to read the authenticated principal in
|
||||
handlers
|
||||
|
||||
**Errors**
|
||||
- Sentinel errors (`ErrEmailTaken`, `ErrInvalidCredentials`, `ErrTokenInvalid`,
|
||||
`ErrTokenReused`, `ErrSessionInvalid`, `ErrAPIKeyInvalid`,
|
||||
`ErrPermissionDenied`, ...) compatible with `errors.Is`
|
||||
- All internal errors wrap with [`errx`](https://git.juancwu.dev/juancwu/errx)
|
||||
for op tags
|
||||
|
||||
## Out of scope (v1)
|
||||
|
||||
MFA/TOTP, OAuth/social login, soft-delete, in-memory permission caching,
|
||||
pluggable JWT signers (HS256 only), built-in HTTP handlers, MySQL/SQLite
|
||||
dialects (architecture supports them; only Postgres ships in v1), and
|
||||
column-name overrides in `Schema` (table-name overrides only).
|
||||
|
||||
## Quick start
|
||||
|
||||
### 1. Open a database and run migrations
|
||||
|
||||
```go
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // or _ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
db, err := sql.Open("pgx", os.Getenv("DATABASE_URL"))
|
||||
if err != nil { /* ... */ }
|
||||
defer db.Close()
|
||||
|
||||
if err := sqlstore.Migrate(ctx, db, pgdialect.New(), sqlstore.DefaultSchema()); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
`Migrate` is idempotent and safe to call from multiple processes — it takes
|
||||
a session-scoped `pg_advisory_lock` to serialise rollouts.
|
||||
|
||||
`sqlx` users can pass `sqlxDB.DB` (the underlying `*sql.DB`) to the same
|
||||
calls — the library only cares about `*sql.DB`.
|
||||
|
||||
### 2. Wire the service
|
||||
|
||||
```go
|
||||
import (
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/authkit/hasher"
|
||||
)
|
||||
|
||||
stores, err := sqlstore.New(db, pgdialect.New(), sqlstore.DefaultSchema())
|
||||
if err != nil { /* ... */ }
|
||||
|
||||
auth := authkit.New(authkit.Deps{
|
||||
Users: stores.Users,
|
||||
Sessions: stores.Sessions,
|
||||
Tokens: stores.Tokens,
|
||||
APIKeys: stores.APIKeys,
|
||||
Roles: stores.Roles,
|
||||
Permissions: stores.Permissions,
|
||||
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
|
||||
}, authkit.Config{
|
||||
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
|
||||
JWTIssuer: "myapp",
|
||||
SessionCookieSecure: true,
|
||||
SessionCookieHTTPOnly: true,
|
||||
})
|
||||
```
|
||||
|
||||
`Config` zero values fall back to sensible defaults (24h idle / 30d absolute
|
||||
session TTL, 15m access tokens, 30d refresh tokens, 48h email-verify, 1h
|
||||
password-reset, 15m magic-link). `JWTSecret` and the seven `Deps` fields are
|
||||
required; `New` panics on a misconfiguration.
|
||||
|
||||
### 3. Use the service
|
||||
|
||||
```go
|
||||
// Registration + password login
|
||||
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
|
||||
u, err = auth.LoginPassword(ctx, "alice@example.com", "hunter2hunter2")
|
||||
|
||||
// Opaque session (cookie-friendly)
|
||||
plaintext, sess, err := auth.IssueSession(ctx, u.ID, r.UserAgent(), clientIP)
|
||||
http.SetCookie(w, auth.SessionCookie(plaintext, sess.ExpiresAt))
|
||||
|
||||
// JWT + rotating refresh
|
||||
access, refresh, err := auth.IssueJWT(ctx, u.ID)
|
||||
access, refresh, err = auth.RefreshJWT(ctx, refresh) // old refresh is consumed
|
||||
|
||||
// API key with abilities
|
||||
plaintext, key, err := auth.IssueAPIKey(ctx, u.ID, "ci",
|
||||
[]string{"billing:read", "users:list"}, nil)
|
||||
|
||||
// Email verification + password reset + magic link
|
||||
tok, err := auth.RequestEmailVerification(ctx, u.ID)
|
||||
_, err = auth.ConfirmEmail(ctx, tok)
|
||||
|
||||
tok, err = auth.RequestPasswordReset(ctx, "alice@example.com")
|
||||
err = auth.ConfirmPasswordReset(ctx, tok, "new-password")
|
||||
|
||||
tok, err = auth.RequestMagicLink(ctx, "alice@example.com")
|
||||
u, err = auth.ConsumeMagicLink(ctx, tok)
|
||||
```
|
||||
|
||||
The plaintext returned by `IssueSession`, `IssueJWT`, `IssueAPIKey`, and the
|
||||
token-minting flows is **show-once** — only its SHA-256 hash is stored. Show
|
||||
it to the user immediately; you cannot recover it later.
|
||||
|
||||
### 4. Wire middleware
|
||||
|
||||
`authkit/middleware` returns standard `func(http.Handler) http.Handler`
|
||||
values, so it composes with `lightmux.Mux.Use`/`Group`/`Handle` and any
|
||||
`net/http` mux that accepts the same shape.
|
||||
|
||||
```go
|
||||
import (
|
||||
authkitmw "git.juancwu.dev/juancwu/authkit/middleware"
|
||||
"git.juancwu.dev/juancwu/lightmux"
|
||||
)
|
||||
|
||||
mux := lightmux.New()
|
||||
|
||||
cookieAuth := authkitmw.RequireSession(authkitmw.Options{
|
||||
Auth: auth,
|
||||
Extractor: authkit.ChainExtractors(
|
||||
authkit.CookieExtractor("authkit_session"),
|
||||
authkit.BearerExtractor(),
|
||||
),
|
||||
})
|
||||
|
||||
me := mux.Group("/me", cookieAuth)
|
||||
me.Get("", func(w http.ResponseWriter, r *http.Request) {
|
||||
p := authkitmw.MustPrincipal(r)
|
||||
json.NewEncoder(w).Encode(p)
|
||||
})
|
||||
|
||||
// RBAC: stack authz on top of any auth method
|
||||
admin := mux.Group("/admin", cookieAuth, authkitmw.RequireRole("admin"))
|
||||
|
||||
// API-key-only route with a per-endpoint ability check
|
||||
api := mux.Group("/api/v1", authkitmw.RequireAPIKey(authkitmw.Options{Auth: auth}))
|
||||
api.Get("/billing", billingHandler, authkitmw.RequireAbility("billing:read"))
|
||||
```
|
||||
|
||||
`Options.Extractor` defaults to `BearerExtractor`; pass `CookieExtractor` (or
|
||||
chain extractors) when reading session cookies. `Options.OnUnauth` and
|
||||
`Options.OnForbidden` default to a JSON `401` / `403`; override them to match
|
||||
your error envelope.
|
||||
|
||||
## Custom table names
|
||||
|
||||
Pass a non-default `Schema` to use your own table names. Identifiers must
|
||||
match `^[a-zA-Z_][a-zA-Z0-9_]*$`; anything else is rejected at `New()` and
|
||||
`Migrate()` time, so SQL injection through the schema is impossible.
|
||||
|
||||
```go
|
||||
schema := sqlstore.DefaultSchema()
|
||||
schema.Tables.Users = "accounts"
|
||||
schema.Tables.APIKeys = "api_credentials"
|
||||
|
||||
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
|
||||
```
|
||||
|
||||
The bundled migration files use the default `authkit_*` names. If you
|
||||
override, you're responsible for matching DDL (most consumers with custom
|
||||
naming already have their own DDL pipeline).
|
||||
|
||||
Column-name overrides are not exposed in v1 — the column set is fixed for
|
||||
each table. Adding column overrides later is purely additive.
|
||||
|
||||
## How things work
|
||||
|
||||
### Secret token format
|
||||
|
||||
Sessions, refresh tokens, API keys, email-verify tokens, password-reset tokens,
|
||||
and magic-link tokens all share one format:
|
||||
|
||||
```
|
||||
plaintext = "<prefix>_" + base64url(32 random bytes, no padding)
|
||||
lookup = sha256(plaintext)
|
||||
```
|
||||
|
||||
Plaintext is returned to the caller exactly once and never persisted; the
|
||||
SHA-256 is the database lookup key. Random bytes come from `crypto/rand` (or
|
||||
`Config.Random` for tests).
|
||||
|
||||
### JWT revocation
|
||||
|
||||
Access tokens carry `sv` (`session_version`) in their claims. When you call
|
||||
`RevokeAllUserSessions` or `ChangePassword`, the user's `session_version`
|
||||
column increments and every outstanding access token fails the next
|
||||
`AuthenticateJWT`. This is the only way to invalidate a JWT before its
|
||||
`exp`.
|
||||
|
||||
### Refresh token rotation
|
||||
|
||||
Each `RefreshJWT` consumes the presented refresh token and issues a new one
|
||||
on the same chain (the `chain_id` column on `authkit_tokens`). If a
|
||||
*consumed* refresh token is ever presented again — a strong replay signal —
|
||||
the entire chain is deleted via `TokenStore.DeleteByChain` and the call
|
||||
returns `ErrTokenReused`.
|
||||
|
||||
### Sliding session TTL
|
||||
|
||||
Each authenticated request via `AuthenticateSession` slides `expires_at` to
|
||||
`now + Config.SessionIdleTTL`, capped at `created_at + Config.SessionAbsoluteTTL`.
|
||||
Long-lived idle sessions still hit the absolute boundary.
|
||||
|
||||
### Schema and migrations
|
||||
|
||||
`sqlstore.Migrate` applies every embedded `.sql` file under
|
||||
`sqlstore/dialect/postgres/migrations/` whose version (filename without
|
||||
`.sql`) is not in `authkit_schema_migrations`. The dialect's
|
||||
`AcquireMigrationLock` (Postgres uses `pg_advisory_lock`) serialises
|
||||
concurrent migrators. Each migration owns its own transaction so future
|
||||
migrations can use statements like `CREATE INDEX CONCURRENTLY`.
|
||||
|
||||
Every default table is prefixed `authkit_` so the schema can live alongside
|
||||
your application's own tables in a shared database.
|
||||
|
||||
### Driver and dialect architecture
|
||||
|
||||
The `sqlstore` package speaks `database/sql` only. Driver-specific behaviour
|
||||
lives behind a small `Dialect` interface:
|
||||
|
||||
```go
|
||||
type Dialect interface {
|
||||
Name() string
|
||||
BuildQueries(s Schema) Queries
|
||||
Bootstrap(ctx context.Context, db *sql.DB) error
|
||||
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
|
||||
Migrations() fs.FS
|
||||
IsUniqueViolation(err error) bool
|
||||
Placeholder(n int) string
|
||||
PlaceholderList(start, count int) string
|
||||
}
|
||||
```
|
||||
|
||||
v1 ships `dialect/postgres`. A future MySQL or SQLite dialect adds a new
|
||||
implementation; no changes to store code.
|
||||
|
||||
## Configuration reference
|
||||
|
||||
| Field | Default | Notes |
|
||||
|---|---|---|
|
||||
| `SessionIdleTTL` | 24h | Sliding window applied on each authenticated request |
|
||||
| `SessionAbsoluteTTL` | 30d | Cap from `created_at`; sliding never exceeds this |
|
||||
| `SessionCookieName` | `authkit_session` | |
|
||||
| `SessionCookieSameSite` | `Lax` | |
|
||||
| `SessionCookieSecure` / `HTTPOnly` | `false` / `false` | Set both to `true` in production |
|
||||
| `JWTSecret` | — (required) | HS256 key |
|
||||
| `JWTIssuer` / `JWTAudience` | empty | When set, parser enforces them |
|
||||
| `AccessTokenTTL` | 15m | |
|
||||
| `RefreshTokenTTL` | 30d | |
|
||||
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
|
||||
| `Clock` | `time.Now().UTC` | Controls every observable timestamp; override for deterministic tests |
|
||||
| `Random` | `crypto/rand.Reader` | Override for deterministic tests |
|
||||
| `LoginHook` | nil | `func(ctx, email, success) error`; integration point for rate limiting / audit |
|
||||
|
||||
## Implementing your own store
|
||||
|
||||
Every store is a small interface with explicit semantics — see `stores.go`.
|
||||
The most subtle contract is `TokenStore.ConsumeToken`: it MUST mark the token
|
||||
consumed and return it in a single statement (`UPDATE ... RETURNING` on
|
||||
Postgres / SQLite 3.35+) so two concurrent callers cannot both succeed.
|
||||
|
||||
## Testing
|
||||
|
||||
```
|
||||
go test ./... # unit tests, no DB
|
||||
AUTHKIT_TEST_DATABASE_URL=postgres://... go test ./sqlstore... # integration tests
|
||||
```
|
||||
|
||||
Unit tests cover token mint/parse, Argon2id encode/verify (including
|
||||
`needsRehash` on parameter change), JWT issue/parse (incl. expired,
|
||||
`sv`-mismatch, refresh rotation, reuse detection), session lifecycle, email
|
||||
verification, password reset cascading session invalidation, magic-link
|
||||
self-verification, API keys with abilities, and RBAC role-permission
|
||||
resolution. Integration tests run the full `sqlstore` contract against a
|
||||
real Postgres when `AUTHKIT_TEST_DATABASE_URL` is set.
|
||||
|
||||
## License
|
||||
|
||||
MIT. See `LICENSE`.
|
||||
|
|
|
|||
53
Taskfile.yml
Normal file
53
Taskfile.yml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
version: '3'
|
||||
|
||||
tasks:
|
||||
default:
|
||||
desc: Run vet and unit tests
|
||||
cmds:
|
||||
- task: check
|
||||
|
||||
install:tools:
|
||||
desc: Install development tools (tparse for prettier test output)
|
||||
cmds:
|
||||
- go install github.com/mfridman/tparse@latest
|
||||
|
||||
build:
|
||||
desc: Build all packages
|
||||
cmds:
|
||||
- go build ./...
|
||||
|
||||
test:
|
||||
desc: Run unit tests with prettier output via tparse
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -json -cover | tparse -all
|
||||
|
||||
test:race:
|
||||
desc: Run unit tests under the race detector
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -race -json | tparse -all
|
||||
|
||||
test:integration:
|
||||
desc: Run sqlstore integration tests against a real Postgres (requires AUTHKIT_TEST_DATABASE_URL)
|
||||
cmds:
|
||||
- set -o pipefail && go test ./sqlstore/... -run Integration -count=1 -json | tparse -all
|
||||
|
||||
vet:
|
||||
desc: Run go vet
|
||||
cmds:
|
||||
- go vet ./...
|
||||
|
||||
fmt:
|
||||
desc: Format Go source files
|
||||
cmds:
|
||||
- gofmt -w .
|
||||
|
||||
tidy:
|
||||
desc: Tidy go.mod
|
||||
cmds:
|
||||
- go mod tidy
|
||||
|
||||
check:
|
||||
desc: Run vet and unit tests
|
||||
cmds:
|
||||
- task: vet
|
||||
- task: test
|
||||
121
authkit.go
Normal file
121
authkit.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// Deps bundles every backing store and the password hasher the Auth service
|
||||
// depends on. All fields are required; New panics on a nil dep so misuse is
|
||||
// caught at boot rather than under load.
|
||||
type Deps struct {
|
||||
Users UserStore
|
||||
Sessions SessionStore
|
||||
Tokens TokenStore
|
||||
APIKeys APIKeyStore
|
||||
Roles RoleStore
|
||||
Permissions PermissionStore
|
||||
Hasher Hasher
|
||||
}
|
||||
|
||||
// Config tunes session/JWT/token TTLs, cookie shape, JWT signing material,
|
||||
// and optional hooks. Any zero-valued duration is replaced with a sane
|
||||
// default in New; required fields (notably JWTSecret) cause New to panic.
|
||||
type Config struct {
|
||||
// Session (opaque) cookies + DB-backed lifetime
|
||||
SessionIdleTTL time.Duration
|
||||
SessionAbsoluteTTL time.Duration
|
||||
SessionCookieName string
|
||||
SessionCookieDomain string
|
||||
SessionCookiePath string
|
||||
SessionCookieSecure bool
|
||||
SessionCookieHTTPOnly bool
|
||||
SessionCookieSameSite http.SameSite
|
||||
|
||||
// JWT (HS256)
|
||||
JWTSecret []byte
|
||||
JWTIssuer string
|
||||
JWTAudience string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
|
||||
// Single-use tokens
|
||||
EmailVerifyTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
MagicLinkTTL time.Duration
|
||||
|
||||
// Hooks (optional)
|
||||
Clock func() time.Time
|
||||
Random io.Reader
|
||||
LoginHook func(ctx context.Context, email string, success bool) error
|
||||
}
|
||||
|
||||
// Auth is the high-level service that composes the stores and hasher into the
|
||||
// flows callers use: registration, login, sessions, JWTs, magic links, API
|
||||
// keys, and authz checks. It is safe for concurrent use; method receivers
|
||||
// never mutate Auth state after construction.
|
||||
type Auth struct {
|
||||
deps Deps
|
||||
cfg Config
|
||||
}
|
||||
|
||||
// New validates Deps and Config, fills in defaults, and returns a ready
|
||||
// service. It panics on missing deps or missing JWT secret rather than
|
||||
// returning an error — these are programmer errors, not runtime ones.
|
||||
func New(deps Deps, cfg Config) *Auth {
|
||||
if deps.Users == nil || deps.Sessions == nil || deps.Tokens == nil ||
|
||||
deps.APIKeys == nil || deps.Roles == nil || deps.Permissions == nil ||
|
||||
deps.Hasher == nil {
|
||||
panic(errx.New("authkit.New", "all Deps fields are required"))
|
||||
}
|
||||
if len(cfg.JWTSecret) == 0 {
|
||||
panic(errx.New("authkit.New", "Config.JWTSecret is required"))
|
||||
}
|
||||
|
||||
if cfg.SessionIdleTTL == 0 {
|
||||
cfg.SessionIdleTTL = 24 * time.Hour
|
||||
}
|
||||
if cfg.SessionAbsoluteTTL == 0 {
|
||||
cfg.SessionAbsoluteTTL = 30 * 24 * time.Hour
|
||||
}
|
||||
if cfg.SessionCookieName == "" {
|
||||
cfg.SessionCookieName = "authkit_session"
|
||||
}
|
||||
if cfg.SessionCookiePath == "" {
|
||||
cfg.SessionCookiePath = "/"
|
||||
}
|
||||
if cfg.SessionCookieSameSite == 0 {
|
||||
cfg.SessionCookieSameSite = http.SameSiteLaxMode
|
||||
}
|
||||
if cfg.AccessTokenTTL == 0 {
|
||||
cfg.AccessTokenTTL = 15 * time.Minute
|
||||
}
|
||||
if cfg.RefreshTokenTTL == 0 {
|
||||
cfg.RefreshTokenTTL = 30 * 24 * time.Hour
|
||||
}
|
||||
if cfg.EmailVerifyTTL == 0 {
|
||||
cfg.EmailVerifyTTL = 48 * time.Hour
|
||||
}
|
||||
if cfg.PasswordResetTTL == 0 {
|
||||
cfg.PasswordResetTTL = time.Hour
|
||||
}
|
||||
if cfg.MagicLinkTTL == 0 {
|
||||
cfg.MagicLinkTTL = 15 * time.Minute
|
||||
}
|
||||
if cfg.Clock == nil {
|
||||
cfg.Clock = func() time.Time { return time.Now().UTC() }
|
||||
}
|
||||
if cfg.Random == nil {
|
||||
cfg.Random = rand.Reader
|
||||
}
|
||||
|
||||
return &Auth{deps: deps, cfg: cfg}
|
||||
}
|
||||
|
||||
// now returns the configured wall clock, defaulting to time.Now in UTC.
|
||||
func (a *Auth) now() time.Time { return a.cfg.Clock() }
|
||||
13
doc.go
Normal file
13
doc.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
// Package authkit is an authentication and authorization toolkit for Go web
|
||||
// services. It defines storage interfaces (UserStore, SessionStore, TokenStore,
|
||||
// APIKeyStore, RoleStore, PermissionStore) and a high-level Auth service that
|
||||
// composes them to support registration, password login, opaque server-side
|
||||
// sessions, JWT access plus rotating refresh tokens, email verification,
|
||||
// password resets, magic-link passwordless login, role-based access control,
|
||||
// and API keys with custom abilities.
|
||||
//
|
||||
// Default Postgres implementations of every store live in the pgstore
|
||||
// subpackage. Argon2id password hashing lives in hasher. Framework-neutral
|
||||
// HTTP middleware (compatible with lightmux and any net/http stack) lives in
|
||||
// middleware.
|
||||
package authkit
|
||||
18
errors.go
Normal file
18
errors.go
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
package authkit
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrEmailTaken = errors.New("authkit: email already registered")
|
||||
ErrUserNotFound = errors.New("authkit: user not found")
|
||||
ErrInvalidCredentials = errors.New("authkit: invalid credentials")
|
||||
ErrEmailNotVerified = errors.New("authkit: email not verified")
|
||||
ErrTokenInvalid = errors.New("authkit: invalid or expired token")
|
||||
ErrTokenReused = errors.New("authkit: token reuse detected")
|
||||
ErrSessionInvalid = errors.New("authkit: invalid or expired session")
|
||||
ErrAPIKeyInvalid = errors.New("authkit: invalid or expired api key")
|
||||
ErrPermissionDenied = errors.New("authkit: permission denied")
|
||||
ErrRoleNotFound = errors.New("authkit: role not found")
|
||||
ErrPermissionNotFound = errors.New("authkit: permission not found")
|
||||
ErrConfigInvalid = errors.New("authkit: invalid configuration")
|
||||
)
|
||||
64
extractor.go
Normal file
64
extractor.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Extractor pulls a credential string out of an HTTP request. It returns
|
||||
// (value, true) when a value was found, otherwise ("", false).
|
||||
type Extractor func(r *http.Request) (string, bool)
|
||||
|
||||
// BearerExtractor reads the value following "Bearer " in the Authorization
|
||||
// header. Comparison is case-insensitive on the scheme.
|
||||
func BearerExtractor() Extractor {
|
||||
return func(r *http.Request) (string, bool) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return "", false
|
||||
}
|
||||
const prefix = "bearer "
|
||||
if len(h) <= len(prefix) || !strings.EqualFold(h[:len(prefix)], prefix) {
|
||||
return "", false
|
||||
}
|
||||
v := strings.TrimSpace(h[len(prefix):])
|
||||
if v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
|
||||
// CookieExtractor reads the named cookie's value.
|
||||
func CookieExtractor(name string) Extractor {
|
||||
return func(r *http.Request) (string, bool) {
|
||||
c, err := r.Cookie(name)
|
||||
if err != nil || c.Value == "" {
|
||||
return "", false
|
||||
}
|
||||
return c.Value, true
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderExtractor reads a custom header verbatim.
|
||||
func HeaderExtractor(name string) Extractor {
|
||||
return func(r *http.Request) (string, bool) {
|
||||
v := strings.TrimSpace(r.Header.Get(name))
|
||||
if v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
|
||||
// ChainExtractors tries each extractor in order, returning the first hit.
|
||||
func ChainExtractors(es ...Extractor) Extractor {
|
||||
return func(r *http.Request) (string, bool) {
|
||||
for _, e := range es {
|
||||
if v, ok := e(r); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
20
go.mod
Normal file
20
go.mod
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
module git.juancwu.dev/juancwu/authkit
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
git.juancwu.dev/juancwu/errx v0.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
golang.org/x/crypto v0.50.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
)
|
||||
36
go.sum
Normal file
36
go.sum
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
||||
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
136
hasher/argon2id.go
Normal file
136
hasher/argon2id.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// Package hasher provides authkit.Hasher implementations. The default
|
||||
// implementation, Argon2id, encodes hashes in the standard PHC string format
|
||||
// (https://github.com/P-H-C/phc-string-format) so callers can introspect
|
||||
// parameters and migrate.
|
||||
package hasher
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Argon2idParams configures the Argon2id KDF.
|
||||
type Argon2idParams struct {
|
||||
Memory uint32 // KB; OWASP 2024 baseline: 19 MiB minimum, 64 MiB recommended
|
||||
Iterations uint32 // time cost
|
||||
Parallelism uint8 // lanes
|
||||
SaltLen uint32 // bytes
|
||||
KeyLen uint32 // bytes
|
||||
}
|
||||
|
||||
// DefaultArgon2idParams returns sensible defaults: 64 MiB memory, 3
|
||||
// iterations, 2 lanes, 16-byte salt, 32-byte key. Tune up Memory/Iterations
|
||||
// over time and rely on Verify's needsRehash signal to migrate stored hashes.
|
||||
func DefaultArgon2idParams() Argon2idParams {
|
||||
return Argon2idParams{
|
||||
Memory: 64 * 1024,
|
||||
Iterations: 3,
|
||||
Parallelism: 2,
|
||||
SaltLen: 16,
|
||||
KeyLen: 32,
|
||||
}
|
||||
}
|
||||
|
||||
type argon2idHasher struct {
|
||||
params Argon2idParams
|
||||
rng io.Reader
|
||||
}
|
||||
|
||||
// NewArgon2id builds an authkit.Hasher backed by Argon2id. If params is the
|
||||
// zero value, DefaultArgon2idParams() is used. rng defaults to crypto/rand.
|
||||
func NewArgon2id(params Argon2idParams, rng io.Reader) authkit.Hasher {
|
||||
if params == (Argon2idParams{}) {
|
||||
params = DefaultArgon2idParams()
|
||||
}
|
||||
if rng == nil {
|
||||
rng = rand.Reader
|
||||
}
|
||||
return &argon2idHasher{params: params, rng: rng}
|
||||
}
|
||||
|
||||
func (h *argon2idHasher) Hash(password string) (string, error) {
|
||||
const op = "authkit.hasher.Argon2id.Hash"
|
||||
if password == "" {
|
||||
return "", errx.New(op, "password is empty")
|
||||
}
|
||||
salt := make([]byte, h.params.SaltLen)
|
||||
if _, err := io.ReadFull(h.rng, salt); err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
key := argon2.IDKey([]byte(password), salt,
|
||||
h.params.Iterations, h.params.Memory, h.params.Parallelism, h.params.KeyLen)
|
||||
return encodePHC(h.params, salt, key), nil
|
||||
}
|
||||
|
||||
func (h *argon2idHasher) Verify(password, encoded string) (bool, bool, error) {
|
||||
const op = "authkit.hasher.Argon2id.Verify"
|
||||
got, salt, key, err := decodePHC(encoded)
|
||||
if err != nil {
|
||||
return false, false, errx.Wrap(op, err)
|
||||
}
|
||||
want := argon2.IDKey([]byte(password), salt,
|
||||
got.Iterations, got.Memory, got.Parallelism, uint32(len(key)))
|
||||
if subtle.ConstantTimeCompare(want, key) != 1 {
|
||||
return false, false, nil
|
||||
}
|
||||
needsRehash := got.Memory != h.params.Memory ||
|
||||
got.Iterations != h.params.Iterations ||
|
||||
got.Parallelism != h.params.Parallelism ||
|
||||
uint32(len(salt)) != h.params.SaltLen ||
|
||||
uint32(len(key)) != h.params.KeyLen
|
||||
return true, needsRehash, nil
|
||||
}
|
||||
|
||||
// PHC string format:
|
||||
//
|
||||
// $argon2id$v=19$m=<mem>,t=<iters>,p=<lanes>$<salt b64>$<key b64>
|
||||
//
|
||||
// base64 here is the unpadded standard alphabet, per the spec.
|
||||
func encodePHC(p Argon2idParams, salt, key []byte) string {
|
||||
enc := base64.RawStdEncoding
|
||||
return fmt.Sprintf(
|
||||
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, p.Memory, p.Iterations, p.Parallelism,
|
||||
enc.EncodeToString(salt), enc.EncodeToString(key),
|
||||
)
|
||||
}
|
||||
|
||||
func decodePHC(s string) (Argon2idParams, []byte, []byte, error) {
|
||||
parts := strings.Split(s, "$")
|
||||
// Expect: ["", "argon2id", "v=19", "m=...,t=...,p=...", "<salt>", "<key>"]
|
||||
if len(parts) != 6 || parts[0] != "" || parts[1] != "argon2id" {
|
||||
return Argon2idParams{}, nil, nil, errors.New("hasher: not an argon2id phc string")
|
||||
}
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad version segment: %w", err)
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: unsupported argon2 version %d", version)
|
||||
}
|
||||
var p Argon2idParams
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism); err != nil {
|
||||
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad params segment: %w", err)
|
||||
}
|
||||
enc := base64.RawStdEncoding
|
||||
salt, err := enc.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad salt: %w", err)
|
||||
}
|
||||
key, err := enc.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return Argon2idParams{}, nil, nil, fmt.Errorf("hasher: bad key: %w", err)
|
||||
}
|
||||
p.SaltLen = uint32(len(salt))
|
||||
p.KeyLen = uint32(len(key))
|
||||
return p, salt, key, nil
|
||||
}
|
||||
78
hasher/argon2id_test.go
Normal file
78
hasher/argon2id_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package hasher
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestArgon2idHashVerifyRoundtrip(t *testing.T) {
|
||||
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||
encoded, err := h.Hash("hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(encoded, "$argon2id$") {
|
||||
t.Fatalf("encoded hash not in PHC form: %s", encoded)
|
||||
}
|
||||
ok, needsRehash, err := h.Verify("hunter2hunter2", encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("Verify rejected the original password")
|
||||
}
|
||||
if needsRehash {
|
||||
t.Fatalf("Verify with default params should not signal rehash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2idVerifyWrongPassword(t *testing.T) {
|
||||
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||
encoded, err := h.Hash("correct horse battery staple")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash: %v", err)
|
||||
}
|
||||
ok, _, err := h.Verify("nope", encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("Verify should reject wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2idNeedsRehashOnParamChange(t *testing.T) {
|
||||
// Hash with light params...
|
||||
light := Argon2idParams{Memory: 8 * 1024, Iterations: 1, Parallelism: 1, SaltLen: 16, KeyLen: 32}
|
||||
encoded, err := NewArgon2id(light, nil).Hash("hello world")
|
||||
if err != nil {
|
||||
t.Fatalf("Hash: %v", err)
|
||||
}
|
||||
// ...verify with stronger params should still match but flag rehash.
|
||||
heavier := DefaultArgon2idParams()
|
||||
ok, needsRehash, err := NewArgon2id(heavier, nil).Verify("hello world", encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("Verify rejected legitimate password across params")
|
||||
}
|
||||
if !needsRehash {
|
||||
t.Fatalf("Verify should flag rehash when stored params differ from current")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2idRejectsMalformed(t *testing.T) {
|
||||
h := NewArgon2id(DefaultArgon2idParams(), nil)
|
||||
cases := []string{
|
||||
"",
|
||||
"not-a-phc",
|
||||
"$argon2i$v=19$m=64,t=1,p=1$abc$def",
|
||||
"$argon2id$v=99$m=64,t=1,p=1$YWJj$ZGVm",
|
||||
}
|
||||
for _, c := range cases {
|
||||
if _, _, err := h.Verify("x", c); err == nil {
|
||||
t.Fatalf("Verify should reject malformed encoding: %q", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
69
jwt.go
Normal file
69
jwt.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// accessClaims is the JWT shape issued by IssueJWT. The session_version
|
||||
// field carries the User.SessionVersion at issue time so AuthenticateJWT
|
||||
// can detect global revocations (logout-everywhere, password change).
|
||||
type accessClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
SessionVersion int `json:"sv"`
|
||||
Method string `json:"m"`
|
||||
}
|
||||
|
||||
func (a *Auth) signAccessToken(userID uuid.UUID, sessionVersion int) (string, error) {
|
||||
const op = "authkit.signAccessToken"
|
||||
now := a.now()
|
||||
claims := accessClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID.String(),
|
||||
Issuer: a.cfg.JWTIssuer,
|
||||
Audience: jwt.ClaimStrings{a.cfg.JWTAudience},
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(a.cfg.AccessTokenTTL)),
|
||||
ID: uuid.NewString(),
|
||||
},
|
||||
SessionVersion: sessionVersion,
|
||||
Method: string(AuthMethodJWT),
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := tok.SignedString(a.cfg.JWTSecret)
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
// parseAccessToken validates the signature and returns the parsed claims.
|
||||
func (a *Auth) parseAccessToken(token string) (*accessClaims, error) {
|
||||
const op = "authkit.parseAccessToken"
|
||||
opts := []jwt.ParserOption{
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithExpirationRequired(),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithTimeFunc(a.cfg.Clock),
|
||||
}
|
||||
if a.cfg.JWTIssuer != "" {
|
||||
opts = append(opts, jwt.WithIssuer(a.cfg.JWTIssuer))
|
||||
}
|
||||
if a.cfg.JWTAudience != "" {
|
||||
opts = append(opts, jwt.WithAudience(a.cfg.JWTAudience))
|
||||
}
|
||||
parser := jwt.NewParser(opts...)
|
||||
|
||||
parsed, err := parser.ParseWithClaims(token, &accessClaims{}, func(t *jwt.Token) (any, error) {
|
||||
return a.cfg.JWTSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
claims, ok := parsed.Claims.(*accessClaims)
|
||||
if !ok || !parsed.Valid {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
99
jwt_test.go
Normal file
99
jwt_test.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestJWTIssueAndAuthenticate(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "alice@example.com", "hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
access, refresh, err := a.IssueJWT(context.Background(), u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
if access == "" || refresh == "" {
|
||||
t.Fatalf("empty tokens")
|
||||
}
|
||||
p, err := a.AuthenticateJWT(context.Background(), access)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthenticateJWT: %v", err)
|
||||
}
|
||||
if p.UserID != u.ID {
|
||||
t.Fatalf("principal user id mismatch")
|
||||
}
|
||||
if p.Method != AuthMethodJWT {
|
||||
t.Fatalf("principal method = %s, want jwt", p.Method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTSessionVersionMismatchRejected(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "bob@example.com", "hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
access, _, err := a.IssueJWT(context.Background(), u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
if err := a.RevokeAllUserSessions(context.Background(), u.ID); err != nil {
|
||||
t.Fatalf("RevokeAllUserSessions: %v", err)
|
||||
}
|
||||
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid after session bump, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTRefreshRotationAndReuseDetection(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "carol@example.com", "hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
_, refresh1, err := a.IssueJWT(context.Background(), u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
_, refresh2, err := a.RefreshJWT(context.Background(), refresh1)
|
||||
if err != nil {
|
||||
t.Fatalf("first RefreshJWT: %v", err)
|
||||
}
|
||||
if refresh1 == refresh2 {
|
||||
t.Fatalf("refresh token did not rotate")
|
||||
}
|
||||
|
||||
// Replaying refresh1 must surface ErrTokenReused and revoke the chain.
|
||||
if _, _, err := a.RefreshJWT(context.Background(), refresh1); !errors.Is(err, ErrTokenReused) {
|
||||
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
|
||||
}
|
||||
// After chain revocation, even refresh2 (the legitimate next one) must
|
||||
// be rejected.
|
||||
if _, _, err := a.RefreshJWT(context.Background(), refresh2); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid on post-revoke refresh, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTExpiredTokenRejected(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
now := time.Now().UTC()
|
||||
a.cfg.Clock = func() time.Time { return now }
|
||||
u, err := a.Register(context.Background(), "dan@example.com", "hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
access, _, err := a.IssueJWT(context.Background(), u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
// Advance clock past TTL.
|
||||
a.cfg.Clock = func() time.Time { return now.Add(10 * time.Minute) }
|
||||
if _, err := a.AuthenticateJWT(context.Background(), access); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid for expired token, got %v", err)
|
||||
}
|
||||
}
|
||||
619
memstore_test.go
Normal file
619
memstore_test.go
Normal file
|
|
@ -0,0 +1,619 @@
|
|||
package authkit
|
||||
|
||||
// In-memory store fakes used by service-level tests. Kept in a _test.go file
|
||||
// so they don't ship in the public API. Each fake is intentionally minimal —
|
||||
// it satisfies the interface and supports the flows the tests exercise; not
|
||||
// every method is wired up (a few panic to flag accidental use).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type memUserStore struct {
|
||||
mu sync.Mutex
|
||||
byID map[uuid.UUID]*User
|
||||
byEml map[string]uuid.UUID
|
||||
}
|
||||
|
||||
func newMemUserStore() *memUserStore {
|
||||
return &memUserStore{byID: map[uuid.UUID]*User{}, byEml: map[string]uuid.UUID{}}
|
||||
}
|
||||
|
||||
func (s *memUserStore) CreateUser(_ context.Context, u *User) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, exists := s.byEml[u.EmailNormalized]; exists {
|
||||
return ErrEmailTaken
|
||||
}
|
||||
if u.ID == uuid.Nil {
|
||||
u.ID = uuid.New()
|
||||
}
|
||||
cp := *u
|
||||
s.byID[u.ID] = &cp
|
||||
s.byEml[u.EmailNormalized] = u.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
cp := *u
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func (s *memUserStore) GetUserByEmail(_ context.Context, ne string) (*User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
id, ok := s.byEml[ne]
|
||||
if !ok {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
u := *s.byID[id]
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (s *memUserStore) UpdateUser(_ context.Context, u *User) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.byID[u.ID]; !ok {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
cp := *u
|
||||
s.byID[u.ID] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
delete(s.byID, id)
|
||||
delete(s.byEml, u.EmailNormalized)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, h string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
u.PasswordHash = h
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
u.EmailVerifiedAt = &at
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return 0, ErrUserNotFound
|
||||
}
|
||||
u.SessionVersion++
|
||||
return u.SessionVersion, nil
|
||||
}
|
||||
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return 0, ErrUserNotFound
|
||||
}
|
||||
u.FailedLogins++
|
||||
return u.FailedLogins, nil
|
||||
}
|
||||
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.byID[id]
|
||||
if !ok {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
u.FailedLogins = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
type memSessionStore struct {
|
||||
mu sync.Mutex
|
||||
m map[string]*Session
|
||||
}
|
||||
|
||||
func newMemSessionStore() *memSessionStore { return &memSessionStore{m: map[string]*Session{}} }
|
||||
func (s *memSessionStore) CreateSession(_ context.Context, ses *Session) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *ses
|
||||
s.m[string(ses.IDHash)] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return nil, ErrSessionInvalid
|
||||
}
|
||||
cp := *v
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, exp time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return ErrSessionInvalid
|
||||
}
|
||||
v.LastSeenAt = lastSeen
|
||||
v.ExpiresAt = exp
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.m, string(h))
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) DeleteUserSessions(_ context.Context, uid uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, v := range s.m {
|
||||
if v.UserID == uid {
|
||||
delete(s.m, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var n int64
|
||||
for k, v := range s.m {
|
||||
if !v.ExpiresAt.After(now) {
|
||||
delete(s.m, k)
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type memTokenStore struct {
|
||||
mu sync.Mutex
|
||||
// Keyed by kind+hex(hash) so identical hashes across kinds don't collide.
|
||||
m map[string]*Token
|
||||
}
|
||||
|
||||
func newMemTokenStore() *memTokenStore { return &memTokenStore{m: map[string]*Token{}} }
|
||||
func tokKey(kind TokenKind, h []byte) string {
|
||||
return string(kind) + ":" + string(h)
|
||||
}
|
||||
func (s *memTokenStore) CreateToken(_ context.Context, t *Token) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *t
|
||||
s.m[tokKey(t.Kind, t.Hash)] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memTokenStore) ConsumeToken(_ context.Context, kind TokenKind, h []byte, now time.Time) (*Token, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
t, ok := s.m[tokKey(kind, h)]
|
||||
if !ok || t.ConsumedAt != nil || !t.ExpiresAt.After(now) {
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
t.ConsumedAt = &now
|
||||
cp := *t
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memTokenStore) GetToken(_ context.Context, kind TokenKind, h []byte) (*Token, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
t, ok := s.m[tokKey(kind, h)]
|
||||
if !ok {
|
||||
return nil, ErrTokenInvalid
|
||||
}
|
||||
cp := *t
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memTokenStore) DeleteByChain(_ context.Context, chainID string) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var n int64
|
||||
for k, t := range s.m {
|
||||
if t.ChainID != nil && *t.ChainID == chainID {
|
||||
delete(s.m, k)
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
func (s *memTokenStore) DeleteExpired(_ context.Context, now time.Time) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var n int64
|
||||
for k, t := range s.m {
|
||||
if !t.ExpiresAt.After(now) {
|
||||
delete(s.m, k)
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type memAPIKeyStore struct {
|
||||
mu sync.Mutex
|
||||
m map[string]*APIKey
|
||||
}
|
||||
|
||||
func newMemAPIKeyStore() *memAPIKeyStore { return &memAPIKeyStore{m: map[string]*APIKey{}} }
|
||||
func (s *memAPIKeyStore) CreateAPIKey(_ context.Context, k *APIKey) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *k
|
||||
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||
s.m[string(k.IDHash)] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memAPIKeyStore) GetAPIKey(_ context.Context, h []byte) (*APIKey, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
k, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return nil, ErrAPIKeyInvalid
|
||||
}
|
||||
cp := *k
|
||||
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memAPIKeyStore) ListAPIKeysByOwner(_ context.Context, owner uuid.UUID) ([]*APIKey, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var out []*APIKey
|
||||
for _, k := range s.m {
|
||||
if k.OwnerID == owner {
|
||||
cp := *k
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memAPIKeyStore) TouchAPIKey(_ context.Context, h []byte, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if k, ok := s.m[string(h)]; ok {
|
||||
k.LastUsedAt = &at
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memAPIKeyStore) RevokeAPIKey(_ context.Context, h []byte, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
k, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return ErrAPIKeyInvalid
|
||||
}
|
||||
if k.RevokedAt != nil {
|
||||
return ErrAPIKeyInvalid
|
||||
}
|
||||
k.RevokedAt = &at
|
||||
return nil
|
||||
}
|
||||
func (s *memAPIKeyStore) RevokeAPIKeysByOwner(_ context.Context, owner uuid.UUID, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, k := range s.m {
|
||||
if k.OwnerID == owner && k.RevokedAt == nil {
|
||||
k.RevokedAt = &at
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type memRoleStore struct {
|
||||
mu sync.Mutex
|
||||
roles map[uuid.UUID]*Role
|
||||
rolesByNm map[string]uuid.UUID
|
||||
userRoles map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
}
|
||||
|
||||
func newMemRoleStore() *memRoleStore {
|
||||
return &memRoleStore{
|
||||
roles: map[uuid.UUID]*Role{},
|
||||
rolesByNm: map[string]uuid.UUID{},
|
||||
userRoles: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
}
|
||||
}
|
||||
func (s *memRoleStore) CreateRole(_ context.Context, r *Role) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, dup := s.rolesByNm[r.Name]; dup {
|
||||
return errors.New("role exists")
|
||||
}
|
||||
if r.ID == uuid.Nil {
|
||||
r.ID = uuid.New()
|
||||
}
|
||||
cp := *r
|
||||
s.roles[r.ID] = &cp
|
||||
s.rolesByNm[r.Name] = r.ID
|
||||
return nil
|
||||
}
|
||||
func (s *memRoleStore) GetRoleByID(_ context.Context, id uuid.UUID) (*Role, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
r, ok := s.roles[id]
|
||||
if !ok {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
cp := *r
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memRoleStore) GetRoleByName(_ context.Context, n string) (*Role, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
id, ok := s.rolesByNm[n]
|
||||
if !ok {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
r := *s.roles[id]
|
||||
return &r, nil
|
||||
}
|
||||
func (s *memRoleStore) ListRoles(_ context.Context) ([]*Role, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*Role, 0, len(s.roles))
|
||||
for _, r := range s.roles {
|
||||
cp := *r
|
||||
out = append(out, &cp)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memRoleStore) DeleteRole(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
r, ok := s.roles[id]
|
||||
if !ok {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
delete(s.roles, id)
|
||||
delete(s.rolesByNm, r.Name)
|
||||
for _, m := range s.userRoles {
|
||||
delete(m, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memRoleStore) AssignRoleToUser(_ context.Context, uid, rid uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
m, ok := s.userRoles[uid]
|
||||
if !ok {
|
||||
m = map[uuid.UUID]struct{}{}
|
||||
s.userRoles[uid] = m
|
||||
}
|
||||
m[rid] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
func (s *memRoleStore) RemoveRoleFromUser(_ context.Context, uid, rid uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if m, ok := s.userRoles[uid]; ok {
|
||||
delete(m, rid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memRoleStore) GetUserRoles(_ context.Context, uid uuid.UUID) ([]*Role, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := []*Role{}
|
||||
for rid := range s.userRoles[uid] {
|
||||
if r, ok := s.roles[rid]; ok {
|
||||
cp := *r
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memRoleStore) HasAnyRole(_ context.Context, uid uuid.UUID, names []string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for rid := range s.userRoles[uid] {
|
||||
r := s.roles[rid]
|
||||
if slices.Contains(names, r.Name) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type memPermStore struct {
|
||||
mu sync.Mutex
|
||||
perms map[uuid.UUID]*Permission
|
||||
permsByNm map[string]uuid.UUID
|
||||
rolePerms map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
roles *memRoleStore
|
||||
}
|
||||
|
||||
func newMemPermStore(rs *memRoleStore) *memPermStore {
|
||||
return &memPermStore{
|
||||
perms: map[uuid.UUID]*Permission{},
|
||||
permsByNm: map[string]uuid.UUID{},
|
||||
rolePerms: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
roles: rs,
|
||||
}
|
||||
}
|
||||
func (s *memPermStore) CreatePermission(_ context.Context, p *Permission) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, dup := s.permsByNm[p.Name]; dup {
|
||||
return errors.New("perm exists")
|
||||
}
|
||||
if p.ID == uuid.Nil {
|
||||
p.ID = uuid.New()
|
||||
}
|
||||
cp := *p
|
||||
s.perms[p.ID] = &cp
|
||||
s.permsByNm[p.Name] = p.ID
|
||||
return nil
|
||||
}
|
||||
func (s *memPermStore) GetPermissionByID(_ context.Context, id uuid.UUID) (*Permission, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
p, ok := s.perms[id]
|
||||
if !ok {
|
||||
return nil, ErrPermissionNotFound
|
||||
}
|
||||
cp := *p
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memPermStore) GetPermissionByName(_ context.Context, n string) (*Permission, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
id, ok := s.permsByNm[n]
|
||||
if !ok {
|
||||
return nil, ErrPermissionNotFound
|
||||
}
|
||||
p := *s.perms[id]
|
||||
return &p, nil
|
||||
}
|
||||
func (s *memPermStore) ListPermissions(_ context.Context) ([]*Permission, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*Permission, 0, len(s.perms))
|
||||
for _, p := range s.perms {
|
||||
cp := *p
|
||||
out = append(out, &cp)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memPermStore) DeletePermission(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
p, ok := s.perms[id]
|
||||
if !ok {
|
||||
return ErrPermissionNotFound
|
||||
}
|
||||
delete(s.perms, id)
|
||||
delete(s.permsByNm, p.Name)
|
||||
for _, m := range s.rolePerms {
|
||||
delete(m, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memPermStore) AssignPermissionToRole(_ context.Context, rid, pid uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
m, ok := s.rolePerms[rid]
|
||||
if !ok {
|
||||
m = map[uuid.UUID]struct{}{}
|
||||
s.rolePerms[rid] = m
|
||||
}
|
||||
m[pid] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
func (s *memPermStore) RemovePermissionFromRole(_ context.Context, rid, pid uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if m, ok := s.rolePerms[rid]; ok {
|
||||
delete(m, pid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memPermStore) GetRolePermissions(_ context.Context, rid uuid.UUID) ([]*Permission, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := []*Permission{}
|
||||
for pid := range s.rolePerms[rid] {
|
||||
if p, ok := s.perms[pid]; ok {
|
||||
cp := *p
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memPermStore) GetUserPermissions(_ context.Context, uid uuid.UUID) ([]*Permission, error) {
|
||||
s.roles.mu.Lock()
|
||||
roleIDs := make([]uuid.UUID, 0)
|
||||
for rid := range s.roles.userRoles[uid] {
|
||||
roleIDs = append(roleIDs, rid)
|
||||
}
|
||||
s.roles.mu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
seen := map[uuid.UUID]struct{}{}
|
||||
out := []*Permission{}
|
||||
for _, rid := range roleIDs {
|
||||
for pid := range s.rolePerms[rid] {
|
||||
if _, dup := seen[pid]; dup {
|
||||
continue
|
||||
}
|
||||
seen[pid] = struct{}{}
|
||||
if p, ok := s.perms[pid]; ok {
|
||||
cp := *p
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Stub Hasher: stores plaintext for trivial verify. Tests of hashing live in
|
||||
// the hasher package itself; we just need *something* callable here.
|
||||
type stubHasher struct{}
|
||||
|
||||
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
|
||||
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
|
||||
want := "stub:" + p
|
||||
if !bytes.Equal([]byte(want), []byte(encoded)) {
|
||||
return false, false, nil
|
||||
}
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// newTestAuth wires the fakes into Auth with deterministic config.
|
||||
func newTestAuth(t interface{ Helper() }) *Auth {
|
||||
if h, ok := t.(interface{ Helper() }); ok {
|
||||
h.Helper()
|
||||
}
|
||||
roles := newMemRoleStore()
|
||||
return New(Deps{
|
||||
Users: newMemUserStore(),
|
||||
Sessions: newMemSessionStore(),
|
||||
Tokens: newMemTokenStore(),
|
||||
APIKeys: newMemAPIKeyStore(),
|
||||
Roles: roles,
|
||||
Permissions: newMemPermStore(roles),
|
||||
Hasher: stubHasher{},
|
||||
}, Config{
|
||||
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
|
||||
JWTIssuer: "authkit-test",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: 1 * time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
PasswordResetTTL: time.Hour,
|
||||
MagicLinkTTL: time.Minute,
|
||||
})
|
||||
}
|
||||
71
middleware/authz.go
Normal file
71
middleware/authz.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// authzGuard wraps the common pattern of "look up the Principal, run a
|
||||
// predicate, succeed or 403". onForbidden defaults to JSON 403.
|
||||
func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler {
|
||||
if onForbidden == nil {
|
||||
onForbidden = defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p, ok := PrincipalFrom(r.Context())
|
||||
if !ok {
|
||||
// No auth middleware ran upstream; treat as forbidden
|
||||
// rather than crashing — composition is the caller's
|
||||
// responsibility but a 403 is the safer default.
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
if !pred(p) {
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole permits requests whose Principal holds the named role.
|
||||
func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasRole(name)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAnyRole permits requests whose Principal holds at least one of the
|
||||
// named roles.
|
||||
func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasAnyRole(names...)
|
||||
})
|
||||
}
|
||||
|
||||
// RequirePermission permits requests whose Principal holds the named
|
||||
// permission (resolved via roles at auth time).
|
||||
func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasPermission(name)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAbility permits requests whose Principal carries the named ability.
|
||||
// Abilities are populated only for API-key authentication; this middleware
|
||||
// will reject session/JWT-authenticated requests by design.
|
||||
func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasAbility(name)
|
||||
})
|
||||
}
|
||||
|
||||
func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s[0]
|
||||
}
|
||||
39
middleware/context.go
Normal file
39
middleware/context.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Package middleware provides framework-neutral HTTP middleware for authkit.
|
||||
// Every middleware function returns the standard func(http.Handler)
|
||||
// http.Handler type, so it composes with lightmux's Use/Group/Handle as well
|
||||
// as any net/http stack that uses the same signature.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// principalKey is an unexported context key. Using a distinct empty struct
|
||||
// type guarantees no collision with caller-defined keys.
|
||||
type principalKey struct{}
|
||||
|
||||
// withPrincipal stashes p on the request context for downstream handlers.
|
||||
func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context {
|
||||
return context.WithValue(ctx, principalKey{}, p)
|
||||
}
|
||||
|
||||
// PrincipalFrom retrieves the authenticated Principal placed by RequireSession,
|
||||
// RequireJWT, or RequireAPIKey. The boolean is false if no auth middleware
|
||||
// ran for this request.
|
||||
func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) {
|
||||
p, ok := ctx.Value(principalKey{}).(*authkit.Principal)
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// MustPrincipal panics if no Principal is on the context. Use only on
|
||||
// handlers known to be behind a Require* middleware.
|
||||
func MustPrincipal(r *http.Request) *authkit.Principal {
|
||||
p, ok := PrincipalFrom(r.Context())
|
||||
if !ok {
|
||||
panic("authkit/middleware: no principal on request context")
|
||||
}
|
||||
return p
|
||||
}
|
||||
138
middleware/middleware.go
Normal file
138
middleware/middleware.go
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// Options configures auth middleware. Auth is required; the rest fall back
|
||||
// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403
|
||||
// on authz failure.
|
||||
type Options struct {
|
||||
Auth *authkit.Auth
|
||||
Extractor authkit.Extractor
|
||||
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
|
||||
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
func (o Options) extractor() authkit.Extractor {
|
||||
if o.Extractor != nil {
|
||||
return o.Extractor
|
||||
}
|
||||
return authkit.BearerExtractor()
|
||||
}
|
||||
|
||||
func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if o.OnUnauth != nil {
|
||||
return o.OnUnauth
|
||||
}
|
||||
return defaultJSONError(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if o.OnForbidden != nil {
|
||||
return o.OnForbidden
|
||||
}
|
||||
return defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
|
||||
func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
return func(w http.ResponseWriter, _ *http.Request, err error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": http.StatusText(status),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSession authenticates the request via an opaque session string. The
|
||||
// extractor is consulted first; if no extractor is set the default Bearer
|
||||
// extractor is used. For cookie-based session lookup, set
|
||||
// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName).
|
||||
func RequireSession(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||
return opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireJWT authenticates the request via an HS256 JWT.
|
||||
func RequireJWT(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||
return opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAPIKey authenticates the request via an opaque API secret.
|
||||
func RequireAPIKey(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||
return opts.Auth.AuthenticateAPIKey(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAny tries each method in order until one succeeds. Useful for routes
|
||||
// that accept either a session cookie or an API key.
|
||||
func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
|
||||
if len(methods) == 0 {
|
||||
methods = []authkit.AuthMethod{
|
||||
authkit.AuthMethodSession,
|
||||
authkit.AuthMethodJWT,
|
||||
authkit.AuthMethodAPIKey,
|
||||
}
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := opts.extractor()(r)
|
||||
if !ok || raw == "" {
|
||||
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
|
||||
return
|
||||
}
|
||||
var (
|
||||
p *authkit.Principal
|
||||
lastErr error
|
||||
)
|
||||
for _, m := range methods {
|
||||
switch m {
|
||||
case authkit.AuthMethodSession:
|
||||
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||
case authkit.AuthMethodJWT:
|
||||
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||
case authkit.AuthMethodAPIKey:
|
||||
p, lastErr = opts.Auth.AuthenticateAPIKey(r.Context(), raw)
|
||||
}
|
||||
if lastErr == nil && p != nil {
|
||||
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||
return
|
||||
}
|
||||
}
|
||||
opts.onUnauth()(w, r, lastErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireWith is the shared scaffolding for the single-method Require*
|
||||
// middlewares.
|
||||
func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: Options.Auth is required")
|
||||
}
|
||||
extractor := opts.extractor()
|
||||
onUnauth := opts.onUnauth()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := extractor(r)
|
||||
if !ok || raw == "" {
|
||||
onUnauth(w, r, authkit.ErrSessionInvalid)
|
||||
return
|
||||
}
|
||||
p, err := authn(r, raw)
|
||||
if err != nil {
|
||||
onUnauth(w, r, err)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||
})
|
||||
}
|
||||
}
|
||||
75
models.go
Normal file
75
models.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID
|
||||
Email string
|
||||
EmailNormalized string
|
||||
EmailVerifiedAt *time.Time
|
||||
PasswordHash string
|
||||
SessionVersion int
|
||||
FailedLogins int
|
||||
LastLoginAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
IDHash []byte
|
||||
UserID uuid.UUID
|
||||
UserAgent string
|
||||
IP netip.Addr
|
||||
CreatedAt time.Time
|
||||
LastSeenAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type TokenKind string
|
||||
|
||||
const (
|
||||
TokenEmailVerify TokenKind = "email_verify"
|
||||
TokenPasswordReset TokenKind = "password_reset"
|
||||
TokenMagicLink TokenKind = "magic_link"
|
||||
TokenRefresh TokenKind = "refresh"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Hash []byte
|
||||
Kind TokenKind
|
||||
UserID uuid.UUID
|
||||
ChainID *string
|
||||
ConsumedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
IDHash []byte
|
||||
OwnerID uuid.UUID
|
||||
Name string
|
||||
Abilities []string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Description string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Permission struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Description string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
63
principal.go
Normal file
63
principal.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
const (
|
||||
AuthMethodSession AuthMethod = "session"
|
||||
AuthMethodJWT AuthMethod = "jwt"
|
||||
AuthMethodAPIKey AuthMethod = "api_key"
|
||||
)
|
||||
|
||||
type Principal struct {
|
||||
UserID uuid.UUID
|
||||
Method AuthMethod
|
||||
SessionID []byte
|
||||
APIKeyID []byte
|
||||
Roles []string
|
||||
Permissions []string
|
||||
Abilities []string
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func (p *Principal) HasRole(name string) bool {
|
||||
for _, r := range p.Roles {
|
||||
if r == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Principal) HasAnyRole(names ...string) bool {
|
||||
for _, n := range names {
|
||||
if p.HasRole(n) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Principal) HasPermission(name string) bool {
|
||||
for _, perm := range p.Permissions {
|
||||
if perm == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Principal) HasAbility(name string) bool {
|
||||
for _, a := range p.Abilities {
|
||||
if a == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
92
service_apikey.go
Normal file
92
service_apikey.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// IssueAPIKey mints a fresh API secret with the given abilities and an
|
||||
// optional TTL. The plaintext is returned to the caller (show-once) and the
|
||||
// SHA-256 lookup hash is stored. Pass ttl=nil for a non-expiring key.
|
||||
func (a *Auth) IssueAPIKey(ctx context.Context, ownerID uuid.UUID, name string, abilities []string, ttl *time.Duration) (string, *APIKey, error) {
|
||||
const op = "authkit.Auth.IssueAPIKey"
|
||||
plaintext, hash, err := mintSecret(prefixAPIKey, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", nil, errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
k := &APIKey{
|
||||
IDHash: hash,
|
||||
OwnerID: ownerID,
|
||||
Name: name,
|
||||
Abilities: append([]string(nil), abilities...),
|
||||
CreatedAt: now,
|
||||
}
|
||||
if ttl != nil {
|
||||
exp := now.Add(*ttl)
|
||||
k.ExpiresAt = &exp
|
||||
}
|
||||
if err := a.deps.APIKeys.CreateAPIKey(ctx, k); err != nil {
|
||||
return "", nil, errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, k, nil
|
||||
}
|
||||
|
||||
// AuthenticateAPIKey validates an API secret string, touches last_used_at
|
||||
// (best-effort), and returns a Principal carrying the key's abilities. The
|
||||
// owning user's roles+permissions are also resolved so the same Principal
|
||||
// can satisfy RequireRole / RequirePermission middleware.
|
||||
func (a *Auth) AuthenticateAPIKey(ctx context.Context, plaintext string) (*Principal, error) {
|
||||
const op = "authkit.Auth.AuthenticateAPIKey"
|
||||
hash, ok := parseSecret(prefixAPIKey, plaintext)
|
||||
if !ok {
|
||||
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||
}
|
||||
k, err := a.deps.APIKeys.GetAPIKey(ctx, hash)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
if k.RevokedAt != nil {
|
||||
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||
}
|
||||
if k.ExpiresAt != nil && !k.ExpiresAt.After(now) {
|
||||
return nil, errx.Wrap(op, ErrAPIKeyInvalid)
|
||||
}
|
||||
_ = a.deps.APIKeys.TouchAPIKey(ctx, hash, now)
|
||||
|
||||
roles, perms, err := a.resolveRolesAndPermissions(ctx, k.OwnerID)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
expires := now
|
||||
if k.ExpiresAt != nil {
|
||||
expires = *k.ExpiresAt
|
||||
}
|
||||
return &Principal{
|
||||
UserID: k.OwnerID,
|
||||
Method: AuthMethodAPIKey,
|
||||
APIKeyID: hash,
|
||||
Roles: roles,
|
||||
Permissions: perms,
|
||||
Abilities: append([]string(nil), k.Abilities...),
|
||||
IssuedAt: k.CreatedAt,
|
||||
ExpiresAt: expires,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeAPIKey marks a key revoked. Idempotent on already-revoked keys.
|
||||
func (a *Auth) RevokeAPIKey(ctx context.Context, plaintext string) error {
|
||||
const op = "authkit.Auth.RevokeAPIKey"
|
||||
hash, ok := parseSecret(prefixAPIKey, plaintext)
|
||||
if !ok {
|
||||
return errx.Wrap(op, ErrAPIKeyInvalid)
|
||||
}
|
||||
if err := a.deps.APIKeys.RevokeAPIKey(ctx, hash, a.now()); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
85
service_authz.go
Normal file
85
service_authz.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserPermissions returns the union of permission names a user holds via
|
||||
// their assigned roles. Resolved at call time; v1 does not cache.
|
||||
func (a *Auth) UserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
|
||||
const op = "authkit.Auth.UserPermissions"
|
||||
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out := make([]string, len(perms))
|
||||
for i, p := range perms {
|
||||
out[i] = p.Name
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// HasPermission checks whether a user holds the named permission via any
|
||||
// assigned role.
|
||||
func (a *Auth) HasPermission(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
|
||||
const op = "authkit.Auth.HasPermission"
|
||||
perms, err := a.UserPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return false, errx.Wrap(op, err)
|
||||
}
|
||||
for _, p := range perms {
|
||||
if p == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// HasRole checks whether a user is assigned the named role.
|
||||
func (a *Auth) HasRole(ctx context.Context, userID uuid.UUID, name string) (bool, error) {
|
||||
const op = "authkit.Auth.HasRole"
|
||||
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, []string{name})
|
||||
if err != nil {
|
||||
return false, errx.Wrap(op, err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// HasAnyRole checks whether a user holds at least one of the named roles.
|
||||
func (a *Auth) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
|
||||
const op = "authkit.Auth.HasAnyRole"
|
||||
ok, err := a.deps.Roles.HasAnyRole(ctx, userID, names)
|
||||
if err != nil {
|
||||
return false, errx.Wrap(op, err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// AssignRole is a convenience that looks up a role by name and assigns it.
|
||||
func (a *Auth) AssignRole(ctx context.Context, userID uuid.UUID, roleName string) error {
|
||||
const op = "authkit.Auth.AssignRole"
|
||||
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Roles.AssignRoleToUser(ctx, userID, r.ID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRole is the symmetric helper for AssignRole.
|
||||
func (a *Auth) RemoveRole(ctx context.Context, userID uuid.UUID, roleName string) error {
|
||||
const op = "authkit.Auth.RemoveRole"
|
||||
r, err := a.deps.Roles.GetRoleByName(ctx, roleName)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Roles.RemoveRoleFromUser(ctx, userID, r.ID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
147
service_jwt.go
Normal file
147
service_jwt.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// IssueJWT issues a fresh access JWT and a rotating opaque refresh token.
|
||||
// The refresh token is bound to a chain via Token.ChainID; rotation
|
||||
// preserves that chain so reuse-detection can revoke the whole family.
|
||||
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
|
||||
const op = "authkit.Auth.IssueJWT"
|
||||
u, err := a.deps.Users.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
access, err = a.signAccessToken(u.ID, u.SessionVersion)
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString())
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
// AuthenticateJWT validates the access JWT, cross-checks the user's
|
||||
// session_version (instant revocation), and resolves a Principal.
|
||||
func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, error) {
|
||||
const op = "authkit.Auth.AuthenticateJWT"
|
||||
claims, err := a.parseAccessToken(access)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
uid, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
u, err := a.deps.Users.GetUserByID(ctx, uid)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
if u.SessionVersion != claims.SessionVersion {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
roles, perms, err := a.resolveRolesAndPermissions(ctx, u.ID)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return &Principal{
|
||||
UserID: u.ID,
|
||||
Method: AuthMethodJWT,
|
||||
Roles: roles,
|
||||
Permissions: perms,
|
||||
IssuedAt: claims.IssuedAt.Time,
|
||||
ExpiresAt: claims.ExpiresAt.Time,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshJWT consumes the presented refresh token and mints a new access +
|
||||
// refresh pair. Reuse of an already-consumed refresh token deletes the
|
||||
// entire chain (logout-everywhere on that device family) and returns
|
||||
// ErrTokenReused.
|
||||
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
|
||||
const op = "authkit.Auth.RefreshJWT"
|
||||
hash, ok := parseSecret(prefixRefresh, plaintextRefresh)
|
||||
if !ok {
|
||||
return "", "", errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
now := a.now()
|
||||
|
||||
consumed, err := a.deps.Tokens.ConsumeToken(ctx, TokenRefresh, hash, now)
|
||||
if err != nil {
|
||||
// Differentiate plain-invalid (never existed / expired) from
|
||||
// reuse (existed, already consumed). The presence-check below is
|
||||
// the reuse signal.
|
||||
if errors.Is(err, ErrTokenInvalid) {
|
||||
if existing, gerr := a.deps.Tokens.GetToken(ctx, TokenRefresh, hash); gerr == nil && existing.ConsumedAt != nil {
|
||||
if existing.ChainID != nil && *existing.ChainID != "" {
|
||||
_, _ = a.deps.Tokens.DeleteByChain(ctx, *existing.ChainID)
|
||||
}
|
||||
return "", "", errx.Wrap(op, ErrTokenReused)
|
||||
}
|
||||
}
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
var chainID string
|
||||
if consumed.ChainID != nil {
|
||||
chainID = *consumed.ChainID
|
||||
}
|
||||
if chainID == "" {
|
||||
// Defensive: every refresh token should be chain-bound. Fall back
|
||||
// to a fresh chain so we never throw on missing metadata.
|
||||
chainID = uuid.NewString()
|
||||
}
|
||||
|
||||
access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID))
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID)
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
// mintRefreshToken stores a fresh refresh token bound to chainID and returns
|
||||
// the plaintext.
|
||||
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
|
||||
const op = "authkit.Auth.mintRefreshToken"
|
||||
plaintext, hash, err := mintSecret(prefixRefresh, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
t := &Token{
|
||||
Hash: hash,
|
||||
Kind: TokenRefresh,
|
||||
UserID: userID,
|
||||
ChainID: &chainID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
|
||||
}
|
||||
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// userSessionVersion fetches the current session_version. Errors collapse to
|
||||
// 0 on the assumption that AuthenticateJWT will reject stale tokens cleanly
|
||||
// — but we still need a value to embed in the freshly-minted access token.
|
||||
func (a *Auth) userSessionVersion(ctx context.Context, userID uuid.UUID) int {
|
||||
if u, err := a.deps.Users.GetUserByID(ctx, userID); err == nil {
|
||||
return u.SessionVersion
|
||||
}
|
||||
return 0
|
||||
}
|
||||
62
service_magic.go
Normal file
62
service_magic.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// RequestMagicLink mints a single-use magic-link token for the email and
|
||||
// returns the plaintext for delivery. ErrUserNotFound is returned for
|
||||
// unregistered emails.
|
||||
func (a *Auth) RequestMagicLink(ctx context.Context, email string) (string, error) {
|
||||
const op = "authkit.Auth.RequestMagicLink"
|
||||
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
plaintext, hash, err := mintSecret(prefixMagicLink, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
t := &Token{
|
||||
Hash: hash,
|
||||
Kind: TokenMagicLink,
|
||||
UserID: u.ID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(a.cfg.MagicLinkTTL),
|
||||
}
|
||||
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ConsumeMagicLink consumes the magic-link token and returns the
|
||||
// authenticated user. Callers typically follow this with IssueSession or
|
||||
// IssueJWT to actually log the user in.
|
||||
func (a *Auth) ConsumeMagicLink(ctx context.Context, plaintextToken string) (*User, error) {
|
||||
const op = "authkit.Auth.ConsumeMagicLink"
|
||||
hash, ok := parseSecret(prefixMagicLink, plaintextToken)
|
||||
if !ok {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
now := a.now()
|
||||
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenMagicLink, hash, now)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
u, err := a.deps.Users.GetUserByID(ctx, t.UserID)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
// A successful magic-link login also implicitly verifies the email
|
||||
// (the user demonstrably controls the inbox).
|
||||
if u.EmailVerifiedAt == nil {
|
||||
if err := a.deps.Users.SetEmailVerified(ctx, u.ID, now); err == nil {
|
||||
u.EmailVerifiedAt = &now
|
||||
}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
69
service_reset.go
Normal file
69
service_reset.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// RequestPasswordReset mints a single-use password-reset token for the user
|
||||
// behind email and returns the plaintext for the caller to deliver via email.
|
||||
// Returns ErrUserNotFound when the email isn't registered (per project
|
||||
// policy of distinct errors over anti-enumeration).
|
||||
func (a *Auth) RequestPasswordReset(ctx context.Context, email string) (string, error) {
|
||||
const op = "authkit.Auth.RequestPasswordReset"
|
||||
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
plaintext, hash, err := mintSecret(prefixPasswordRset, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
t := &Token{
|
||||
Hash: hash,
|
||||
Kind: TokenPasswordReset,
|
||||
UserID: u.ID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(a.cfg.PasswordResetTTL),
|
||||
}
|
||||
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ConfirmPasswordReset consumes the reset token, sets the new password,
|
||||
// bumps the user's session_version, and revokes outstanding sessions so the
|
||||
// reset constitutes a global logout.
|
||||
func (a *Auth) ConfirmPasswordReset(ctx context.Context, plaintextToken, newPassword string) error {
|
||||
const op = "authkit.Auth.ConfirmPasswordReset"
|
||||
hash, ok := parseSecret(prefixPasswordRset, plaintextToken)
|
||||
if !ok {
|
||||
return errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
now := a.now()
|
||||
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenPasswordReset, hash, now)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
newHash, err := a.deps.Hasher.Hash(newPassword)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Users.SetPassword(ctx, t.UserID, newHash); err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return errx.Wrap(op, ErrUserNotFound)
|
||||
}
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if _, err := a.deps.Users.BumpSessionVersion(ctx, t.UserID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Sessions.DeleteUserSessions(ctx, t.UserID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
154
service_session.go
Normal file
154
service_session.go
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// IssueSession mints an opaque session ID, persists the session record, and
|
||||
// returns the plaintext (for the cookie) plus the stored Session.
|
||||
func (a *Auth) IssueSession(ctx context.Context, userID uuid.UUID, userAgent string, ip netip.Addr) (string, *Session, error) {
|
||||
const op = "authkit.Auth.IssueSession"
|
||||
plaintext, hash, err := mintSecret(prefixSession, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", nil, errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
expires := now.Add(a.cfg.SessionIdleTTL)
|
||||
if cap := now.Add(a.cfg.SessionAbsoluteTTL); expires.After(cap) {
|
||||
expires = cap
|
||||
}
|
||||
s := &Session{
|
||||
IDHash: hash,
|
||||
UserID: userID,
|
||||
UserAgent: userAgent,
|
||||
IP: ip,
|
||||
CreatedAt: now,
|
||||
LastSeenAt: now,
|
||||
ExpiresAt: expires,
|
||||
}
|
||||
if err := a.deps.Sessions.CreateSession(ctx, s); err != nil {
|
||||
return "", nil, errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, s, nil
|
||||
}
|
||||
|
||||
// AuthenticateSession validates an opaque session string, slides the TTL,
|
||||
// resolves the user's roles+permissions, and returns a Principal. Expired or
|
||||
// unknown sessions return ErrSessionInvalid.
|
||||
func (a *Auth) AuthenticateSession(ctx context.Context, plaintext string) (*Principal, error) {
|
||||
const op = "authkit.Auth.AuthenticateSession"
|
||||
hash, ok := parseSecret(prefixSession, plaintext)
|
||||
if !ok {
|
||||
return nil, errx.Wrap(op, ErrSessionInvalid)
|
||||
}
|
||||
s, err := a.deps.Sessions.GetSession(ctx, hash)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
if !s.ExpiresAt.After(now) {
|
||||
_ = a.deps.Sessions.DeleteSession(ctx, hash)
|
||||
return nil, errx.Wrap(op, ErrSessionInvalid)
|
||||
}
|
||||
|
||||
// Slide the idle TTL, capped at created_at + AbsoluteTTL so an active
|
||||
// session still expires at the absolute boundary.
|
||||
newExpires := now.Add(a.cfg.SessionIdleTTL)
|
||||
if cap := s.CreatedAt.Add(a.cfg.SessionAbsoluteTTL); newExpires.After(cap) {
|
||||
newExpires = cap
|
||||
}
|
||||
if err := a.deps.Sessions.TouchSession(ctx, hash, now, newExpires); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
roles, perms, err := a.resolveRolesAndPermissions(ctx, s.UserID)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return &Principal{
|
||||
UserID: s.UserID,
|
||||
Method: AuthMethodSession,
|
||||
SessionID: hash,
|
||||
Roles: roles,
|
||||
Permissions: perms,
|
||||
IssuedAt: s.CreatedAt,
|
||||
ExpiresAt: newExpires,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeSession deletes a single session by its plaintext id. Idempotent:
|
||||
// missing sessions are not an error (logout twice should not 500).
|
||||
func (a *Auth) RevokeSession(ctx context.Context, plaintext string) error {
|
||||
const op = "authkit.Auth.RevokeSession"
|
||||
hash, ok := parseSecret(prefixSession, plaintext)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := a.deps.Sessions.DeleteSession(ctx, hash); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllUserSessions kills every active session for the user and bumps
|
||||
// the user's session_version (invalidating outstanding JWT access tokens).
|
||||
func (a *Auth) RevokeAllUserSessions(ctx context.Context, userID uuid.UUID) error {
|
||||
const op = "authkit.Auth.RevokeAllUserSessions"
|
||||
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionCookie builds an *http.Cookie pre-configured from Config. Pass the
|
||||
// plaintext returned by IssueSession; pass the matching ExpiresAt from the
|
||||
// returned *Session as `expires`. To clear a cookie at logout, pass an empty
|
||||
// plaintext and a past expiry.
|
||||
func (a *Auth) SessionCookie(plaintext string, expires time.Time) *http.Cookie {
|
||||
c := &http.Cookie{
|
||||
Name: a.cfg.SessionCookieName,
|
||||
Value: plaintext,
|
||||
Path: a.cfg.SessionCookiePath,
|
||||
Domain: a.cfg.SessionCookieDomain,
|
||||
Secure: a.cfg.SessionCookieSecure,
|
||||
HttpOnly: a.cfg.SessionCookieHTTPOnly,
|
||||
SameSite: a.cfg.SessionCookieSameSite,
|
||||
Expires: expires,
|
||||
}
|
||||
if plaintext == "" {
|
||||
c.MaxAge = -1
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// resolveRolesAndPermissions fetches the user's role names and the union of
|
||||
// their permission names. Both are returned as flat string slices for cheap
|
||||
// containment checks on the Principal.
|
||||
func (a *Auth) resolveRolesAndPermissions(ctx context.Context, userID uuid.UUID) ([]string, []string, error) {
|
||||
roles, err := a.deps.Roles.GetUserRoles(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
perms, err := a.deps.Permissions.GetUserPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rNames := make([]string, len(roles))
|
||||
for i, r := range roles {
|
||||
rNames[i] = r.Name
|
||||
}
|
||||
pNames := make([]string, len(perms))
|
||||
for i, p := range perms {
|
||||
pNames[i] = p.Name
|
||||
}
|
||||
return rNames, pNames, nil
|
||||
}
|
||||
220
service_test.go
Normal file
220
service_test.go
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegisterAndLogin(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "Alice@Example.com", "hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
if u.EmailNormalized != "alice@example.com" {
|
||||
t.Fatalf("email_normalized = %q", u.EmailNormalized)
|
||||
}
|
||||
got, err := a.LoginPassword(context.Background(), "alice@example.com", "hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("LoginPassword: %v", err)
|
||||
}
|
||||
if got.ID != u.ID {
|
||||
t.Fatalf("login user mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterDuplicateEmail(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
if _, err := a.Register(context.Background(), "x@y.com", "abc"); err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
_, err := a.Register(context.Background(), "X@Y.COM", "abc")
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Fatalf("expected ErrEmailTaken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginWrongPassword(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
if _, err := a.Register(context.Background(), "p@q.com", "right"); err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
if _, err := a.LoginPassword(context.Background(), "p@q.com", "wrong"); !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionIssueAuthenticateRevoke(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "s@s.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
plain, sess, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
}
|
||||
if sess == nil || plain == "" {
|
||||
t.Fatalf("missing session or plaintext")
|
||||
}
|
||||
|
||||
p, err := a.AuthenticateSession(context.Background(), plain)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthenticateSession: %v", err)
|
||||
}
|
||||
if p.UserID != u.ID {
|
||||
t.Fatalf("principal user id mismatch")
|
||||
}
|
||||
if err := a.RevokeSession(context.Background(), plain); err != nil {
|
||||
t.Fatalf("RevokeSession: %v", err)
|
||||
}
|
||||
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
|
||||
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailVerificationFlow(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "ev@e.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
tok, err := a.RequestEmailVerification(context.Background(), u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("RequestEmailVerification: %v", err)
|
||||
}
|
||||
confirmed, err := a.ConfirmEmail(context.Background(), tok)
|
||||
if err != nil {
|
||||
t.Fatalf("ConfirmEmail: %v", err)
|
||||
}
|
||||
if confirmed.EmailVerifiedAt == nil {
|
||||
t.Fatalf("email_verified_at not set")
|
||||
}
|
||||
// Re-using the token must fail.
|
||||
if _, err := a.ConfirmEmail(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid on token reuse, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetFlow(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "r@r.com", "old")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
// Issue a session that should be invalidated by the reset.
|
||||
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.Addr{})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
}
|
||||
tok, err := a.RequestPasswordReset(context.Background(), "r@r.com")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestPasswordReset: %v", err)
|
||||
}
|
||||
if err := a.ConfirmPasswordReset(context.Background(), tok, "new"); err != nil {
|
||||
t.Fatalf("ConfirmPasswordReset: %v", err)
|
||||
}
|
||||
if _, err := a.LoginPassword(context.Background(), "r@r.com", "old"); !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("old password should fail post-reset, got %v", err)
|
||||
}
|
||||
if _, err := a.LoginPassword(context.Background(), "r@r.com", "new"); err != nil {
|
||||
t.Fatalf("new password should work post-reset, got %v", err)
|
||||
}
|
||||
if _, err := a.AuthenticateSession(context.Background(), plain); !errors.Is(err, ErrSessionInvalid) {
|
||||
t.Fatalf("session should be invalidated by reset, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMagicLinkFlow(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
if _, err := a.Register(context.Background(), "m@m.com", "pw"); err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
tok, err := a.RequestMagicLink(context.Background(), "m@m.com")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestMagicLink: %v", err)
|
||||
}
|
||||
u, err := a.ConsumeMagicLink(context.Background(), tok)
|
||||
if err != nil {
|
||||
t.Fatalf("ConsumeMagicLink: %v", err)
|
||||
}
|
||||
if u.EmailVerifiedAt == nil {
|
||||
t.Fatalf("magic link should imply email verification")
|
||||
}
|
||||
if _, err := a.ConsumeMagicLink(context.Background(), tok); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid on magic link reuse, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyFlowWithAbilities(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "k@k.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
plaintext, k, err := a.IssueAPIKey(context.Background(), u.ID, "ci", []string{"billing:read", "users:list"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueAPIKey: %v", err)
|
||||
}
|
||||
if k == nil || plaintext == "" {
|
||||
t.Fatalf("missing api key")
|
||||
}
|
||||
p, err := a.AuthenticateAPIKey(context.Background(), plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthenticateAPIKey: %v", err)
|
||||
}
|
||||
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
|
||||
t.Fatalf("abilities missing on principal: %+v", p.Abilities)
|
||||
}
|
||||
if p.HasAbility("admin:nuke") {
|
||||
t.Fatalf("unexpected ability granted")
|
||||
}
|
||||
if err := a.RevokeAPIKey(context.Background(), plaintext); err != nil {
|
||||
t.Fatalf("RevokeAPIKey: %v", err)
|
||||
}
|
||||
if _, err := a.AuthenticateAPIKey(context.Background(), plaintext); !errors.Is(err, ErrAPIKeyInvalid) {
|
||||
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRBACRolesAndPermissions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(ctx, "rb@a.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
// Create role + permission, hook them up.
|
||||
role := &Role{Name: "editor"}
|
||||
if err := a.deps.Roles.CreateRole(ctx, role); err != nil {
|
||||
t.Fatalf("CreateRole: %v", err)
|
||||
}
|
||||
perm := &Permission{Name: "posts:write"}
|
||||
if err := a.deps.Permissions.CreatePermission(ctx, perm); err != nil {
|
||||
t.Fatalf("CreatePermission: %v", err)
|
||||
}
|
||||
if err := a.deps.Permissions.AssignPermissionToRole(ctx, role.ID, perm.ID); err != nil {
|
||||
t.Fatalf("AssignPermissionToRole: %v", err)
|
||||
}
|
||||
if err := a.AssignRole(ctx, u.ID, "editor"); err != nil {
|
||||
t.Fatalf("AssignRole: %v", err)
|
||||
}
|
||||
ok, err := a.HasPermission(ctx, u.ID, "posts:write")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("HasPermission posts:write should be true, got %v %v", ok, err)
|
||||
}
|
||||
ok, err = a.HasRole(ctx, u.ID, "editor")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("HasRole editor should be true, got %v %v", ok, err)
|
||||
}
|
||||
if err := a.RemoveRole(ctx, u.ID, "editor"); err != nil {
|
||||
t.Fatalf("RemoveRole: %v", err)
|
||||
}
|
||||
ok, _ = a.HasPermission(ctx, u.ID, "posts:write")
|
||||
if ok {
|
||||
t.Fatalf("HasPermission should be false after RemoveRole")
|
||||
}
|
||||
}
|
||||
178
service_user.go
Normal file
178
service_user.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// normalizeEmail produces the lookup form used by UserStore.GetUserByEmail
|
||||
// and the email_normalized column. Trim + lowercase is intentional; we do
|
||||
// not collapse Gmail-style "+" addressing or strip dots — that's a policy
|
||||
// decision callers can layer on top.
|
||||
func normalizeEmail(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
|
||||
// Register creates a new user with an Argon2id-hashed password. Returns
|
||||
// ErrEmailTaken if the normalized email is already registered.
|
||||
func (a *Auth) Register(ctx context.Context, email, password string) (*User, error) {
|
||||
const op = "authkit.Auth.Register"
|
||||
if email == "" || password == "" {
|
||||
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
hash, err := a.deps.Hasher.Hash(password)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
u := &User{
|
||||
ID: uuid.New(),
|
||||
Email: email,
|
||||
EmailNormalized: normalizeEmail(email),
|
||||
PasswordHash: hash,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := a.deps.Users.CreateUser(ctx, u); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// LoginPassword verifies the password and returns the authenticated user.
|
||||
// Failure increments failed_logins; success resets it and stamps last_login_at.
|
||||
// LoginHook (if configured) is invoked with the success outcome — use this to
|
||||
// hook in rate limiting or audit logging.
|
||||
func (a *Auth) LoginPassword(ctx context.Context, email, password string) (*User, error) {
|
||||
const op = "authkit.Auth.LoginPassword"
|
||||
u, err := a.deps.Users.GetUserByEmail(ctx, normalizeEmail(email))
|
||||
if err != nil {
|
||||
_ = a.fireLoginHook(ctx, email, false)
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
_ = a.fireLoginHook(ctx, email, false)
|
||||
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
ok, needsRehash, err := a.deps.Hasher.Verify(password, u.PasswordHash)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
if !ok {
|
||||
_, _ = a.deps.Users.IncrementFailedLogins(ctx, u.ID)
|
||||
_ = a.fireLoginHook(ctx, email, false)
|
||||
return nil, errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
|
||||
now := a.now()
|
||||
u.LastLoginAt = &now
|
||||
u.FailedLogins = 0
|
||||
if err := a.deps.Users.ResetFailedLogins(ctx, u.ID); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Users.UpdateUser(ctx, u); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
if needsRehash {
|
||||
if newHash, herr := a.deps.Hasher.Hash(password); herr == nil {
|
||||
_ = a.deps.Users.SetPassword(ctx, u.ID, newHash)
|
||||
u.PasswordHash = newHash
|
||||
}
|
||||
}
|
||||
_ = a.fireLoginHook(ctx, email, true)
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// ChangePassword verifies the current password, sets the new one, and bumps
|
||||
// the user's session_version so all outstanding JWT access tokens are
|
||||
// instantly invalidated. Outstanding opaque sessions are also revoked.
|
||||
func (a *Auth) ChangePassword(ctx context.Context, userID uuid.UUID, oldPassword, newPassword string) error {
|
||||
const op = "authkit.Auth.ChangePassword"
|
||||
u, err := a.deps.Users.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
return errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
ok, _, err := a.deps.Hasher.Verify(oldPassword, u.PasswordHash)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if !ok {
|
||||
return errx.Wrap(op, ErrInvalidCredentials)
|
||||
}
|
||||
newHash, err := a.deps.Hasher.Hash(newPassword)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Users.SetPassword(ctx, userID, newHash); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if _, err := a.deps.Users.BumpSessionVersion(ctx, userID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Sessions.DeleteUserSessions(ctx, userID); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestEmailVerification mints a single-use email-verify token for the
|
||||
// user. Return the plaintext to the caller so they can put it in an email
|
||||
// link; the lookup hash is stored in TokenStore.
|
||||
func (a *Auth) RequestEmailVerification(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
const op = "authkit.Auth.RequestEmailVerification"
|
||||
plaintext, hash, err := mintSecret(prefixEmailVerify, a.cfg.Random)
|
||||
if err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
now := a.now()
|
||||
t := &Token{
|
||||
Hash: hash,
|
||||
Kind: TokenEmailVerify,
|
||||
UserID: userID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(a.cfg.EmailVerifyTTL),
|
||||
}
|
||||
if err := a.deps.Tokens.CreateToken(ctx, t); err != nil {
|
||||
return "", errx.Wrap(op, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ConfirmEmail consumes the verification token and marks the user's email
|
||||
// verified. Returns ErrTokenInvalid if the token is missing/expired/used.
|
||||
func (a *Auth) ConfirmEmail(ctx context.Context, plaintextToken string) (*User, error) {
|
||||
const op = "authkit.Auth.ConfirmEmail"
|
||||
hash, ok := parseSecret(prefixEmailVerify, plaintextToken)
|
||||
if !ok {
|
||||
return nil, errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
now := a.now()
|
||||
t, err := a.deps.Tokens.ConsumeToken(ctx, TokenEmailVerify, hash, now)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
if err := a.deps.Users.SetEmailVerified(ctx, t.UserID, now); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
return a.deps.Users.GetUserByID(ctx, t.UserID)
|
||||
}
|
||||
|
||||
// fireLoginHook is a thin wrapper that suppresses panics from caller-supplied
|
||||
// hooks; we never want a misbehaving telemetry hook to break login.
|
||||
func (a *Auth) fireLoginHook(ctx context.Context, email string, success bool) error {
|
||||
if a.cfg.LoginHook == nil {
|
||||
return nil
|
||||
}
|
||||
return a.cfg.LoginHook(ctx, email, success)
|
||||
}
|
||||
124
sqlstore/apikeys.go
Normal file
124
sqlstore/apikeys.go
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type apiKeyStore struct{ storeBase }
|
||||
|
||||
func (s *apiKeyStore) CreateAPIKey(ctx context.Context, k *authkit.APIKey) error {
|
||||
const op = "authkit.sqlstore.APIKeyStore.CreateAPIKey"
|
||||
if k.CreatedAt.IsZero() {
|
||||
k.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if k.Abilities == nil {
|
||||
k.Abilities = []string{}
|
||||
}
|
||||
abilities, err := json.Marshal(k.Abilities)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
_, err = s.db.ExecContext(ctx, s.q.CreateAPIKey,
|
||||
k.IDHash, uuidArg(k.OwnerID), k.Name, abilities,
|
||||
nullableTime(k.LastUsedAt), k.CreatedAt,
|
||||
nullableTime(k.ExpiresAt), nullableTime(k.RevokedAt))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyStore) GetAPIKey(ctx context.Context, idHash []byte) (*authkit.APIKey, error) {
|
||||
const op = "authkit.sqlstore.APIKeyStore.GetAPIKey"
|
||||
k, err := scanAPIKey(s.db.QueryRowContext(ctx, s.q.GetAPIKey, idHash))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrAPIKeyInvalid))
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
func (s *apiKeyStore) ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*authkit.APIKey, error) {
|
||||
const op = "authkit.sqlstore.APIKeyStore.ListAPIKeysByOwner"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.ListAPIKeysByOwner, uuidArg(ownerID))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.APIKey
|
||||
for rows.Next() {
|
||||
k, err := scanAPIKey(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, k)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
func (s *apiKeyStore) TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
|
||||
const op = "authkit.sqlstore.APIKeyStore.TouchAPIKey"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.TouchAPIKey, at, idHash); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyStore) RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error {
|
||||
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKey"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.RevokeAPIKey, at, idHash)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrAPIKeyInvalid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyStore) RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error {
|
||||
const op = "authkit.sqlstore.APIKeyStore.RevokeAPIKeysByOwner"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.RevokeAPIKeysByOwner, at, uuidArg(ownerID)); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanAPIKey(row rowScanner) (*authkit.APIKey, error) {
|
||||
var (
|
||||
k authkit.APIKey
|
||||
ownerIDStr string
|
||||
abilitiesRaw []byte
|
||||
lastUsed sql.NullTime
|
||||
expires sql.NullTime
|
||||
revoked sql.NullTime
|
||||
)
|
||||
if err := row.Scan(&k.IDHash, &ownerIDStr, &k.Name, &abilitiesRaw,
|
||||
&lastUsed, &k.CreatedAt, &expires, &revoked); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
owner, err := scanUUID(ownerIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k.OwnerID = owner
|
||||
if len(abilitiesRaw) > 0 {
|
||||
if err := json.Unmarshal(abilitiesRaw, &k.Abilities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if k.Abilities == nil {
|
||||
k.Abilities = []string{}
|
||||
}
|
||||
k.LastUsedAt = scanNullTimePtr(lastUsed)
|
||||
k.ExpiresAt = scanNullTimePtr(expires)
|
||||
k.RevokedAt = scanNullTimePtr(revoked)
|
||||
return &k, nil
|
||||
}
|
||||
118
sqlstore/dialect.go
Normal file
118
sqlstore/dialect.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
// Dialect describes how a particular SQL backend renders the queries authkit
|
||||
// runs at runtime, bootstraps its session for migrations, and reports
|
||||
// driver-specific errors. v1 ships the Postgres dialect; future MySQL or
|
||||
// SQLite implementations satisfy the same interface without changes to the
|
||||
// store code.
|
||||
type Dialect interface {
|
||||
// Name returns a stable short identifier ("postgres", "mysql", ...).
|
||||
Name() string
|
||||
|
||||
// BuildQueries renders every templated SQL string from a validated
|
||||
// Schema. Called once at New() and at Migrate() so identifiers and
|
||||
// placeholder styles are baked in by the time stores execute queries.
|
||||
BuildQueries(s Schema) Queries
|
||||
|
||||
// Bootstrap runs once per Migrate() before the migration lock is taken.
|
||||
// It is the place for things that must happen outside any transaction
|
||||
// (CREATE EXTENSION on Postgres, for example). Must be idempotent and
|
||||
// may be a no-op.
|
||||
Bootstrap(ctx context.Context, db *sql.DB) error
|
||||
|
||||
// AcquireMigrationLock takes a session-scoped lock on conn so concurrent
|
||||
// migrators serialise. The returned release function never returns an
|
||||
// error — implementations may log internally.
|
||||
AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (release func(), err error)
|
||||
|
||||
// Migrations returns the dialect's embedded .sql files, rooted such
|
||||
// that fs.ReadDir(".") lists them lex-sorted by filename.
|
||||
Migrations() fs.FS
|
||||
|
||||
// IsUniqueViolation maps a duplicate-key error from this driver to true
|
||||
// so insert paths can return the matching authkit sentinel without
|
||||
// depending on driver internals.
|
||||
IsUniqueViolation(err error) bool
|
||||
|
||||
// Placeholder returns the placeholder for parameter index n (1-based)
|
||||
// in this dialect — "$1" for Postgres, "?" for MySQL/SQLite.
|
||||
Placeholder(n int) string
|
||||
|
||||
// PlaceholderList returns a comma-separated placeholder list for `count`
|
||||
// parameters starting at position `start` (1-based), suitable for
|
||||
// dynamic IN-clauses. The second return is the same length, prefilled
|
||||
// with nil so the caller can append actual values.
|
||||
PlaceholderList(start, count int) string
|
||||
}
|
||||
|
||||
// Queries is the full set of statement templates authkit issues. Field names
|
||||
// match the store method that consumes the query.
|
||||
type Queries struct {
|
||||
// users
|
||||
CreateUser string
|
||||
GetUserByID string
|
||||
GetUserByEmail string
|
||||
UpdateUser string
|
||||
DeleteUser string
|
||||
SetPassword string
|
||||
SetEmailVerified string
|
||||
BumpSessionVersion string
|
||||
IncrementFailedLogins string
|
||||
ResetFailedLogins string
|
||||
|
||||
// sessions
|
||||
CreateSession string
|
||||
GetSession string
|
||||
TouchSession string
|
||||
DeleteSession string
|
||||
DeleteUserSessions string
|
||||
DeleteExpiredSessions string
|
||||
|
||||
// tokens
|
||||
CreateToken string
|
||||
ConsumeToken string
|
||||
GetToken string
|
||||
DeleteByChain string
|
||||
DeleteExpiredTokens string
|
||||
|
||||
// api keys
|
||||
CreateAPIKey string
|
||||
GetAPIKey string
|
||||
ListAPIKeysByOwner string
|
||||
TouchAPIKey string
|
||||
RevokeAPIKey string
|
||||
RevokeAPIKeysByOwner string
|
||||
|
||||
// roles
|
||||
CreateRole string
|
||||
GetRoleByID string
|
||||
GetRoleByName string
|
||||
ListRoles string
|
||||
DeleteRole string
|
||||
AssignRoleToUser string
|
||||
RemoveRoleFromUser string
|
||||
GetUserRoles string
|
||||
// HasAnyRole is built at call time because the placeholder count varies.
|
||||
|
||||
// permissions
|
||||
CreatePermission string
|
||||
GetPermissionByID string
|
||||
GetPermissionByName string
|
||||
ListPermissions string
|
||||
DeletePermission string
|
||||
AssignPermissionToRole string
|
||||
RemovePermissionFromRole string
|
||||
GetRolePermissions string
|
||||
GetUserPermissions string
|
||||
|
||||
// migrations
|
||||
CreateMigrationsTable string
|
||||
SelectAppliedVersions string
|
||||
InsertAppliedVersion string
|
||||
}
|
||||
33
sqlstore/dialect/postgres/errors.go
Normal file
33
sqlstore/dialect/postgres/errors.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// pgUniqueViolation is the SQLSTATE for unique_violation. Both pgx-stdlib
|
||||
// and lib/pq surface this code, but only pgx-stdlib uses *pgconn.PgError.
|
||||
// lib/pq uses *pq.Error which has a Code field of the same value.
|
||||
const pgUniqueViolation = "23505"
|
||||
|
||||
// isUniqueViolation inspects err for a Postgres unique-violation, regardless
|
||||
// of which driver registered the connection. We match on either the pgx
|
||||
// error type or any error implementing a Code() string method (lib/pq's
|
||||
// pq.Error has SQLState and Code fields; we check via reflection-free
|
||||
// duck-typing through an interface).
|
||||
func isUniqueViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var pgxErr *pgconn.PgError
|
||||
if errors.As(err, &pgxErr) {
|
||||
return pgxErr.Code == pgUniqueViolation
|
||||
}
|
||||
type sqlStater interface{ SQLState() string }
|
||||
var s sqlStater
|
||||
if errors.As(err, &s) {
|
||||
return s.SQLState() == pgUniqueViolation
|
||||
}
|
||||
return false
|
||||
}
|
||||
99
sqlstore/dialect/postgres/migrations/0001_init.sql
Normal file
99
sqlstore/dialect/postgres/migrations/0001_init.sql
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
-- 0001_init.sql
|
||||
-- Initial authkit schema for Postgres. Tables are prefixed authkit_ so the
|
||||
-- library can be embedded in an existing application database. Each
|
||||
-- migration owns its own transaction and inserts its version row at the
|
||||
-- bottom; the runner only orchestrates file discovery and concurrency.
|
||||
|
||||
BEGIN;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_schema_migrations (
|
||||
version TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_users (
|
||||
id UUID PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
email_normalized TEXT NOT NULL,
|
||||
email_verified_at TIMESTAMPTZ,
|
||||
password_hash TEXT,
|
||||
session_version INTEGER NOT NULL DEFAULT 0,
|
||||
failed_logins INTEGER NOT NULL DEFAULT 0,
|
||||
last_login_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS authkit_users_email_normalized_uniq
|
||||
ON authkit_users (email_normalized);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_sessions (
|
||||
id_hash BYTEA PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||
user_agent TEXT NOT NULL DEFAULT '',
|
||||
ip TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
last_seen_at TIMESTAMPTZ NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS authkit_sessions_user_id_idx ON authkit_sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS authkit_sessions_expires_at_idx ON authkit_sessions(expires_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_tokens (
|
||||
hash BYTEA NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||
chain_id TEXT,
|
||||
consumed_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
PRIMARY KEY (kind, hash)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS authkit_tokens_user_id_idx ON authkit_tokens(user_id);
|
||||
CREATE INDEX IF NOT EXISTS authkit_tokens_expires_at_idx ON authkit_tokens(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS authkit_tokens_chain_id_idx
|
||||
ON authkit_tokens(chain_id) WHERE chain_id IS NOT NULL;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_api_keys (
|
||||
id_hash BYTEA PRIMARY KEY,
|
||||
owner_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
abilities JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
last_used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
expires_at TIMESTAMPTZ,
|
||||
revoked_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS authkit_api_keys_owner_id_idx ON authkit_api_keys(owner_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_roles (
|
||||
id UUID PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_permissions (
|
||||
id UUID PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_role_permissions (
|
||||
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
|
||||
permission_id UUID NOT NULL REFERENCES authkit_permissions(id) ON DELETE CASCADE,
|
||||
PRIMARY KEY (role_id, permission_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS authkit_user_roles (
|
||||
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||
role_id UUID NOT NULL REFERENCES authkit_roles(id) ON DELETE CASCADE,
|
||||
granted_at TIMESTAMPTZ NOT NULL,
|
||||
PRIMARY KEY (user_id, role_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS authkit_user_roles_role_id_idx ON authkit_user_roles(role_id);
|
||||
|
||||
INSERT INTO authkit_schema_migrations (version, applied_at) VALUES ('0001_init', now())
|
||||
ON CONFLICT (version) DO NOTHING;
|
||||
|
||||
COMMIT;
|
||||
272
sqlstore/dialect/postgres/postgres.go
Normal file
272
sqlstore/dialect/postgres/postgres.go
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
// Package postgres is the Postgres dialect for authkit/sqlstore. Importing
|
||||
// it does not register a driver — callers do `_ "github.com/jackc/pgx/v5/stdlib"`
|
||||
// or `_ "github.com/lib/pq"` themselves and then `sql.Open(...)`.
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
// Dialect implements sqlstore.Dialect for Postgres. The zero value is the
|
||||
// only required form; New() returns a pointer for clarity at call sites.
|
||||
type Dialect struct{}
|
||||
|
||||
// New returns a Postgres dialect ready to pass to sqlstore.New / Migrate.
|
||||
func New() *Dialect { return &Dialect{} }
|
||||
|
||||
func (Dialect) Name() string { return "postgres" }
|
||||
|
||||
// advisoryLockKey is the ASCII bytes of "authkit" packed into an int64. Stable
|
||||
// across rollouts and unlikely to clash with caller advisory locks.
|
||||
const advisoryLockKey int64 = 0x617574686b6974
|
||||
|
||||
func (Dialect) Bootstrap(ctx context.Context, db *sql.DB) error {
|
||||
// Nothing to do — schema avoids gen_random_uuid()/pgcrypto. Kept as a
|
||||
// hook for future migration prerequisites.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Dialect) AcquireMigrationLock(ctx context.Context, conn *sql.Conn) (func(), error) {
|
||||
if _, err := conn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
release := func() {
|
||||
if _, err := conn.ExecContext(context.Background(),
|
||||
"SELECT pg_advisory_unlock($1)", advisoryLockKey); err != nil {
|
||||
log.Printf("authkit/postgres: pg_advisory_unlock failed: %v", err)
|
||||
}
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func (Dialect) Migrations() fs.FS {
|
||||
sub, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
// migrationsFS is statically populated; this can only fail if the
|
||||
// embed directive is removed, which would be a build-time error.
|
||||
panic(err)
|
||||
}
|
||||
return sub
|
||||
}
|
||||
|
||||
func (Dialect) IsUniqueViolation(err error) bool { return isUniqueViolation(err) }
|
||||
|
||||
func (Dialect) Placeholder(n int) string { return fmt.Sprintf("$%d", n) }
|
||||
|
||||
// PlaceholderList renders a comma-separated `$start,$start+1,...` list of
|
||||
// `count` placeholders. Used by HasAnyRole's dynamic IN-clause expansion.
|
||||
func (Dialect) PlaceholderList(start, count int) string {
|
||||
if count <= 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := 0; i < count; i++ {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
fmt.Fprintf(&b, "$%d", start+i)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// BuildQueries renders every query authkit issues, with table identifiers
|
||||
// taken from s and `?` placeholders rewritten to `$N`. Identifiers are
|
||||
// already validated by Schema.Validate — this is interpolated with
|
||||
// fmt.Sprintf, so the validation gate is load-bearing.
|
||||
func (Dialect) BuildQueries(s sqlstore.Schema) sqlstore.Queries {
|
||||
t := s.Tables
|
||||
q := sqlstore.Queries{
|
||||
// users
|
||||
CreateUser: `INSERT INTO ` + t.Users + `
|
||||
(id, email, email_normalized, email_verified_at, password_hash,
|
||||
session_version, failed_logins, last_login_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
GetUserByID: `SELECT id, email, email_normalized, email_verified_at,
|
||||
password_hash, session_version, failed_logins, last_login_at,
|
||||
created_at, updated_at FROM ` + t.Users + ` WHERE id = ?`,
|
||||
GetUserByEmail: `SELECT id, email, email_normalized, email_verified_at,
|
||||
password_hash, session_version, failed_logins, last_login_at,
|
||||
created_at, updated_at FROM ` + t.Users + ` WHERE email_normalized = ?`,
|
||||
UpdateUser: `UPDATE ` + t.Users + ` SET
|
||||
email = ?, email_normalized = ?, email_verified_at = ?,
|
||||
password_hash = ?, session_version = ?, failed_logins = ?,
|
||||
last_login_at = ?, updated_at = ?
|
||||
WHERE id = ?`,
|
||||
DeleteUser: `DELETE FROM ` + t.Users + ` WHERE id = ?`,
|
||||
SetPassword: `UPDATE ` + t.Users + ` SET password_hash = ?, updated_at = ? WHERE id = ?`,
|
||||
SetEmailVerified: `UPDATE ` + t.Users + ` SET email_verified_at = ?, updated_at = ? WHERE id = ?`,
|
||||
BumpSessionVersion: `UPDATE ` + t.Users + ` SET session_version = session_version + 1, updated_at = ? WHERE id = ? RETURNING session_version`,
|
||||
IncrementFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = failed_logins + 1, updated_at = ? WHERE id = ? RETURNING failed_logins`,
|
||||
ResetFailedLogins: `UPDATE ` + t.Users + ` SET failed_logins = 0, updated_at = ? WHERE id = ?`,
|
||||
|
||||
// sessions
|
||||
CreateSession: `INSERT INTO ` + t.Sessions + `
|
||||
(id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
GetSession: `SELECT id_hash, user_id, user_agent, ip, created_at, last_seen_at, expires_at
|
||||
FROM ` + t.Sessions + ` WHERE id_hash = ?`,
|
||||
TouchSession: `UPDATE ` + t.Sessions + ` SET last_seen_at = ?, expires_at = ? WHERE id_hash = ?`,
|
||||
DeleteSession: `DELETE FROM ` + t.Sessions + ` WHERE id_hash = ?`,
|
||||
DeleteUserSessions: `DELETE FROM ` + t.Sessions + ` WHERE user_id = ?`,
|
||||
DeleteExpiredSessions: `DELETE FROM ` + t.Sessions + ` WHERE expires_at <= ?`,
|
||||
|
||||
// tokens
|
||||
CreateToken: `INSERT INTO ` + t.Tokens + `
|
||||
(hash, kind, user_id, chain_id, consumed_at, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
ConsumeToken: `UPDATE ` + t.Tokens + `
|
||||
SET consumed_at = ?
|
||||
WHERE kind = ? AND hash = ? AND consumed_at IS NULL AND expires_at > ?
|
||||
RETURNING hash, kind, user_id, chain_id, consumed_at, created_at, expires_at`,
|
||||
GetToken: `SELECT hash, kind, user_id, chain_id, consumed_at, created_at, expires_at
|
||||
FROM ` + t.Tokens + ` WHERE kind = ? AND hash = ?`,
|
||||
DeleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = ?`,
|
||||
DeleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= ?`,
|
||||
|
||||
// api keys
|
||||
CreateAPIKey: `INSERT INTO ` + t.APIKeys + `
|
||||
(id_hash, owner_id, name, abilities, last_used_at, created_at, expires_at, revoked_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
GetAPIKey: `SELECT id_hash, owner_id, name, abilities, last_used_at,
|
||||
created_at, expires_at, revoked_at
|
||||
FROM ` + t.APIKeys + ` WHERE id_hash = ?`,
|
||||
ListAPIKeysByOwner: `SELECT id_hash, owner_id, name, abilities, last_used_at,
|
||||
created_at, expires_at, revoked_at
|
||||
FROM ` + t.APIKeys + ` WHERE owner_id = ? ORDER BY created_at DESC`,
|
||||
TouchAPIKey: `UPDATE ` + t.APIKeys + ` SET last_used_at = ? WHERE id_hash = ?`,
|
||||
RevokeAPIKey: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE id_hash = ? AND revoked_at IS NULL`,
|
||||
RevokeAPIKeysByOwner: `UPDATE ` + t.APIKeys + ` SET revoked_at = ? WHERE owner_id = ? AND revoked_at IS NULL`,
|
||||
|
||||
// roles
|
||||
CreateRole: `INSERT INTO ` + t.Roles + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
|
||||
GetRoleByID: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE id = ?`,
|
||||
GetRoleByName: `SELECT id, name, description, created_at FROM ` + t.Roles + ` WHERE name = ?`,
|
||||
ListRoles: `SELECT id, name, description, created_at FROM ` + t.Roles + ` ORDER BY name`,
|
||||
DeleteRole: `DELETE FROM ` + t.Roles + ` WHERE id = ?`,
|
||||
AssignRoleToUser: `INSERT INTO ` + t.UserRoles + ` (user_id, role_id, granted_at) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`,
|
||||
RemoveRoleFromUser: `DELETE FROM ` + t.UserRoles + ` WHERE user_id = ? AND role_id = ?`,
|
||||
GetUserRoles: `SELECT r.id, r.name, r.description, r.created_at
|
||||
FROM ` + t.Roles + ` r
|
||||
JOIN ` + t.UserRoles + ` ur ON ur.role_id = r.id
|
||||
WHERE ur.user_id = ? ORDER BY r.name`,
|
||||
|
||||
// permissions
|
||||
CreatePermission: `INSERT INTO ` + t.Permissions + ` (id, name, description, created_at) VALUES (?, ?, ?, ?)`,
|
||||
GetPermissionByID: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE id = ?`,
|
||||
GetPermissionByName: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` WHERE name = ?`,
|
||||
ListPermissions: `SELECT id, name, description, created_at FROM ` + t.Permissions + ` ORDER BY name`,
|
||||
DeletePermission: `DELETE FROM ` + t.Permissions + ` WHERE id = ?`,
|
||||
AssignPermissionToRole: `INSERT INTO ` + t.RolePermissions + ` (role_id, permission_id) VALUES (?, ?)
|
||||
ON CONFLICT DO NOTHING`,
|
||||
RemovePermissionFromRole: `DELETE FROM ` + t.RolePermissions + ` WHERE role_id = ? AND permission_id = ?`,
|
||||
GetRolePermissions: `SELECT p.id, p.name, p.description, p.created_at
|
||||
FROM ` + t.Permissions + ` p
|
||||
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
|
||||
WHERE rp.role_id = ? ORDER BY p.name`,
|
||||
GetUserPermissions: `SELECT DISTINCT p.id, p.name, p.description, p.created_at
|
||||
FROM ` + t.Permissions + ` p
|
||||
JOIN ` + t.RolePermissions + ` rp ON rp.permission_id = p.id
|
||||
JOIN ` + t.UserRoles + ` ur ON ur.role_id = rp.role_id
|
||||
WHERE ur.user_id = ? ORDER BY p.name`,
|
||||
|
||||
// migrations
|
||||
CreateMigrationsTable: `CREATE TABLE IF NOT EXISTS ` + t.SchemaMigrations + ` (
|
||||
version TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL
|
||||
)`,
|
||||
SelectAppliedVersions: `SELECT version FROM ` + t.SchemaMigrations,
|
||||
InsertAppliedVersion: `INSERT INTO ` + t.SchemaMigrations + ` (version, applied_at) VALUES (?, ?)`,
|
||||
}
|
||||
|
||||
// Rewrite `?` placeholders to `$N`. Each query is independent; numbering
|
||||
// resets per query.
|
||||
q.CreateUser = rebind(q.CreateUser)
|
||||
q.GetUserByID = rebind(q.GetUserByID)
|
||||
q.GetUserByEmail = rebind(q.GetUserByEmail)
|
||||
q.UpdateUser = rebind(q.UpdateUser)
|
||||
q.DeleteUser = rebind(q.DeleteUser)
|
||||
q.SetPassword = rebind(q.SetPassword)
|
||||
q.SetEmailVerified = rebind(q.SetEmailVerified)
|
||||
q.BumpSessionVersion = rebind(q.BumpSessionVersion)
|
||||
q.IncrementFailedLogins = rebind(q.IncrementFailedLogins)
|
||||
q.ResetFailedLogins = rebind(q.ResetFailedLogins)
|
||||
|
||||
q.CreateSession = rebind(q.CreateSession)
|
||||
q.GetSession = rebind(q.GetSession)
|
||||
q.TouchSession = rebind(q.TouchSession)
|
||||
q.DeleteSession = rebind(q.DeleteSession)
|
||||
q.DeleteUserSessions = rebind(q.DeleteUserSessions)
|
||||
q.DeleteExpiredSessions = rebind(q.DeleteExpiredSessions)
|
||||
|
||||
q.CreateToken = rebind(q.CreateToken)
|
||||
q.ConsumeToken = rebind(q.ConsumeToken)
|
||||
q.GetToken = rebind(q.GetToken)
|
||||
q.DeleteByChain = rebind(q.DeleteByChain)
|
||||
q.DeleteExpiredTokens = rebind(q.DeleteExpiredTokens)
|
||||
|
||||
q.CreateAPIKey = rebind(q.CreateAPIKey)
|
||||
q.GetAPIKey = rebind(q.GetAPIKey)
|
||||
q.ListAPIKeysByOwner = rebind(q.ListAPIKeysByOwner)
|
||||
q.TouchAPIKey = rebind(q.TouchAPIKey)
|
||||
q.RevokeAPIKey = rebind(q.RevokeAPIKey)
|
||||
q.RevokeAPIKeysByOwner = rebind(q.RevokeAPIKeysByOwner)
|
||||
|
||||
q.CreateRole = rebind(q.CreateRole)
|
||||
q.GetRoleByID = rebind(q.GetRoleByID)
|
||||
q.GetRoleByName = rebind(q.GetRoleByName)
|
||||
q.ListRoles = rebind(q.ListRoles)
|
||||
q.DeleteRole = rebind(q.DeleteRole)
|
||||
q.AssignRoleToUser = rebind(q.AssignRoleToUser)
|
||||
q.RemoveRoleFromUser = rebind(q.RemoveRoleFromUser)
|
||||
q.GetUserRoles = rebind(q.GetUserRoles)
|
||||
|
||||
q.CreatePermission = rebind(q.CreatePermission)
|
||||
q.GetPermissionByID = rebind(q.GetPermissionByID)
|
||||
q.GetPermissionByName = rebind(q.GetPermissionByName)
|
||||
q.ListPermissions = rebind(q.ListPermissions)
|
||||
q.DeletePermission = rebind(q.DeletePermission)
|
||||
q.AssignPermissionToRole = rebind(q.AssignPermissionToRole)
|
||||
q.RemovePermissionFromRole = rebind(q.RemovePermissionFromRole)
|
||||
q.GetRolePermissions = rebind(q.GetRolePermissions)
|
||||
q.GetUserPermissions = rebind(q.GetUserPermissions)
|
||||
|
||||
q.SelectAppliedVersions = rebind(q.SelectAppliedVersions)
|
||||
q.InsertAppliedVersion = rebind(q.InsertAppliedVersion)
|
||||
// CreateMigrationsTable contains no parameters.
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// rebind walks s and replaces each unquoted `?` with $1, $2, ... in order.
|
||||
// Our query strings contain no string literals that include `?` (verified
|
||||
// by inspection of every query in BuildQueries); a literal-aware rewriter
|
||||
// would be more defensive but is not needed for v1.
|
||||
func rebind(s string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s) + 16)
|
||||
n := 1
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c == '?' {
|
||||
fmt.Fprintf(&b, "$%d", n)
|
||||
n++
|
||||
continue
|
||||
}
|
||||
b.WriteByte(c)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Compile-time interface compliance check.
|
||||
var _ sqlstore.Dialect = (*Dialect)(nil)
|
||||
116
sqlstore/migrate.go
Normal file
116
sqlstore/migrate.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// Migrate applies every embedded migration the dialect ships that has not
|
||||
// yet been recorded in the schema-migrations table. It is safe to call
|
||||
// repeatedly and concurrently across processes — the dialect's session
|
||||
// lock serialises rollouts.
|
||||
//
|
||||
// Each migration .sql file is responsible for owning its own
|
||||
// BEGIN/COMMIT and inserting a row into the schema-migrations table on
|
||||
// success. The runner only handles file discovery, version tracking, and
|
||||
// concurrency.
|
||||
func Migrate(ctx context.Context, db *sql.DB, dialect Dialect, schema Schema) error {
|
||||
const op = "authkit.sqlstore.Migrate"
|
||||
if db == nil {
|
||||
return errx.New(op, "db is required")
|
||||
}
|
||||
if dialect == nil {
|
||||
return errx.New(op, "dialect is required")
|
||||
}
|
||||
if err := schema.Validate(); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
if err := dialect.Bootstrap(ctx, db); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
release, err := dialect.AcquireMigrationLock(ctx, conn)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
defer release()
|
||||
|
||||
q := dialect.BuildQueries(schema)
|
||||
if _, err := conn.ExecContext(ctx, q.CreateMigrationsTable); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
applied, err := loadAppliedVersions(ctx, conn, q.SelectAppliedVersions)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
migs := dialect.Migrations()
|
||||
files, err := fs.ReadDir(migs, ".")
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
names := make([]string, 0, len(files))
|
||||
for _, f := range files {
|
||||
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
|
||||
names = append(names, f.Name())
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
for _, name := range names {
|
||||
version := strings.TrimSuffix(name, ".sql")
|
||||
if _, ok := applied[version]; ok {
|
||||
continue
|
||||
}
|
||||
body, err := fs.ReadFile(migs, name)
|
||||
if err != nil {
|
||||
return errx.Wrapf(op, err, "read %s", name)
|
||||
}
|
||||
if _, err := conn.ExecContext(ctx, string(body)); err != nil {
|
||||
return errx.Wrapf(op, err, "apply %s", version)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyVersionRow is intentionally not exposed: migration files own their
|
||||
// own version-row insert. We keep this helper around in case a dialect ever
|
||||
// needs to backfill versions from outside a migration body — currently
|
||||
// unused.
|
||||
var _ = applyVersionRow
|
||||
|
||||
func applyVersionRow(ctx context.Context, conn *sql.Conn, insertQ, version string, at time.Time) error {
|
||||
_, err := conn.ExecContext(ctx, insertQ, version, at)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadAppliedVersions(ctx context.Context, conn *sql.Conn, q string) (map[string]struct{}, error) {
|
||||
rows, err := conn.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := make(map[string]struct{})
|
||||
for rows.Next() {
|
||||
var v string
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[v] = struct{}{}
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
301
sqlstore/rbac.go
Normal file
301
sqlstore/rbac.go
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type roleStore struct{ storeBase }
|
||||
type permissionStore struct{ storeBase }
|
||||
|
||||
// ----- roleStore ------------------------------------------------------------
|
||||
|
||||
func (s *roleStore) CreateRole(ctx context.Context, r *authkit.Role) error {
|
||||
const op = "authkit.sqlstore.RoleStore.CreateRole"
|
||||
if r.ID == uuid.Nil {
|
||||
r.ID = uuid.New()
|
||||
}
|
||||
if r.CreatedAt.IsZero() {
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if _, err := s.db.ExecContext(ctx, s.q.CreateRole,
|
||||
uuidArg(r.ID), r.Name, r.Description, r.CreatedAt); err != nil {
|
||||
if s.d.IsUniqueViolation(err) {
|
||||
return errx.Wrapf(op, err, "role %q already exists", r.Name)
|
||||
}
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleStore) GetRoleByID(ctx context.Context, id uuid.UUID) (*authkit.Role, error) {
|
||||
const op = "authkit.sqlstore.RoleStore.GetRoleByID"
|
||||
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByID, uuidArg(id)))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *roleStore) GetRoleByName(ctx context.Context, name string) (*authkit.Role, error) {
|
||||
const op = "authkit.sqlstore.RoleStore.GetRoleByName"
|
||||
r, err := scanRole(s.db.QueryRowContext(ctx, s.q.GetRoleByName, name))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrRoleNotFound))
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *roleStore) ListRoles(ctx context.Context) ([]*authkit.Role, error) {
|
||||
const op = "authkit.sqlstore.RoleStore.ListRoles"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.ListRoles)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.Role
|
||||
for rows.Next() {
|
||||
r, err := scanRole(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, r)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
func (s *roleStore) DeleteRole(ctx context.Context, id uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.RoleStore.DeleteRole"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeleteRole, uuidArg(id))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrRoleNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleStore) AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.RoleStore.AssignRoleToUser"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.AssignRoleToUser,
|
||||
uuidArg(userID), uuidArg(roleID), time.Now().UTC()); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleStore) RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.RoleStore.RemoveRoleFromUser"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.RemoveRoleFromUser,
|
||||
uuidArg(userID), uuidArg(roleID)); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *roleStore) GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*authkit.Role, error) {
|
||||
const op = "authkit.sqlstore.RoleStore.GetUserRoles"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.GetUserRoles, uuidArg(userID))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.Role
|
||||
for rows.Next() {
|
||||
r, err := scanRole(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, r)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
// HasAnyRole builds the IN-clause at call time because the placeholder count
|
||||
// depends on len(names). Identifier substitution comes from the validated
|
||||
// Schema; values are bound, never interpolated.
|
||||
func (s *roleStore) HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error) {
|
||||
const op = "authkit.sqlstore.RoleStore.HasAnyRole"
|
||||
if len(names) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
// Placeholder $1 is the user_id; $2..$N+1 cover the names slice.
|
||||
listSQL := s.d.PlaceholderList(2, len(names))
|
||||
q := fmt.Sprintf(`SELECT EXISTS (
|
||||
SELECT 1 FROM %s ur JOIN %s r ON r.id = ur.role_id
|
||||
WHERE ur.user_id = %s AND r.name IN (%s)
|
||||
)`, s.s.Tables.UserRoles, s.s.Tables.Roles, s.d.Placeholder(1), listSQL)
|
||||
|
||||
args := make([]any, 0, 1+len(names))
|
||||
args = append(args, uuidArg(userID))
|
||||
for _, n := range names {
|
||||
args = append(args, n)
|
||||
}
|
||||
var ok bool
|
||||
if err := s.db.QueryRowContext(ctx, q, args...).Scan(&ok); err != nil {
|
||||
return false, errx.Wrap(op, err)
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// ----- permissionStore ------------------------------------------------------
|
||||
|
||||
func (s *permissionStore) CreatePermission(ctx context.Context, p *authkit.Permission) error {
|
||||
const op = "authkit.sqlstore.PermissionStore.CreatePermission"
|
||||
if p.ID == uuid.Nil {
|
||||
p.ID = uuid.New()
|
||||
}
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if _, err := s.db.ExecContext(ctx, s.q.CreatePermission,
|
||||
uuidArg(p.ID), p.Name, p.Description, p.CreatedAt); err != nil {
|
||||
if s.d.IsUniqueViolation(err) {
|
||||
return errx.Wrapf(op, err, "permission %q already exists", p.Name)
|
||||
}
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) GetPermissionByID(ctx context.Context, id uuid.UUID) (*authkit.Permission, error) {
|
||||
const op = "authkit.sqlstore.PermissionStore.GetPermissionByID"
|
||||
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByID, uuidArg(id)))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) GetPermissionByName(ctx context.Context, name string) (*authkit.Permission, error) {
|
||||
const op = "authkit.sqlstore.PermissionStore.GetPermissionByName"
|
||||
p, err := scanPermission(s.db.QueryRowContext(ctx, s.q.GetPermissionByName, name))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrPermissionNotFound))
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) ListPermissions(ctx context.Context) ([]*authkit.Permission, error) {
|
||||
const op = "authkit.sqlstore.PermissionStore.ListPermissions"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.ListPermissions)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.Permission
|
||||
for rows.Next() {
|
||||
p, err := scanPermission(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
func (s *permissionStore) DeletePermission(ctx context.Context, id uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.PermissionStore.DeletePermission"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeletePermission, uuidArg(id))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrPermissionNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.PermissionStore.AssignPermissionToRole"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.AssignPermissionToRole,
|
||||
uuidArg(roleID), uuidArg(permID)); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.PermissionStore.RemovePermissionFromRole"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.RemovePermissionFromRole,
|
||||
uuidArg(roleID), uuidArg(permID)); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *permissionStore) GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*authkit.Permission, error) {
|
||||
const op = "authkit.sqlstore.PermissionStore.GetRolePermissions"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.GetRolePermissions, uuidArg(roleID))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.Permission
|
||||
for rows.Next() {
|
||||
p, err := scanPermission(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
func (s *permissionStore) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*authkit.Permission, error) {
|
||||
const op = "authkit.sqlstore.PermissionStore.GetUserPermissions"
|
||||
rows, err := s.db.QueryContext(ctx, s.q.GetUserPermissions, uuidArg(userID))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*authkit.Permission
|
||||
for rows.Next() {
|
||||
p, err := scanPermission(rows)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, errx.Wrap(op, rows.Err())
|
||||
}
|
||||
|
||||
func scanRole(row rowScanner) (*authkit.Role, error) {
|
||||
var (
|
||||
r authkit.Role
|
||||
idStr string
|
||||
)
|
||||
if err := row.Scan(&idStr, &r.Name, &r.Description, &r.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := scanUUID(idStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.ID = id
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func scanPermission(row rowScanner) (*authkit.Permission, error) {
|
||||
var (
|
||||
p authkit.Permission
|
||||
idStr string
|
||||
)
|
||||
if err := row.Scan(&idStr, &p.Name, &p.Description, &p.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := scanUUID(idStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ID = id
|
||||
return &p, nil
|
||||
}
|
||||
86
sqlstore/scan.go
Normal file
86
sqlstore/scan.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// rowScanner is the lowest-common-denominator interface satisfied by both
|
||||
// *sql.Row and *sql.Rows so scanXxx helpers serve QueryRow and Query loops
|
||||
// uniformly.
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// nullableTime returns nil when t is the zero time, otherwise &t. Callers
|
||||
// should usually accept *time.Time on the model side and bind via this.
|
||||
func nullableTime(t *time.Time) any {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return *t
|
||||
}
|
||||
|
||||
// nullableString turns "" into nil so columns store NULL rather than an
|
||||
// empty string. Used for password hashes and similar optional text columns.
|
||||
func nullableString(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// nullableAddrString returns the string form of a netip.Addr when valid, or
|
||||
// nil to bind as SQL NULL. Pairs with scanAddr.
|
||||
func nullableAddrString(a netip.Addr) any {
|
||||
if !a.IsValid() {
|
||||
return nil
|
||||
}
|
||||
return a.String()
|
||||
}
|
||||
|
||||
// scanAddr parses a *string column into a netip.Addr. The zero Addr value is
|
||||
// returned when the column was NULL or empty.
|
||||
func scanAddr(s *string) (netip.Addr, error) {
|
||||
if s == nil || *s == "" {
|
||||
return netip.Addr{}, nil
|
||||
}
|
||||
return netip.ParseAddr(*s)
|
||||
}
|
||||
|
||||
// uuidArg returns the canonical string form of a UUID for binding. Every
|
||||
// store uses this rather than passing uuid.UUID directly to keep behaviour
|
||||
// identical across drivers (some accept driver.Valuer, some don't).
|
||||
func uuidArg(id uuid.UUID) any { return id.String() }
|
||||
|
||||
// scanUUID reads a string column and parses it back to a uuid.UUID.
|
||||
func scanUUID(s string) (uuid.UUID, error) { return uuid.Parse(s) }
|
||||
|
||||
// chainArg returns either a *string or nil for binding the chain_id column.
|
||||
func chainArg(c *string) any {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return *c
|
||||
}
|
||||
|
||||
// scanNullStringPtr converts sql.NullString to *string for the model.
|
||||
func scanNullStringPtr(ns sql.NullString) *string {
|
||||
if !ns.Valid {
|
||||
return nil
|
||||
}
|
||||
v := ns.String
|
||||
return &v
|
||||
}
|
||||
|
||||
// scanNullTimePtr converts sql.NullTime to *time.Time for the model.
|
||||
func scanNullTimePtr(nt sql.NullTime) *time.Time {
|
||||
if !nt.Valid {
|
||||
return nil
|
||||
}
|
||||
t := nt.Time
|
||||
return &t
|
||||
}
|
||||
78
sqlstore/schema.go
Normal file
78
sqlstore/schema.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// Schema lets consumers map authkit storage to their own table names.
|
||||
// Column overrides are intentionally not present in v1 — adding them later
|
||||
// is purely additive.
|
||||
type Schema struct {
|
||||
Tables Tables
|
||||
}
|
||||
|
||||
// Tables is the per-table identifier override set. Every field must be a
|
||||
// valid unquoted SQL identifier (matching identifierRE). Validation runs at
|
||||
// New() and Migrate() time so SQL injection through Schema is impossible
|
||||
// past that gate.
|
||||
type Tables struct {
|
||||
Users string
|
||||
Sessions string
|
||||
Tokens string
|
||||
APIKeys string
|
||||
Roles string
|
||||
Permissions string
|
||||
UserRoles string
|
||||
RolePermissions string
|
||||
SchemaMigrations string
|
||||
}
|
||||
|
||||
// DefaultSchema returns the stock authkit_* names used by the embedded
|
||||
// migration files.
|
||||
func DefaultSchema() Schema {
|
||||
return Schema{Tables: Tables{
|
||||
Users: "authkit_users",
|
||||
Sessions: "authkit_sessions",
|
||||
Tokens: "authkit_tokens",
|
||||
APIKeys: "authkit_api_keys",
|
||||
Roles: "authkit_roles",
|
||||
Permissions: "authkit_permissions",
|
||||
UserRoles: "authkit_user_roles",
|
||||
RolePermissions: "authkit_role_permissions",
|
||||
SchemaMigrations: "authkit_schema_migrations",
|
||||
}}
|
||||
}
|
||||
|
||||
// identifierRE matches the safe ASCII identifier subset shared by Postgres,
|
||||
// MySQL and SQLite when not quoted. Anything outside this set is rejected
|
||||
// rather than escaped — Schema is not the place to support exotic names.
|
||||
var identifierRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// Validate ensures every Schema.Tables field is a non-empty, safe identifier.
|
||||
func (s Schema) Validate() error {
|
||||
const op = "authkit.sqlstore.Schema.Validate"
|
||||
checks := []struct {
|
||||
field, value string
|
||||
}{
|
||||
{"Users", s.Tables.Users},
|
||||
{"Sessions", s.Tables.Sessions},
|
||||
{"Tokens", s.Tables.Tokens},
|
||||
{"APIKeys", s.Tables.APIKeys},
|
||||
{"Roles", s.Tables.Roles},
|
||||
{"Permissions", s.Tables.Permissions},
|
||||
{"UserRoles", s.Tables.UserRoles},
|
||||
{"RolePermissions", s.Tables.RolePermissions},
|
||||
{"SchemaMigrations", s.Tables.SchemaMigrations},
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.value == "" {
|
||||
return errx.Newf(op, "Schema.Tables.%s is empty", c.field)
|
||||
}
|
||||
if !identifierRE.MatchString(c.value) {
|
||||
return errx.Newf(op, "Schema.Tables.%s = %q is not a valid identifier", c.field, c.value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
98
sqlstore/sessions.go
Normal file
98
sqlstore/sessions.go
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type sessionStore struct{ storeBase }
|
||||
|
||||
func (s *sessionStore) CreateSession(ctx context.Context, ses *authkit.Session) error {
|
||||
const op = "authkit.sqlstore.SessionStore.CreateSession"
|
||||
now := time.Now().UTC()
|
||||
if ses.CreatedAt.IsZero() {
|
||||
ses.CreatedAt = now
|
||||
}
|
||||
if ses.LastSeenAt.IsZero() {
|
||||
ses.LastSeenAt = ses.CreatedAt
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, s.q.CreateSession,
|
||||
ses.IDHash, uuidArg(ses.UserID), ses.UserAgent, nullableAddrString(ses.IP),
|
||||
ses.CreatedAt, ses.LastSeenAt, ses.ExpiresAt)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sessionStore) GetSession(ctx context.Context, idHash []byte) (*authkit.Session, error) {
|
||||
const op = "authkit.sqlstore.SessionStore.GetSession"
|
||||
var (
|
||||
ses authkit.Session
|
||||
uidStr string
|
||||
ipStr sql.NullString
|
||||
)
|
||||
err := s.db.QueryRowContext(ctx, s.q.GetSession, idHash).Scan(
|
||||
&ses.IDHash, &uidStr, &ses.UserAgent, &ipStr,
|
||||
&ses.CreatedAt, &ses.LastSeenAt, &ses.ExpiresAt)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrSessionInvalid))
|
||||
}
|
||||
uid, err := scanUUID(uidStr)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
ses.UserID = uid
|
||||
if ipStr.Valid {
|
||||
addr, err := scanAddr(&ipStr.String)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
ses.IP = addr
|
||||
}
|
||||
return &ses, nil
|
||||
}
|
||||
|
||||
func (s *sessionStore) TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error {
|
||||
const op = "authkit.sqlstore.SessionStore.TouchSession"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.TouchSession, lastSeenAt, newExpiresAt, idHash)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrSessionInvalid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sessionStore) DeleteSession(ctx context.Context, idHash []byte) error {
|
||||
const op = "authkit.sqlstore.SessionStore.DeleteSession"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.DeleteSession, idHash); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sessionStore) DeleteUserSessions(ctx context.Context, userID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.SessionStore.DeleteUserSessions"
|
||||
if _, err := s.db.ExecContext(ctx, s.q.DeleteUserSessions, uuidArg(userID)); err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sessionStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
|
||||
const op = "authkit.sqlstore.SessionStore.DeleteExpired"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredSessions, now)
|
||||
if err != nil {
|
||||
return 0, errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
59
sqlstore/sqlstore.go
Normal file
59
sqlstore/sqlstore.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
// Package sqlstore provides database/sql-backed implementations of every
|
||||
// authkit store interface. It works with any driver (pgx-stdlib, lib/pq,
|
||||
// sqlx wrapping *sql.DB, ...) and any consumer-supplied table naming via
|
||||
// Schema. Driver-specific behaviour lives in a Dialect; v1 ships
|
||||
// dialect/postgres.
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
// Stores bundles every store implementation produced by New, ready to drop
|
||||
// into authkit.Deps.
|
||||
type Stores struct {
|
||||
Users authkit.UserStore
|
||||
Sessions authkit.SessionStore
|
||||
Tokens authkit.TokenStore
|
||||
APIKeys authkit.APIKeyStore
|
||||
Roles authkit.RoleStore
|
||||
Permissions authkit.PermissionStore
|
||||
}
|
||||
|
||||
// New validates the Schema, builds the dialect's templated queries, and
|
||||
// returns store implementations bound to db.
|
||||
func New(db *sql.DB, dialect Dialect, schema Schema) (*Stores, error) {
|
||||
const op = "authkit.sqlstore.New"
|
||||
if db == nil {
|
||||
return nil, errx.New(op, "db is required")
|
||||
}
|
||||
if dialect == nil {
|
||||
return nil, errx.New(op, "dialect is required")
|
||||
}
|
||||
if err := schema.Validate(); err != nil {
|
||||
return nil, errx.Wrap(op, err)
|
||||
}
|
||||
q := dialect.BuildQueries(schema)
|
||||
|
||||
base := storeBase{db: db, q: q, d: dialect, s: schema}
|
||||
return &Stores{
|
||||
Users: &userStore{storeBase: base},
|
||||
Sessions: &sessionStore{storeBase: base},
|
||||
Tokens: &tokenStore{storeBase: base},
|
||||
APIKeys: &apiKeyStore{storeBase: base},
|
||||
Roles: &roleStore{storeBase: base},
|
||||
Permissions: &permissionStore{storeBase: base},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// storeBase carries the shared dependencies every store needs. Embedded into
|
||||
// each concrete store struct.
|
||||
type storeBase struct {
|
||||
db *sql.DB
|
||||
q Queries
|
||||
d Dialect
|
||||
s Schema
|
||||
}
|
||||
258
sqlstore/sqlstore_test.go
Normal file
258
sqlstore/sqlstore_test.go
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
package sqlstore_test
|
||||
|
||||
// Integration tests against a real Postgres. Skipped unless
|
||||
// AUTHKIT_TEST_DATABASE_URL is set. Each test acquires a fresh schema by
|
||||
// running Migrate against a randomly-named set of tables — no cleanup
|
||||
// fixture, no external Docker dependency.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/authkit/hasher"
|
||||
"git.juancwu.dev/juancwu/authkit/sqlstore"
|
||||
pgdialect "git.juancwu.dev/juancwu/authkit/sqlstore/dialect/postgres"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
func envURL(t *testing.T) string {
|
||||
t.Helper()
|
||||
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
|
||||
if url == "" {
|
||||
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
// freshDB opens a connection, runs Migrate against the default schema, and
|
||||
// schedules a teardown that drops every authkit_* table. Tests must run
|
||||
// sequentially against a single database (the package's tests do, by
|
||||
// default — go test serialises within a package unless t.Parallel is
|
||||
// called).
|
||||
func freshDB(t *testing.T) (*authkit.Auth, *sql.DB, sqlstore.Schema) {
|
||||
t.Helper()
|
||||
url := envURL(t)
|
||||
db, err := sql.Open("pgx", url)
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
if err := db.PingContext(context.Background()); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
|
||||
schema := sqlstore.DefaultSchema()
|
||||
// Drop first to start clean — previous failed tests may have left rows.
|
||||
dropAuthkitTables(t, db, schema)
|
||||
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
|
||||
|
||||
stores, err := sqlstore.New(db, pgdialect.New(), schema)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlstore.New: %v", err)
|
||||
}
|
||||
auth := authkit.New(authkit.Deps{
|
||||
Users: stores.Users,
|
||||
Sessions: stores.Sessions,
|
||||
Tokens: stores.Tokens,
|
||||
APIKeys: stores.APIKeys,
|
||||
Roles: stores.Roles,
|
||||
Permissions: stores.Permissions,
|
||||
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
|
||||
}, authkit.Config{
|
||||
JWTSecret: []byte("integration-secret-thirty-two!!!!"),
|
||||
JWTIssuer: "authkit-int",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: 1 * time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
PasswordResetTTL: time.Hour,
|
||||
MagicLinkTTL: time.Minute,
|
||||
})
|
||||
return auth, db, schema
|
||||
}
|
||||
|
||||
func dropAuthkitTables(t *testing.T, db *sql.DB, s sqlstore.Schema) {
|
||||
t.Helper()
|
||||
tables := []string{
|
||||
s.Tables.UserRoles, s.Tables.RolePermissions,
|
||||
s.Tables.Roles, s.Tables.Permissions,
|
||||
s.Tables.APIKeys, s.Tables.Tokens,
|
||||
s.Tables.Sessions, s.Tables.Users,
|
||||
s.Tables.SchemaMigrations,
|
||||
}
|
||||
for _, name := range tables {
|
||||
_, _ = db.ExecContext(context.Background(),
|
||||
fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_MigrateIdempotent(t *testing.T) {
|
||||
url := envURL(t)
|
||||
db, err := sql.Open("pgx", url)
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
schema := sqlstore.DefaultSchema()
|
||||
t.Cleanup(func() { dropAuthkitTables(t, db, schema) })
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := sqlstore.Migrate(context.Background(), db, pgdialect.New(), schema); err != nil {
|
||||
t.Fatalf("Migrate iter %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RegisterAndLogin(t *testing.T) {
|
||||
auth, _, _ := freshDB(t)
|
||||
ctx := context.Background()
|
||||
u, err := auth.Register(ctx, "alice@example.com", "hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
got, err := auth.LoginPassword(ctx, "Alice@Example.com", "hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("LoginPassword: %v", err)
|
||||
}
|
||||
if got.ID != u.ID {
|
||||
t.Fatalf("login user id mismatch")
|
||||
}
|
||||
|
||||
if _, err := auth.Register(ctx, "alice@example.com", "x"); !errors.Is(err, authkit.ErrEmailTaken) {
|
||||
t.Fatalf("expected ErrEmailTaken, got %v", err)
|
||||
}
|
||||
if _, err := auth.LoginPassword(ctx, "alice@example.com", "wrong"); !errors.Is(err, authkit.ErrInvalidCredentials) {
|
||||
t.Fatalf("expected ErrInvalidCredentials, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_SessionLifecycle(t *testing.T) {
|
||||
auth, _, _ := freshDB(t)
|
||||
ctx := context.Background()
|
||||
u, err := auth.Register(ctx, "s@s.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
plain, sess, err := auth.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
}
|
||||
if sess.ExpiresAt.Before(time.Now()) {
|
||||
t.Fatalf("session already expired at issue")
|
||||
}
|
||||
if _, err := auth.AuthenticateSession(ctx, plain); err != nil {
|
||||
t.Fatalf("AuthenticateSession: %v", err)
|
||||
}
|
||||
if err := auth.RevokeSession(ctx, plain); err != nil {
|
||||
t.Fatalf("RevokeSession: %v", err)
|
||||
}
|
||||
if _, err := auth.AuthenticateSession(ctx, plain); !errors.Is(err, authkit.ErrSessionInvalid) {
|
||||
t.Fatalf("expected ErrSessionInvalid post-revoke, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_JWTRefreshRotationAndReuse(t *testing.T) {
|
||||
auth, _, _ := freshDB(t)
|
||||
ctx := context.Background()
|
||||
u, err := auth.Register(ctx, "j@j.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
_, refresh1, err := auth.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
_, refresh2, err := auth.RefreshJWT(ctx, refresh1)
|
||||
if err != nil {
|
||||
t.Fatalf("first RefreshJWT: %v", err)
|
||||
}
|
||||
if refresh1 == refresh2 {
|
||||
t.Fatalf("refresh token did not rotate")
|
||||
}
|
||||
|
||||
if _, _, err := auth.RefreshJWT(ctx, refresh1); !errors.Is(err, authkit.ErrTokenReused) {
|
||||
t.Fatalf("expected ErrTokenReused on replay, got %v", err)
|
||||
}
|
||||
if _, _, err := auth.RefreshJWT(ctx, refresh2); !errors.Is(err, authkit.ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid after chain revocation, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_APIKeyWithAbilities(t *testing.T) {
|
||||
auth, _, _ := freshDB(t)
|
||||
ctx := context.Background()
|
||||
u, err := auth.Register(ctx, "k@k.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
plain, _, err := auth.IssueAPIKey(ctx, u.ID, "ci",
|
||||
[]string{"billing:read", "users:list"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueAPIKey: %v", err)
|
||||
}
|
||||
p, err := auth.AuthenticateAPIKey(ctx, plain)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthenticateAPIKey: %v", err)
|
||||
}
|
||||
if !p.HasAbility("billing:read") || !p.HasAbility("users:list") {
|
||||
t.Fatalf("abilities missing: %+v", p.Abilities)
|
||||
}
|
||||
if err := auth.RevokeAPIKey(ctx, plain); err != nil {
|
||||
t.Fatalf("RevokeAPIKey: %v", err)
|
||||
}
|
||||
if _, err := auth.AuthenticateAPIKey(ctx, plain); !errors.Is(err, authkit.ErrAPIKeyInvalid) {
|
||||
t.Fatalf("expected ErrAPIKeyInvalid post-revoke, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RBAC(t *testing.T) {
|
||||
auth, db, schema := freshDB(t)
|
||||
ctx := context.Background()
|
||||
u, err := auth.Register(ctx, "rb@b.com", "pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
stores, _ := sqlstore.New(db, pgdialect.New(), schema)
|
||||
r := &authkit.Role{Name: "editor"}
|
||||
if err := stores.Roles.CreateRole(ctx, r); err != nil {
|
||||
t.Fatalf("CreateRole: %v", err)
|
||||
}
|
||||
p := &authkit.Permission{Name: "posts:write"}
|
||||
if err := stores.Permissions.CreatePermission(ctx, p); err != nil {
|
||||
t.Fatalf("CreatePermission: %v", err)
|
||||
}
|
||||
if err := stores.Permissions.AssignPermissionToRole(ctx, r.ID, p.ID); err != nil {
|
||||
t.Fatalf("AssignPermissionToRole: %v", err)
|
||||
}
|
||||
if err := auth.AssignRole(ctx, u.ID, "editor"); err != nil {
|
||||
t.Fatalf("AssignRole: %v", err)
|
||||
}
|
||||
ok, err := auth.HasPermission(ctx, u.ID, "posts:write")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("HasPermission: %v %v", ok, err)
|
||||
}
|
||||
ok, err = auth.HasAnyRole(ctx, u.ID, []string{"editor", "admin"})
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("HasAnyRole: %v %v", ok, err)
|
||||
}
|
||||
if err := auth.RemoveRole(ctx, u.ID, "editor"); err != nil {
|
||||
t.Fatalf("RemoveRole: %v", err)
|
||||
}
|
||||
ok, _ = auth.HasPermission(ctx, u.ID, "posts:write")
|
||||
if ok {
|
||||
t.Fatalf("HasPermission should be false after RemoveRole")
|
||||
}
|
||||
}
|
||||
92
sqlstore/tokens.go
Normal file
92
sqlstore/tokens.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
type tokenStore struct{ storeBase }
|
||||
|
||||
func (s *tokenStore) CreateToken(ctx context.Context, t *authkit.Token) error {
|
||||
const op = "authkit.sqlstore.TokenStore.CreateToken"
|
||||
if t.CreatedAt.IsZero() {
|
||||
t.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, s.q.CreateToken,
|
||||
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
|
||||
nullableTime(t.ConsumedAt), t.CreatedAt, t.ExpiresAt)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumeToken does the find-and-mark-consumed in a single statement so two
|
||||
// concurrent callers cannot both successfully consume the same token. The
|
||||
// row is returned for inspection (e.g. ChainID for refresh rotation).
|
||||
func (s *tokenStore) ConsumeToken(ctx context.Context, kind authkit.TokenKind, hash []byte, now time.Time) (*authkit.Token, error) {
|
||||
const op = "authkit.sqlstore.TokenStore.ConsumeToken"
|
||||
row := s.db.QueryRowContext(ctx, s.q.ConsumeToken, now, string(kind), hash, now)
|
||||
t, err := scanToken(row)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *tokenStore) GetToken(ctx context.Context, kind authkit.TokenKind, hash []byte) (*authkit.Token, error) {
|
||||
const op = "authkit.sqlstore.TokenStore.GetToken"
|
||||
row := s.db.QueryRowContext(ctx, s.q.GetToken, string(kind), hash)
|
||||
t, err := scanToken(row)
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrTokenInvalid))
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *tokenStore) DeleteByChain(ctx context.Context, chainID string) (int64, error) {
|
||||
const op = "authkit.sqlstore.TokenStore.DeleteByChain"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeleteByChain, chainID)
|
||||
if err != nil {
|
||||
return 0, errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *tokenStore) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
|
||||
const op = "authkit.sqlstore.TokenStore.DeleteExpired"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeleteExpiredTokens, now)
|
||||
if err != nil {
|
||||
return 0, errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func scanToken(row rowScanner) (*authkit.Token, error) {
|
||||
var (
|
||||
t authkit.Token
|
||||
kind string
|
||||
userIDStr string
|
||||
chainID sql.NullString
|
||||
consumedAt sql.NullTime
|
||||
)
|
||||
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
|
||||
&consumedAt, &t.CreatedAt, &t.ExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Kind = authkit.TokenKind(kind)
|
||||
uid, err := scanUUID(userIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.UserID = uid
|
||||
t.ChainID = scanNullStringPtr(chainID)
|
||||
t.ConsumedAt = scanNullTimePtr(consumedAt)
|
||||
return &t, nil
|
||||
}
|
||||
186
sqlstore/users.go
Normal file
186
sqlstore/users.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type userStore struct{ storeBase }
|
||||
|
||||
func (s *userStore) CreateUser(ctx context.Context, u *authkit.User) error {
|
||||
const op = "authkit.sqlstore.UserStore.CreateUser"
|
||||
if u.ID == uuid.Nil {
|
||||
u.ID = uuid.New()
|
||||
}
|
||||
if u.EmailNormalized == "" {
|
||||
u.EmailNormalized = strings.ToLower(strings.TrimSpace(u.Email))
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = now
|
||||
}
|
||||
if u.UpdatedAt.IsZero() {
|
||||
u.UpdatedAt = now
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, s.q.CreateUser,
|
||||
uuidArg(u.ID), u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
|
||||
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
|
||||
nullableTime(u.LastLoginAt), u.CreatedAt, u.UpdatedAt)
|
||||
if err != nil {
|
||||
if s.d.IsUniqueViolation(err) {
|
||||
return errx.Wrap(op, authkit.ErrEmailTaken)
|
||||
}
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userStore) GetUserByID(ctx context.Context, id uuid.UUID) (*authkit.User, error) {
|
||||
const op = "authkit.sqlstore.UserStore.GetUserByID"
|
||||
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByID, uuidArg(id)))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *userStore) GetUserByEmail(ctx context.Context, normalizedEmail string) (*authkit.User, error) {
|
||||
const op = "authkit.sqlstore.UserStore.GetUserByEmail"
|
||||
u, err := scanUser(s.db.QueryRowContext(ctx, s.q.GetUserByEmail, normalizedEmail))
|
||||
if err != nil {
|
||||
return nil, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *userStore) UpdateUser(ctx context.Context, u *authkit.User) error {
|
||||
const op = "authkit.sqlstore.UserStore.UpdateUser"
|
||||
u.UpdatedAt = time.Now().UTC()
|
||||
tag, err := s.db.ExecContext(ctx, s.q.UpdateUser,
|
||||
u.Email, u.EmailNormalized, nullableTime(u.EmailVerifiedAt),
|
||||
nullableString(u.PasswordHash), u.SessionVersion, u.FailedLogins,
|
||||
nullableTime(u.LastLoginAt), u.UpdatedAt, uuidArg(u.ID))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userStore) DeleteUser(ctx context.Context, id uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.UserStore.DeleteUser"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.DeleteUser, uuidArg(id))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userStore) SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error {
|
||||
const op = "authkit.sqlstore.UserStore.SetPassword"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.SetPassword,
|
||||
nullableString(encodedHash), time.Now().UTC(), uuidArg(userID))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userStore) SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error {
|
||||
const op = "authkit.sqlstore.UserStore.SetEmailVerified"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.SetEmailVerified, at, time.Now().UTC(), uuidArg(userID))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userStore) BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||
const op = "authkit.sqlstore.UserStore.BumpSessionVersion"
|
||||
var v int
|
||||
if err := s.db.QueryRowContext(ctx, s.q.BumpSessionVersion,
|
||||
time.Now().UTC(), uuidArg(userID)).Scan(&v); err != nil {
|
||||
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *userStore) IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||
const op = "authkit.sqlstore.UserStore.IncrementFailedLogins"
|
||||
var n int
|
||||
if err := s.db.QueryRowContext(ctx, s.q.IncrementFailedLogins,
|
||||
time.Now().UTC(), uuidArg(userID)).Scan(&n); err != nil {
|
||||
return 0, errx.Wrap(op, mapNotFound(err, authkit.ErrUserNotFound))
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *userStore) ResetFailedLogins(ctx context.Context, userID uuid.UUID) error {
|
||||
const op = "authkit.sqlstore.UserStore.ResetFailedLogins"
|
||||
tag, err := s.db.ExecContext(ctx, s.q.ResetFailedLogins, time.Now().UTC(), uuidArg(userID))
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
n, _ := tag.RowsAffected()
|
||||
if n == 0 {
|
||||
return errx.Wrap(op, authkit.ErrUserNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanUser(row rowScanner) (*authkit.User, error) {
|
||||
var (
|
||||
u authkit.User
|
||||
idStr string
|
||||
passwordHash sql.NullString
|
||||
emailVerified sql.NullTime
|
||||
lastLogin sql.NullTime
|
||||
)
|
||||
if err := row.Scan(&idStr, &u.Email, &u.EmailNormalized, &emailVerified,
|
||||
&passwordHash, &u.SessionVersion, &u.FailedLogins, &lastLogin,
|
||||
&u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := scanUUID(idStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.ID = id
|
||||
if passwordHash.Valid {
|
||||
u.PasswordHash = passwordHash.String
|
||||
}
|
||||
u.EmailVerifiedAt = scanNullTimePtr(emailVerified)
|
||||
u.LastLoginAt = scanNullTimePtr(lastLogin)
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// mapNotFound translates sql.ErrNoRows into the supplied authkit sentinel so
|
||||
// callers get reliable errors.Is targets through errx wrapping.
|
||||
func mapNotFound(err error, sentinel error) error {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return sentinel
|
||||
}
|
||||
return err
|
||||
}
|
||||
83
stores.go
Normal file
83
stores.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type UserStore interface {
|
||||
CreateUser(ctx context.Context, u *User) error
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (*User, error)
|
||||
GetUserByEmail(ctx context.Context, normalizedEmail string) (*User, error)
|
||||
UpdateUser(ctx context.Context, u *User) error
|
||||
DeleteUser(ctx context.Context, id uuid.UUID) error
|
||||
SetPassword(ctx context.Context, userID uuid.UUID, encodedHash string) error
|
||||
SetEmailVerified(ctx context.Context, userID uuid.UUID, at time.Time) error
|
||||
BumpSessionVersion(ctx context.Context, userID uuid.UUID) (int, error)
|
||||
IncrementFailedLogins(ctx context.Context, userID uuid.UUID) (int, error)
|
||||
ResetFailedLogins(ctx context.Context, userID uuid.UUID) error
|
||||
}
|
||||
|
||||
type SessionStore interface {
|
||||
CreateSession(ctx context.Context, s *Session) error
|
||||
GetSession(ctx context.Context, idHash []byte) (*Session, error)
|
||||
TouchSession(ctx context.Context, idHash []byte, lastSeenAt, newExpiresAt time.Time) error
|
||||
DeleteSession(ctx context.Context, idHash []byte) error
|
||||
DeleteUserSessions(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
|
||||
}
|
||||
|
||||
type TokenStore interface {
|
||||
CreateToken(ctx context.Context, t *Token) error
|
||||
// ConsumeToken atomically marks the matching unexpired, unconsumed token
|
||||
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
|
||||
// Implementations MUST do this in one statement (UPDATE ... RETURNING)
|
||||
// to prevent double-spend under concurrent callers.
|
||||
ConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error)
|
||||
// GetToken returns a token without consuming it. Used for refresh-token
|
||||
// reuse detection: a token that exists with consumed_at != nil is a
|
||||
// replay signal.
|
||||
GetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error)
|
||||
DeleteByChain(ctx context.Context, chainID string) (int64, error)
|
||||
DeleteExpired(ctx context.Context, now time.Time) (int64, error)
|
||||
}
|
||||
|
||||
type APIKeyStore interface {
|
||||
CreateAPIKey(ctx context.Context, k *APIKey) error
|
||||
GetAPIKey(ctx context.Context, idHash []byte) (*APIKey, error)
|
||||
ListAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID) ([]*APIKey, error)
|
||||
TouchAPIKey(ctx context.Context, idHash []byte, at time.Time) error
|
||||
RevokeAPIKey(ctx context.Context, idHash []byte, at time.Time) error
|
||||
RevokeAPIKeysByOwner(ctx context.Context, ownerID uuid.UUID, at time.Time) error
|
||||
}
|
||||
|
||||
type RoleStore interface {
|
||||
CreateRole(ctx context.Context, r *Role) error
|
||||
GetRoleByID(ctx context.Context, id uuid.UUID) (*Role, error)
|
||||
GetRoleByName(ctx context.Context, name string) (*Role, error)
|
||||
ListRoles(ctx context.Context) ([]*Role, error)
|
||||
DeleteRole(ctx context.Context, id uuid.UUID) error
|
||||
AssignRoleToUser(ctx context.Context, userID, roleID uuid.UUID) error
|
||||
RemoveRoleFromUser(ctx context.Context, userID, roleID uuid.UUID) error
|
||||
GetUserRoles(ctx context.Context, userID uuid.UUID) ([]*Role, error)
|
||||
HasAnyRole(ctx context.Context, userID uuid.UUID, names []string) (bool, error)
|
||||
}
|
||||
|
||||
type PermissionStore interface {
|
||||
CreatePermission(ctx context.Context, p *Permission) error
|
||||
GetPermissionByID(ctx context.Context, id uuid.UUID) (*Permission, error)
|
||||
GetPermissionByName(ctx context.Context, name string) (*Permission, error)
|
||||
ListPermissions(ctx context.Context) ([]*Permission, error)
|
||||
DeletePermission(ctx context.Context, id uuid.UUID) error
|
||||
AssignPermissionToRole(ctx context.Context, roleID, permID uuid.UUID) error
|
||||
RemovePermissionFromRole(ctx context.Context, roleID, permID uuid.UUID) error
|
||||
GetRolePermissions(ctx context.Context, roleID uuid.UUID) ([]*Permission, error)
|
||||
GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]*Permission, error)
|
||||
}
|
||||
|
||||
type Hasher interface {
|
||||
Hash(password string) (string, error)
|
||||
Verify(password, encoded string) (ok bool, needsRehash bool, err error)
|
||||
}
|
||||
67
tokens.go
Normal file
67
tokens.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
const secretRandomBytes = 32
|
||||
|
||||
// Secret prefixes. The leading token namespaces the secret so callers can
|
||||
// route them to the right verifier without trial-and-error.
|
||||
const (
|
||||
prefixSession = "sess"
|
||||
prefixRefresh = "rfr"
|
||||
prefixAPIKey = "ak"
|
||||
prefixEmailVerify = "evr"
|
||||
prefixPasswordRset = "pwr"
|
||||
prefixMagicLink = "mlnk"
|
||||
)
|
||||
|
||||
// mintSecret generates a new opaque secret of the given prefix kind and
|
||||
// returns the plaintext (to be returned to the user, never stored) and the
|
||||
// SHA-256 lookup hash (to be stored).
|
||||
func mintSecret(prefix string, rng io.Reader) (plaintext string, hash []byte, err error) {
|
||||
const op = "authkit.mintSecret"
|
||||
if rng == nil {
|
||||
rng = rand.Reader
|
||||
}
|
||||
buf := make([]byte, secretRandomBytes)
|
||||
if _, err := io.ReadFull(rng, buf); err != nil {
|
||||
return "", nil, errx.Wrap(op, err)
|
||||
}
|
||||
body := base64.RawURLEncoding.EncodeToString(buf)
|
||||
plaintext = prefix + "_" + body
|
||||
hash = hashSecret(plaintext)
|
||||
return plaintext, hash, nil
|
||||
}
|
||||
|
||||
// hashSecret returns sha256(plaintext) — the lookup key for opaque secrets.
|
||||
// Plaintexts have full entropy from a CSPRNG so a plain hash is sufficient
|
||||
// (no per-record salt needed; the random body is the salt).
|
||||
func hashSecret(plaintext string) []byte {
|
||||
sum := sha256.Sum256([]byte(plaintext))
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
// parseSecret validates that a plaintext starts with the expected prefix
|
||||
// and returns the lookup hash. The prefix check is constant-time relative
|
||||
// to a fixed-length comparison.
|
||||
func parseSecret(prefix, plaintext string) (hash []byte, ok bool) {
|
||||
want := prefix + "_"
|
||||
if !strings.HasPrefix(plaintext, want) {
|
||||
return nil, false
|
||||
}
|
||||
return hashSecret(plaintext), true
|
||||
}
|
||||
|
||||
// constantTimeEqual is a thin wrapper for readability at call sites.
|
||||
func constantTimeEqual(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
56
tokens_test.go
Normal file
56
tokens_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package authkit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMintSecretRoundtrip(t *testing.T) {
|
||||
plaintext, hash, err := mintSecret(prefixSession, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("mintSecret: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(plaintext, prefixSession+"_") {
|
||||
t.Fatalf("missing prefix: %q", plaintext)
|
||||
}
|
||||
parsed, ok := parseSecret(prefixSession, plaintext)
|
||||
if !ok {
|
||||
t.Fatalf("parseSecret rejected our own mint")
|
||||
}
|
||||
if !bytes.Equal(hash, parsed) {
|
||||
t.Fatalf("hash mismatch")
|
||||
}
|
||||
want := sha256.Sum256([]byte(plaintext))
|
||||
if !bytes.Equal(hash, want[:]) {
|
||||
t.Fatalf("hashSecret != sha256(plaintext)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSecretWrongPrefix(t *testing.T) {
|
||||
plaintext, _, err := mintSecret(prefixSession, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("mintSecret: %v", err)
|
||||
}
|
||||
if _, ok := parseSecret(prefixAPIKey, plaintext); ok {
|
||||
t.Fatalf("parseSecret should reject mismatched prefix")
|
||||
}
|
||||
if _, ok := parseSecret(prefixSession, "sessXXXX"); ok {
|
||||
t.Fatalf("parseSecret should require trailing underscore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMintSecretUniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
p, _, err := mintSecret(prefixAPIKey, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("mintSecret: %v", err)
|
||||
}
|
||||
if _, dup := seen[p]; dup {
|
||||
t.Fatalf("duplicate mint: %s", p)
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue