Sessions had an absolute cap (created_at + SessionAbsoluteTTL) but the JWT path only had per-token TTL on the refresh row, letting a well-behaved client refresh indefinitely. Add chain_started_at to authkit_tokens, copy it forward on every rotation, and reject in RefreshJWT when now > chainStartedAt + RefreshChainAbsoluteTTL. Default 30d, mirroring SessionAbsoluteTTL. Schema, verifier, queries, model, and integration test updated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
136 lines
4.4 KiB
Go
136 lines
4.4 KiB
Go
package authkit
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"git.juancwu.dev/juancwu/errx"
|
|
)
|
|
|
|
func (a *Auth) storeCreateToken(ctx context.Context, t *Token) error {
|
|
const op = "authkit.storeCreateToken"
|
|
if t.CreatedAt.IsZero() {
|
|
t.CreatedAt = a.now()
|
|
}
|
|
_, err := a.db.ExecContext(ctx, a.q.createToken,
|
|
t.Hash, string(t.Kind), uuidArg(t.UserID), chainArg(t.ChainID),
|
|
nullableTime(t.ChainStartedAt), nullableTime(t.ConsumedAt),
|
|
nullableInt(t.AttemptsRemaining), t.CreatedAt, t.ExpiresAt)
|
|
if err != nil {
|
|
return errx.Wrap(op, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// storeConsumeToken atomically marks the matching unexpired, unconsumed token
|
|
// as consumed and returns it. Returns ErrTokenInvalid if no row matched.
|
|
// Implementations MUST do this in one statement to prevent double-spend
|
|
// under concurrent callers.
|
|
func (a *Auth) storeConsumeToken(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
|
|
const op = "authkit.storeConsumeToken"
|
|
row := a.db.QueryRowContext(ctx, a.q.consumeToken, now, string(kind), hash, now)
|
|
t, err := scanToken(row)
|
|
if err != nil {
|
|
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (a *Auth) storeGetToken(ctx context.Context, kind TokenKind, hash []byte) (*Token, error) {
|
|
const op = "authkit.storeGetToken"
|
|
row := a.db.QueryRowContext(ctx, a.q.getToken, string(kind), hash)
|
|
t, err := scanToken(row)
|
|
if err != nil {
|
|
return nil, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// storeGetActiveOTPForUser returns the most recent unconsumed, unexpired OTP
|
|
// row for a user. Used by ConsumeEmailOTP to verify a code by hash-comparing
|
|
// client input.
|
|
func (a *Auth) storeGetActiveOTPForUser(ctx context.Context, kind TokenKind, userID any, now time.Time) (*Token, error) {
|
|
const op = "authkit.storeGetActiveOTPForUser"
|
|
row := a.db.QueryRowContext(ctx, a.q.getOTPForUser, string(kind), userID, now)
|
|
t, err := scanToken(row)
|
|
if err != nil {
|
|
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// storeDecrementOTPAttempt drops attempts_remaining by 1 on the matched
|
|
// (kind, hash) row, consuming it when zero. Returns the new
|
|
// attempts_remaining (0 = consumed). ErrTokenInvalid when no row matched.
|
|
func (a *Auth) storeDecrementOTPAttempt(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (int, error) {
|
|
const op = "authkit.storeDecrementOTPAttempt"
|
|
var remaining sql.NullInt32
|
|
if err := a.db.QueryRowContext(ctx, a.q.decrementOTPAttempt,
|
|
now, string(kind), hash).Scan(&remaining); err != nil {
|
|
return 0, errx.Wrap(op, mapNotFound(err, ErrTokenInvalid))
|
|
}
|
|
if !remaining.Valid {
|
|
return 0, nil
|
|
}
|
|
return int(remaining.Int32), nil
|
|
}
|
|
|
|
// storeConsumeOTPByHash marks an OTP row consumed by direct hash match. Used
|
|
// on the success path of ConsumeEmailOTP.
|
|
func (a *Auth) storeConsumeOTPByHash(ctx context.Context, kind TokenKind, hash []byte, now time.Time) (*Token, error) {
|
|
const op = "authkit.storeConsumeOTPByHash"
|
|
row := a.db.QueryRowContext(ctx, a.q.consumeOTPByID, now, string(kind), hash)
|
|
t, err := scanToken(row)
|
|
if err != nil {
|
|
return nil, errx.Wrap(op, mapNotFound(err, ErrOTPInvalid))
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (a *Auth) storeDeleteByChain(ctx context.Context, chainID string) (int64, error) {
|
|
const op = "authkit.storeDeleteByChain"
|
|
tag, err := a.db.ExecContext(ctx, a.q.deleteByChain, chainID)
|
|
if err != nil {
|
|
return 0, errx.Wrap(op, err)
|
|
}
|
|
n, _ := tag.RowsAffected()
|
|
return n, nil
|
|
}
|
|
|
|
func (a *Auth) storeDeleteExpiredTokens(ctx context.Context, now time.Time) (int64, error) {
|
|
const op = "authkit.storeDeleteExpiredTokens"
|
|
tag, err := a.db.ExecContext(ctx, a.q.deleteExpiredTokens, now)
|
|
if err != nil {
|
|
return 0, errx.Wrap(op, err)
|
|
}
|
|
n, _ := tag.RowsAffected()
|
|
return n, nil
|
|
}
|
|
|
|
func scanToken(row rowScanner) (*Token, error) {
|
|
var (
|
|
t Token
|
|
kind string
|
|
userIDStr string
|
|
chainID sql.NullString
|
|
chainStartedAt sql.NullTime
|
|
consumedAt sql.NullTime
|
|
attempts sql.NullInt32
|
|
)
|
|
if err := row.Scan(&t.Hash, &kind, &userIDStr, &chainID, &chainStartedAt,
|
|
&consumedAt, &attempts, &t.CreatedAt, &t.ExpiresAt); err != nil {
|
|
return nil, err
|
|
}
|
|
t.Kind = TokenKind(kind)
|
|
uid, err := scanUUID(userIDStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.UserID = uid
|
|
t.ChainID = scanNullStringPtr(chainID)
|
|
t.ChainStartedAt = scanNullTimePtr(chainStartedAt)
|
|
t.ConsumedAt = scanNullTimePtr(consumedAt)
|
|
t.AttemptsRemaining = scanNullIntPtr(attempts)
|
|
return &t, nil
|
|
}
|