diff --git a/README.md b/README.md index ddee04c..ce159c4 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,7 @@ custom names back to defaults without manual intervention. | `SessionCookieSameSite` | `Lax` | | | `JWTSecret` | — (required) | HS256 key | | `AccessTokenTTL` / `RefreshTokenTTL` | 15m / 30d | | +| `RefreshChainAbsoluteTTL` | 30d | Hard cap from chain start. Refresh fails past this even if the per-token TTL hasn't elapsed; user must re-authenticate. Mirrors `SessionAbsoluteTTL`. | | `EmailVerifyTTL` / `PasswordResetTTL` / `MagicLinkTTL` | 48h / 1h / 15m | | | `EmailOTPTTL` / `EmailOTPDigits` / `EmailOTPMaxAttempts` | 10m / 6 / 5 | | | `RevealUnknownEmail` | `false` | Default anti-enumeration: silent success on unknown email | diff --git a/authkit.go b/authkit.go index b357915..3e4b3ef 100644 --- a/authkit.go +++ b/authkit.go @@ -61,11 +61,17 @@ type Config struct { SessionCookieSameSite http.SameSite // JWT (HS256) - JWTSecret []byte - JWTIssuer string - JWTAudience string - AccessTokenTTL time.Duration - RefreshTokenTTL time.Duration + JWTSecret []byte + JWTIssuer string + JWTAudience string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + // RefreshChainAbsoluteTTL caps the maximum lifetime of a refresh chain. + // A user can refresh as often as they want within RefreshTokenTTL of the + // last rotation, but the chain itself dies once now > chainStartedAt + + // RefreshChainAbsoluteTTL — at which point the user must re-authenticate. + // Mirrors SessionAbsoluteTTL on the session path. + RefreshChainAbsoluteTTL time.Duration // Single-use tokens EmailVerifyTTL time.Duration @@ -173,6 +179,9 @@ func applyDefaults(cfg Config) Config { if cfg.RefreshTokenTTL == 0 { cfg.RefreshTokenTTL = 30 * 24 * time.Hour } + if cfg.RefreshChainAbsoluteTTL == 0 { + cfg.RefreshChainAbsoluteTTL = 30 * 24 * time.Hour + } if cfg.EmailVerifyTTL == 0 { cfg.EmailVerifyTTL = 48 * time.Hour } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 8d45490..51e2f15 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -42,17 +42,18 @@ func freshAuth(t *testing.T) *authkit.Auth { DB: db, Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), }, authkit.Config{ - JWTSecret: []byte("integration-secret-thirty-two!!!"), - JWTIssuer: "authkit-mw-int", - AccessTokenTTL: 2 * time.Minute, - RefreshTokenTTL: time.Hour, - SessionIdleTTL: time.Hour, - SessionAbsoluteTTL: 24 * time.Hour, - EmailVerifyTTL: time.Hour, - PasswordResetTTL: time.Hour, - MagicLinkTTL: time.Minute, - EmailOTPTTL: time.Minute, - EmailOTPMaxAttempts: 3, + JWTSecret: []byte("integration-secret-thirty-two!!!"), + JWTIssuer: "authkit-mw-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: time.Hour, + RefreshChainAbsoluteTTL: 24 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + EmailOTPTTL: time.Minute, + EmailOTPMaxAttempts: 3, // Plain HTTP for tests so secure-cookie defaults don't interfere // with httptest's HTTP server. SessionCookieSecure: authkit.BoolPtr(false), diff --git a/migrations/0001_init.sql b/migrations/0001_init.sql index 0f9ae3e..6f8b648 100644 --- a/migrations/0001_init.sql +++ b/migrations/0001_init.sql @@ -48,6 +48,10 @@ CREATE TABLE IF NOT EXISTS authkit_tokens ( kind TEXT NOT NULL, user_id UUID NOT NULL REFERENCES authkit_users(id) ON DELETE CASCADE, chain_id TEXT, + -- chain_started_at is the timestamp of the first refresh in a chain. + -- Copied forward on every rotation so the absolute-cap check in + -- RefreshJWT is O(1). Non-null only for refresh-token rows. + chain_started_at TIMESTAMPTZ, consumed_at TIMESTAMPTZ, attempts_remaining INTEGER, created_at TIMESTAMPTZ NOT NULL, diff --git a/models.go b/models.go index 108c4ed..f76ba7c 100644 --- a/models.go +++ b/models.go @@ -46,12 +46,15 @@ const ( // Token is one row in authkit_tokens. AttemptsRemaining is non-nil only for // tokens that allow retry on incorrect input (email OTPs); other kinds are -// strictly one-shot via ConsumeToken. +// strictly one-shot via ConsumeToken. ChainStartedAt is non-nil only for +// refresh-token rows; copied forward on every rotation so the absolute-cap +// check in RefreshJWT is O(1). type Token struct { Hash []byte Kind TokenKind UserID uuid.UUID ChainID *string + ChainStartedAt *time.Time ConsumedAt *time.Time AttemptsRemaining *int CreatedAt time.Time diff --git a/service_jwt.go b/service_jwt.go index f79ba5e..a3de6ca 100644 --- a/service_jwt.go +++ b/service_jwt.go @@ -3,6 +3,7 @@ package authkit import ( "context" "errors" + "time" "git.juancwu.dev/juancwu/errx" "github.com/google/uuid" @@ -11,6 +12,8 @@ import ( // IssueJWT issues a fresh access JWT and a rotating opaque refresh token. // The refresh token is bound to a chain via Token.ChainID; rotation // preserves that chain so reuse-detection can revoke the whole family. +// chainStartedAt is stamped on this row and copied forward on every +// rotation so RefreshJWT can enforce RefreshChainAbsoluteTTL in O(1). func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh string, err error) { const op = "authkit.Auth.IssueJWT" u, err := a.storeGetUserByID(ctx, userID) @@ -21,7 +24,7 @@ func (a *Auth) IssueJWT(ctx context.Context, userID uuid.UUID) (access, refresh if err != nil { return "", "", errx.Wrap(op, err) } - refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString()) + refresh, err = a.mintRefreshToken(ctx, u.ID, uuid.NewString(), a.now()) if err != nil { return "", "", errx.Wrap(op, err) } @@ -67,7 +70,9 @@ func (a *Auth) AuthenticateJWT(ctx context.Context, access string) (*Principal, // RefreshJWT consumes the presented refresh token and mints a new // access+refresh pair. Reuse of an already-consumed refresh token deletes // the entire chain (logout-everywhere on that device family) and returns -// ErrTokenReused. +// ErrTokenReused. The chain itself is capped at RefreshChainAbsoluteTTL +// from chain_started_at — past that, refresh fails with ErrTokenInvalid +// and the chain is deleted, forcing the user to re-authenticate. func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, refresh string, err error) { const op = "authkit.Auth.RefreshJWT" hash, ok := ParseOpaqueSecret(prefixRefresh, plaintextRefresh) @@ -92,6 +97,20 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, return "", "", errx.Wrap(op, err) } + // Enforce the absolute chain cap. If the chain is older than the + // configured ceiling, kill the whole chain rather than minting a + // successor — re-authentication is required. + chainStartedAt := now + if consumed.ChainStartedAt != nil { + chainStartedAt = *consumed.ChainStartedAt + } + if now.After(chainStartedAt.Add(a.cfg.RefreshChainAbsoluteTTL)) { + if consumed.ChainID != nil && *consumed.ChainID != "" { + _, _ = a.storeDeleteByChain(ctx, *consumed.ChainID) + } + return "", "", errx.Wrap(op, ErrTokenInvalid) + } + var chainID string if consumed.ChainID != nil { chainID = *consumed.ChainID @@ -100,20 +119,21 @@ func (a *Auth) RefreshJWT(ctx context.Context, plaintextRefresh string) (access, // Defensive: every refresh token should be chain-bound. Fall back // to a fresh chain rather than throwing on missing metadata. chainID = uuid.NewString() + chainStartedAt = now } access, err = a.signAccessToken(consumed.UserID, a.userSessionVersion(ctx, consumed.UserID)) if err != nil { return "", "", errx.Wrap(op, err) } - refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID) + refresh, err = a.mintRefreshToken(ctx, consumed.UserID, chainID, chainStartedAt) if err != nil { return "", "", errx.Wrap(op, err) } return access, refresh, nil } -func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string) (string, error) { +func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID string, chainStartedAt time.Time) (string, error) { const op = "authkit.Auth.mintRefreshToken" plaintext, hash, err := MintOpaqueSecret(a.cfg.Random, prefixRefresh) if err != nil { @@ -121,12 +141,13 @@ func (a *Auth) mintRefreshToken(ctx context.Context, userID uuid.UUID, chainID s } now := a.now() t := &Token{ - Hash: hash, - Kind: TokenRefresh, - UserID: userID, - ChainID: &chainID, - CreatedAt: now, - ExpiresAt: now.Add(a.cfg.RefreshTokenTTL), + Hash: hash, + Kind: TokenRefresh, + UserID: userID, + ChainID: &chainID, + ChainStartedAt: &chainStartedAt, + CreatedAt: now, + ExpiresAt: now.Add(a.cfg.RefreshTokenTTL), } if err := a.storeCreateToken(ctx, t); err != nil { return "", errx.Wrap(op, err) diff --git a/service_jwt_test.go b/service_jwt_test.go index 3f94d4d..3fc25af 100644 --- a/service_jwt_test.go +++ b/service_jwt_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" ) func TestIntegration_JWTIssueAuthenticate(t *testing.T) { @@ -65,3 +66,35 @@ func TestIntegration_JWTInvalidPrefix(t *testing.T) { t.Fatalf("expected ErrTokenInvalid for malformed input, got %v", err) } } + +func TestIntegration_JWTRefreshChainAbsoluteCap(t *testing.T) { + a := freshAuth(t) + ctx := context.Background() + u, err := a.CreateUser(ctx, "cap@example.com") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + // freshAuth sets RefreshChainAbsoluteTTL = 24h. Pin clock so the issue + // time is fixed, then advance past the cap on refresh. + t0 := time.Now().UTC() + a.cfg.Clock = func() time.Time { return t0 } + _, refresh1, err := a.IssueJWT(ctx, u.ID) + if err != nil { + t.Fatalf("IssueJWT: %v", err) + } + + // One rotation well within the cap should still work. + a.cfg.Clock = func() time.Time { return t0.Add(time.Hour) } + _, refresh2, err := a.RefreshJWT(ctx, refresh1) + if err != nil { + t.Fatalf("RefreshJWT within cap: %v", err) + } + + // Advance past the absolute cap. The refresh succeeds at the consume + // step (token row is not yet expired by RefreshTokenTTL — only the + // chain cap kicks in), but the chain cap rejects the rotation. + a.cfg.Clock = func() time.Time { return t0.Add(25 * time.Hour) } + if _, _, err := a.RefreshJWT(ctx, refresh2); !errors.Is(err, ErrTokenInvalid) { + t.Fatalf("expected ErrTokenInvalid past chain absolute cap, got %v", err) + } +} diff --git a/store_queries.go b/store_queries.go index 9e4fb32..c18fa66 100644 --- a/store_queries.go +++ b/store_queries.go @@ -121,17 +121,17 @@ func buildQueries(t Tables) queries { // tokens createToken: `INSERT INTO ` + t.Tokens + ` - (hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + (hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, consumeToken: `UPDATE ` + t.Tokens + ` SET consumed_at = $1 WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $4 - RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`, - getToken: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at + RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`, + getToken: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at FROM ` + t.Tokens + ` WHERE kind = $1 AND hash = $2`, // getOTPForUser returns the most recent unconsumed, unexpired OTP for // the user, used to verify a code by hash-comparing client input. - getOTPForUser: `SELECT hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at + getOTPForUser: `SELECT hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at FROM ` + t.Tokens + ` WHERE kind = $1 AND user_id = $2 AND consumed_at IS NULL AND expires_at > $3 ORDER BY created_at DESC LIMIT 1`, @@ -146,7 +146,7 @@ func buildQueries(t Tables) queries { consumeOTPByID: `UPDATE ` + t.Tokens + ` SET consumed_at = $1 WHERE kind = $2 AND hash = $3 AND consumed_at IS NULL AND expires_at > $1 - RETURNING hash, kind, user_id, chain_id, consumed_at, attempts_remaining, created_at, expires_at`, + RETURNING hash, kind, user_id, chain_id, chain_started_at, consumed_at, attempts_remaining, created_at, expires_at`, deleteByChain: `DELETE FROM ` + t.Tokens + ` WHERE chain_id = $1`, deleteExpiredTokens: `DELETE FROM ` + t.Tokens + ` WHERE expires_at <= $1`, diff --git a/store_tokens.go b/store_tokens.go index f9273d7..1cabdd9 100644 --- a/store_tokens.go +++ b/store_tokens.go @@ -15,8 +15,8 @@ func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error { } _, err := a.db.ExecContext(ctx, a.q.createToken, t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID), - nullableTime(t.ConsumedAt), nullableInt(t.AttemptsRemaining), - t.CreatedAt, t.ExpiresAt) + nullableTime(t.ChainStartedAt), nullableTime(t.ConsumedAt), + nullableInt(t.AttemptsRemaining), t.CreatedAt, t.ExpiresAt) if err != nil { return errx.Wrap(op, err) } @@ -110,14 +110,15 @@ func (a *Auth) storeDeleteExpiredTokens(ctx context.Context, now time.Time) (int func scanToken(row rowScanner) (*Token, error) { var ( - t Token - kind string - userIDStr string - chainID sql.NullString - consumedAt sql.NullTime - attempts sql.NullInt32 + t Token + kind string + userIDStr string + chainID sql.NullString + chainStartedAt sql.NullTime + consumedAt sql.NullTime + attempts sql.NullInt32 ) - if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, + if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, &chainStartedAt, &consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil { return nil, err } @@ -128,6 +129,7 @@ func scanToken(row rowScanner) (*Token, error) { } t.UserID = uid t.ChainID = scanNullStringPtr(chainID) + t.ChainStartedAt = scanNullTimePtr(chainStartedAt) t.ConsumedAt = scanNullTimePtr(consumedAt) t.AttemptsRemaining = scanNullIntPtr(attempts) return &t, nil diff --git a/store_verify.go b/store_verify.go index 3310bde..217cd46 100644 --- a/store_verify.go +++ b/store_verify.go @@ -75,6 +75,7 @@ func expectedSchema(s Schema) []tableSpec { {"kind", "text", false}, {"user_id", "uuid", false}, {"chain_id", "text", true}, + {"chain_started_at", "timestamp with time zone", true}, {"consumed_at", "timestamp with time zone", true}, {"attempts_remaining", "integer", true}, {"created_at", "timestamp with time zone", false}, diff --git a/testdb_test.go b/testdb_test.go index 85d0140..df1b42f 100644 --- a/testdb_test.go +++ b/testdb_test.go @@ -52,17 +52,18 @@ func freshAuth(t *testing.T) *Auth { DB: db, Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil), }, Config{ - JWTSecret: []byte("integration-secret-thirty-two!!!"), - JWTIssuer: "authkit-int", - AccessTokenTTL: 2 * time.Minute, - RefreshTokenTTL: time.Hour, - SessionIdleTTL: time.Hour, - SessionAbsoluteTTL: 24 * time.Hour, - EmailVerifyTTL: time.Hour, - PasswordResetTTL: time.Hour, - MagicLinkTTL: time.Minute, - EmailOTPTTL: time.Minute, - EmailOTPMaxAttempts: 3, + JWTSecret: []byte("integration-secret-thirty-two!!!"), + JWTIssuer: "authkit-int", + AccessTokenTTL: 2 * time.Minute, + RefreshTokenTTL: 48 * time.Hour, + RefreshChainAbsoluteTTL: 24 * time.Hour, + SessionIdleTTL: time.Hour, + SessionAbsoluteTTL: 24 * time.Hour, + EmailVerifyTTL: time.Hour, + PasswordResetTTL: time.Hour, + MagicLinkTTL: time.Minute, + EmailOTPTTL: time.Minute, + EmailOTPMaxAttempts: 3, }) if err != nil { t.Fatalf("authkit.New: %v", err)