Merge branch 'get-real-ip'

This commit is contained in:
juancwu 2026-04-26 13:14:40 +00:00
commit d983cca25e
7 changed files with 298 additions and 5 deletions

View file

@ -84,15 +84,23 @@ Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-com
The `pkg/middleware` package ships: The `pkg/middleware` package ships:
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. - **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP.
- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default. - **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default.
- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response. - **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response.
- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy.
- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched.
```go ```go
custom := splinter.New(splinter.WithStream(...)) custom := splinter.New(splinter.WithStream(...))
mux.Use(middleware.Recoverer, middleware.LoggerWith(custom)) mux.Use(middleware.Recoverer, middleware.LoggerWith(custom))
``` ```
When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address:
```go
mux.Use(middleware.RealIP, middleware.Logger)
```
## Path parameters ## Path parameters
lightmux is a thin wrapper, so path parameters work the stdlib way: lightmux is a thin wrapper, so path parameters work the stdlib way:

View file

@ -53,6 +53,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
"path", r.URL.Path, "path", r.URL.Path,
"status", rec.status, "status", rec.status,
"duration", time.Since(start), "duration", time.Since(start),
"client", r.RemoteAddr,
) )
return return
} }
@ -61,6 +62,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
"path", r.URL.Path, "path", r.URL.Path,
"status", rec.status, "status", rec.status,
"duration", time.Since(start), "duration", time.Since(start),
"client", r.RemoteAddr,
) )
}) })
} }

View file

