feat: use router to group routes
This commit is contained in:
parent
a3f4661456
commit
280cb93648
12 changed files with 222 additions and 198 deletions
|
|
@ -8,9 +8,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthMiddleware checks for JWT token and adds user to context if valid
|
// AuthMiddleware checks for JWT token and adds user to context if valid
|
||||||
func AuthMiddleware(authService *service.AuthService, userService *service.UserService) func(http.Handler) http.Handler {
|
func AuthMiddleware(authService *service.AuthService, userService *service.UserService) Middleware {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get JWT from cookie
|
// Get JWT from cookie
|
||||||
cookie, err := r.Cookie("auth_token")
|
cookie, err := r.Cookie("auth_token")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -50,12 +50,12 @@ func AuthMiddleware(authService *service.AuthService, userService *service.UserS
|
||||||
// Add user to context
|
// Add user to context
|
||||||
ctx := ctxkeys.WithUser(r.Context(), user)
|
ctx := ctxkeys.WithUser(r.Context(), user)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireGuest ensures request is not authenticated
|
// RequireGuest ensures request is not authenticated
|
||||||
func RequireGuest(next http.HandlerFunc) http.HandlerFunc {
|
func RequireGuest(next http.Handler) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := ctxkeys.User(r.Context())
|
user := ctxkeys.User(r.Context())
|
||||||
if user != nil {
|
if user != nil {
|
||||||
|
|
@ -67,7 +67,7 @@ func RequireGuest(next http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireAuth ensures the user is authenticated and has completed onboarding
|
// RequireAuth ensures the user is authenticated and has completed onboarding
|
||||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
func RequireAuth(next http.Handler) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := ctxkeys.User(r.Context())
|
user := ctxkeys.User(r.Context())
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import "net/http"
|
||||||
// CacheStatic wraps a handler to set long-lived cache headers for static assets.
|
// CacheStatic wraps a handler to set long-lived cache headers for static assets.
|
||||||
// Assets use query-string cache busting (?v=<timestamp>), so it's safe to cache
|
// Assets use query-string cache busting (?v=<timestamp>), so it's safe to cache
|
||||||
// them indefinitely — the URL changes when the content changes.
|
// them indefinitely — the URL changes when the content changes.
|
||||||
func CacheStatic(h http.Handler) http.Handler {
|
func CacheStatic(h http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
|
|
@ -15,7 +15,7 @@ func CacheStatic(h http.Handler) http.Handler {
|
||||||
// NoCacheDynamic sets Cache-Control: no-cache on responses so browsers always
|
// NoCacheDynamic sets Cache-Control: no-cache on responses so browsers always
|
||||||
// revalidate with the server. This prevents stale HTML from being shown after
|
// revalidate with the server. This prevents stale HTML from being shown after
|
||||||
// navigation (e.g. back button) while still allowing conditional requests.
|
// navigation (e.g. back button) while still allowing conditional requests.
|
||||||
func NoCacheDynamic(next http.Handler) http.Handler {
|
func NoCacheDynamic(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip static assets — they're handled by CacheStatic.
|
// Skip static assets — they're handled by CacheStatic.
|
||||||
if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/assets/" {
|
if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/assets/" {
|
||||||
|
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
// Chain applies multiple middleware in order (first to last)
|
|
||||||
// The middleware are executed in the order they are provided
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// handler := Chain(mux,
|
|
||||||
// AuthMiddleware(...), // Executes first
|
|
||||||
// WithURLPath, // Executes second
|
|
||||||
// Config(...), // Executes third
|
|
||||||
// )
|
|
||||||
func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
|
|
||||||
// Apply middleware in reverse order so they execute in the order provided
|
|
||||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
|
||||||
h = middlewares[i](h)
|
|
||||||
}
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
@ -9,11 +9,11 @@ import (
|
||||||
|
|
||||||
// Config middleware adds the sanitized app configuration to the request context.
|
// Config middleware adds the sanitized app configuration to the request context.
|
||||||
// Sensitive values like JWTSecret and DBPath are excluded for security.
|
// Sensitive values like JWTSecret and DBPath are excluded for security.
|
||||||
func Config(cfg *config.Config) func(http.Handler) http.Handler {
|
func Config(cfg *config.Config) Middleware {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := ctxkeys.WithConfig(r.Context(), cfg.Sanitized())
|
ctx := ctxkeys.WithConfig(r.Context(), cfg.Sanitized())
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// CSRFProtection validates CSRF tokens on all state-changing requests
|
// CSRFProtection validates CSRF tokens on all state-changing requests
|
||||||
func CSRFProtection(next http.Handler) http.Handler {
|
func CSRFProtection(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip CSRF check for safe methods (GET, HEAD, OPTIONS)
|
// Skip CSRF check for safe methods (GET, HEAD, OPTIONS)
|
||||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ var skipLoggingPaths = []string{
|
||||||
|
|
||||||
// RequestLogging logs HTTP requests with method, path, status, and duration
|
// RequestLogging logs HTTP requests with method, path, status, and duration
|
||||||
// Skips logging for paths defined in skipLoggingPaths
|
// Skips logging for paths defined in skipLoggingPaths
|
||||||
func RequestLogging(next http.Handler) http.Handler {
|
func RequestLogging(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip logging for configured paths
|
// Skip logging for configured paths
|
||||||
for _, prefix := range skipLoggingPaths {
|
for _, prefix := range skipLoggingPaths {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,26 @@ import (
|
||||||
"git.juancwu.dev/juancwu/budgit/internal/ui/pages"
|
"git.juancwu.dev/juancwu/budgit/internal/ui/pages"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Middleware func(http.Handler) http.HandlerFunc
|
||||||
|
|
||||||
|
// Chain applies multiple middleware in order (first to last)
|
||||||
|
// The middleware are executed in the order they are provided
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// handler := Chain(mux,
|
||||||
|
// AuthMiddleware(...), // Executes first
|
||||||
|
// WithURLPath, // Executes second
|
||||||
|
// Config(...), // Executes third
|
||||||
|
// )
|
||||||
|
func Chain(h http.Handler, middlewares ...Middleware) http.Handler {
|
||||||
|
// Apply middleware in reverse order so they execute in the order provided
|
||||||
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||||
|
h = middlewares[i](h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect handles both HTMX and regular HTTP redirects.
|
// Redirect handles both HTMX and regular HTTP redirects.
|
||||||
// For HTMX requests, it sets the HX-Redirect header; for regular requests,
|
// For HTMX requests, it sets the HX-Redirect header; for regular requests,
|
||||||
// it uses http.Redirect.
|
// it uses http.Redirect.
|
||||||
|
|
@ -97,18 +97,12 @@ func (rl *RateLimiter) cleanup() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RateLimitAuth creates middleware for auth endpoints
|
// Middleware returns a Middleware that enforces this rate limiter per client IP.
|
||||||
// Limits: 5 requests per 15 minutes per IP
|
func (rl *RateLimiter) Middleware() Middleware {
|
||||||
func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc {
|
return func(next http.Handler) http.HandlerFunc {
|
||||||
limiter := NewRateLimiter(5, 15*time.Minute)
|
|
||||||
|
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get real IP (handle proxies)
|
|
||||||
ip := getClientIP(r)
|
ip := getClientIP(r)
|
||||||
|
if !rl.Allow(ip) {
|
||||||
// Check rate limit
|
|
||||||
if !limiter.Allow(ip) {
|
|
||||||
slog.Warn("rate limit exceeded",
|
slog.Warn("rate limit exceeded",
|
||||||
"ip", ip,
|
"ip", ip,
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
|
|
@ -116,32 +110,8 @@ func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc {
|
||||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RateLimitCRUD creates middleware for state-changing CRUD endpoints.
|
|
||||||
// Limits: 60 requests per minute per IP.
|
|
||||||
func RateLimitCRUD() func(http.Handler) http.Handler {
|
|
||||||
limiter := NewRateLimiter(60, 1*time.Minute)
|
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ip := getClientIP(r)
|
|
||||||
|
|
||||||
if !limiter.Allow(ip) {
|
|
||||||
slog.Warn("CRUD rate limit exceeded",
|
|
||||||
"ip", ip,
|
|
||||||
"path", r.URL.Path,
|
|
||||||
)
|
|
||||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,26 +6,24 @@ import (
|
||||||
|
|
||||||
// SecurityHeaders sets common security response headers on every response.
|
// SecurityHeaders sets common security response headers on every response.
|
||||||
// Note: HSTS is handled by Caddy at the reverse proxy layer.
|
// Note: HSTS is handled by Caddy at the reverse proxy layer.
|
||||||
func SecurityHeaders() func(http.Handler) http.Handler {
|
func SecurityHeaders(next http.Handler) http.HandlerFunc {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
h := w.Header()
|
||||||
h := w.Header()
|
|
||||||
|
|
||||||
h.Set("Content-Security-Policy",
|
h.Set("Content-Security-Policy",
|
||||||
"default-src 'self'; "+
|
"default-src 'self'; "+
|
||||||
"style-src 'self' 'unsafe-inline'; "+
|
"style-src 'self' 'unsafe-inline'; "+
|
||||||
"img-src 'self' data:; "+
|
"img-src 'self' data:; "+
|
||||||
"font-src 'self'; "+
|
"font-src 'self'; "+
|
||||||
"frame-ancestors 'none'; "+
|
"frame-ancestors 'none'; "+
|
||||||
"base-uri 'self'; "+
|
"base-uri 'self'; "+
|
||||||
"form-action 'self'")
|
"form-action 'self'")
|
||||||
|
|
||||||
h.Set("X-Frame-Options", "DENY")
|
h.Set("X-Frame-Options", "DENY")
|
||||||
h.Set("X-Content-Type-Options", "nosniff")
|
h.Set("X-Content-Type-Options", "nosniff")
|
||||||
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
|
h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// WithURLPath adds the current URL's path to the context
|
// WithURLPath adds the current URL's path to the context
|
||||||
func WithURLPath(next http.Handler) http.Handler {
|
func WithURLPath(next http.Handler) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctxWithPath := ctxkeys.WithURLPath(r.Context(), r.URL.Path)
|
ctxWithPath := ctxkeys.WithURLPath(r.Context(), r.URL.Path)
|
||||||
next.ServeHTTP(w, r.WithContext(ctxWithPath))
|
next.ServeHTTP(w, r.WithContext(ctxWithPath))
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,29 @@
|
||||||
package router
|
package router
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
type Middleware func(http.Handler) http.Handler
|
"git.juancwu.dev/juancwu/budgit/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
prefix string
|
prefix string
|
||||||
middleware []Middleware
|
middleware []middleware.Middleware
|
||||||
|
limiter *middleware.RateLimiter
|
||||||
|
parent *Group
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGroup(mux *http.ServeMux, prefix string, mw []Middleware) *Group {
|
func (g *Group) Use(mw ...middleware.Middleware) {
|
||||||
return &Group{prefix: prefix, middleware: mw, mux: mux}
|
g.middleware = append(g.middleware, mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) Use(mw ...Middleware) {
|
// RateLimit sets a rate limit on this group. It runs before any middleware
|
||||||
g.middleware = append(g.middleware, mw...)
|
// in the chain, including inherited middleware from parent groups.
|
||||||
|
// Parent group rate limits are checked first (root → leaf order).
|
||||||
|
func (g *Group) RateLimit(limit int, window time.Duration) {
|
||||||
|
g.limiter = middleware.NewRateLimiter(limit, window)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Method string
|
type Method string
|
||||||
|
|
@ -28,78 +36,120 @@ const (
|
||||||
MethodPatch Method = "PATCH"
|
MethodPatch Method = "PATCH"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (g *Group) Handle(method Method, path string, handler http.HandlerFunc) {
|
func (g *Group) Handle(method Method, path string, handler http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
// Build chain: [rate limiters root→self] → [middleware root→self] → [route mw] → handler
|
||||||
|
rateLimiters := g.collectRateLimiters()
|
||||||
|
middlewares := g.collectMiddleware()
|
||||||
|
middlewares = append(middlewares, mw...)
|
||||||
|
|
||||||
|
chain := append(rateLimiters, middlewares...)
|
||||||
|
|
||||||
pattern := string(method) + " " + g.prefix + path
|
pattern := string(method) + " " + g.prefix + path
|
||||||
wrapped := chain(handler, g.middleware)
|
wrapped := middleware.Chain(handler, chain...)
|
||||||
g.mux.Handle(pattern, wrapped)
|
g.mux.Handle(pattern, wrapped)
|
||||||
}
|
}
|
||||||
func (g *Group) Get(path string, h http.HandlerFunc) { g.Handle(MethodGet, path, h) }
|
|
||||||
func (g *Group) Post(path string, h http.HandlerFunc) { g.Handle(MethodPost, path, h) }
|
|
||||||
func (g *Group) Put(path string, h http.HandlerFunc) { g.Handle(MethodPut, path, h) }
|
|
||||||
func (g *Group) Patch(path string, h http.HandlerFunc) { g.Handle(MethodPatch, path, h) }
|
|
||||||
func (g *Group) Delete(path string, h http.HandlerFunc) { g.Handle(MethodDelete, path, h) }
|
|
||||||
|
|
||||||
// SubGroup creates a nested group with accumulated prefix and middleware.
|
func (g *Group) Get(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
// Middleware added inside fn does not affect the parent group.
|
g.Handle(MethodGet, path, h, mw...)
|
||||||
|
}
|
||||||
|
func (g *Group) Post(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
g.Handle(MethodPost, path, h, mw...)
|
||||||
|
}
|
||||||
|
func (g *Group) Put(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
g.Handle(MethodPut, path, h, mw...)
|
||||||
|
}
|
||||||
|
func (g *Group) Patch(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
g.Handle(MethodPatch, path, h, mw...)
|
||||||
|
}
|
||||||
|
func (g *Group) Delete(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
g.Handle(MethodDelete, path, h, mw...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubGroup creates a nested group. It inherits rate limits and middleware
|
||||||
|
// from the parent via the parent pointer (not by copying).
|
||||||
func (g *Group) SubGroup(prefix string, fn func(*Group)) {
|
func (g *Group) SubGroup(prefix string, fn func(*Group)) {
|
||||||
mw := make([]Middleware, len(g.middleware))
|
sub := &Group{
|
||||||
copy(mw, g.middleware)
|
prefix: g.prefix + prefix,
|
||||||
sub := newGroup(g.mux, g.prefix+prefix, mw)
|
parent: g,
|
||||||
|
mux: g.mux,
|
||||||
|
}
|
||||||
fn(sub)
|
fn(sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteGroup is implemented by feature modules to register their routes.
|
// collectRateLimiters walks up the parent chain and returns rate limit
|
||||||
type RouteGroup interface {
|
// middleware in root → leaf order.
|
||||||
Prefix() string
|
func (g *Group) collectRateLimiters() []middleware.Middleware {
|
||||||
Register(g *Group)
|
var result []middleware.Middleware
|
||||||
|
if g.parent != nil {
|
||||||
|
result = g.parent.collectRateLimiters()
|
||||||
|
}
|
||||||
|
if g.limiter != nil {
|
||||||
|
result = append(result, g.limiter.Middleware())
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// MiddlewareProvider is optionally implemented by RouteGroups that need
|
// collectMiddleware walks up the parent chain and returns middleware
|
||||||
// group-level middleware.
|
// in root → leaf order.
|
||||||
type MiddlewareProvider interface {
|
func (g *Group) collectMiddleware() []middleware.Middleware {
|
||||||
Middlewares() []Middleware
|
var result []middleware.Middleware
|
||||||
|
if g.parent != nil {
|
||||||
|
result = g.parent.collectMiddleware()
|
||||||
|
}
|
||||||
|
result = append(result, g.middleware...)
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router struct {
|
type Router struct {
|
||||||
mux *http.ServeMux
|
root *Group
|
||||||
middleware []Middleware
|
mux *http.ServeMux
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() *Router {
|
func New() *Router {
|
||||||
return &Router{mux: http.NewServeMux()}
|
mux := http.NewServeMux()
|
||||||
|
return &Router{
|
||||||
|
mux: mux,
|
||||||
|
root: &Group{
|
||||||
|
mux: mux,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Mux() *http.ServeMux {
|
func (r *Router) Mux() *http.ServeMux {
|
||||||
return r.mux
|
return r.mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Use(mw ...Middleware) {
|
func (r *Router) Use(mw ...middleware.Middleware) {
|
||||||
r.middleware = append(r.middleware, mw...)
|
r.root.Use(mw...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mount registers one or more RouteGroups.
|
// Group creates a route group that inherits global middleware from the router.
|
||||||
func (r *Router) Mount(groups ...RouteGroup) {
|
func (r *Router) Group(prefix string, fn func(*Group)) {
|
||||||
for _, rg := range groups {
|
r.root.SubGroup(prefix, fn)
|
||||||
var mw []Middleware
|
|
||||||
if mp, ok := rg.(MiddlewareProvider); ok {
|
|
||||||
mw = mp.Middlewares()
|
|
||||||
}
|
|
||||||
g := newGroup(r.mux, rg.Prefix(), mw)
|
|
||||||
rg.Register(g)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns the final http.Handler with global middleware applied.
|
// Handler returns the final http.Handler. All middleware is already applied
|
||||||
|
// per-route through the group hierarchy, so this just returns the mux.
|
||||||
func (r *Router) Handler() http.Handler {
|
func (r *Router) Handler() http.Handler {
|
||||||
if len(r.middleware) == 0 {
|
return r.mux
|
||||||
return r.mux
|
|
||||||
}
|
|
||||||
return chain(r.mux, r.middleware)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func chain(base http.Handler, mws []Middleware) http.Handler {
|
func (r *Router) Handle(method Method, path string, handler http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
for i := len(mws) - 1; i >= 0; i-- {
|
r.root.Handle(method, path, handler, mw...)
|
||||||
base = mws[i](base)
|
}
|
||||||
}
|
|
||||||
return base
|
func (r *Router) Get(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
r.root.Get(path, h, mw...)
|
||||||
|
}
|
||||||
|
func (r *Router) Post(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
r.root.Post(path, h, mw...)
|
||||||
|
}
|
||||||
|
func (r *Router) Put(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
r.root.Put(path, h, mw...)
|
||||||
|
}
|
||||||
|
func (r *Router) Patch(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
r.root.Patch(path, h, mw...)
|
||||||
|
}
|
||||||
|
func (r *Router) Delete(path string, h http.HandlerFunc, mw ...middleware.Middleware) {
|
||||||
|
r.root.Delete(path, h, mw...)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,70 +3,25 @@ package routes
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/budgit/assets"
|
"git.juancwu.dev/juancwu/budgit/assets"
|
||||||
"git.juancwu.dev/juancwu/budgit/internal/app"
|
"git.juancwu.dev/juancwu/budgit/internal/app"
|
||||||
"git.juancwu.dev/juancwu/budgit/internal/handler"
|
"git.juancwu.dev/juancwu/budgit/internal/handler"
|
||||||
"git.juancwu.dev/juancwu/budgit/internal/middleware"
|
"git.juancwu.dev/juancwu/budgit/internal/middleware"
|
||||||
|
"git.juancwu.dev/juancwu/budgit/internal/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupRoutes(a *app.App) http.Handler {
|
func SetupRoutes(a *app.App) http.Handler {
|
||||||
auth := handler.NewAuthHandler(a.AuthService, a.InviteService, a.SpaceService)
|
authH := handler.NewAuthHandler(a.AuthService, a.InviteService, a.SpaceService)
|
||||||
home := handler.NewHomeHandler()
|
homeH := handler.NewHomeHandler()
|
||||||
settings := handler.NewSettingsHandler(a.AuthService, a.UserService)
|
settingsH := handler.NewSettingsHandler(a.AuthService, a.UserService)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
r := router.New()
|
||||||
|
|
||||||
// ====================================================================================
|
// Global middleware
|
||||||
// PUBLIC ROUTES
|
r.Use(
|
||||||
// ====================================================================================
|
middleware.SecurityHeaders,
|
||||||
|
|
||||||
// Static assets with long-lived cache (cache-busted via ?v=<timestamp>)
|
|
||||||
sub, _ := fs.Sub(assets.AssetsFS, ".")
|
|
||||||
mux.Handle("GET /assets/", middleware.CacheStatic(http.StripPrefix("/assets/", http.FileServer(http.FS(sub)))))
|
|
||||||
|
|
||||||
// Home
|
|
||||||
mux.HandleFunc("GET /{$}", home.HomePage)
|
|
||||||
mux.HandleFunc("GET /forbidden", home.ForbiddenPage)
|
|
||||||
mux.HandleFunc("GET /privacy", home.PrivacyPage)
|
|
||||||
mux.HandleFunc("GET /terms", home.TermsPage)
|
|
||||||
|
|
||||||
// Auth pages
|
|
||||||
authRateLimiter := middleware.RateLimitAuth()
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /auth", middleware.RequireGuest(auth.AuthPage))
|
|
||||||
mux.HandleFunc("GET /auth/password", middleware.RequireGuest(auth.PasswordPage))
|
|
||||||
|
|
||||||
// Token Verifications
|
|
||||||
mux.HandleFunc("GET /auth/magic-link/{token}", auth.VerifyMagicLink)
|
|
||||||
|
|
||||||
// Auth Actions
|
|
||||||
mux.HandleFunc("POST /auth/magic-link", authRateLimiter(middleware.RequireGuest(auth.SendMagicLink)))
|
|
||||||
mux.HandleFunc("POST /auth/password", authRateLimiter(middleware.RequireGuest(auth.LoginWithPassword)))
|
|
||||||
mux.HandleFunc("POST /auth/logout", auth.Logout)
|
|
||||||
|
|
||||||
// Join via invite
|
|
||||||
mux.HandleFunc("GET /join/{token}", auth.JoinSpace)
|
|
||||||
|
|
||||||
// ====================================================================================
|
|
||||||
// PRIVATE ROUTES
|
|
||||||
// ====================================================================================
|
|
||||||
|
|
||||||
crudLimiter := middleware.RateLimitCRUD()
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /auth/onboarding", middleware.RequireAuth(auth.OnboardingPage))
|
|
||||||
mux.Handle("POST /auth/onboarding", crudLimiter(http.HandlerFunc(middleware.RequireAuth(auth.CompleteOnboarding))))
|
|
||||||
|
|
||||||
mux.HandleFunc("GET /app/settings", middleware.RequireAuth(settings.SettingsPage))
|
|
||||||
mux.HandleFunc("POST /app/settings/password", authRateLimiter(middleware.RequireAuth(settings.SetPassword)))
|
|
||||||
|
|
||||||
// 404
|
|
||||||
mux.HandleFunc("/{path...}", home.NotFoundPage)
|
|
||||||
|
|
||||||
// Global middlewares
|
|
||||||
handler := middleware.Chain(
|
|
||||||
mux,
|
|
||||||
middleware.SecurityHeaders(),
|
|
||||||
middleware.Config(a.Cfg),
|
middleware.Config(a.Cfg),
|
||||||
middleware.RequestLogging,
|
middleware.RequestLogging,
|
||||||
middleware.NoCacheDynamic,
|
middleware.NoCacheDynamic,
|
||||||
|
|
@ -75,5 +30,57 @@ func SetupRoutes(a *app.App) http.Handler {
|
||||||
middleware.WithURLPath,
|
middleware.WithURLPath,
|
||||||
)
|
)
|
||||||
|
|
||||||
return handler
|
// Static assets (bypass router groups — registered directly on mux)
|
||||||
|
sub, _ := fs.Sub(assets.AssetsFS, ".")
|
||||||
|
r.Mux().Handle("GET /assets/",
|
||||||
|
middleware.CacheStatic(http.StripPrefix("/assets/", http.FileServer(http.FS(sub)))),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Public pages
|
||||||
|
r.Get("/{$}", homeH.HomePage)
|
||||||
|
r.Get("/forbidden", homeH.ForbiddenPage)
|
||||||
|
r.Get("/privacy", homeH.PrivacyPage)
|
||||||
|
r.Get("/terms", homeH.TermsPage)
|
||||||
|
r.Get("/join/{token}", authH.JoinSpace)
|
||||||
|
|
||||||
|
// Auth - guest routes
|
||||||
|
r.Group("/auth", func(g *router.Group) {
|
||||||
|
g.Use(middleware.RequireGuest)
|
||||||
|
g.Get("", authH.AuthPage)
|
||||||
|
g.Get("/password", authH.PasswordPage)
|
||||||
|
g.Get("/magic-link/{token}", authH.VerifyMagicLink)
|
||||||
|
|
||||||
|
g.SubGroup("", func(g *router.Group) {
|
||||||
|
g.RateLimit(5, 15*time.Minute)
|
||||||
|
g.Post("/magic-link", authH.SendMagicLink)
|
||||||
|
g.Post("/password", authH.LoginWithPassword)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Auth - authenticated routes
|
||||||
|
r.Group("/auth", func(g *router.Group) {
|
||||||
|
g.Use(middleware.RequireAuth)
|
||||||
|
g.Get("/onboarding", authH.OnboardingPage)
|
||||||
|
g.Post("/onboarding", authH.CompleteOnboarding)
|
||||||
|
})
|
||||||
|
r.Post("/auth/logout", authH.Logout)
|
||||||
|
|
||||||
|
// App routes
|
||||||
|
r.Group("/app", func(g *router.Group) {
|
||||||
|
g.Use(middleware.RequireAuth)
|
||||||
|
|
||||||
|
g.SubGroup("/settings", func(g *router.Group) {
|
||||||
|
g.Get("", settingsH.SettingsPage)
|
||||||
|
|
||||||
|
g.SubGroup("", func(g *router.Group) {
|
||||||
|
g.RateLimit(5, 15*time.Minute)
|
||||||
|
g.Post("/password", settingsH.SetPassword)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 404 catch-all
|
||||||
|
r.Get("/{path...}", homeH.NotFoundPage)
|
||||||
|
|
||||||
|
return r.Handler()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue