Cap refresh chain lifetime via RefreshChainAbsoluteTTL

Sessions had an absolute cap (created_at + SessionAbsoluteTTL) but the
JWT path only had per-token TTL on the refresh row, letting a
well-behaved client refresh indefinitely. Add chain_started_at to
authkit_tokens, copy it forward on every rotation, and reject in
RefreshJWT when now > chainStartedAt + RefreshChainAbsoluteTTL.
Default 30d, mirroring SessionAbsoluteTTL.

Schema, verifier, queries, model, and integration test updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
juancwu 2026-04-26 23:41:02 +00:00
commit ca5525d4bd
11 changed files with 129 additions and 53 deletions

View file

@ -257,6 +257,7 @@ custom names back to defaults without manual intervention.
| `SessionCookieSameSite` | `Lax` | |
| `JWTSecret` | — (required) | HS256 key |
| `AccessTokenTTL` / `RefreshTokenTTL` | 15m / 30d | |
| `RefreshChainAbsoluteTTL` | 30d | Hard cap from chain start. Refresh fails past this even if the per-token TTL hasn't elapsed; user must re-authenticate. Mirrors `SessionAbsoluteTTL`. |
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
| `EmailOTPTTL` / `EmailOTPDigits` / `EmailOTPMaxAttempts` | 10m / 6 / 5 | |
| `RevealUnknownEmail` | `false` | Default anti-enumeration: silent success on unknown email |

View file

