budgit/internal/middleware/ratelimit.go

151 lines
3.2 KiB
Go

package middleware
import (
"log/slog"
"net/http"
"strings"
"sync"
"time"
)
// RateLimiter tracks request counts per IP address
type RateLimiter struct {
mu sync.RWMutex
requests map[string][]time.Time
limit int // Max requests allowed
window time.Duration // Time window for rate limiting
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
// Start cleanup goroutine to prevent memory leak
go rl.cleanupLoop()
return rl
}
// Allow checks if request from IP should be allowed
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Get requests for this IP
requests := rl.requests[ip]
// Remove old requests outside time window
validRequests := []time.Time{}
for _, reqTime := range requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
// Check if limit exceeded
if len(validRequests) >= rl.limit {
rl.requests[ip] = validRequests
return false
}
// Add current request
validRequests = append(validRequests, now)
rl.requests[ip] = validRequests
return true
}
// cleanupLoop periodically removes old entries to prevent memory leak
func (rl *RateLimiter) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
// cleanup removes IPs with no recent requests
func (rl *RateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.window * 2) // Keep data for 2x window
for ip, requests := range rl.requests {
// Check if all requests are old
allOld := true
for _, reqTime := range requests {
if reqTime.After(cutoff) {
allOld = false
break
}
}
// Remove IP if all requests are old
if allOld {
delete(rl.requests, ip)
}
}
}
// RateLimitAuth creates middleware for auth endpoints
// Limits: 5 requests per 15 minutes per IP
func RateLimitAuth() func(http.HandlerFunc) http.HandlerFunc {
limiter := NewRateLimiter(5, 15*time.Minute)
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get real IP (handle proxies)
ip := getClientIP(r)
// Check rate limit
if !limiter.Allow(ip) {
slog.Warn("rate limit exceeded",
"ip", ip,
"path", r.URL.Path,
)
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
return
}
next(w, r)
}
}
}
// getClientIP extracts real client IP from request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (proxy/load balancer)
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
// Take first IP in list
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return strings.TrimSpace(xri)
}
// Fallback to RemoteAddr
ip := r.RemoteAddr
// Remove port if present
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}