Rebuild for v1.0.0: postgres-only, slug-keyed authz, predicate API
Drops the Dialect/Queries abstraction in favor of a single PostgreSQL 16+ implementation collapsed into the root authkit package, removes the public store interfaces, and reshapes the authorization model around seeded slugs (roles, permissions, abilities) with optional labels. Schema is now squashed into one migrations/0001_init.sql and applied automatically on authkit.New (opt-out via Config.SkipAutoMigrate). A schema verifier checks tables/columns/types/nullability on startup, tolerates extra columns, and falls back to default table names when a configured override is missing. Auth API: CreateUser + SetPassword replace Register; password is nullable. Email OTP (RequestEmailOTP/ConsumeEmailOTP) joins magic links and password reset, all with anti-enumeration silent-success defaults and a Config.RevealUnknownEmail opt-in. Service tokens drop owner columns and validate ability slugs against authkit_abilities at issue. Direct user permissions live alongside role-derived ones; queries return their UNION. Predicate API: HasRole/HasPermission/HasAbility leaves with AnyLogin/AllLogin/AnyServiceKey/AllServiceKey combinators. Validate runs at middleware construction, panicking on unknown slugs. Middleware collapses to RequireLogin (cookie + JWT), RequireGuest (configurable OnAuthenticated), and RequireServiceKey. UserIDFromCtx / UserFromCtx (lazy) / RefreshUserInCtx provide request-lifetime user caching. Cookie defaults flip to Secure=true and HttpOnly=true via *bool with BoolPtr opt-out. CLIs ship under cmd/perms, cmd/roles, cmd/abilities for seeding the authorization vocabulary; the library never seeds rows itself. Tests cover unit-level (slug validation + fuzz, opaque secrets, email normalization, extractors, predicates, OTP generator) and integration flows gated on AUTHKIT_TEST_DATABASE_URL (every Auth method, schema drift detection, migration idempotency, lazy user cache, all middleware paths). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
7f1db871bc
commit
d3c5367492
80 changed files with 5605 additions and 4565 deletions
|
|
@ -1,84 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// authzGuard wraps the common pattern of "look up the Principal, run a
|
||||
// predicate, succeed or 403". onForbidden defaults to JSON 403.
|
||||
func authzGuard(onForbidden func(http.ResponseWriter, *http.Request, error), pred func(*authkit.Principal) bool) func(http.Handler) http.Handler {
|
||||
if onForbidden == nil {
|
||||
onForbidden = defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p, ok := PrincipalFrom(r.Context())
|
||||
if !ok {
|
||||
// No auth middleware ran upstream; treat as forbidden
|
||||
// rather than crashing — composition is the caller's
|
||||
// responsibility but a 403 is the safer default.
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
if !pred(p) {
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole permits requests whose Principal holds the named role.
|
||||
func RequireRole(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasRole(name)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAnyRole permits requests whose Principal holds at least one of the
|
||||
// named roles.
|
||||
func RequireAnyRole(names []string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasAnyRole(names...)
|
||||
})
|
||||
}
|
||||
|
||||
// RequirePermission permits requests whose Principal holds the named
|
||||
// permission (resolved via roles at auth time).
|
||||
func RequirePermission(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
return authzGuard(firstOrNil(onForbidden), func(p *authkit.Principal) bool {
|
||||
return p.HasPermission(name)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAbility permits requests whose ServiceKey carries the named ability.
|
||||
// Abilities live only on service tokens — this middleware reads
|
||||
// *authkit.ServiceKey from the request context (placed by RequireServiceKey
|
||||
// or RequireAnyOrServiceKey) and 403s any request authenticated as a user
|
||||
// (session or JWT), which by definition has no abilities.
|
||||
func RequireAbility(name string, onForbidden ...func(http.ResponseWriter, *http.Request, error)) func(http.Handler) http.Handler {
|
||||
onForb := firstOrNil(onForbidden)
|
||||
if onForb == nil {
|
||||
onForb = defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
k, ok := ServiceKeyFrom(r.Context())
|
||||
if !ok || !k.HasAbility(name) {
|
||||
onForb(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func firstOrNil(s []func(http.ResponseWriter, *http.Request, error)) func(http.ResponseWriter, *http.Request, error) {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s[0]
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
// Package middleware provides framework-neutral HTTP middleware for authkit.
|
||||
// Every middleware function returns the standard func(http.Handler)
|
||||
// http.Handler type, so it composes with lightmux's Use/Group/Handle as well
|
||||
// as any net/http stack that uses the same signature.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// principalKey and serviceKeyKey are unexported context keys. Using distinct
|
||||
// empty struct types guarantees no collision with caller-defined keys.
|
||||
type principalKey struct{}
|
||||
type serviceKeyKey struct{}
|
||||
|
||||
// withPrincipal stashes p on the request context for downstream handlers.
|
||||
func withPrincipal(ctx context.Context, p *authkit.Principal) context.Context {
|
||||
return context.WithValue(ctx, principalKey{}, p)
|
||||
}
|
||||
|
||||
// PrincipalFrom retrieves the authenticated Principal placed by RequireSession
|
||||
// or RequireJWT. The boolean is false if no user-bound auth middleware ran for
|
||||
// this request (e.g. the request was authenticated via service key instead).
|
||||
func PrincipalFrom(ctx context.Context) (*authkit.Principal, bool) {
|
||||
p, ok := ctx.Value(principalKey{}).(*authkit.Principal)
|
||||
return p, ok
|
||||
}
|
||||
|
||||
// MustPrincipal panics if no Principal is on the context. Use only on
|
||||
// handlers known to be behind a Require* middleware that authenticates a
|
||||
// user (RequireSession or RequireJWT).
|
||||
func MustPrincipal(r *http.Request) *authkit.Principal {
|
||||
p, ok := PrincipalFrom(r.Context())
|
||||
if !ok {
|
||||
panic("authkit/middleware: no principal on request context")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// withServiceKey stashes k on the request context for downstream handlers.
|
||||
func withServiceKey(ctx context.Context, k *authkit.ServiceKey) context.Context {
|
||||
return context.WithValue(ctx, serviceKeyKey{}, k)
|
||||
}
|
||||
|
||||
// ServiceKeyFrom retrieves the authenticated ServiceKey placed by
|
||||
// RequireServiceKey. The boolean is false if no service-key middleware ran
|
||||
// for this request.
|
||||
func ServiceKeyFrom(ctx context.Context) (*authkit.ServiceKey, bool) {
|
||||
k, ok := ctx.Value(serviceKeyKey{}).(*authkit.ServiceKey)
|
||||
return k, ok
|
||||
}
|
||||
|
||||
// MustServiceKey panics if no ServiceKey is on the context. Use only on
|
||||
// handlers known to be behind RequireServiceKey.
|
||||
func MustServiceKey(r *http.Request) *authkit.ServiceKey {
|
||||
k, ok := ServiceKeyFrom(r.Context())
|
||||
if !ok {
|
||||
panic("authkit/middleware: no service key on request context")
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
|
@ -1,45 +1,236 @@
|
|||
// Package middleware provides framework-neutral HTTP middleware for authkit.
|
||||
// Every middleware function returns the standard func(http.Handler)
|
||||
// http.Handler shape so it composes with lightmux.Mux.Use/Group/Handle, with
|
||||
// chi/gorilla, or with any net/http stack accepting that signature.
|
||||
//
|
||||
// Three primitives:
|
||||
// - RequireLogin — accept session OR JWT, optionally constrain by Authz
|
||||
// - RequireGuest — reject authenticated requests
|
||||
// - RequireServiceKey — accept a service token, optionally constrain by Authz
|
||||
//
|
||||
// All three attach the relevant subject to the request context via
|
||||
// authkit.WithUserContext or authkit.WithServiceKey, so handlers can read
|
||||
// it via authkit.UserIDFromCtx / authkit.UserFromCtx /
|
||||
// authkit.ServiceKeyFromCtx.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
)
|
||||
|
||||
// Options configures auth middleware. Auth is required; the rest fall back
|
||||
// to defaults: BearerExtractor, a JSON 401 on auth failure, and a JSON 403
|
||||
// on authz failure.
|
||||
type Options struct {
|
||||
Auth *authkit.Auth
|
||||
Extractor authkit.Extractor
|
||||
// LoginOptions configures RequireLogin.
|
||||
type LoginOptions struct {
|
||||
// Auth is required.
|
||||
Auth *authkit.Auth
|
||||
|
||||
// SessionExtractor reads the session plaintext from the request.
|
||||
// Defaults to a cookie extractor using Auth.SessionCookieName().
|
||||
SessionExtractor authkit.Extractor
|
||||
|
||||
// JWTExtractor reads the JWT access token from the request. Defaults
|
||||
// to BearerExtractor.
|
||||
JWTExtractor authkit.Extractor
|
||||
|
||||
// Authz, if non-nil, gates the request on a predicate over the
|
||||
// resolved *Principal. Validate is called once at construction; an
|
||||
// invalid predicate (unknown slug) panics.
|
||||
Authz authkit.LoginAuthz
|
||||
|
||||
// OnUnauth handles "no credential / bad credential" failures (HTTP
|
||||
// 401). Default: JSON {"error":"Unauthorized"}.
|
||||
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
|
||||
|
||||
// OnForbidden handles "credential ok but Authz failed" (HTTP 403).
|
||||
// Default: JSON {"error":"Forbidden"}.
|
||||
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
// RequireLogin returns middleware that authenticates the request via either
|
||||
// a session cookie or a JWT (in that order) and, if Authz is set, gates the
|
||||
// resolved Principal against the predicate.
|
||||
//
|
||||
// Panics at construction time if Auth is nil or Authz references unknown
|
||||
// slugs.
|
||||
func RequireLogin(opts LoginOptions) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: LoginOptions.Auth is required")
|
||||
}
|
||||
if opts.Authz != nil {
|
||||
if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil {
|
||||
panic(fmt.Sprintf("authkit/middleware: %v", err))
|
||||
}
|
||||
}
|
||||
sessionEx := opts.SessionExtractor
|
||||
if sessionEx == nil {
|
||||
sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName())
|
||||
}
|
||||
jwtEx := opts.JWTExtractor
|
||||
if jwtEx == nil {
|
||||
jwtEx = authkit.BearerExtractor()
|
||||
}
|
||||
onUnauth := opts.OnUnauth
|
||||
if onUnauth == nil {
|
||||
onUnauth = defaultJSONError(http.StatusUnauthorized)
|
||||
}
|
||||
onForbidden := opts.OnForbidden
|
||||
if onForbidden == nil {
|
||||
onForbidden = defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx)
|
||||
if err != nil {
|
||||
onUnauth(w, r, err)
|
||||
return
|
||||
}
|
||||
if opts.Authz != nil && !opts.Authz.Match(p) {
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
ctx := authkit.WithUserContext(r.Context(), opts.Auth, p.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// authenticatePrincipal tries the session extractor first, then the JWT
|
||||
// extractor. Returns the first successful Principal or the last error.
|
||||
func authenticatePrincipal(r *http.Request, a *authkit.Auth, sessionEx, jwtEx authkit.Extractor) (*authkit.Principal, error) {
|
||||
if v, ok := sessionEx(r); ok && v != "" {
|
||||
if p, err := a.AuthenticateSession(r.Context(), v); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
if v, ok := jwtEx(r); ok && v != "" {
|
||||
if p, err := a.AuthenticateJWT(r.Context(), v); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, authkit.ErrSessionInvalid
|
||||
}
|
||||
|
||||
// GuestOptions configures RequireGuest.
|
||||
type GuestOptions struct {
|
||||
// Auth is required.
|
||||
Auth *authkit.Auth
|
||||
|
||||
SessionExtractor authkit.Extractor
|
||||
JWTExtractor authkit.Extractor
|
||||
|
||||
// OnAuthenticated handles requests that present a valid credential
|
||||
// (where a guest was expected). Default: JSON 403.
|
||||
OnAuthenticated func(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// RequireGuest returns middleware that rejects requests carrying a valid
|
||||
// session or JWT. Useful for /login or /register pages where authenticated
|
||||
// users should be redirected away.
|
||||
//
|
||||
// Default rejection is HTTP 403 JSON. Pass Options.OnAuthenticated to
|
||||
// implement a redirect or custom response.
|
||||
func RequireGuest(opts GuestOptions) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: GuestOptions.Auth is required")
|
||||
}
|
||||
sessionEx := opts.SessionExtractor
|
||||
if sessionEx == nil {
|
||||
sessionEx = authkit.CookieExtractor(opts.Auth.SessionCookieName())
|
||||
}
|
||||
jwtEx := opts.JWTExtractor
|
||||
if jwtEx == nil {
|
||||
jwtEx = authkit.BearerExtractor()
|
||||
}
|
||||
onAuthenticated := opts.OnAuthenticated
|
||||
if onAuthenticated == nil {
|
||||
onAuthenticated = func(w http.ResponseWriter, r *http.Request) {
|
||||
defaultJSONError(http.StatusForbidden)(w, r, authkit.ErrPermissionDenied)
|
||||
}
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := authenticatePrincipal(r, opts.Auth, sessionEx, jwtEx); err == nil {
|
||||
onAuthenticated(w, r)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceKeyOptions configures RequireServiceKey.
|
||||
type ServiceKeyOptions struct {
|
||||
Auth *authkit.Auth
|
||||
|
||||
// Extractor reads the service token plaintext. Defaults to
|
||||
// BearerExtractor.
|
||||
Extractor authkit.Extractor
|
||||
|
||||
// Authz, if non-nil, gates the request on a predicate over the
|
||||
// resolved *ServiceKey.
|
||||
Authz authkit.ServiceKeyAuthz
|
||||
|
||||
OnUnauth func(w http.ResponseWriter, r *http.Request, err error)
|
||||
OnForbidden func(w http.ResponseWriter, r *http.Request, err error)
|
||||
}
|
||||
|
||||
func (o Options) extractor() authkit.Extractor {
|
||||
if o.Extractor != nil {
|
||||
return o.Extractor
|
||||
// RequireServiceKey returns middleware that authenticates the request via a
|
||||
// service token and, if Authz is set, gates on the predicate.
|
||||
//
|
||||
// Panics at construction time if Auth is nil or Authz references unknown
|
||||
// ability slugs.
|
||||
func RequireServiceKey(opts ServiceKeyOptions) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: ServiceKeyOptions.Auth is required")
|
||||
}
|
||||
if opts.Authz != nil {
|
||||
if err := opts.Authz.Validate(context.Background(), opts.Auth); err != nil {
|
||||
panic(fmt.Sprintf("authkit/middleware: %v", err))
|
||||
}
|
||||
}
|
||||
ex := opts.Extractor
|
||||
if ex == nil {
|
||||
ex = authkit.BearerExtractor()
|
||||
}
|
||||
onUnauth := opts.OnUnauth
|
||||
if onUnauth == nil {
|
||||
onUnauth = defaultJSONError(http.StatusUnauthorized)
|
||||
}
|
||||
onForbidden := opts.OnForbidden
|
||||
if onForbidden == nil {
|
||||
onForbidden = defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
return authkit.BearerExtractor()
|
||||
}
|
||||
|
||||
func (o Options) onUnauth() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if o.OnUnauth != nil {
|
||||
return o.OnUnauth
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := ex(r)
|
||||
if !ok || raw == "" {
|
||||
onUnauth(w, r, authkit.ErrServiceKeyInvalid)
|
||||
return
|
||||
}
|
||||
k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw)
|
||||
if err != nil {
|
||||
onUnauth(w, r, err)
|
||||
return
|
||||
}
|
||||
if opts.Authz != nil && !opts.Authz.Match(k) {
|
||||
onForbidden(w, r, authkit.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
ctx := authkit.WithServiceKey(r.Context(), k)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
return defaultJSONError(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (o Options) onForbidden() func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if o.OnForbidden != nil {
|
||||
return o.OnForbidden
|
||||
}
|
||||
return defaultJSONError(http.StatusForbidden)
|
||||
}
|
||||
|
||||
func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
return func(w http.ResponseWriter, _ *http.Request, err error) {
|
||||
return func(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
|
|
@ -47,169 +238,3 @@ func defaultJSONError(status int) func(w http.ResponseWriter, r *http.Request, e
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSession authenticates the request via an opaque session string. The
|
||||
// extractor is consulted first; if no extractor is set the default Bearer
|
||||
// extractor is used. For cookie-based session lookup, set
|
||||
// Options.Extractor = authkit.CookieExtractor(cfg.SessionCookieName).
|
||||
func RequireSession(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||
return opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireJWT authenticates the request via an HS256 JWT.
|
||||
func RequireJWT(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWith(opts, func(r *http.Request, raw string) (*authkit.Principal, error) {
|
||||
return opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireServiceKey authenticates the request via an opaque service token
|
||||
// secret. On success the resolved *authkit.ServiceKey is placed on the
|
||||
// request context; downstream handlers retrieve it via ServiceKeyFrom. Note
|
||||
// that this middleware does NOT place a *Principal on the context — service
|
||||
// tokens have no user — so user-bound authz middleware (RequireRole,
|
||||
// RequirePermission) will reject service-key requests with 403.
|
||||
func RequireServiceKey(opts Options) func(http.Handler) http.Handler {
|
||||
return requireWithServiceKey(opts, func(r *http.Request, raw string) (*authkit.ServiceKey, error) {
|
||||
return opts.Auth.AuthenticateServiceKey(r.Context(), raw)
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAny tries each user-bound method in order until one succeeds. The
|
||||
// default set is [Session, JWT]; service tokens are NOT included because
|
||||
// they yield a different subject type. For routes that accept either a user
|
||||
// credential or a service token, use RequireAnyOrServiceKey.
|
||||
func RequireAny(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
|
||||
if len(methods) == 0 {
|
||||
methods = []authkit.AuthMethod{
|
||||
authkit.AuthMethodSession,
|
||||
authkit.AuthMethodJWT,
|
||||
}
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := opts.extractor()(r)
|
||||
if !ok || raw == "" {
|
||||
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
|
||||
return
|
||||
}
|
||||
var (
|
||||
p *authkit.Principal
|
||||
lastErr error
|
||||
)
|
||||
for _, m := range methods {
|
||||
switch m {
|
||||
case authkit.AuthMethodSession:
|
||||
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||
case authkit.AuthMethodJWT:
|
||||
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||
}
|
||||
if lastErr == nil && p != nil {
|
||||
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||
return
|
||||
}
|
||||
}
|
||||
opts.onUnauth()(w, r, lastErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyOrServiceKey tries the user-bound methods first (default
|
||||
// [Session, JWT]); on failure, falls through to a service-key lookup. The
|
||||
// downstream handler sees either a *Principal or a *ServiceKey on context —
|
||||
// retrieve via PrincipalFrom or ServiceKeyFrom and dispatch accordingly.
|
||||
func RequireAnyOrServiceKey(opts Options, methods ...authkit.AuthMethod) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: Options.Auth is required")
|
||||
}
|
||||
if len(methods) == 0 {
|
||||
methods = []authkit.AuthMethod{
|
||||
authkit.AuthMethodSession,
|
||||
authkit.AuthMethodJWT,
|
||||
}
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := opts.extractor()(r)
|
||||
if !ok || raw == "" {
|
||||
opts.onUnauth()(w, r, authkit.ErrSessionInvalid)
|
||||
return
|
||||
}
|
||||
var lastErr error
|
||||
for _, m := range methods {
|
||||
var p *authkit.Principal
|
||||
switch m {
|
||||
case authkit.AuthMethodSession:
|
||||
p, lastErr = opts.Auth.AuthenticateSession(r.Context(), raw)
|
||||
case authkit.AuthMethodJWT:
|
||||
p, lastErr = opts.Auth.AuthenticateJWT(r.Context(), raw)
|
||||
}
|
||||
if lastErr == nil && p != nil {
|
||||
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||
return
|
||||
}
|
||||
}
|
||||
k, err := opts.Auth.AuthenticateServiceKey(r.Context(), raw)
|
||||
if err == nil && k != nil {
|
||||
next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k)))
|
||||
return
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = err
|
||||
}
|
||||
opts.onUnauth()(w, r, lastErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireWith is the shared scaffolding for the single-method user-bound
|
||||
// Require* middlewares.
|
||||
func requireWith(opts Options, authn func(r *http.Request, raw string) (*authkit.Principal, error)) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: Options.Auth is required")
|
||||
}
|
||||
extractor := opts.extractor()
|
||||
onUnauth := opts.onUnauth()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := extractor(r)
|
||||
if !ok || raw == "" {
|
||||
onUnauth(w, r, authkit.ErrSessionInvalid)
|
||||
return
|
||||
}
|
||||
p, err := authn(r, raw)
|
||||
if err != nil {
|
||||
onUnauth(w, r, err)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(withPrincipal(r.Context(), p)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireWithServiceKey is the service-key analogue of requireWith. It places
|
||||
// a *ServiceKey (not a *Principal) on the request context.
|
||||
func requireWithServiceKey(opts Options, authn func(r *http.Request, raw string) (*authkit.ServiceKey, error)) func(http.Handler) http.Handler {
|
||||
if opts.Auth == nil {
|
||||
panic("authkit/middleware: Options.Auth is required")
|
||||
}
|
||||
extractor := opts.extractor()
|
||||
onUnauth := opts.onUnauth()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, ok := extractor(r)
|
||||
if !ok || raw == "" {
|
||||
onUnauth(w, r, authkit.ErrServiceKeyInvalid)
|
||||
return
|
||||
}
|
||||
k, err := authn(r, raw)
|
||||
if err != nil {
|
||||
onUnauth(w, r, err)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(withServiceKey(r.Context(), k)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,319 +1,92 @@
|
|||
package middleware_test
|
||||
|
||||
// Integration tests for the middleware package. Skipped when
|
||||
// AUTHKIT_TEST_DATABASE_URL is not set.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/authkit"
|
||||
"git.juancwu.dev/juancwu/authkit/hasher"
|
||||
"git.juancwu.dev/juancwu/authkit/middleware"
|
||||
"github.com/google/uuid"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
// ─── minimal in-memory stores ──────────────────────────────────────────────
|
||||
//
|
||||
// The middleware package can't import the parent's _test stores, so we wire
|
||||
// up a fresh-but-minimal set here. Only the methods actually exercised by
|
||||
// the middleware tests below have meaningful bodies; unused store methods
|
||||
// panic to surface unexpected call paths.
|
||||
|
||||
type memUserStore struct {
|
||||
mu sync.Mutex
|
||||
m map[uuid.UUID]*authkit.User
|
||||
}
|
||||
|
||||
func newMemUserStore() *memUserStore { return &memUserStore{m: map[uuid.UUID]*authkit.User{}} }
|
||||
|
||||
func (s *memUserStore) CreateUser(_ context.Context, u *authkit.User) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, existing := range s.m {
|
||||
if existing.EmailNormalized == u.EmailNormalized {
|
||||
return authkit.ErrEmailTaken
|
||||
}
|
||||
}
|
||||
cp := *u
|
||||
s.m[u.ID] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) GetUserByID(_ context.Context, id uuid.UUID) (*authkit.User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u, ok := s.m[id]
|
||||
if !ok {
|
||||
return nil, authkit.ErrUserNotFound
|
||||
}
|
||||
cp := *u
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memUserStore) GetUserByEmail(_ context.Context, normalized string) (*authkit.User, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, u := range s.m {
|
||||
if u.EmailNormalized == normalized {
|
||||
cp := *u
|
||||
return &cp, nil
|
||||
}
|
||||
}
|
||||
return nil, authkit.ErrUserNotFound
|
||||
}
|
||||
func (s *memUserStore) UpdateUser(_ context.Context, u *authkit.User) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *u
|
||||
s.m[u.ID] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) DeleteUser(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.m, id)
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) SetPassword(_ context.Context, id uuid.UUID, encoded string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if u, ok := s.m[id]; ok {
|
||||
u.PasswordHash = encoded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) SetEmailVerified(_ context.Context, id uuid.UUID, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if u, ok := s.m[id]; ok {
|
||||
u.EmailVerifiedAt = &at
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memUserStore) BumpSessionVersion(_ context.Context, id uuid.UUID) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if u, ok := s.m[id]; ok {
|
||||
u.SessionVersion++
|
||||
return u.SessionVersion, nil
|
||||
}
|
||||
return 0, authkit.ErrUserNotFound
|
||||
}
|
||||
func (s *memUserStore) IncrementFailedLogins(_ context.Context, id uuid.UUID) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if u, ok := s.m[id]; ok {
|
||||
u.FailedLogins++
|
||||
return u.FailedLogins, nil
|
||||
}
|
||||
return 0, authkit.ErrUserNotFound
|
||||
}
|
||||
func (s *memUserStore) ResetFailedLogins(_ context.Context, id uuid.UUID) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if u, ok := s.m[id]; ok {
|
||||
u.FailedLogins = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type memSessionStore struct {
|
||||
mu sync.Mutex
|
||||
m map[string]*authkit.Session
|
||||
}
|
||||
|
||||
func newMemSessionStore() *memSessionStore {
|
||||
return &memSessionStore{m: map[string]*authkit.Session{}}
|
||||
}
|
||||
func (s *memSessionStore) CreateSession(_ context.Context, sess *authkit.Session) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *sess
|
||||
s.m[string(sess.IDHash)] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) GetSession(_ context.Context, h []byte) (*authkit.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return nil, authkit.ErrSessionInvalid
|
||||
}
|
||||
cp := *sess
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memSessionStore) TouchSession(_ context.Context, h []byte, lastSeen, newExp time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if sess, ok := s.m[string(h)]; ok {
|
||||
sess.LastSeenAt = lastSeen
|
||||
sess.ExpiresAt = newExp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) DeleteSession(_ context.Context, h []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.m, string(h))
|
||||
return nil
|
||||
}
|
||||
func (s *memSessionStore) DeleteUserSessions(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (s *memSessionStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type memTokenStore struct{}
|
||||
|
||||
func (memTokenStore) CreateToken(_ context.Context, _ *authkit.Token) error { return nil }
|
||||
func (memTokenStore) ConsumeToken(_ context.Context, _ authkit.TokenKind, _ []byte, _ time.Time) (*authkit.Token, error) {
|
||||
return nil, authkit.ErrTokenInvalid
|
||||
}
|
||||
func (memTokenStore) GetToken(_ context.Context, _ authkit.TokenKind, _ []byte) (*authkit.Token, error) {
|
||||
return nil, authkit.ErrTokenInvalid
|
||||
}
|
||||
func (memTokenStore) DeleteByChain(_ context.Context, _ string) (int64, error) { return 0, nil }
|
||||
func (memTokenStore) DeleteExpired(_ context.Context, _ time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type memServiceKeyStore struct {
|
||||
mu sync.Mutex
|
||||
m map[string]*authkit.ServiceKey
|
||||
}
|
||||
|
||||
func newMemServiceKeyStore() *memServiceKeyStore {
|
||||
return &memServiceKeyStore{m: map[string]*authkit.ServiceKey{}}
|
||||
}
|
||||
func (s *memServiceKeyStore) CreateServiceKey(_ context.Context, k *authkit.ServiceKey) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cp := *k
|
||||
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||
s.m[string(k.IDHash)] = &cp
|
||||
return nil
|
||||
}
|
||||
func (s *memServiceKeyStore) GetServiceKey(_ context.Context, h []byte) (*authkit.ServiceKey, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
k, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return nil, authkit.ErrServiceKeyInvalid
|
||||
}
|
||||
cp := *k
|
||||
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||
return &cp, nil
|
||||
}
|
||||
func (s *memServiceKeyStore) ListServiceKeysByOwner(_ context.Context, kind string, owner uuid.UUID) ([]*authkit.ServiceKey, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var out []*authkit.ServiceKey
|
||||
for _, k := range s.m {
|
||||
if k.OwnerKind == kind && k.OwnerID == owner {
|
||||
cp := *k
|
||||
cp.Abilities = append([]string(nil), k.Abilities...)
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *memServiceKeyStore) TouchServiceKey(_ context.Context, h []byte, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if k, ok := s.m[string(h)]; ok {
|
||||
k.LastUsedAt = &at
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memServiceKeyStore) RevokeServiceKey(_ context.Context, h []byte, at time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
k, ok := s.m[string(h)]
|
||||
if !ok {
|
||||
return authkit.ErrServiceKeyInvalid
|
||||
}
|
||||
if k.RevokedAt != nil {
|
||||
return authkit.ErrServiceKeyInvalid
|
||||
}
|
||||
k.RevokedAt = &at
|
||||
return nil
|
||||
}
|
||||
|
||||
type memRoleStore struct{}
|
||||
|
||||
func (memRoleStore) CreateRole(_ context.Context, _ *authkit.Role) error { return nil }
|
||||
func (memRoleStore) GetRoleByID(_ context.Context, _ uuid.UUID) (*authkit.Role, error) {
|
||||
return nil, authkit.ErrRoleNotFound
|
||||
}
|
||||
func (memRoleStore) GetRoleByName(_ context.Context, _ string) (*authkit.Role, error) {
|
||||
return nil, authkit.ErrRoleNotFound
|
||||
}
|
||||
func (memRoleStore) ListRoles(_ context.Context) ([]*authkit.Role, error) { return nil, nil }
|
||||
func (memRoleStore) DeleteRole(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (memRoleStore) AssignRoleToUser(_ context.Context, _, _ uuid.UUID) error { return nil }
|
||||
func (memRoleStore) RemoveRoleFromUser(_ context.Context, _, _ uuid.UUID) error { return nil }
|
||||
func (memRoleStore) GetUserRoles(_ context.Context, _ uuid.UUID) ([]*authkit.Role, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (memRoleStore) HasAnyRole(_ context.Context, _ uuid.UUID, _ []string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type memPermStore struct{}
|
||||
|
||||
func (memPermStore) CreatePermission(_ context.Context, _ *authkit.Permission) error { return nil }
|
||||
func (memPermStore) GetPermissionByID(_ context.Context, _ uuid.UUID) (*authkit.Permission, error) {
|
||||
return nil, authkit.ErrPermissionNotFound
|
||||
}
|
||||
func (memPermStore) GetPermissionByName(_ context.Context, _ string) (*authkit.Permission, error) {
|
||||
return nil, authkit.ErrPermissionNotFound
|
||||
}
|
||||
func (memPermStore) ListPermissions(_ context.Context) ([]*authkit.Permission, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (memPermStore) DeletePermission(_ context.Context, _ uuid.UUID) error { return nil }
|
||||
func (memPermStore) AssignPermissionToRole(_ context.Context, _, _ uuid.UUID) error { return nil }
|
||||
func (memPermStore) RemovePermissionFromRole(_ context.Context, _, _ uuid.UUID) error { return nil }
|
||||
func (memPermStore) GetRolePermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (memPermStore) GetUserPermissions(_ context.Context, _ uuid.UUID) ([]*authkit.Permission, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubHasher struct{}
|
||||
|
||||
func (stubHasher) Hash(p string) (string, error) { return "stub:" + p, nil }
|
||||
func (stubHasher) Verify(p, encoded string) (bool, bool, error) {
|
||||
return encoded == "stub:"+p, false, nil
|
||||
}
|
||||
|
||||
func newTestAuth(t *testing.T) *authkit.Auth {
|
||||
func freshAuth(t *testing.T) *authkit.Auth {
|
||||
t.Helper()
|
||||
return authkit.New(authkit.Deps{
|
||||
Users: newMemUserStore(),
|
||||
Sessions: newMemSessionStore(),
|
||||
Tokens: memTokenStore{},
|
||||
ServiceKeys: newMemServiceKeyStore(),
|
||||
Roles: memRoleStore{},
|
||||
Permissions: memPermStore{},
|
||||
Hasher: stubHasher{},
|
||||
url := os.Getenv("AUTHKIT_TEST_DATABASE_URL")
|
||||
if url == "" {
|
||||
t.Skip("AUTHKIT_TEST_DATABASE_URL not set; skipping integration test")
|
||||
}
|
||||
db, err := sql.Open("pgx", url)
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
if err := db.PingContext(context.Background()); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
dropAuthkitTables(t, db)
|
||||
t.Cleanup(func() { dropAuthkitTables(t, db) })
|
||||
|
||||
a, err := authkit.New(context.Background(), authkit.Deps{
|
||||
DB: db,
|
||||
Hasher: hasher.NewArgon2id(hasher.DefaultArgon2idParams(), nil),
|
||||
}, authkit.Config{
|
||||
JWTSecret: []byte("test-secret-thirty-two-bytes!!!!"),
|
||||
JWTIssuer: "mw-test",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: 1 * time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
PasswordResetTTL: time.Hour,
|
||||
MagicLinkTTL: time.Minute,
|
||||
JWTSecret: []byte("integration-secret-thirty-two!!!"),
|
||||
JWTIssuer: "authkit-mw-int",
|
||||
AccessTokenTTL: 2 * time.Minute,
|
||||
RefreshTokenTTL: time.Hour,
|
||||
SessionIdleTTL: time.Hour,
|
||||
SessionAbsoluteTTL: 24 * time.Hour,
|
||||
EmailVerifyTTL: time.Hour,
|
||||
PasswordResetTTL: time.Hour,
|
||||
MagicLinkTTL: time.Minute,
|
||||
EmailOTPTTL: time.Minute,
|
||||
EmailOTPMaxAttempts: 3,
|
||||
// Plain HTTP for tests so secure-cookie defaults don't interfere
|
||||
// with httptest's HTTP server.
|
||||
SessionCookieSecure: authkit.BoolPtr(false),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("authkit.New: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// Bearer-style request helper.
|
||||
func req(token string) *http.Request {
|
||||
func dropAuthkitTables(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
tables := []string{
|
||||
"authkit_service_key_abilities",
|
||||
"authkit_user_permissions",
|
||||
"authkit_user_roles",
|
||||
"authkit_role_permissions",
|
||||
"authkit_service_keys",
|
||||
"authkit_abilities",
|
||||
"authkit_roles",
|
||||
"authkit_permissions",
|
||||
"authkit_tokens",
|
||||
"authkit_sessions",
|
||||
"authkit_users",
|
||||
"authkit_schema_migrations",
|
||||
}
|
||||
ctx := context.Background()
|
||||
for _, name := range tables {
|
||||
_, _ = db.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", name))
|
||||
}
|
||||
}
|
||||
|
||||
// reqWithBearer issues a request carrying Authorization: Bearer <token>.
|
||||
func reqWithBearer(token string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if token != "" {
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
|
|
@ -323,191 +96,263 @@ func req(token string) *http.Request {
|
|||
|
||||
func ok200(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }
|
||||
|
||||
// ─── tests ─────────────────────────────────────────────────────────────────
|
||||
// ─── RequireLogin ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestRequireServiceKey_Authenticates(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
plain, _, err := a.IssueServiceKey(context.Background(),
|
||||
"application", uuid.New(), "ci", []string{"events:write"}, nil)
|
||||
func TestRequireLogin_AcceptsSessionCookie(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
u, err := a.CreateUser(ctx, "alice@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
plain, _, err := a.IssueSession(ctx, u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
}
|
||||
|
||||
var seen *authkit.ServiceKey
|
||||
handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
k, ok := middleware.ServiceKeyFrom(r.Context())
|
||||
if !ok {
|
||||
t.Fatalf("no ServiceKey on context")
|
||||
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := authkit.UserIDFromCtx(r.Context())
|
||||
if !ok || uid != u.ID {
|
||||
t.Fatalf("user_id missing or wrong on context: ok=%v id=%v", ok, uid)
|
||||
}
|
||||
seen = k
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.AddCookie(a.SessionCookie(plain, time.Now().Add(time.Hour)))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req(plain))
|
||||
handler.ServeHTTP(rr, r)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", rr.Code)
|
||||
}
|
||||
if seen == nil || !seen.HasAbility("events:write") {
|
||||
t.Fatalf("expected ServiceKey with events:write ability; got %+v", seen)
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireServiceKey_RejectsRevoked(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
plain, _, err := a.IssueServiceKey(context.Background(),
|
||||
"application", uuid.New(), "ci", nil, nil)
|
||||
func TestRequireLogin_AcceptsJWT(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
u, err := a.CreateUser(ctx, "j@j.com")
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
if err := a.RevokeServiceKey(context.Background(), plain); err != nil {
|
||||
t.Fatalf("RevokeServiceKey: %v", err)
|
||||
access, _, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
|
||||
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(access))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireLogin_RejectsUnauthenticated(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
handler := middleware.RequireLogin(middleware.LoginOptions{Auth: a})(http.HandlerFunc(ok200))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireLogin_AuthzRoleGate(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
if _, err := a.CreateRole(ctx, "admin", ""); err != nil {
|
||||
t.Fatalf("CreateRole: %v", err)
|
||||
}
|
||||
u, err := a.CreateUser(ctx, "noadmin@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
access, _, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
handler := middleware.RequireLogin(middleware.LoginOptions{
|
||||
Auth: a,
|
||||
Authz: authkit.HasRole("admin"),
|
||||
})(http.HandlerFunc(ok200))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(access))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("non-admin should get 403, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// Promote the user to admin and retry.
|
||||
if err := a.AssignRole(ctx, u.ID, "admin"); err != nil {
|
||||
t.Fatalf("AssignRole: %v", err)
|
||||
}
|
||||
access2, _, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
rr = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(access2))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("admin should get 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireLogin_PanicsOnUnknownSlug(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("expected panic on unknown role slug")
|
||||
}
|
||||
}()
|
||||
middleware.RequireLogin(middleware.LoginOptions{
|
||||
Auth: a,
|
||||
Authz: authkit.HasRole("never-registered"),
|
||||
})
|
||||
}
|
||||
|
||||
// ─── RequireGuest ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestRequireGuest_LetsUnauthenticatedThrough(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
called := false
|
||||
handler := middleware.RequireServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req(plain))
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want 401", rr.Code)
|
||||
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if !called {
|
||||
t.Fatalf("guest middleware should pass through unauthenticated request")
|
||||
}
|
||||
if called {
|
||||
t.Fatalf("handler should not have been invoked for revoked key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAbility_AcceptsServiceKeyWithAbility(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
plain, _, err := a.IssueServiceKey(context.Background(),
|
||||
"application", uuid.New(), "ci", []string{"events:write"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
}
|
||||
chain := middleware.RequireServiceKey(middleware.Options{Auth: a})(
|
||||
middleware.RequireAbility("events:write")(http.HandlerFunc(ok200)))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
chain.ServeHTTP(rr, req(plain))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
// Same chain but ability the key does not carry → 403.
|
||||
chainBad := middleware.RequireServiceKey(middleware.Options{Auth: a})(
|
||||
middleware.RequireAbility("admin:nuke")(http.HandlerFunc(ok200)))
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
chainBad.ServeHTTP(rr, req(plain))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("missing-ability status = %d, want 403", rr.Code)
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAbility_RejectsUserPrincipal(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2")
|
||||
func TestRequireGuest_BlocksAuthenticated(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
u, err := a.CreateUser(ctx, "g@g.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
plain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
|
||||
access, _, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
chain := middleware.RequireSession(middleware.Options{Auth: a})(
|
||||
middleware.RequireAbility("events:write")(http.HandlerFunc(ok200)))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
chain.ServeHTTP(rr, req(plain))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want 403 (user principal carries no abilities)", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRole_RejectsServiceKey(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
plain, _, err := a.IssueServiceKey(context.Background(),
|
||||
"application", uuid.New(), "ci", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
}
|
||||
chain := middleware.RequireServiceKey(middleware.Options{Auth: a})(
|
||||
middleware.RequireRole("admin")(http.HandlerFunc(ok200)))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
chain.ServeHTTP(rr, req(plain))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want 403 (service key carries no Principal/role)", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAnyOrServiceKey(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
u, err := a.Register(context.Background(), "alice@example.com", "hunter2hunter2")
|
||||
if err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
sessionPlain, _, err := a.IssueSession(context.Background(), u.ID, "ua", netip.MustParseAddr("127.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("IssueSession: %v", err)
|
||||
}
|
||||
servicePlain, _, err := a.IssueServiceKey(context.Background(),
|
||||
"application", uuid.New(), "ci", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
}
|
||||
|
||||
type subject struct {
|
||||
hasPrincipal bool
|
||||
hasServiceKey bool
|
||||
}
|
||||
var got subject
|
||||
handler := middleware.RequireAnyOrServiceKey(middleware.Options{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, hp := middleware.PrincipalFrom(r.Context())
|
||||
_, hs := middleware.ServiceKeyFrom(r.Context())
|
||||
got = subject{hp, hs}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
handlerCalled := false
|
||||
handler := middleware.RequireGuest(middleware.GuestOptions{Auth: a})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
}))
|
||||
|
||||
// Session token → Principal in context, no ServiceKey.
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req(sessionPlain))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("session: status = %d, want 200", rr.Code)
|
||||
handler.ServeHTTP(rr, reqWithBearer(access))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", rr.Code)
|
||||
}
|
||||
if !got.hasPrincipal || got.hasServiceKey {
|
||||
t.Fatalf("session: ctx subject = %+v, want principal-only", got)
|
||||
}
|
||||
|
||||
// Service token → ServiceKey in context, no Principal.
|
||||
rr = httptest.NewRecorder()
|
||||
got = subject{}
|
||||
handler.ServeHTTP(rr, req(servicePlain))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("service: status = %d, want 200", rr.Code)
|
||||
}
|
||||
if got.hasPrincipal || !got.hasServiceKey {
|
||||
t.Fatalf("service: ctx subject = %+v, want servicekey-only", got)
|
||||
}
|
||||
|
||||
// Garbage token → 401, neither subject set.
|
||||
rr = httptest.NewRecorder()
|
||||
got = subject{}
|
||||
handler.ServeHTTP(rr, req(strings.Repeat("x", 50)))
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("garbage: status = %d, want 401", rr.Code)
|
||||
if handlerCalled {
|
||||
t.Fatalf("handler should not run for authenticated request")
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check: the constructed *authkit.Auth should satisfy errors.Is on the
|
||||
// canonical sentinels — ensures our minimal stores are wired correctly.
|
||||
func TestSentinelsReachable(t *testing.T) {
|
||||
a := newTestAuth(t)
|
||||
_, err := a.AuthenticateServiceKey(context.Background(), "sk_not-real")
|
||||
if !errors.Is(err, authkit.ErrServiceKeyInvalid) {
|
||||
t.Fatalf("expected ErrServiceKeyInvalid, got %v", err)
|
||||
func TestRequireGuest_CustomOnAuthenticated(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
u, err := a.CreateUser(ctx, "custom@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
access, _, err := a.IssueJWT(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueJWT: %v", err)
|
||||
}
|
||||
handler := middleware.RequireGuest(middleware.GuestOptions{
|
||||
Auth: a,
|
||||
OnAuthenticated: func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||||
},
|
||||
})(http.HandlerFunc(ok200))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(access))
|
||||
if rr.Code != http.StatusSeeOther {
|
||||
t.Fatalf("expected 303, got %d", rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("Location"); got != "/dashboard" {
|
||||
t.Fatalf("expected Location=/dashboard, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── RequireServiceKey ─────────────────────────────────────────────────────
|
||||
|
||||
func TestRequireServiceKey_AbilityGate(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
|
||||
t.Fatalf("CreateAbility: %v", err)
|
||||
}
|
||||
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
|
||||
Name: "ci",
|
||||
Abilities: []string{"events:write"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
}
|
||||
|
||||
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
|
||||
Auth: a,
|
||||
Authz: authkit.HasAbility("events:write"),
|
||||
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
k, ok := authkit.ServiceKeyFromCtx(r.Context())
|
||||
if !ok || !k.HasAbility("events:write") {
|
||||
t.Fatalf("expected ServiceKey with events:write on context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(plain))
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireServiceKey_AbilityGateRejectsMissing(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
ctx := context.Background()
|
||||
if _, err := a.CreateAbility(ctx, "events:write", ""); err != nil {
|
||||
t.Fatalf("CreateAbility events:write: %v", err)
|
||||
}
|
||||
if _, err := a.CreateAbility(ctx, "admin:nuke", ""); err != nil {
|
||||
t.Fatalf("CreateAbility admin:nuke: %v", err)
|
||||
}
|
||||
plain, _, err := a.IssueServiceKey(ctx, authkit.IssueServiceKeyParams{
|
||||
Name: "ci",
|
||||
Abilities: []string{"events:write"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("IssueServiceKey: %v", err)
|
||||
}
|
||||
handler := middleware.RequireServiceKey(middleware.ServiceKeyOptions{
|
||||
Auth: a,
|
||||
Authz: authkit.HasAbility("admin:nuke"),
|
||||
})(http.HandlerFunc(ok200))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, reqWithBearer(plain))
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireServiceKey_PanicsOnUnknownAbility(t *testing.T) {
|
||||
a := freshAuth(t)
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("expected panic on unknown ability slug")
|
||||
}
|
||||
}()
|
||||
middleware.RequireServiceKey(middleware.ServiceKeyOptions{
|
||||
Auth: a,
|
||||
Authz: authkit.HasAbility("never-registered"),
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue