authkit/store_tokens.go
juancwu ca5525d4bd Cap refresh chain lifetime via RefreshChainAbsoluteTTL
Sessions had an absolute cap (created_at + SessionAbsoluteTTL) but the
JWT path only had per-token TTL on the refresh row, letting a
well-behaved client refresh indefinitely. Add chain_started_at to
authkit_tokens, copy it forward on every rotation, and reject in
RefreshJWT when now > chainStartedAt + RefreshChainAbsoluteTTL.
Default 30d, mirroring SessionAbsoluteTTL.

Schema, verifier, queries, model, and integration test updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-26 23:41:02 +00:00

136 lines
4.4 KiB
Go

package authkit
import (
"context"
"database/sql"
"time"
"git.juancwu.dev/juancwu/errx"
)
func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error {
const op = "authkit.storeCreateToken"
if t.CreatedAt.IsZero() {
t.CreatedAt = a.now()
}
_, err := a.db.ExecContext(ctx, a.q.createToken,
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
nullableTime(t.ChainStartedAt), nullableTime(t.ConsumedAt),
nullableInt(t.AttemptsRemaining), t.CreatedAt, t.ExpiresAt)
if err != nil {
return errx.Wrap(op, err)
}
return nil
}
// storeConsumeToken atomically marks the matching unexpired, unconsumed token
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
// Implementations MUST do this in one statement to prevent double-spend
// under concurrent callers.
func (a *Auth) storeConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
const op = "authkit.storeConsumeToken"
row := a.db.QueryRowContext(ctx, a.q.consumeToken, now, string(kind), hash, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
return t, nil
}
func (a *Auth) storeGetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) {
const op = "authkit.storeGetToken"
row := a.db.QueryRowContext(ctx, a.q.getToken, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
return t, nil
}
// storeGetActiveOTPForUser returns the most recent unconsumed, unexpired OTP
// row for a user. Used by ConsumeEmailOTP to verify a code by hash-comparing
// client input.
func (a *Auth) storeGetActiveOTPForUser(ctx context.Context, kind TokenKind, userID any, now time.Time) (*Token, error) {
const op = "authkit.storeGetActiveOTPForUser"
row := a.db.QueryRowContext(ctx, a.q.getOTPForUser, string(kind), userID, now)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
}
return t, nil
}
// storeDecrementOTPAttempt drops attempts_remaining by 1 on the matched
// (kind, hash) row, consuming it when zero. Returns the new
// attempts_remaining (0 = consumed). ErrTokenInvalid when no row matched.
func (a *Auth) storeDecrementOTPAttempt(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (int, error) {
const op = "authkit.storeDecrementOTPAttempt"
var remaining sql.NullInt32
if err := a.db.QueryRowContext(ctx, a.q.decrementOTPAttempt,
now, string(kind), hash).Scan(&remaining); err != nil {
return 0, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
}
if !remaining.Valid {
return 0, nil
}
return int(remaining.Int32), nil
}
// storeConsumeOTPByHash marks an OTP row consumed by direct hash match. Used
// on the success path of ConsumeEmailOTP.
func (a *Auth) storeConsumeOTPByHash(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
const op = "authkit.storeConsumeOTPByHash"
row := a.db.QueryRowContext(ctx, a.q.consumeOTPByID, now, string(kind), hash)
t, err := scanToken(row)
if err != nil {
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
}
return t, nil
}
func (a *Auth) storeDeleteByChain(ctx context.Context, chainID string) (int64, error) {
const op = "authkit.storeDeleteByChain"
tag, err := a.db.ExecContext(ctx, a.q.deleteByChain, chainID)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func (a *Auth) storeDeleteExpiredTokens(ctx context.Context, now time.Time) (int64, error) {
const op = "authkit.storeDeleteExpiredTokens"
tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredTokens, now)
if err != nil {
return 0, errx.Wrap(op, err)
}
n, _ := tag.RowsAffected()
return n, nil
}
func scanToken(row rowScanner) (*Token, error) {
var (
t Token
kind string
userIDStr string
chainID sql.NullString
chainStartedAt sql.NullTime
consumedAt sql.NullTime
attempts sql.NullInt32
)
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, &chainStartedAt,
&consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil {
return nil, err
}
t.Kind = TokenKind(kind)
uid, err := scanUUID(userIDStr)
if err != nil {
return nil, err
}
t.UserID = uid
t.ChainID = scanNullStringPtr(chainID)
t.ChainStartedAt = scanNullTimePtr(chainStartedAt)
t.ConsumedAt = scanNullTimePtr(consumedAt)
t.AttemptsRemaining = scanNullIntPtr(attempts)
return &t, nil
}