authkit/middleware/middleware_test.go
juancwu ca5525d4bd Cap refresh chain lifetime via RefreshChainAbsoluteTTL
Sessions had an absolute cap (created_at + SessionAbsoluteTTL) but the
JWT path only had per-token TTL on the refresh row, letting a
well-behaved client refresh indefinitely. Add chain_started_at to
authkit_tokens, copy it forward on every rotation, and reject in
RefreshJWT when now > chainStartedAt + RefreshChainAbsoluteTTL.
Default 30d, mirroring SessionAbsoluteTTL.

Schema, verifier, queries, model, and integration test updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-26 23:41:02 +00:00

359 lines
10 KiB
Go

package middleware_test
// Integration tests for the middleware package. Skipped when
// AUTHKIT_TEST_DATABASE_URL is not set.
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"os"
"testing"
"time"
"git.juancwu.dev/juancwu/authkit"
"git.juancwu.dev/juancwu/authkit/hasher"
"git.juancwu.dev/juancwu/authkit/middleware"
_ "github.com/jackc/pgx/v5/stdlib"
)
func freshAuth(t *testing.T) *authkit.Auth {
t.Helper()
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
if url == "" {
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
}
db, err := sql.Open("pgx", url)
if err != nil {
t.Fatalf("sql.Open: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("ping: %v", err)
}
dropAuthkitTables(t, db)
t.Cleanup(func() { dropAuthkitTables(t, db) })
a, err := authkit.New(context.Background(), authkit.Deps{
DB: db,
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
}, authkit.Config{
JWTSecret: []byte("integration-secret-thirty-two!!!"),
JWTIssuer: "authkit-mw-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: time.Hour,
RefreshChainAbsoluteTTL: 24 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,
PasswordResetTTL: time.Hour,
MagicLinkTTL: time.Minute,
EmailOTPTTL: time.Minute,
EmailOTPMaxAttempts: 3,
// Plain HTTP for tests so secure-cookie defaults don't interfere
// with httptest's HTTP server.
SessionCookieSecure: authkit.BoolPtr(false),
})
if err != nil {
t.Fatalf("authkit.New: %v", err)
}
return a
}
func dropAuthkitTables(t *testing.T, db *sql.DB) {
t.Helper()
tables := []string{
"authkit_service_key_abilities",
"authkit_user_permissions",
"authkit_user_roles",
"authkit_role_permissions",
"authkit_service_keys",
"authkit_abilities",
"authkit_roles",
"authkit_permissions",
"authkit_tokens",
"authkit_sessions",
"authkit_users",
"authkit_schema_migrations",
}
ctx := context.Background()
for _, name := range tables {
_, _ = db.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
}
}
// reqWithBearer issues a request carrying Authorization: Bearer <token>.
func reqWithBearer(token string) *http.Request {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if token != "" {
r.Header.Set("Authorization", "Bearer "+token)
}
return r
}
func ok200(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }
// ─── RequireLogin ──────────────────────────────────────────────────────────
func TestRequireLogin_AcceptsSessionCookie(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "alice@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
plain, _, err := a.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
if err != nil {
t.Fatalf("IssueSession: %v", err)
}
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uid, ok := authkit.UserIDFromCtx(r.Context())
if !ok || uid != u.ID {
t.Fatalf("user_id missing or wrong on context: ok=%v id=%v", ok, uid)
}
w.WriteHeader(http.StatusOK)
}))
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.AddCookie(a.SessionCookie(plain, time.Now().Add(time.Hour)))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireLogin_AcceptsJWT(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "j@j.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireLogin_RejectsUnauthenticated(t *testing.T) {
a := freshAuth(t)
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rr.Code)
}
}
func TestRequireLogin_AuthzRoleGate(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateRole(ctx, "admin", ""); err != nil {
t.Fatalf("CreateRole: %v", err)
}
u, err := a.CreateUser(ctx, "noadmin@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireLogin(middleware.LoginOptions{
Auth: a,
Authz: authkit.HasRole("admin"),
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusForbidden {
t.Fatalf("non-admin should get 403, got %d", rr.Code)
}
// Promote the user to admin and retry.
if err := a.AssignRole(ctx, u.ID, "admin"); err != nil {
t.Fatalf("AssignRole: %v", err)
}
access2, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access2))
if rr.Code != http.StatusOK {
t.Fatalf("admin should get 200, got %d", rr.Code)
}
}
func TestRequireLogin_PanicsOnUnknownSlug(t *testing.T) {
a := freshAuth(t)
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic on unknown role slug")
}
}()
middleware.RequireLogin(middleware.LoginOptions{
Auth: a,
Authz: authkit.HasRole("never-registered"),
})
}
// ─── RequireGuest ──────────────────────────────────────────────────────────
func TestRequireGuest_LetsUnauthenticatedThrough(t *testing.T) {
a := freshAuth(t)
called := false
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if !called {
t.Fatalf("guest middleware should pass through unauthenticated request")
}
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireGuest_BlocksAuthenticated(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "g@g.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handlerCalled := false
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rr.Code)
}
if handlerCalled {
t.Fatalf("handler should not run for authenticated request")
}
}
func TestRequireGuest_CustomOnAuthenticated(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "custom@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
access, _, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
handler := middleware.RequireGuest(middleware.GuestOptions{
Auth: a,
OnAuthenticated: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
},
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(access))
if rr.Code != http.StatusSeeOther {
t.Fatalf("expected 303, got %d", rr.Code)
}
if got := rr.Header().Get("Location"); got != "/dashboard" {
t.Fatalf("expected Location=/dashboard, got %q", got)
}
}
// ─── RequireServiceKey ─────────────────────────────────────────────────────
func TestRequireServiceKey_AbilityGate(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
t.Fatalf("CreateAbility: %v", err)
}
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
Name: "ci",
Abilities: []string{"events:write"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("events:write"),
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
k, ok := authkit.ServiceKeyFromCtx(r.Context())
if !ok || !k.HasAbility("events:write") {
t.Fatalf("expected ServiceKey with events:write on context")
}
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(plain))
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
}
func TestRequireServiceKey_AbilityGateRejectsMissing(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
t.Fatalf("CreateAbility events:write: %v", err)
}
if _, err := a.CreateAbility(ctx, "admin:nuke", ""); err != nil {
t.Fatalf("CreateAbility admin:nuke: %v", err)
}
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
Name: "ci",
Abilities: []string{"events:write"},
})
if err != nil {
t.Fatalf("IssueServiceKey: %v", err)
}
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("admin:nuke"),
})(http.HandlerFunc(ok200))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, reqWithBearer(plain))
if rr.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", rr.Code)
}
}
func TestRequireServiceKey_PanicsOnUnknownAbility(t *testing.T) {
a := freshAuth(t)
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic on unknown ability slug")
}
}()
middleware.RequireServiceKey(middleware.ServiceKeyOptions{
Auth: a,
Authz: authkit.HasAbility("never-registered"),
})
}