@ -66,6 +66,12 @@ type Config struct {
JWTAudience string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
// RefreshChainAbsoluteTTL caps the maximum lifetime of a refresh chain.
// A user can refresh as often as they want within RefreshTokenTTL of the
// last rotation, but the chain itself dies once now > chainStartedAt +
// RefreshChainAbsoluteTTL — at which point the user must re-authenticate.
// Mirrors SessionAbsoluteTTL on the session path.
RefreshChainAbsoluteTTL time.Duration
// Single-use tokens
EmailVerifyTTL time.Duration
@ -173,6 +179,9 @@ func applyDefaults(cfg Config) Config {
if cfg.RefreshTokenTTL == 0 {
cfg.RefreshTokenTTL = 30 * 24 * time.Hour
}
if cfg.RefreshChainAbsoluteTTL == 0 {
cfg.RefreshChainAbsoluteTTL = 30 * 24 * time.Hour
}
if cfg.EmailVerifyTTL == 0 {
cfg.EmailVerifyTTL = 48 * time.Hour
}

View file

@ -46,6 +46,7 @@ func freshAuth(t *testing.T) *authkit.Auth {
JWTIssuer: "authkit-mw-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: time.Hour,
RefreshChainAbsoluteTTL: 24 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,

View file

@ -48,6 +48,10 @@ CREATE TABLE IF NOT EXISTS authkit_tokens (
kind TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
chain_id TEXT,
-- chain_started_at is the timestamp of the first refresh in a chain.
-- Copied forward on every rotation so the absolute-cap check in
-- RefreshJWT is O(1). Non-null only for refresh-token rows.
chain_started_at TIMESTAMPTZ,
consumed_at TIMESTAMPTZ,
attempts_remaining INTEGER,
created_at TIMESTAMPTZ NOT NULL,

View file

@ -46,12 +46,15 @@ const (
// Token is one row in authkit_tokens. AttemptsRemaining is non-nil only for
// tokens that allow retry on incorrect input (email OTPs); other kinds are
// strictly one-shot via ConsumeToken.
// strictly one-shot via ConsumeToken. ChainStartedAt is non-nil only for
// refresh-token rows; copied forward on every rotation so the absolute-cap
// check in RefreshJWT is O(1).
type Token struct {
Hash []byte
Kind TokenKind
UserID uuid.UUID
ChainID *string
ChainStartedAt *time.Time
ConsumedAt *time.Time
AttemptsRemaining *int
CreatedAt time.Time

View file

@ -3,6 +3,7 @@ package authkit
import (
"context"
"errors"
"time"
"git.juancwu.dev/juancwu/errx"
"github.com/google/uuid"
@ -11,6 +12,8 @@ import (
// IssueJWT issues a fresh access JWT and a rotating opaque refresh token.
// The refresh token is bound to a chain via Token.ChainID; rotation
// preserves that chain so reuse-detection can revoke the whole family.
// chainStartedAt is stamped on this row and copied forward on every
// rotation so RefreshJWT can enforce RefreshChainAbsoluteTTL in O(1).
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
const op = "authkit.Auth.IssueJWT"
u, err := a.storeGetUserByID(ctx, userID)
@ -21,7 +24,7 @@ func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh
if err != nil {
return "", "", errx.Wrap(op, err)
}
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString())
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString(), a.now())
if err != nil {
return "", "", errx.Wrap(op, err)
}
@ -67,7 +70,9 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal,
// RefreshJWT consumes the presented refresh token and mints a new
// access+refresh pair. Reuse of an already-consumed refresh token deletes
// the entire chain (logout-everywhere on that device family) and returns
// ErrTokenReused.
// ErrTokenReused. The chain itself is capped at RefreshChainAbsoluteTTL
// from chain_started_at — past that, refresh fails with ErrTokenInvalid
// and the chain is deleted, forcing the user to re-authenticate.
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
const op = "authkit.Auth.RefreshJWT"
hash, ok := ParseOpaqueSecret(prefixRefresh, plaintextRefresh)
@ -92,6 +97,20 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
return "", "", errx.Wrap(op, err)
}
// Enforce the absolute chain cap. If the chain is older than the
// configured ceiling, kill the whole chain rather than minting a
// successor — re-authentication is required.
chainStartedAt := now
if consumed.ChainStartedAt != nil {
chainStartedAt = *consumed.ChainStartedAt
}
if now.After(chainStartedAt.Add(a.cfg.RefreshChainAbsoluteTTL)) {
if consumed.ChainID != nil && *consumed.ChainID != "" {
_, _ = a.storeDeleteByChain(ctx, *consumed.ChainID)
}
return "", "", errx.Wrap(op, ErrTokenInvalid)
}
var chainID string
if consumed.ChainID != nil {
chainID = *consumed.ChainID
@ -100,20 +119,21 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
// Defensive: every refresh token should be chain-bound. Fall back
// to a fresh chain rather than throwing on missing metadata.
chainID = uuid.NewString()
chainStartedAt = now
}
access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID))
if err != nil {
return "", "", errx.Wrap(op, err)
}
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID)
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID, chainStartedAt)
if err != nil {
return "", "", errx.Wrap(op, err)
}
return access, refresh, nil
}
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string, chainStartedAt time.Time) (string, error) {
const op = "authkit.Auth.mintRefreshToken"
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixRefresh)
if err != nil {
@ -125,6 +145,7 @@ func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID s
Kind: TokenRefresh,
UserID: userID,
ChainID: &chainID,
ChainStartedAt: &chainStartedAt,
CreatedAt: now,
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
}

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"
)
func TestIntegration_JWTIssueAuthenticate(t *testing.T) {
@ -65,3 +66,35 @@ func TestIntegration_JWTInvalidPrefix(t *testing.T) {
t.Fatalf("expected ErrTokenInvalid for malformed input, got %v", err)
}
}
func TestIntegration_JWTRefreshChainAbsoluteCap(t *testing.T) {
a := freshAuth(t)
ctx := context.Background()
u, err := a.CreateUser(ctx, "cap@example.com")
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
// freshAuth sets RefreshChainAbsoluteTTL = 24h. Pin clock so the issue
// time is fixed, then advance past the cap on refresh.
t0 := time.Now().UTC()
a.cfg.Clock = func() time.Time { return t0 }
_, refresh1, err := a.IssueJWT(ctx, u.ID)
if err != nil {
t.Fatalf("IssueJWT: %v", err)
}
// One rotation well within the cap should still work.
a.cfg.Clock = func() time.Time { return t0.Add(time.Hour) }
_, refresh2, err := a.RefreshJWT(ctx, refresh1)
if err != nil {
t.Fatalf("RefreshJWT within cap: %v", err)
}
// Advance past the absolute cap. The refresh succeeds at the consume
// step (token row is not yet expired by RefreshTokenTTL — only the
// chain cap kicks in), but the chain cap rejects the rotation.
a.cfg.Clock = func() time.Time { return t0.Add(25 * time.Hour) }
if _, _, err := a.RefreshJWT(ctx, refresh2); !errors.Is(err, ErrTokenInvalid) {
t.Fatalf("expected ErrTokenInvalid past chain absolute cap, got %v", err)
}
}

View file

@ -121,17 +121,17 @@ func buildQueries(t Tables) queries {
// tokens
createToken: `INSERT INTO ` + t.Tokens + `
(hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
(hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
consumeToken: `UPDATE ` + t.Tokens + `
SET consumed_at = $1
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $4
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
getToken: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`,
getToken: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at
FROM ` + t.Tokens + ` WHERE kind = $1 AND hash = $2`,
// getOTPForUser returns the most recent unconsumed, unexpired OTP for
// the user, used to verify a code by hash-comparing client input.
getOTPForUser: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
getOTPForUser: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at
FROM ` + t.Tokens + `
WHERE kind = $1 AND user_id = $2 AND consumed_at IS NULL AND expires_at > $3
ORDER BY created_at DESC LIMIT 1`,
@ -146,7 +146,7 @@ func buildQueries(t Tables) queries {
consumeOTPByID: `UPDATE ` + t.Tokens + `
SET consumed_at = $1
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`,
deleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = $1`,
deleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= $1`,

View file

@ -15,8 +15,8 @@ func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error {
}
_, err := a.db.ExecContext(ctx, a.q.createToken,
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
nullableTime(t.ConsumedAt), nullableInt(t.AttemptsRemaining),
t.CreatedAt, t.ExpiresAt)
nullableTime(t.ChainStartedAt), nullableTime(t.ConsumedAt),
nullableInt(t.AttemptsRemaining), t.CreatedAt, t.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
@ -114,10 +114,11 @@ func scanToken(row rowScanner) (*Token, error) {
kind string
userIDStr string
chainID sql.NullString
chainStartedAt sql.NullTime
consumedAt sql.NullTime
attempts sql.NullInt32
)
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, &chainStartedAt,
&consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil {
return nil, err
}
@ -128,6 +129,7 @@ func scanToken(row rowScanner) (*Token, error) {
}
t.UserID = uid
t.ChainID = scanNullStringPtr(chainID)
t.ChainStartedAt = scanNullTimePtr(chainStartedAt)
t.ConsumedAt = scanNullTimePtr(consumedAt)
t.AttemptsRemaining = scanNullIntPtr(attempts)
return &t, nil

View file

@ -75,6 +75,7 @@ func expectedSchema(s Schema) []tableSpec {
{"kind", "text", false},
{"user_id", "uuid", false},
{"chain_id", "text", true},
{"chain_started_at", "timestamp with time zone", true},
{"consumed_at", "timestamp with time zone", true},
{"attempts_remaining", "integer", true},
{"created_at", "timestamp with time zone", false},

View file

@ -55,7 +55,8 @@ func freshAuth(t *testing.T) *Auth {
JWTSecret: []byte("integration-secret-thirty-two!!!"),
JWTIssuer: "authkit-int",
AccessTokenTTL: 2 * time.Minute,
RefreshTokenTTL: time.Hour,
RefreshTokenTTL: 48 * time.Hour,
RefreshChainAbsoluteTTL: 24 * time.Hour,
SessionIdleTTL: time.Hour,
SessionAbsoluteTTL: 24 * time.Hour,
EmailVerifyTTL: time.Hour,