add iplimit middleware

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
juancwu 2026-04-26 23:28:40 +00:00
commit 522ac09cdc
5 changed files with 371 additions and 0 deletions

134
iplimit/iplimit.go Normal file
View file

@ -0,0 +1,134 @@
// Package iplimit provides a per-IP rate-limiting HTTP middleware backed by
// a token bucket from golang.org/x/time/rate. Each unique IP gets its own
// bucket; rejected requests are served a 429 response with a Retry-After
// header.
package iplimit
import (
"math"
"net"
"net/http"
"strconv"
"sync"
"time"
"git.juancwu.dev/juancwu/errx"
"golang.org/x/time/rate"
)
const (
op = "iplimit"
// idleTTL is how long a per-IP bucket survives without activity before
// being evicted from the in-memory map.
idleTTL = 10 * time.Minute
// sweepEvery bounds how often the eviction pass runs (lazily, on the
// next request after the interval expires).
sweepEvery = idleTTL / 10
)
// timeNow is the clock the middleware reads. Tests override it to drive the
// limiter and the eviction sweep deterministically.
var timeNow = time.Now
type entry struct {
limiter *rate.Limiter
seen time.Time
}
// New returns a per-IP rate-limiting middleware. r is the steady-state token
// refill rate; burst is the maximum burst size. Use rate.Every(d) to express
// the rate as one event per duration d.
//
// Pair with realip.New() upstream so the limiter keys on the originating
// client IP rather than the proxy peer. Loopback, private, link-local, and
// unspecified addresses pass through unlimited — they are typically internal
// callers (health checks, dev) that should not be rate limited.
//
// Idle buckets are evicted after 10 minutes of inactivity. The sweep runs
// inline on the next request after the interval expires, so the only state
// retained between requests is the entries map itself — no background
// goroutines.
//
// Rejected requests receive a 429 Too Many Requests response with a
// Retry-After header (in whole seconds, ceiling).
func New(r rate.Limit, burst int) func(http.Handler) http.Handler {
if r <= 0 {
panic(errx.New(op, "rate must be > 0"))
}
if burst <= 0 {
panic(errx.New(op, "burst must be > 0"))
}
var (
mu sync.Mutex
entries = make(map[string]*entry)
nextSweep time.Time
)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ip := parseClientIP(req.RemoteAddr)
if ip == nil || isLocal(ip) {
next.ServeHTTP(w, req)
return
}
now := timeNow()
key := ip.String()
mu.Lock()
if now.After(nextSweep) {
for k, e := range entries {
if now.Sub(e.seen) > idleTTL {
delete(entries, k)
}
}
nextSweep = now.Add(sweepEvery)
}
e, ok := entries[key]
if !ok {
e = &entry{limiter: rate.NewLimiter(r, burst)}
entries[key] = e
}
e.seen = now
lim := e.limiter
mu.Unlock()
res := lim.ReserveN(now, 1)
delay := res.DelayFrom(now)
if delay > 0 {
res.CancelAt(now)
w.Header().Set("Retry-After", strconv.Itoa(retryAfterSeconds(delay)))
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, req)
})
}
}
func retryAfterSeconds(d time.Duration) int {
s := int(math.Ceil(d.Seconds()))
if s < 1 {
return 1
}
return s
}
func parseClientIP(remoteAddr string) net.IP {
host := remoteAddr
if h, _, err := net.SplitHostPort(remoteAddr); err == nil {
host = h
}
return net.ParseIP(host)
}
func isLocal(ip net.IP) bool {
return ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsUnspecified()
}

215
iplimit/iplimit_test.go Normal file
View file

@ -0,0 +1,215 @@
package iplimit
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"golang.org/x/time/rate"
)
func okHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
}
func serve(h http.Handler, remoteAddr string) *httptest.ResponseRecorder {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = remoteAddr
h.ServeHTTP(rr, req)
return rr
}
// fakeClock returns a closure that advances the package-level timeNow when
// called. The cleanup restores the real clock.
func fakeClock(t *testing.T, start time.Time) func(time.Duration) {
t.Helper()
now := start
timeNow = func() time.Time { return now }
t.Cleanup(func() { timeNow = time.Now })
return func(d time.Duration) { now = now.Add(d) }
}
func TestNewAllowsBurstThenBlocks(t *testing.T) {
h := New(rate.Every(time.Second), 3)(okHandler())
const ip = "203.0.113.1:1234"
for i := 0; i < 3; i++ {
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Errorf("request %d: status = %d, want 200", i, rr.Code)
}
}
rr := serve(h, ip)
if rr.Code != http.StatusTooManyRequests {
t.Errorf("4th request: status = %d, want 429", rr.Code)
}
if rr.Header().Get("Retry-After") == "" {
t.Error("expected Retry-After header on 429")
}
}
func TestNewSeparatesBucketsPerIP(t *testing.T) {
// burst=1, near-zero refill — second request from the same IP would 429.
h := New(rate.Every(time.Hour), 1)(okHandler())
if rr := serve(h, "203.0.113.1:1234"); rr.Code != http.StatusOK {
t.Errorf("first IP: status = %d, want 200", rr.Code)
}
if rr := serve(h, "203.0.113.2:1234"); rr.Code != http.StatusOK {
t.Errorf("second IP: status = %d, want 200 (separate bucket)", rr.Code)
}
}
func TestNewIgnoresPort(t *testing.T) {
// Same IP from two source ports must share a bucket.
h := New(rate.Every(time.Hour), 1)(okHandler())
if rr := serve(h, "203.0.113.1:1111"); rr.Code != http.StatusOK {
t.Errorf("port 1111: status = %d, want 200", rr.Code)
}
if rr := serve(h, "203.0.113.1:2222"); rr.Code != http.StatusTooManyRequests {
t.Errorf("port 2222 (same IP): status = %d, want 429", rr.Code)
}
}
func TestNewSkipsLocalAddresses(t *testing.T) {
h := New(rate.Every(time.Hour), 1)(okHandler())
cases := []string{
"127.0.0.1:1234", // loopback
"10.0.0.1:1234", // private
"192.168.1.1:1234", // private
"169.254.0.1:1234", // link-local
"[::1]:1234", // IPv6 loopback
"[fe80::1]:1234", // IPv6 link-local
}
for _, ra := range cases {
t.Run(ra, func(t *testing.T) {
for i := 0; i < 5; i++ {
if rr := serve(h, ra); rr.Code != http.StatusOK {
t.Errorf("local %s req %d: status = %d, want 200", ra, i, rr.Code)
}
}
})
}
}
func TestNewPassesThroughOnUnparseableRemoteAddr(t *testing.T) {
h := New(rate.Every(time.Hour), 1)(okHandler())
for i := 0; i < 5; i++ {
if rr := serve(h, "not-an-addr"); rr.Code != http.StatusOK {
t.Errorf("unparseable req %d: status = %d, want 200", i, rr.Code)
}
}
}
func TestNewRefillsAfterDelay(t *testing.T) {
advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC))
h := New(rate.Every(time.Second), 1)(okHandler())
const ip = "203.0.113.5:1234"
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Fatalf("first: status = %d, want 200", rr.Code)
}
if rr := serve(h, ip); rr.Code != http.StatusTooManyRequests {
t.Fatalf("second (same instant): status = %d, want 429", rr.Code)
}
advance(time.Second)
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Errorf("third (after refill): status = %d, want 200", rr.Code)
}
}
func TestNewRetryAfterReflectsDelay(t *testing.T) {
fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC))
h := New(rate.Every(2*time.Second), 1)(okHandler())
const ip = "203.0.113.6:1234"
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Fatalf("first: status = %d, want 200", rr.Code)
}
rr := serve(h, ip)
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("second: status = %d, want 429", rr.Code)
}
got, err := strconv.Atoi(rr.Header().Get("Retry-After"))
if err != nil {
t.Fatalf("Retry-After: %v", err)
}
if got != 2 {
t.Errorf("Retry-After = %d, want 2", got)
}
}
func TestNewCancelDoesNotConsumeTokenOnReject(t *testing.T) {
// burst=1, rate=1/sec. After the burst is used, cancelled rejections
// should not push the next-allowed time further into the future.
advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC))
h := New(rate.Every(time.Second), 1)(okHandler())
const ip = "203.0.113.7:1234"
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Fatalf("first: status = %d, want 200", rr.Code)
}
// Pile on rejections at the same instant.
for i := 0; i < 5; i++ {
if rr := serve(h, ip); rr.Code != http.StatusTooManyRequests {
t.Fatalf("reject %d: status = %d, want 429", i, rr.Code)
}
}
advance(time.Second)
if rr := serve(h, ip); rr.Code != http.StatusOK {
t.Errorf("after 1s with cancelled rejections: status = %d, want 200", rr.Code)
}
}
func TestNewEvictsIdleEntries(t *testing.T) {
advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC))
h := New(rate.Every(time.Hour), 1)(okHandler())
const idle = "203.0.113.8:1234"
const fresh = "203.0.113.9:1234"
// Burn the idle IP's bucket — its limiter is now empty.
if rr := serve(h, idle); rr.Code != http.StatusOK {
t.Fatalf("idle setup: status = %d, want 200", rr.Code)
}
// Advance past the eviction TTL and trigger a sweep via a different IP.
advance(idleTTL + time.Minute)
if rr := serve(h, fresh); rr.Code != http.StatusOK {
t.Fatalf("fresh sweep trigger: status = %d, want 200", rr.Code)
}
// Idle's old limiter should have been evicted; a brand-new bucket gets
// the full burst back.
if rr := serve(h, idle); rr.Code != http.StatusOK {
t.Errorf("after eviction: status = %d, want 200 (new bucket)", rr.Code)
}
}
func TestNewPanicsOnInvalidRate(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic on rate=0")
}
}()
_ = New(0, 1)
}
func TestNewPanicsOnInvalidBurst(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic on burst=0")
}
}()
_ = New(rate.Every(time.Second), 0)
}