@ -37,7 +37,7 @@ func TestLogger(t *testing.T) {
t.Errorf("status code = %d, want 418", rr.Code) t.Errorf("status code = %d, want 418", rr.Code)
} }
out := buf.String() out := buf.String()
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`} { for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} {
if !strings.Contains(out, want) { if !strings.Contains(out, want) {
t.Errorf("log output missing %s\nfull output: %s", want, out) t.Errorf("log output missing %s\nfull output: %s", want, out)
} }

86
pkg/middleware/realip.go Normal file
View file

@ -0,0 +1,86 @@
package middleware
import (
"net"
"net/http"
"net/netip"
"strings"
)
var realIPHeaders = []string{
"CF-Connecting-IP",
"True-Client-IP",
"X-Real-IP",
"X-Forwarded-For",
}
// RealIP rewrites r.RemoteAddr with the originating client IP found in common
// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers —
// only register it when the service sits behind a trusted proxy. Use
// RealIPWith for an allowlist-gated variant.
func RealIP(next http.Handler) http.Handler {
return realIPHandler(nil, next)
}
// RealIPWith returns a RealIP middleware that only honors the proxy headers
// when the immediate peer (parsed from r.RemoteAddr) falls within one of the
// trusted prefixes. Requests from outside the allowlist are passed through
// untouched.
func RealIPWith(trusted ...netip.Prefix) Middleware {
if trusted == nil {
trusted = []netip.Prefix{}
}
return func(next http.Handler) http.Handler {
return realIPHandler(trusted, next)
}
}
func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) {
next.ServeHTTP(w, r)
return
}
if ip := extractRealIP(r); ip != "" {
r2 := *r
r2.RemoteAddr = ip
next.ServeHTTP(w, &r2)
return
}
next.ServeHTTP(w, r)
})
}
func extractRealIP(r *http.Request) string {
for _, h := range realIPHeaders {
v := r.Header.Get(h)
if v == "" {
continue
}
if i := strings.IndexByte(v, ','); i >= 0 {
v = v[:i]
}
v = strings.TrimSpace(v)
if net.ParseIP(v) != nil {
return v
}
}
return ""
}
func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool {
var peer netip.Addr
if ap, err := netip.ParseAddrPort(remoteAddr); err == nil {
peer = ap.Addr()
} else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil {
peer = a
} else {
return false
}
for _, p := range trusted {
if p.Contains(peer) {
return true
}
}
return false
}

View file

@ -0,0 +1,198 @@
package middleware
import (
"net/http"
"net/http/httptest"
"net/netip"
"testing"
)
const defaultTestRemoteAddr = "192.0.2.1:1234"
func captureRemoteAddr(got *string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
*got = r.RemoteAddr
})
}
func TestRealIP(t *testing.T) {
cases := []struct {
name string
headers map[string]string
want string
}{
{
name: "no headers",
want: defaultTestRemoteAddr,
},
{
name: "CF-Connecting-IP",
headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"},
want: "203.0.113.5",
},
{
name: "True-Client-IP",
headers: map[string]string{"True-Client-IP": "203.0.113.6"},
want: "203.0.113.6",
},
{
name: "X-Real-IP",
headers: map[string]string{"X-Real-IP": "203.0.113.7"},
want: "203.0.113.7",
},
{
name: "X-Forwarded-For single",
headers: map[string]string{"X-Forwarded-For": "203.0.113.8"},
want: "203.0.113.8",
},
{
name: "X-Forwarded-For list",
headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"},
want: "203.0.113.9",
},
{
name: "X-Forwarded-For with spaces",
headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"},
want: "203.0.113.10",
},
{
name: "precedence CF over XFF",
headers: map[string]string{
"CF-Connecting-IP": "203.0.113.11",
"X-Forwarded-For": "198.51.100.1",
},
want: "203.0.113.11",
},
{
name: "invalid then valid",
headers: map[string]string{
"CF-Connecting-IP": "not-an-ip",
"X-Real-IP": "203.0.113.12",
},
want: "203.0.113.12",
},
{
name: "IPv6",
headers: map[string]string{"X-Real-IP": "2001:db8::1"},
want: "2001:db8::1",
},
{
name: "all invalid falls through",
headers: map[string]string{"CF-Connecting-IP": "garbage"},
want: defaultTestRemoteAddr,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var got string
h := RealIP(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
for k, v := range tc.headers {
req.Header.Set(k, v)
}
h.ServeHTTP(httptest.NewRecorder(), req)
if got != tc.want {
t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want)
}
})
}
}
func TestRealIPDoesNotMutateCallerRequest(t *testing.T) {
var seen string
h := RealIP(captureRemoteAddr(&seen))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Real-IP", "203.0.113.13")
original := req.RemoteAddr
h.ServeHTTP(httptest.NewRecorder(), req)
if seen != "203.0.113.13" {
t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13")
}
if req.RemoteAddr != original {
t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original)
}
}
func TestRealIPWithTrustedPeer(t *testing.T) {
var got string
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.1.2.3:55555"
req.Header.Set("X-Real-IP", "203.0.113.20")
h.ServeHTTP(httptest.NewRecorder(), req)
if got != "203.0.113.20" {
t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20")
}
}
func TestRealIPWithUntrustedPeer(t *testing.T) {
var got string
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.99:55555"
req.Header.Set("X-Real-IP", "198.51.100.1")
h.ServeHTTP(httptest.NewRecorder(), req)
if got != "203.0.113.99:55555" {
t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555")
}
}
func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) {
var got string
h := RealIPWith(
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("2001:db8::/32"),
)(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "[2001:db8::abcd]:55555"
req.Header.Set("X-Real-IP", "203.0.113.30")
h.ServeHTTP(httptest.NewRecorder(), req)
if got != "203.0.113.30" {
t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30")
}
}
func TestRealIPWithZeroPrefixesNoOp(t *testing.T) {
var got string
h := RealIPWith()(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("CF-Connecting-IP", "203.0.113.40")
original := req.RemoteAddr
h.ServeHTTP(httptest.NewRecorder(), req)
if got != original {
t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original)
}
}
func TestRealIPWithUnparseableRemoteAddr(t *testing.T) {
var got string
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "not-an-addr"
req.Header.Set("CF-Connecting-IP", "203.0.113.50")
h.ServeHTTP(httptest.NewRecorder(), req)
if got != "not-an-addr" {
t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got)
}
}

View file

@ -166,4 +166,3 @@ func TestConflictingPatternsPanic(t *testing.T) {
}() }()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {}) m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
} }

View file

@ -4,8 +4,8 @@ import "testing"
func TestSplitPattern(t *testing.T) { func TestSplitPattern(t *testing.T) {
cases := []struct { cases := []struct {
in string in string
method, host, path string method, host, path string
}{ }{
{"/foo", "", "", "/foo"}, {"/foo", "", "", "/foo"},
{"GET /foo", "GET", "", "/foo"}, {"GET /foo", "GET", "", "/foo"},