add middlewares, handlers and database models
This commit is contained in:
parent
979a415b95
commit
7e288ea67a
24 changed files with 1045 additions and 14 deletions
21
internal/middleware/chain.go
Normal file
21
internal/middleware/chain.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Chain applies multiple middleware in order (first to last)
|
||||
// The middleware are executed in the order they are provided
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// handler := Chain(mux,
|
||||
// AuthMiddleware(...), // Executes first
|
||||
// WithURLPath, // Executes second
|
||||
// Config(...), // Executes third
|
||||
// )
|
||||
func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
|
||||
// Apply middleware in reverse order so they execute in the order provided
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
h = middlewares[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
19
internal/middleware/config.go
Normal file
19
internal/middleware/config.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/budgething/internal/config"
|
||||
"git.juancwu.dev/juancwu/budgething/internal/ctxkeys"
|
||||
)
|
||||
|
||||
// Config middleware adds the sanitized app configuration to the request context.
|
||||
// Sensitive values like JWTSecret and DBPath are excluded for security.
|
||||
func Config(cfg *config.Config) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := ctxkeys.WithConfig(r.Context(), cfg.Sanitized())
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
108
internal/middleware/csrf.go
Normal file
108
internal/middleware/csrf.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.juancwu.dev/juancwu/budgething/internal/ctxkeys"
|
||||
)
|
||||
|
||||
const (
|
||||
csrfCookieName = "csrf_token"
|
||||
csrfFormField = "csrf_token"
|
||||
csrfHeader = "X-CSRF-Token"
|
||||
csrfTokenLen = 32
|
||||
)
|
||||
|
||||
// CSRFProtection validates CSRF tokens on all state-changing requests
|
||||
func CSRFProtection(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip CSRF check for safe methods (GET, HEAD, OPTIONS)
|
||||
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
|
||||
token := getOrGenerateCSRFToken(w, r)
|
||||
ctx := ctxkeys.WithCSRFToken(r.Context(), token)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Skip CSRF check for webhooks (external services)
|
||||
if strings.HasPrefix(r.URL.Path, "/webhooks/") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF token for state-changing methods (POST, PUT, PATCH, DELETE)
|
||||
token := getOrGenerateCSRFToken(w, r)
|
||||
ctx := ctxkeys.WithCSRFToken(r.Context(), token)
|
||||
|
||||
// Get submitted token - try multiple sources in priority order
|
||||
// 1. Header (HTMX automatic via meta tag)
|
||||
// 2. Form field (both application/x-www-form-urlencoded and multipart/form-data)
|
||||
// PostFormValue() automatically parses the request based on Content-Type
|
||||
submittedToken := r.Header.Get(csrfHeader)
|
||||
if submittedToken == "" {
|
||||
submittedToken = r.PostFormValue(csrfFormField)
|
||||
}
|
||||
|
||||
// Validate token using constant-time comparison
|
||||
if !validCSRFToken(token, submittedToken) {
|
||||
slog.Warn("csrf validation failed",
|
||||
"path", r.URL.Path,
|
||||
"method", r.Method,
|
||||
"ip", getClientIP(r),
|
||||
)
|
||||
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// getOrGenerateCSRFToken retrieves existing token or generates new one
|
||||
func getOrGenerateCSRFToken(w http.ResponseWriter, r *http.Request) string {
|
||||
cookie, err := r.Cookie(csrfCookieName)
|
||||
if err == nil && cookie.Value != "" && len(cookie.Value) == base64.RawURLEncoding.EncodedLen(csrfTokenLen) {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
token := generateCSRFToken()
|
||||
|
||||
cfg := ctxkeys.Config(r.Context())
|
||||
isProduction := cfg != nil && cfg.IsProduction()
|
||||
|
||||
// Set cookie with SameSite=Lax for CSRF protection
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: csrfCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isProduction, // Secure flag based on APP_ENV (safer than r.TLS behind load balancers)
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 86400 * 7, // 7 days
|
||||
})
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// generateCSRFToken creates cryptographically secure random token
|
||||
func generateCSRFToken() string {
|
||||
bytes := make([]byte, csrfTokenLen)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
panic("failed to generate csrf token: " + err.Error())
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// validCSRFToken performs constant-time comparison of tokens
|
||||
func validCSRFToken(expected, actual string) bool {
|
||||
if expected == "" || actual == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1
|
||||
}
|
||||
70
internal/middleware/logging.go
Normal file
70
internal/middleware/logging.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written bool
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
if !rw.written {
|
||||
rw.statusCode = code
|
||||
rw.written = true
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
if !rw.written {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return rw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// Paths to skip logging (static assets, etc.)
|
||||
var skipLoggingPaths = []string{
|
||||
"/assets/",
|
||||
"/uploads/",
|
||||
"/favicon.ico",
|
||||
}
|
||||
|
||||
// RequestLogging logs HTTP requests with method, path, status, and duration
|
||||
// Skips logging for paths defined in skipLoggingPaths
|
||||
func RequestLogging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip logging for configured paths
|
||||
for _, prefix := range skipLoggingPaths {
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
written: false,
|
||||
}
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
duration := time.Since(start)
|
||||
slog.Info("http request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", rw.statusCode,
|
||||
"duration_ms", duration.Milliseconds(),
|
||||
"remote_addr", getClientIP(r),
|
||||
)
|
||||
})
|
||||
}
|
||||
151
internal/middleware/ratelimit.go
Normal file
151
internal/middleware/ratelimit.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter tracks request counts per IP address
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
requests map[string][]time.Time
|
||||
limit int // Max requests allowed
|
||||
window time.Duration // Time window for rate limiting
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine to prevent memory leak
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks if request from IP should be allowed
|
||||
func (rl *RateLimiter) Allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-rl.window)
|
||||
|
||||
// Get requests for this IP
|
||||
requests := rl.requests[ip]
|
||||
|
||||
// Remove old requests outside time window
|
||||
validRequests := []time.Time{}
|
||||
for _, reqTime := range requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if limit exceeded
|
||||
if len(validRequests) >= rl.limit {
|
||||
rl.requests[ip] = validRequests
|
||||
return false
|
||||
}
|
||||
|
||||
// Add current request
|
||||
validRequests = append(validRequests, now)
|
||||
rl.requests[ip] = validRequests
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes old entries to prevent memory leak
|
||||
func (rl *RateLimiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes IPs with no recent requests
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-rl.window * 2) // Keep data for 2x window
|
||||
|
||||
for ip, requests := range rl.requests {
|
||||
// Check if all requests are old
|
||||
allOld := true
|
||||
for _, reqTime := range requests {
|
||||
if reqTime.After(cutoff) {
|
||||
allOld = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Remove IP if all requests are old
|
||||
if allOld {
|
||||
delete(rl.requests, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitAuth creates middleware for auth endpoints
|
||||
// Limits: 5 requests per 15 minutes per IP
|
||||
func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc {
|
||||
limiter := NewRateLimiter(5, 15*time.Minute)
|
||||
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get real IP (handle proxies)
|
||||
ip := getClientIP(r)
|
||||
|
||||
// Check rate limit
|
||||
if !limiter.Allow(ip) {
|
||||
slog.Warn("rate limit exceeded",
|
||||
"ip", ip,
|
||||
"path", r.URL.Path,
|
||||
)
|
||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP extracts real client IP from request
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (proxy/load balancer)
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// Take first IP in list
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
xri := r.Header.Get("X-Real-IP")
|
||||
if xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fallback to RemoteAddr
|
||||
ip := r.RemoteAddr
|
||||
// Remove port if present
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
15
internal/middleware/urlpath.go
Normal file
15
internal/middleware/urlpath.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/budgething/internal/ctxkeys"
|
||||
)
|
||||
|
||||
// WithURLPath adds the current URL's path to the context
|
||||
func WithURLPath(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctxWithPath := ctxkeys.WithURLPath(r.Context(), r.URL.Path)
|
||||
next.ServeHTTP(w, r.WithContext(ctxWithPath))
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue