diff --git a/README.md b/README.md index d843b39..5a1370f 100644 --- a/README.md +++ b/README.md @@ -84,15 +84,23 @@ Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-com The `pkg/middleware` package ships: -- **`Logger`** — emits a structured `http.request` record (method, path, status, duration) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. +- **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP. - **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default. - **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response. +- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy. +- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched. ```go custom := splinter.New(splinter.WithStream(...)) mux.Use(middleware.Recoverer, middleware.LoggerWith(custom)) ``` +When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address: + +```go +mux.Use(middleware.RealIP, middleware.Logger) +``` + ## Path parameters lightmux is a thin wrapper, so path parameters work the stdlib way: diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go index 4628a5e..4d5e754 100644 --- a/pkg/middleware/logger.go +++ b/pkg/middleware/logger.go @@ -53,6 +53,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler { "path", r.URL.Path, "status", rec.status, "duration", time.Since(start), + "client", r.RemoteAddr, ) return } @@ -61,6 +62,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler { "path", r.URL.Path, "status", rec.status, "duration", time.Since(start), + "client", r.RemoteAddr, ) }) } diff --git a/pkg/middleware/logger_test.go b/pkg/middleware/logger_test.go index 3e5ffa5..a35b3d8 100644 --- a/pkg/middleware/logger_test.go +++ b/pkg/middleware/logger_test.go @@ -37,7 +37,7 @@ func TestLogger(t *testing.T) { t.Errorf("status code = %d, want 418", rr.Code) } out := buf.String() - for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`} { + for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} { if !strings.Contains(out, want) { t.Errorf("log output missing %s\nfull output: %s", want, out) } diff --git a/pkg/middleware/realip.go b/pkg/middleware/realip.go new file mode 100644 index 0000000..a0d323c --- /dev/null +++ b/pkg/middleware/realip.go @@ -0,0 +1,86 @@ +package middleware + +import ( + "net" + "net/http" + "net/netip" + "strings" +) + +var realIPHeaders = []string{ + "CF-Connecting-IP", + "True-Client-IP", + "X-Real-IP", + "X-Forwarded-For", +} + +// RealIP rewrites r.RemoteAddr with the originating client IP found in common +// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers — +// only register it when the service sits behind a trusted proxy. Use +// RealIPWith for an allowlist-gated variant. +func RealIP(next http.Handler) http.Handler { + return realIPHandler(nil, next) +} + +// RealIPWith returns a RealIP middleware that only honors the proxy headers +// when the immediate peer (parsed from r.RemoteAddr) falls within one of the +// trusted prefixes. Requests from outside the allowlist are passed through +// untouched. +func RealIPWith(trusted ...netip.Prefix) Middleware { + if trusted == nil { + trusted = []netip.Prefix{} + } + return func(next http.Handler) http.Handler { + return realIPHandler(trusted, next) + } +} + +func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) { + next.ServeHTTP(w, r) + return + } + if ip := extractRealIP(r); ip != "" { + r2 := *r + r2.RemoteAddr = ip + next.ServeHTTP(w, &r2) + return + } + next.ServeHTTP(w, r) + }) +} + +func extractRealIP(r *http.Request) string { + for _, h := range realIPHeaders { + v := r.Header.Get(h) + if v == "" { + continue + } + if i := strings.IndexByte(v, ','); i >= 0 { + v = v[:i] + } + v = strings.TrimSpace(v) + if net.ParseIP(v) != nil { + return v + } + } + return "" +} + +func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool { + var peer netip.Addr + if ap, err := netip.ParseAddrPort(remoteAddr); err == nil { + peer = ap.Addr() + } else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil { + peer = a + } else { + return false + } + for _, p := range trusted { + if p.Contains(peer) { + return true + } + } + return false +} diff --git a/pkg/middleware/realip_test.go b/pkg/middleware/realip_test.go new file mode 100644 index 0000000..39897e0 --- /dev/null +++ b/pkg/middleware/realip_test.go @@ -0,0 +1,198 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "net/netip" + "testing" +) + +const defaultTestRemoteAddr = "192.0.2.1:1234" + +func captureRemoteAddr(got *string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *got = r.RemoteAddr + }) +} + +func TestRealIP(t *testing.T) { + cases := []struct { + name string + headers map[string]string + want string + }{ + { + name: "no headers", + want: defaultTestRemoteAddr, + }, + { + name: "CF-Connecting-IP", + headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"}, + want: "203.0.113.5", + }, + { + name: "True-Client-IP", + headers: map[string]string{"True-Client-IP": "203.0.113.6"}, + want: "203.0.113.6", + }, + { + name: "X-Real-IP", + headers: map[string]string{"X-Real-IP": "203.0.113.7"}, + want: "203.0.113.7", + }, + { + name: "X-Forwarded-For single", + headers: map[string]string{"X-Forwarded-For": "203.0.113.8"}, + want: "203.0.113.8", + }, + { + name: "X-Forwarded-For list", + headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"}, + want: "203.0.113.9", + }, + { + name: "X-Forwarded-For with spaces", + headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"}, + want: "203.0.113.10", + }, + { + name: "precedence CF over XFF", + headers: map[string]string{ + "CF-Connecting-IP": "203.0.113.11", + "X-Forwarded-For": "198.51.100.1", + }, + want: "203.0.113.11", + }, + { + name: "invalid then valid", + headers: map[string]string{ + "CF-Connecting-IP": "not-an-ip", + "X-Real-IP": "203.0.113.12", + }, + want: "203.0.113.12", + }, + { + name: "IPv6", + headers: map[string]string{"X-Real-IP": "2001:db8::1"}, + want: "2001:db8::1", + }, + { + name: "all invalid falls through", + headers: map[string]string{"CF-Connecting-IP": "garbage"}, + want: defaultTestRemoteAddr, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got string + h := RealIP(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + for k, v := range tc.headers { + req.Header.Set(k, v) + } + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != tc.want { + t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want) + } + }) + } +} + +func TestRealIPDoesNotMutateCallerRequest(t *testing.T) { + var seen string + h := RealIP(captureRemoteAddr(&seen)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Real-IP", "203.0.113.13") + original := req.RemoteAddr + + h.ServeHTTP(httptest.NewRecorder(), req) + + if seen != "203.0.113.13" { + t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13") + } + if req.RemoteAddr != original { + t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original) + } +} + +func TestRealIPWithTrustedPeer(t *testing.T) { + var got string + h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.1.2.3:55555" + req.Header.Set("X-Real-IP", "203.0.113.20") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.20" { + t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20") + } +} + +func TestRealIPWithUntrustedPeer(t *testing.T) { + var got string + h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.99:55555" + req.Header.Set("X-Real-IP", "198.51.100.1") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.99:55555" { + t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555") + } +} + +func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) { + var got string + h := RealIPWith( + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("2001:db8::/32"), + )(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "[2001:db8::abcd]:55555" + req.Header.Set("X-Real-IP", "203.0.113.30") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.30" { + t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30") + } +} + +func TestRealIPWithZeroPrefixesNoOp(t *testing.T) { + var got string + h := RealIPWith()(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("CF-Connecting-IP", "203.0.113.40") + original := req.RemoteAddr + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != original { + t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original) + } +} + +func TestRealIPWithUnparseableRemoteAddr(t *testing.T) { + var got string + h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "not-an-addr" + req.Header.Set("CF-Connecting-IP", "203.0.113.50") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "not-an-addr" { + t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got) + } +} diff --git a/pkg/router/mux_test.go b/pkg/router/mux_test.go index 1af57d1..1cc2fd1 100644 --- a/pkg/router/mux_test.go +++ b/pkg/router/mux_test.go @@ -166,4 +166,3 @@ func TestConflictingPatternsPanic(t *testing.T) { }() m.Get("/x", func(w http.ResponseWriter, r *http.Request) {}) } - diff --git a/pkg/router/pattern_test.go b/pkg/router/pattern_test.go index 1ef240a..d93e8f7 100644 --- a/pkg/router/pattern_test.go +++ b/pkg/router/pattern_test.go @@ -4,8 +4,8 @@ import "testing" func TestSplitPattern(t *testing.T) { cases := []struct { - in string - method, host, path string + in string + method, host, path string }{ {"/foo", "", "", "/foo"}, {"GET /foo", "GET", "", "/foo"},