Cap refresh chain lifetime via RefreshChainAbsoluteTTL
Sessions had an absolute cap (created_at + SessionAbsoluteTTL) but the JWT path only had per-token TTL on the refresh row, letting a well-behaved client refresh indefinitely. Add chain_started_at to authkit_tokens, copy it forward on every rotation, and reject in RefreshJWT when now > chainStartedAt + RefreshChainAbsoluteTTL. Default 30d, mirroring SessionAbsoluteTTL. Schema, verifier, queries, model, and integration test updated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
d3c5367492
commit
ca5525d4bd
11 changed files with 129 additions and 53 deletions
|
|
@ -257,6 +257,7 @@ custom names back to defaults without manual intervention.
|
|||
| `SessionCookieSameSite` | `Lax` | |
|
||||
| `JWTSecret` | — (required) | HS256 key |
|
||||
| `AccessTokenTTL` / `RefreshTokenTTL` | 15m / 30d | |
|
||||
| `RefreshChainAbsoluteTTL` | 30d | Hard cap from chain start. Refresh fails past this even if the per-token TTL hasn't elapsed; user must re-authenticate. Mirrors `SessionAbsoluteTTL`. |
|
||||
| `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | |
|
||||
| `EmailOTPTTL` / `EmailOTPDigits` / `EmailOTPMaxAttempts` | 10m / 6 / 5 | |
|
||||
| `RevealUnknownEmail` | `false` | Default anti-enumeration: silent success on unknown email |
|
||||
|
|
|
|||
|
|
@ -66,6 +66,12 @@ type Config struct {
|
|||
JWTAudience string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
// RefreshChainAbsoluteTTL caps the maximum lifetime of a refresh chain.
|
||||
// A user can refresh as often as they want within RefreshTokenTTL of the
|
||||
// last rotation, but the chain itself dies once now > chainStartedAt +
|
||||
// RefreshChainAbsoluteTTL — at which point the user must re-authenticate.
|
||||
// Mirrors SessionAbsoluteTTL on the session path.
|
||||
RefreshChainAbsoluteTTL time.Duration
|
||||
|
||||
// Single-use tokens
|
||||
EmailVerifyTTL time.Duration
|
||||
|
|
@ -173,6 +179,9 @@ func applyDefaults(cfg Config) Config {
|
|||
if cfg.RefreshTokenTTL == 0 {
|
||||
cfg.RefreshTokenTTL = 30 * 24 * time.Hour
|
||||
}
|
||||
if cfg.RefreshChainAbsoluteTTL == 0 {
|
||||
cfg.RefreshChainAbsoluteTTL = 30 * 24 * time.Hour
|
||||
}
|
||||
if cfg.EmailVerifyTTL == 0 {
|
||||
cfg.EmailVerifyTTL = 48 * time.Hour
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ func freshAuth(t *testing.T) *authkit.Auth {
|
|||
JWTIssuer: "authkit-mw-int",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: time.Hour,
|
||||
RefreshChainAbsoluteTTL: 24 * time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
|
|
|
|||
|
|
@ -48,6 +48,10 @@ CREATE TABLE IF NOT EXISTS authkit_tokens (
|
|||
kind TEXT NOT NULL,
|
||||
user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE,
|
||||
chain_id TEXT,
|
||||
-- chain_started_at is the timestamp of the first refresh in a chain.
|
||||
-- Copied forward on every rotation so the absolute-cap check in
|
||||
-- RefreshJWT is O(1). Non-null only for refresh-token rows.
|
||||
chain_started_at TIMESTAMPTZ,
|
||||
consumed_at TIMESTAMPTZ,
|
||||
attempts_remaining INTEGER,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
|
|
|
|||
|
|
@ -46,12 +46,15 @@ const (
|
|||
|
||||
// Token is one row in authkit_tokens. AttemptsRemaining is non-nil only for
|
||||
// tokens that allow retry on incorrect input (email OTPs); other kinds are
|
||||
// strictly one-shot via ConsumeToken.
|
||||
// strictly one-shot via ConsumeToken. ChainStartedAt is non-nil only for
|
||||
// refresh-token rows; copied forward on every rotation so the absolute-cap
|
||||
// check in RefreshJWT is O(1).
|
||||
type Token struct {
|
||||
Hash []byte
|
||||
Kind TokenKind
|
||||
UserID uuid.UUID
|
||||
ChainID *string
|
||||
ChainStartedAt *time.Time
|
||||
ConsumedAt *time.Time
|
||||
AttemptsRemaining *int
|
||||
CreatedAt time.Time
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package authkit
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
"github.com/google/uuid"
|
||||
|
|
@ -11,6 +12,8 @@ import (
|
|||
// IssueJWT issues a fresh access JWT and a rotating opaque refresh token.
|
||||
// The refresh token is bound to a chain via Token.ChainID; rotation
|
||||
// preserves that chain so reuse-detection can revoke the whole family.
|
||||
// chainStartedAt is stamped on this row and copied forward on every
|
||||
// rotation so RefreshJWT can enforce RefreshChainAbsoluteTTL in O(1).
|
||||
func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) {
|
||||
const op = "authkit.Auth.IssueJWT"
|
||||
u, err := a.storeGetUserByID(ctx, userID)
|
||||
|
|
@ -21,7 +24,7 @@ func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh
|
|||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString())
|
||||
refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString(), a.now())
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
|
|
@ -67,7 +70,9 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal,
|
|||
// RefreshJWT consumes the presented refresh token and mints a new
|
||||
// access+refresh pair. Reuse of an already-consumed refresh token deletes
|
||||
// the entire chain (logout-everywhere on that device family) and returns
|
||||
// ErrTokenReused.
|
||||
// ErrTokenReused. The chain itself is capped at RefreshChainAbsoluteTTL
|
||||
// from chain_started_at — past that, refresh fails with ErrTokenInvalid
|
||||
// and the chain is deleted, forcing the user to re-authenticate.
|
||||
func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) {
|
||||
const op = "authkit.Auth.RefreshJWT"
|
||||
hash, ok := ParseOpaqueSecret(prefixRefresh, plaintextRefresh)
|
||||
|
|
@ -92,6 +97,20 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
|
|||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
|
||||
// Enforce the absolute chain cap. If the chain is older than the
|
||||
// configured ceiling, kill the whole chain rather than minting a
|
||||
// successor — re-authentication is required.
|
||||
chainStartedAt := now
|
||||
if consumed.ChainStartedAt != nil {
|
||||
chainStartedAt = *consumed.ChainStartedAt
|
||||
}
|
||||
if now.After(chainStartedAt.Add(a.cfg.RefreshChainAbsoluteTTL)) {
|
||||
if consumed.ChainID != nil && *consumed.ChainID != "" {
|
||||
_, _ = a.storeDeleteByChain(ctx, *consumed.ChainID)
|
||||
}
|
||||
return "", "", errx.Wrap(op, ErrTokenInvalid)
|
||||
}
|
||||
|
||||
var chainID string
|
||||
if consumed.ChainID != nil {
|
||||
chainID = *consumed.ChainID
|
||||
|
|
@ -100,20 +119,21 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access,
|
|||
// Defensive: every refresh token should be chain-bound. Fall back
|
||||
// to a fresh chain rather than throwing on missing metadata.
|
||||
chainID = uuid.NewString()
|
||||
chainStartedAt = now
|
||||
}
|
||||
|
||||
access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID))
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID)
|
||||
refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID, chainStartedAt)
|
||||
if err != nil {
|
||||
return "", "", errx.Wrap(op, err)
|
||||
}
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) {
|
||||
func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string, chainStartedAt time.Time) (string, error) {
|
||||
const op = "authkit.Auth.mintRefreshToken"
|
||||
plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixRefresh)
|
||||
if err != nil {
|
||||
|
|
@ -125,6 +145,7 @@ func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID s
|
|||
Kind: TokenRefresh,
|
||||
UserID: userID,
|
||||
ChainID: &chainID,
|
||||
ChainStartedAt: &chainStartedAt,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(a.cfg.RefreshTokenTTL),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIntegration_JWTIssueAuthenticate(t *testing.T) {
|
||||
|
|
@ -65,3 +66,35 @@ func TestIntegration_JWTInvalidPrefix(t *testing.T) {
|
|||
t.Fatalf("expected ErrTokenInvalid for malformed input, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_JWTRefreshChainAbsoluteCap(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
u, err := a.CreateUser(ctx, "cap@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
// freshAuth sets RefreshChainAbsoluteTTL = 24h. Pin clock so the issue
|
||||
// time is fixed, then advance past the cap on refresh.
|
||||
t0 := time.Now().UTC()
|
||||
a.cfg.Clock = func() time.Time { return t0 }
|
||||
_, refresh1, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
|
||||
// One rotation well within the cap should still work.
|
||||
a.cfg.Clock = func() time.Time { return t0.Add(time.Hour) }
|
||||
_, refresh2, err := a.RefreshJWT(ctx, refresh1)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshJWT within cap: %v", err)
|
||||
}
|
||||
|
||||
// Advance past the absolute cap. The refresh succeeds at the consume
|
||||
// step (token row is not yet expired by RefreshTokenTTL — only the
|
||||
// chain cap kicks in), but the chain cap rejects the rotation.
|
||||
a.cfg.Clock = func() time.Time { return t0.Add(25 * time.Hour) }
|
||||
if _, _, err := a.RefreshJWT(ctx, refresh2); !errors.Is(err, ErrTokenInvalid) {
|
||||
t.Fatalf("expected ErrTokenInvalid past chain absolute cap, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -121,17 +121,17 @@ func buildQueries(t Tables) queries {
|
|||
|
||||
// tokens
|
||||
createToken: `INSERT INTO ` + t.Tokens + `
|
||||
(hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
(hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
consumeToken: `UPDATE ` + t.Tokens + `
|
||||
SET consumed_at = $1
|
||||
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $4
|
||||
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
|
||||
getToken: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
|
||||
RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`,
|
||||
getToken: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at
|
||||
FROM ` + t.Tokens + ` WHERE kind = $1 AND hash = $2`,
|
||||
// getOTPForUser returns the most recent unconsumed, unexpired OTP for
|
||||
// the user, used to verify a code by hash-comparing client input.
|
||||
getOTPForUser: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at
|
||||
getOTPForUser: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at
|
||||
FROM ` + t.Tokens + `
|
||||
WHERE kind = $1 AND user_id = $2 AND consumed_at IS NULL AND expires_at > $3
|
||||
ORDER BY created_at DESC LIMIT 1`,
|
||||
|
|
@ -146,7 +146,7 @@ func buildQueries(t Tables) queries {
|
|||
consumeOTPByID: `UPDATE ` + t.Tokens + `
|
||||
SET consumed_at = $1
|
||||
WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1
|
||||
RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`,
|
||||
RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`,
|
||||
deleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = $1`,
|
||||
deleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= $1`,
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error {
|
|||
}
|
||||
_, err := a.db.ExecContext(ctx, a.q.createToken,
|
||||
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
|
||||
nullableTime(t.ConsumedAt), nullableInt(t.AttemptsRemaining),
|
||||
t.CreatedAt, t.ExpiresAt)
|
||||
nullableTime(t.ChainStartedAt), nullableTime(t.ConsumedAt),
|
||||
nullableInt(t.AttemptsRemaining), t.CreatedAt, t.ExpiresAt)
|
||||
if err != nil {
|
||||
return errx.Wrap(op, err)
|
||||
}
|
||||
|
|
@ -114,10 +114,11 @@ func scanToken(row rowScanner) (*Token, error) {
|
|||
kind string
|
||||
userIDStr string
|
||||
chainID sql.NullString
|
||||
chainStartedAt sql.NullTime
|
||||
consumedAt sql.NullTime
|
||||
attempts sql.NullInt32
|
||||
)
|
||||
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID,
|
||||
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, &chainStartedAt,
|
||||
&consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -128,6 +129,7 @@ func scanToken(row rowScanner) (*Token, error) {
|
|||
}
|
||||
t.UserID = uid
|
||||
t.ChainID = scanNullStringPtr(chainID)
|
||||
t.ChainStartedAt = scanNullTimePtr(chainStartedAt)
|
||||
t.ConsumedAt = scanNullTimePtr(consumedAt)
|
||||
t.AttemptsRemaining = scanNullIntPtr(attempts)
|
||||
return &t, nil
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ func expectedSchema(s Schema) []tableSpec {
|
|||
{"kind", "text", false},
|
||||
{"user_id", "uuid", false},
|
||||
{"chain_id", "text", true},
|
||||
{"chain_started_at", "timestamp with time zone", true},
|
||||
{"consumed_at", "timestamp with time zone", true},
|
||||
{"attempts_remaining", "integer", true},
|
||||
{"created_at", "timestamp with time zone", false},
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ func freshAuth(t *testing.T) *Auth {
|
|||
JWTSecret: []byte("integration-secret-thirty-two!!!"),
|
||||
JWTIssuer: "authkit-int",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: time.Hour,
|
||||
RefreshTokenTTL: 48 * time.Hour,
|
||||
RefreshChainAbsoluteTTL: 24 * time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue