Merge branch 'get-real-ip'
This commit is contained in:
commit
d983cca25e
7 changed files with 298 additions and 5 deletions
10
README.md
10
README.md
|
|
@ -84,15 +84,23 @@ Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-com
|
|||
|
||||
The `pkg/middleware` package ships:
|
||||
|
||||
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger.
|
||||
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP.
|
||||
- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default.
|
||||
- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response.
|
||||
- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy.
|
||||
- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched.
|
||||
|
||||
```go
|
||||
custom := splinter.New(splinter.WithStream(...))
|
||||
mux.Use(middleware.Recoverer, middleware.LoggerWith(custom))
|
||||
```
|
||||
|
||||
When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address:
|
||||
|
||||
```go
|
||||
mux.Use(middleware.RealIP, middleware.Logger)
|
||||
```
|
||||
|
||||
## Path parameters
|
||||
|
||||
lightmux is a thin wrapper, so path parameters work the stdlib way:
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
|
|||
"path", r.URL.Path,
|
||||
"status", rec.status,
|
||||
"duration", time.Since(start),
|
||||
"client", r.RemoteAddr,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
|
@ -61,6 +62,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
|
|||
"path", r.URL.Path,
|
||||
"status", rec.status,
|
||||
"duration", time.Since(start),
|
||||
"client", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ func TestLogger(t *testing.T) {
|
|||
t.Errorf("status code = %d, want 418", rr.Code)
|
||||
}
|
||||
out := buf.String()
|
||||
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`} {
|
||||
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Errorf("log output missing %s\nfull output: %s", want, out)
|
||||
}
|
||||
|
|
|
|||
86
pkg/middleware/realip.go
Normal file
86
pkg/middleware/realip.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var realIPHeaders = []string{
|
||||
"CF-Connecting-IP",
|
||||
"True-Client-IP",
|
||||
"X-Real-IP",
|
||||
"X-Forwarded-For",
|
||||
}
|
||||
|
||||
// RealIP rewrites r.RemoteAddr with the originating client IP found in common
|
||||
// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers —
|
||||
// only register it when the service sits behind a trusted proxy. Use
|
||||
// RealIPWith for an allowlist-gated variant.
|
||||
func RealIP(next http.Handler) http.Handler {
|
||||
return realIPHandler(nil, next)
|
||||
}
|
||||
|
||||
// RealIPWith returns a RealIP middleware that only honors the proxy headers
|
||||
// when the immediate peer (parsed from r.RemoteAddr) falls within one of the
|
||||
// trusted prefixes. Requests from outside the allowlist are passed through
|
||||
// untouched.
|
||||
func RealIPWith(trusted ...netip.Prefix) Middleware {
|
||||
if trusted == nil {
|
||||
trusted = []netip.Prefix{}
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return realIPHandler(trusted, next)
|
||||
}
|
||||
}
|
||||
|
||||
func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if ip := extractRealIP(r); ip != "" {
|
||||
r2 := *r
|
||||
r2.RemoteAddr = ip
|
||||
next.ServeHTTP(w, &r2)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func extractRealIP(r *http.Request) string {
|
||||
for _, h := range realIPHeaders {
|
||||
v := r.Header.Get(h)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||
v = v[:i]
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if net.ParseIP(v) != nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool {
|
||||
var peer netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(remoteAddr); err == nil {
|
||||
peer = ap.Addr()
|
||||
} else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil {
|
||||
peer = a
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
for _, p := range trusted {
|
||||
if p.Contains(peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
198
pkg/middleware/realip_test.go
Normal file
198
pkg/middleware/realip_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const defaultTestRemoteAddr = "192.0.2.1:1234"
|
||||
|
||||
func captureRemoteAddr(got *string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*got = r.RemoteAddr
|
||||
})
|
||||
}
|
||||
|
||||
func TestRealIP(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no headers",
|
||||
want: defaultTestRemoteAddr,
|
||||
},
|
||||
{
|
||||
name: "CF-Connecting-IP",
|
||||
headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"},
|
||||
want: "203.0.113.5",
|
||||
},
|
||||
{
|
||||
name: "True-Client-IP",
|
||||
headers: map[string]string{"True-Client-IP": "203.0.113.6"},
|
||||
want: "203.0.113.6",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
headers: map[string]string{"X-Real-IP": "203.0.113.7"},
|
||||
want: "203.0.113.7",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For single",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.8"},
|
||||
want: "203.0.113.8",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For list",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"},
|
||||
want: "203.0.113.9",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For with spaces",
|
||||
headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"},
|
||||
want: "203.0.113.10",
|
||||
},
|
||||
{
|
||||
name: "precedence CF over XFF",
|
||||
headers: map[string]string{
|
||||
"CF-Connecting-IP": "203.0.113.11",
|
||||
"X-Forwarded-For": "198.51.100.1",
|
||||
},
|
||||
want: "203.0.113.11",
|
||||
},
|
||||
{
|
||||
name: "invalid then valid",
|
||||
headers: map[string]string{
|
||||
"CF-Connecting-IP": "not-an-ip",
|
||||
"X-Real-IP": "203.0.113.12",
|
||||
},
|
||||
want: "203.0.113.12",
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
headers: map[string]string{"X-Real-IP": "2001:db8::1"},
|
||||
want: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
name: "all invalid falls through",
|
||||
headers: map[string]string{"CF-Connecting-IP": "garbage"},
|
||||
want: defaultTestRemoteAddr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got string
|
||||
h := RealIP(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
for k, v := range tc.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPDoesNotMutateCallerRequest(t *testing.T) {
|
||||
var seen string
|
||||
h := RealIP(captureRemoteAddr(&seen))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Real-IP", "203.0.113.13")
|
||||
original := req.RemoteAddr
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if seen != "203.0.113.13" {
|
||||
t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13")
|
||||
}
|
||||
if req.RemoteAddr != original {
|
||||
t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPWithTrustedPeer(t *testing.T) {
|
||||
var got string
|
||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.1.2.3:55555"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.20")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.20" {
|
||||
t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPWithUntrustedPeer(t *testing.T) {
|
||||
var got string
|
||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "203.0.113.99:55555"
|
||||
req.Header.Set("X-Real-IP", "198.51.100.1")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.99:55555" {
|
||||
t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) {
|
||||
var got string
|
||||
h := RealIPWith(
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
)(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "[2001:db8::abcd]:55555"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.30")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.30" {
|
||||
t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPWithZeroPrefixesNoOp(t *testing.T) {
|
||||
var got string
|
||||
h := RealIPWith()(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("CF-Connecting-IP", "203.0.113.40")
|
||||
original := req.RemoteAddr
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != original {
|
||||
t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealIPWithUnparseableRemoteAddr(t *testing.T) {
|
||||
var got string
|
||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "not-an-addr"
|
||||
req.Header.Set("CF-Connecting-IP", "203.0.113.50")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "not-an-addr" {
|
||||
t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -166,4 +166,3 @@ func TestConflictingPatternsPanic(t *testing.T) {
|
|||
}()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import "testing"
|
|||
|
||||
func TestSplitPattern(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
method, host, path string
|
||||
in string
|
||||
method, host, path string
|
||||
}{
|
||||
{"/foo", "", "", "/foo"},
|
||||
{"GET /foo", "GET", "", "/foo"},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue