diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 7d0e735..73a380d 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -8,9 +8,9 @@ import ( ) // AuthMiddleware checks for JWT token and adds user to context if valid -func AuthMiddleware(authService *service.AuthService, userService *service.UserService) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func AuthMiddleware(authService *service.AuthService, userService *service.UserService) Middleware { + return func(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { // Get JWT from cookie cookie, err := r.Cookie("auth_token") if err != nil { @@ -50,12 +50,12 @@ func AuthMiddleware(authService *service.AuthService, userService *service.UserS // Add user to context ctx := ctxkeys.WithUser(r.Context(), user) next.ServeHTTP(w, r.WithContext(ctx)) - }) + } } } // RequireGuest ensures request is not authenticated -func RequireGuest(next http.HandlerFunc) http.HandlerFunc { +func RequireGuest(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := ctxkeys.User(r.Context()) if user != nil { @@ -67,7 +67,7 @@ func RequireGuest(next http.HandlerFunc) http.HandlerFunc { } // RequireAuth ensures the user is authenticated and has completed onboarding -func RequireAuth(next http.HandlerFunc) http.HandlerFunc { +func RequireAuth(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := ctxkeys.User(r.Context()) if user == nil { diff --git a/internal/middleware/cache.go b/internal/middleware/cache.go index 0291f3c..53889e2 100644 --- a/internal/middleware/cache.go +++ b/internal/middleware/cache.go @@ -5,7 +5,7 @@ import "net/http" // CacheStatic wraps a handler to set long-lived cache headers for static assets. // Assets use query-string cache busting (?v=), so it's safe to cache // them indefinitely — the URL changes when the content changes. -func CacheStatic(h http.Handler) http.Handler { +func CacheStatic(h http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") h.ServeHTTP(w, r) @@ -15,7 +15,7 @@ func CacheStatic(h http.Handler) http.Handler { // NoCacheDynamic sets Cache-Control: no-cache on responses so browsers always // revalidate with the server. This prevents stale HTML from being shown after // navigation (e.g. back button) while still allowing conditional requests. -func NoCacheDynamic(next http.Handler) http.Handler { +func NoCacheDynamic(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip static assets — they're handled by CacheStatic. if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/assets/" { diff --git a/internal/middleware/chain.go b/internal/middleware/chain.go deleted file mode 100644 index 5dbf70a..0000000 --- a/internal/middleware/chain.go +++ /dev/null @@ -1,21 +0,0 @@ -package middleware - -import "net/http" - -// Chain applies multiple middleware in order (first to last) -// The middleware are executed in the order they are provided -// -// Example: -// -// handler := Chain(mux, -// AuthMiddleware(...), // Executes first -// WithURLPath, // Executes second -// Config(...), // Executes third -// ) -func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler { - // Apply middleware in reverse order so they execute in the order provided - for i := len(middlewares) - 1; i >= 0; i-- { - h = middlewares[i](h) - } - return h -} diff --git a/internal/middleware/config.go b/internal/middleware/config.go index 9dd7f07..a02f245 100644 --- a/internal/middleware/config.go +++ b/internal/middleware/config.go @@ -9,11 +9,11 @@ import ( // Config middleware adds the sanitized app configuration to the request context. // Sensitive values like JWTSecret and DBPath are excluded for security. -func Config(cfg *config.Config) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func Config(cfg *config.Config) Middleware { + return func(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { ctx := ctxkeys.WithConfig(r.Context(), cfg.Sanitized()) next.ServeHTTP(w, r.WithContext(ctx)) - }) + } } } diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 4485b38..d23f088 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -20,7 +20,7 @@ const ( ) // CSRFProtection validates CSRF tokens on all state-changing requests -func CSRFProtection(next http.Handler) http.Handler { +func CSRFProtection(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip CSRF check for safe methods (GET, HEAD, OPTIONS) if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { diff --git a/internal/middleware/logging.go b/internal/middleware/logging.go index 2285dea..14ed726 100644 --- a/internal/middleware/logging.go +++ b/internal/middleware/logging.go @@ -38,7 +38,7 @@ var skipLoggingPaths = []string{ // RequestLogging logs HTTP requests with method, path, status, and duration // Skips logging for paths defined in skipLoggingPaths -func RequestLogging(next http.Handler) http.Handler { +func RequestLogging(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip logging for configured paths for _, prefix := range skipLoggingPaths { diff --git a/internal/middleware/utils.go b/internal/middleware/middleware.go similarity index 51% rename from internal/middleware/utils.go rename to internal/middleware/middleware.go index bd6bfb2..f38c881 100644 --- a/internal/middleware/utils.go +++ b/internal/middleware/middleware.go @@ -7,6 +7,26 @@ import ( "git.juancwu.dev/juancwu/budgit/internal/ui/pages" ) +type Middleware func(http.Handler) http.HandlerFunc + +// Chain applies multiple middleware in order (first to last) +// The middleware are executed in the order they are provided +// +// Example: +// +// handler := Chain(mux, +// AuthMiddleware(...), // Executes first +// WithURLPath, // Executes second +// Config(...), // Executes third +// ) +func Chain(h http.Handler, middlewares ...Middleware) http.Handler { + // Apply middleware in reverse order so they execute in the order provided + for i := len(middlewares) - 1; i >= 0; i-- { + h = middlewares[i](h) + } + return h +} + // Redirect handles both HTMX and regular HTTP redirects. // For HTMX requests, it sets the HX-Redirect header; for regular requests, // it uses http.Redirect. diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index e8fa8e1..5a39331 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -97,18 +97,12 @@ func (rl *RateLimiter) cleanup() { } } -// RateLimitAuth creates middleware for auth endpoints -// Limits: 5 requests per 15 minutes per IP -func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc { - limiter := NewRateLimiter(5, 15*time.Minute) - - return func(next http.HandlerFunc) http.HandlerFunc { +// Middleware returns a Middleware that enforces this rate limiter per client IP. +func (rl *RateLimiter) Middleware() Middleware { + return func(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // Get real IP (handle proxies) ip := getClientIP(r) - - // Check rate limit - if !limiter.Allow(ip) { + if !rl.Allow(ip) { slog.Warn("rate limit exceeded", "ip", ip, "path", r.URL.Path, @@ -116,32 +110,8 @@ func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc { http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) return } - - next(w, r) - } - } -} - -// RateLimitCRUD creates middleware for state-changing CRUD endpoints. -// Limits: 60 requests per minute per IP. -func RateLimitCRUD() func(http.Handler) http.Handler { - limiter := NewRateLimiter(60, 1*time.Minute) - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := getClientIP(r) - - if !limiter.Allow(ip) { - slog.Warn("CRUD rate limit exceeded", - "ip", ip, - "path", r.URL.Path, - ) - http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) + } } } diff --git a/internal/middleware/security_headers.go b/internal/middleware/security_headers.go index 1a83911..2942e37 100644 --- a/internal/middleware/security_headers.go +++ b/internal/middleware/security_headers.go @@ -6,26 +6,24 @@ import ( // SecurityHeaders sets common security response headers on every response. // Note: HSTS is handled by Caddy at the reverse proxy layer. -func SecurityHeaders() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() +func SecurityHeaders(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h := w.Header() - h.Set("Content-Security-Policy", - "default-src 'self'; "+ - "style-src 'self' 'unsafe-inline'; "+ - "img-src 'self' data:; "+ - "font-src 'self'; "+ - "frame-ancestors 'none'; "+ - "base-uri 'self'; "+ - "form-action 'self'") + h.Set("Content-Security-Policy", + "default-src 'self'; "+ + "style-src 'self' 'unsafe-inline'; "+ + "img-src 'self' data:; "+ + "font-src 'self'; "+ + "frame-ancestors 'none'; "+ + "base-uri 'self'; "+ + "form-action 'self'") - h.Set("X-Frame-Options", "DENY") - h.Set("X-Content-Type-Options", "nosniff") - h.Set("Referrer-Policy", "strict-origin-when-cross-origin") - h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()") + h.Set("X-Frame-Options", "DENY") + h.Set("X-Content-Type-Options", "nosniff") + h.Set("Referrer-Policy", "strict-origin-when-cross-origin") + h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()") - next.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) } } diff --git a/internal/middleware/urlpath.go b/internal/middleware/urlpath.go index d565f43..710b5fc 100644 --- a/internal/middleware/urlpath.go +++ b/internal/middleware/urlpath.go @@ -7,9 +7,9 @@ import ( ) // WithURLPath adds the current URL's path to the context -func WithURLPath(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func WithURLPath(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { ctxWithPath := ctxkeys.WithURLPath(r.Context(), r.URL.Path) next.ServeHTTP(w, r.WithContext(ctxWithPath)) - }) + } } diff --git a/internal/router/router.go b/internal/router/router.go index c0def3e..ccc92a9 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,21 +1,29 @@ package router -import "net/http" +import ( + "net/http" + "time" -type Middleware func(http.Handler) http.Handler + "git.juancwu.dev/juancwu/budgit/internal/middleware" +) type Group struct { prefix string - middleware []Middleware + middleware []middleware.Middleware + limiter *middleware.RateLimiter + parent *Group mux *http.ServeMux } -func newGroup(mux *http.ServeMux, prefix string, mw []Middleware) *Group { - return &Group{prefix: prefix, middleware: mw, mux: mux} +func (g *Group) Use(mw ...middleware.Middleware) { + g.middleware = append(g.middleware, mw...) } -func (g *Group) Use(mw ...Middleware) { - g.middleware = append(g.middleware, mw...) +// RateLimit sets a rate limit on this group. It runs before any middleware +// in the chain, including inherited middleware from parent groups. +// Parent group rate limits are checked first (root → leaf order). +func (g *Group) RateLimit(limit int, window time.Duration) { + g.limiter = middleware.NewRateLimiter(limit, window) } type Method string @@ -28,78 +36,120 @@ const ( MethodPatch Method = "PATCH" ) -func (g *Group) Handle(method Method, path string, handler http.HandlerFunc) { +func (g *Group) Handle(method Method, path string, handler http.HandlerFunc, mw ...middleware.Middleware) { + // Build chain: [rate limiters root→self] → [middleware root→self] → [route mw] → handler + rateLimiters := g.collectRateLimiters() + middlewares := g.collectMiddleware() + middlewares = append(middlewares, mw...) + + chain := append(rateLimiters, middlewares...) + pattern := string(method) + " " + g.prefix + path - wrapped := chain(handler, g.middleware) + wrapped := middleware.Chain(handler, chain...) g.mux.Handle(pattern, wrapped) } -func (g *Group) Get(path string, h http.HandlerFunc) { g.Handle(MethodGet, path, h) } -func (g *Group) Post(path string, h http.HandlerFunc) { g.Handle(MethodPost, path, h) } -func (g *Group) Put(path string, h http.HandlerFunc) { g.Handle(MethodPut, path, h) } -func (g *Group) Patch(path string, h http.HandlerFunc) { g.Handle(MethodPatch, path, h) } -func (g *Group) Delete(path string, h http.HandlerFunc) { g.Handle(MethodDelete, path, h) } -// SubGroup creates a nested group with accumulated prefix and middleware. -// Middleware added inside fn does not affect the parent group. +func (g *Group) Get(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + g.Handle(MethodGet, path, h, mw...) +} +func (g *Group) Post(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + g.Handle(MethodPost, path, h, mw...) +} +func (g *Group) Put(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + g.Handle(MethodPut, path, h, mw...) +} +func (g *Group) Patch(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + g.Handle(MethodPatch, path, h, mw...) +} +func (g *Group) Delete(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + g.Handle(MethodDelete, path, h, mw...) +} + +// SubGroup creates a nested group. It inherits rate limits and middleware +// from the parent via the parent pointer (not by copying). func (g *Group) SubGroup(prefix string, fn func(*Group)) { - mw := make([]Middleware, len(g.middleware)) - copy(mw, g.middleware) - sub := newGroup(g.mux, g.prefix+prefix, mw) + sub := &Group{ + prefix: g.prefix + prefix, + parent: g, + mux: g.mux, + } fn(sub) } -// RouteGroup is implemented by feature modules to register their routes. -type RouteGroup interface { - Prefix() string - Register(g *Group) +// collectRateLimiters walks up the parent chain and returns rate limit +// middleware in root → leaf order. +func (g *Group) collectRateLimiters() []middleware.Middleware { + var result []middleware.Middleware + if g.parent != nil { + result = g.parent.collectRateLimiters() + } + if g.limiter != nil { + result = append(result, g.limiter.Middleware()) + } + return result } -// MiddlewareProvider is optionally implemented by RouteGroups that need -// group-level middleware. -type MiddlewareProvider interface { - Middlewares() []Middleware +// collectMiddleware walks up the parent chain and returns middleware +// in root → leaf order. +func (g *Group) collectMiddleware() []middleware.Middleware { + var result []middleware.Middleware + if g.parent != nil { + result = g.parent.collectMiddleware() + } + result = append(result, g.middleware...) + return result } type Router struct { - mux *http.ServeMux - middleware []Middleware + root *Group + mux *http.ServeMux } func New() *Router { - return &Router{mux: http.NewServeMux()} + mux := http.NewServeMux() + return &Router{ + mux: mux, + root: &Group{ + mux: mux, + }, + } } func (r *Router) Mux() *http.ServeMux { return r.mux } -func (r *Router) Use(mw ...Middleware) { - r.middleware = append(r.middleware, mw...) +func (r *Router) Use(mw ...middleware.Middleware) { + r.root.Use(mw...) } -// Mount registers one or more RouteGroups. -func (r *Router) Mount(groups ...RouteGroup) { - for _, rg := range groups { - var mw []Middleware - if mp, ok := rg.(MiddlewareProvider); ok { - mw = mp.Middlewares() - } - g := newGroup(r.mux, rg.Prefix(), mw) - rg.Register(g) - } +// Group creates a route group that inherits global middleware from the router. +func (r *Router) Group(prefix string, fn func(*Group)) { + r.root.SubGroup(prefix, fn) } -// Handler returns the final http.Handler with global middleware applied. +// Handler returns the final http.Handler. All middleware is already applied +// per-route through the group hierarchy, so this just returns the mux. func (r *Router) Handler() http.Handler { - if len(r.middleware) == 0 { - return r.mux - } - return chain(r.mux, r.middleware) + return r.mux } -func chain(base http.Handler, mws []Middleware) http.Handler { - for i := len(mws) - 1; i >= 0; i-- { - base = mws[i](base) - } - return base +func (r *Router) Handle(method Method, path string, handler http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Handle(method, path, handler, mw...) +} + +func (r *Router) Get(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Get(path, h, mw...) +} +func (r *Router) Post(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Post(path, h, mw...) +} +func (r *Router) Put(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Put(path, h, mw...) +} +func (r *Router) Patch(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Patch(path, h, mw...) +} +func (r *Router) Delete(path string, h http.HandlerFunc, mw ...middleware.Middleware) { + r.root.Delete(path, h, mw...) } diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 25504c7..843dfab 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -3,70 +3,25 @@ package routes import ( "io/fs" "net/http" + "time" "git.juancwu.dev/juancwu/budgit/assets" "git.juancwu.dev/juancwu/budgit/internal/app" "git.juancwu.dev/juancwu/budgit/internal/handler" "git.juancwu.dev/juancwu/budgit/internal/middleware" + "git.juancwu.dev/juancwu/budgit/internal/router" ) func SetupRoutes(a *app.App) http.Handler { - auth := handler.NewAuthHandler(a.AuthService, a.InviteService, a.SpaceService) - home := handler.NewHomeHandler() - settings := handler.NewSettingsHandler(a.AuthService, a.UserService) + authH := handler.NewAuthHandler(a.AuthService, a.InviteService, a.SpaceService) + homeH := handler.NewHomeHandler() + settingsH := handler.NewSettingsHandler(a.AuthService, a.UserService) - mux := http.NewServeMux() + r := router.New() - // ==================================================================================== - // PUBLIC ROUTES - // ==================================================================================== - - // Static assets with long-lived cache (cache-busted via ?v=) - sub, _ := fs.Sub(assets.AssetsFS, ".") - mux.Handle("GET /assets/", middleware.CacheStatic(http.StripPrefix("/assets/", http.FileServer(http.FS(sub))))) - - // Home - mux.HandleFunc("GET /{$}", home.HomePage) - mux.HandleFunc("GET /forbidden", home.ForbiddenPage) - mux.HandleFunc("GET /privacy", home.PrivacyPage) - mux.HandleFunc("GET /terms", home.TermsPage) - - // Auth pages - authRateLimiter := middleware.RateLimitAuth() - - mux.HandleFunc("GET /auth", middleware.RequireGuest(auth.AuthPage)) - mux.HandleFunc("GET /auth/password", middleware.RequireGuest(auth.PasswordPage)) - - // Token Verifications - mux.HandleFunc("GET /auth/magic-link/{token}", auth.VerifyMagicLink) - - // Auth Actions - mux.HandleFunc("POST /auth/magic-link", authRateLimiter(middleware.RequireGuest(auth.SendMagicLink))) - mux.HandleFunc("POST /auth/password", authRateLimiter(middleware.RequireGuest(auth.LoginWithPassword))) - mux.HandleFunc("POST /auth/logout", auth.Logout) - - // Join via invite - mux.HandleFunc("GET /join/{token}", auth.JoinSpace) - - // ==================================================================================== - // PRIVATE ROUTES - // ==================================================================================== - - crudLimiter := middleware.RateLimitCRUD() - - mux.HandleFunc("GET /auth/onboarding", middleware.RequireAuth(auth.OnboardingPage)) - mux.Handle("POST /auth/onboarding", crudLimiter(http.HandlerFunc(middleware.RequireAuth(auth.CompleteOnboarding)))) - - mux.HandleFunc("GET /app/settings", middleware.RequireAuth(settings.SettingsPage)) - mux.HandleFunc("POST /app/settings/password", authRateLimiter(middleware.RequireAuth(settings.SetPassword))) - - // 404 - mux.HandleFunc("/{path...}", home.NotFoundPage) - - // Global middlewares - handler := middleware.Chain( - mux, - middleware.SecurityHeaders(), + // Global middleware + r.Use( + middleware.SecurityHeaders, middleware.Config(a.Cfg), middleware.RequestLogging, middleware.NoCacheDynamic, @@ -75,5 +30,57 @@ func SetupRoutes(a *app.App) http.Handler { middleware.WithURLPath, ) - return handler + // Static assets (bypass router groups — registered directly on mux) + sub, _ := fs.Sub(assets.AssetsFS, ".") + r.Mux().Handle("GET /assets/", + middleware.CacheStatic(http.StripPrefix("/assets/", http.FileServer(http.FS(sub)))), + ) + + // Public pages + r.Get("/{$}", homeH.HomePage) + r.Get("/forbidden", homeH.ForbiddenPage) + r.Get("/privacy", homeH.PrivacyPage) + r.Get("/terms", homeH.TermsPage) + r.Get("/join/{token}", authH.JoinSpace) + + // Auth - guest routes + r.Group("/auth", func(g *router.Group) { + g.Use(middleware.RequireGuest) + g.Get("", authH.AuthPage) + g.Get("/password", authH.PasswordPage) + g.Get("/magic-link/{token}", authH.VerifyMagicLink) + + g.SubGroup("", func(g *router.Group) { + g.RateLimit(5, 15*time.Minute) + g.Post("/magic-link", authH.SendMagicLink) + g.Post("/password", authH.LoginWithPassword) + }) + }) + + // Auth - authenticated routes + r.Group("/auth", func(g *router.Group) { + g.Use(middleware.RequireAuth) + g.Get("/onboarding", authH.OnboardingPage) + g.Post("/onboarding", authH.CompleteOnboarding) + }) + r.Post("/auth/logout", authH.Logout) + + // App routes + r.Group("/app", func(g *router.Group) { + g.Use(middleware.RequireAuth) + + g.SubGroup("/settings", func(g *router.Group) { + g.Get("", settingsH.SettingsPage) + + g.SubGroup("", func(g *router.Group) { + g.RateLimit(5, 15*time.Minute) + g.Post("/password", settingsH.SetPassword) + }) + }) + }) + + // 404 catch-all + r.Get("/{path...}", homeH.NotFoundPage) + + return r.Handler() }