From 522ac09cdc02a2bffd20388363ed5600a1824371 Mon Sep 17 00:00:00 2001 From: juancwu Date: Sun, 26 Apr 2026 23:28:40 +0000 Subject: [PATCH] add iplimit middleware Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 19 ++++ go.mod | 1 + go.sum | 2 + iplimit/iplimit.go | 134 +++++++++++++++++++++++++ iplimit/iplimit_test.go | 215 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 371 insertions(+) create mode 100644 iplimit/iplimit.go create mode 100644 iplimit/iplimit_test.go diff --git a/README.md b/README.md index 03e6447..1b0781e 100644 --- a/README.md +++ b/README.md @@ -72,3 +72,22 @@ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { ``` Loopback, private, link-local, and unspecified addresses are skipped to preserve API quota. Lookup failures are logged at warn level via [splinter](https://git.juancwu.dev/juancwu/splinter) (pass `nil` for `splinter.Default()` resolved at request time, or supply a custom `*splinter.Logger`) and let the request through with no context value — handlers should treat the `From` lookup as optional. + +### `iplimit` + +Per-IP token-bucket rate limiter backed by [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate). Each unique IP gets its own bucket; rejected requests get a 429 with a `Retry-After` header (whole seconds, ceiling). + +```go +import ( + "time" + + "git.juancwu.dev/juancwu/lightmux-contrib/iplimit" + "git.juancwu.dev/juancwu/lightmux-contrib/realip" + "golang.org/x/time/rate" +) + +// 5 req/s steady state, bursts of 10 +mux.Use(realip.New(), iplimit.New(rate.Every(time.Second/5), 10)) +``` + +Pair with `realip` upstream so the limiter keys on the originating client IP rather than the proxy peer. Loopback, private, link-local, and unspecified addresses pass through unlimited — they are typically internal callers (health checks, dev) that should not be rate limited. Idle buckets are evicted after 10 minutes of inactivity by an inline sweep, so no background goroutines are spawned. diff --git a/go.mod b/go.mod index 1c4fc88..7501e06 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( git.juancwu.dev/juancwu/errx v0.1.0 git.juancwu.dev/juancwu/splinter v0.1.0 github.com/ipinfo/go/v2 v2.14.0 + golang.org/x/time v0.15.0 ) require ( diff --git a/go.sum b/go.sum index 3b0d732..7b7a748 100644 --- a/go.sum +++ b/go.sum @@ -8,3 +8,5 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= diff --git a/iplimit/iplimit.go b/iplimit/iplimit.go new file mode 100644 index 0000000..027d193 --- /dev/null +++ b/iplimit/iplimit.go @@ -0,0 +1,134 @@ +// Package iplimit provides a per-IP rate-limiting HTTP middleware backed by +// a token bucket from golang.org/x/time/rate. Each unique IP gets its own +// bucket; rejected requests are served a 429 response with a Retry-After +// header. +package iplimit + +import ( + "math" + "net" + "net/http" + "strconv" + "sync" + "time" + + "git.juancwu.dev/juancwu/errx" + "golang.org/x/time/rate" +) + +const ( + op = "iplimit" + + // idleTTL is how long a per-IP bucket survives without activity before + // being evicted from the in-memory map. + idleTTL = 10 * time.Minute + + // sweepEvery bounds how often the eviction pass runs (lazily, on the + // next request after the interval expires). + sweepEvery = idleTTL / 10 +) + +// timeNow is the clock the middleware reads. Tests override it to drive the +// limiter and the eviction sweep deterministically. +var timeNow = time.Now + +type entry struct { + limiter *rate.Limiter + seen time.Time +} + +// New returns a per-IP rate-limiting middleware. r is the steady-state token +// refill rate; burst is the maximum burst size. Use rate.Every(d) to express +// the rate as one event per duration d. +// +// Pair with realip.New() upstream so the limiter keys on the originating +// client IP rather than the proxy peer. Loopback, private, link-local, and +// unspecified addresses pass through unlimited — they are typically internal +// callers (health checks, dev) that should not be rate limited. +// +// Idle buckets are evicted after 10 minutes of inactivity. The sweep runs +// inline on the next request after the interval expires, so the only state +// retained between requests is the entries map itself — no background +// goroutines. +// +// Rejected requests receive a 429 Too Many Requests response with a +// Retry-After header (in whole seconds, ceiling). +func New(r rate.Limit, burst int) func(http.Handler) http.Handler { + if r <= 0 { + panic(errx.New(op, "rate must be > 0")) + } + if burst <= 0 { + panic(errx.New(op, "burst must be > 0")) + } + + var ( + mu sync.Mutex + entries = make(map[string]*entry) + nextSweep time.Time + ) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ip := parseClientIP(req.RemoteAddr) + if ip == nil || isLocal(ip) { + next.ServeHTTP(w, req) + return + } + + now := timeNow() + key := ip.String() + + mu.Lock() + if now.After(nextSweep) { + for k, e := range entries { + if now.Sub(e.seen) > idleTTL { + delete(entries, k) + } + } + nextSweep = now.Add(sweepEvery) + } + e, ok := entries[key] + if !ok { + e = &entry{limiter: rate.NewLimiter(r, burst)} + entries[key] = e + } + e.seen = now + lim := e.limiter + mu.Unlock() + + res := lim.ReserveN(now, 1) + delay := res.DelayFrom(now) + if delay > 0 { + res.CancelAt(now) + w.Header().Set("Retry-After", strconv.Itoa(retryAfterSeconds(delay))) + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, req) + }) + } +} + +func retryAfterSeconds(d time.Duration) int { + s := int(math.Ceil(d.Seconds())) + if s < 1 { + return 1 + } + return s +} + +func parseClientIP(remoteAddr string) net.IP { + host := remoteAddr + if h, _, err := net.SplitHostPort(remoteAddr); err == nil { + host = h + } + return net.ParseIP(host) +} + +func isLocal(ip net.IP) bool { + return ip.IsLoopback() || + ip.IsPrivate() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsUnspecified() +} diff --git a/iplimit/iplimit_test.go b/iplimit/iplimit_test.go new file mode 100644 index 0000000..6f67760 --- /dev/null +++ b/iplimit/iplimit_test.go @@ -0,0 +1,215 @@ +package iplimit + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "golang.org/x/time/rate" +) + +func okHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} + +func serve(h http.Handler, remoteAddr string) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = remoteAddr + h.ServeHTTP(rr, req) + return rr +} + +// fakeClock returns a closure that advances the package-level timeNow when +// called. The cleanup restores the real clock. +func fakeClock(t *testing.T, start time.Time) func(time.Duration) { + t.Helper() + now := start + timeNow = func() time.Time { return now } + t.Cleanup(func() { timeNow = time.Now }) + return func(d time.Duration) { now = now.Add(d) } +} + +func TestNewAllowsBurstThenBlocks(t *testing.T) { + h := New(rate.Every(time.Second), 3)(okHandler()) + const ip = "203.0.113.1:1234" + + for i := 0; i < 3; i++ { + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Errorf("request %d: status = %d, want 200", i, rr.Code) + } + } + rr := serve(h, ip) + if rr.Code != http.StatusTooManyRequests { + t.Errorf("4th request: status = %d, want 429", rr.Code) + } + if rr.Header().Get("Retry-After") == "" { + t.Error("expected Retry-After header on 429") + } +} + +func TestNewSeparatesBucketsPerIP(t *testing.T) { + // burst=1, near-zero refill — second request from the same IP would 429. + h := New(rate.Every(time.Hour), 1)(okHandler()) + + if rr := serve(h, "203.0.113.1:1234"); rr.Code != http.StatusOK { + t.Errorf("first IP: status = %d, want 200", rr.Code) + } + if rr := serve(h, "203.0.113.2:1234"); rr.Code != http.StatusOK { + t.Errorf("second IP: status = %d, want 200 (separate bucket)", rr.Code) + } +} + +func TestNewIgnoresPort(t *testing.T) { + // Same IP from two source ports must share a bucket. + h := New(rate.Every(time.Hour), 1)(okHandler()) + + if rr := serve(h, "203.0.113.1:1111"); rr.Code != http.StatusOK { + t.Errorf("port 1111: status = %d, want 200", rr.Code) + } + if rr := serve(h, "203.0.113.1:2222"); rr.Code != http.StatusTooManyRequests { + t.Errorf("port 2222 (same IP): status = %d, want 429", rr.Code) + } +} + +func TestNewSkipsLocalAddresses(t *testing.T) { + h := New(rate.Every(time.Hour), 1)(okHandler()) + + cases := []string{ + "127.0.0.1:1234", // loopback + "10.0.0.1:1234", // private + "192.168.1.1:1234", // private + "169.254.0.1:1234", // link-local + "[::1]:1234", // IPv6 loopback + "[fe80::1]:1234", // IPv6 link-local + } + for _, ra := range cases { + t.Run(ra, func(t *testing.T) { + for i := 0; i < 5; i++ { + if rr := serve(h, ra); rr.Code != http.StatusOK { + t.Errorf("local %s req %d: status = %d, want 200", ra, i, rr.Code) + } + } + }) + } +} + +func TestNewPassesThroughOnUnparseableRemoteAddr(t *testing.T) { + h := New(rate.Every(time.Hour), 1)(okHandler()) + + for i := 0; i < 5; i++ { + if rr := serve(h, "not-an-addr"); rr.Code != http.StatusOK { + t.Errorf("unparseable req %d: status = %d, want 200", i, rr.Code) + } + } +} + +func TestNewRefillsAfterDelay(t *testing.T) { + advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)) + + h := New(rate.Every(time.Second), 1)(okHandler()) + const ip = "203.0.113.5:1234" + + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Fatalf("first: status = %d, want 200", rr.Code) + } + if rr := serve(h, ip); rr.Code != http.StatusTooManyRequests { + t.Fatalf("second (same instant): status = %d, want 429", rr.Code) + } + advance(time.Second) + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Errorf("third (after refill): status = %d, want 200", rr.Code) + } +} + +func TestNewRetryAfterReflectsDelay(t *testing.T) { + fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)) + + h := New(rate.Every(2*time.Second), 1)(okHandler()) + const ip = "203.0.113.6:1234" + + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Fatalf("first: status = %d, want 200", rr.Code) + } + rr := serve(h, ip) + if rr.Code != http.StatusTooManyRequests { + t.Fatalf("second: status = %d, want 429", rr.Code) + } + got, err := strconv.Atoi(rr.Header().Get("Retry-After")) + if err != nil { + t.Fatalf("Retry-After: %v", err) + } + if got != 2 { + t.Errorf("Retry-After = %d, want 2", got) + } +} + +func TestNewCancelDoesNotConsumeTokenOnReject(t *testing.T) { + // burst=1, rate=1/sec. After the burst is used, cancelled rejections + // should not push the next-allowed time further into the future. + advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)) + + h := New(rate.Every(time.Second), 1)(okHandler()) + const ip = "203.0.113.7:1234" + + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Fatalf("first: status = %d, want 200", rr.Code) + } + // Pile on rejections at the same instant. + for i := 0; i < 5; i++ { + if rr := serve(h, ip); rr.Code != http.StatusTooManyRequests { + t.Fatalf("reject %d: status = %d, want 429", i, rr.Code) + } + } + advance(time.Second) + if rr := serve(h, ip); rr.Code != http.StatusOK { + t.Errorf("after 1s with cancelled rejections: status = %d, want 200", rr.Code) + } +} + +func TestNewEvictsIdleEntries(t *testing.T) { + advance := fakeClock(t, time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)) + + h := New(rate.Every(time.Hour), 1)(okHandler()) + const idle = "203.0.113.8:1234" + const fresh = "203.0.113.9:1234" + + // Burn the idle IP's bucket — its limiter is now empty. + if rr := serve(h, idle); rr.Code != http.StatusOK { + t.Fatalf("idle setup: status = %d, want 200", rr.Code) + } + + // Advance past the eviction TTL and trigger a sweep via a different IP. + advance(idleTTL + time.Minute) + if rr := serve(h, fresh); rr.Code != http.StatusOK { + t.Fatalf("fresh sweep trigger: status = %d, want 200", rr.Code) + } + + // Idle's old limiter should have been evicted; a brand-new bucket gets + // the full burst back. + if rr := serve(h, idle); rr.Code != http.StatusOK { + t.Errorf("after eviction: status = %d, want 200 (new bucket)", rr.Code) + } +} + +func TestNewPanicsOnInvalidRate(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic on rate=0") + } + }() + _ = New(0, 1) +} + +func TestNewPanicsOnInvalidBurst(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic on burst=0") + } + }() + _ = New(rate.Every(time.Second), 0) +}