Merge branch 'get-real-ip'
This commit is contained in:
commit
d983cca25e
7 changed files with 298 additions and 5 deletions
10
README.md
10
README.md
|
|
@ -84,15 +84,23 @@ Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-com
|
||||||
|
|
||||||
The `pkg/middleware` package ships:
|
The `pkg/middleware` package ships:
|
||||||
|
|
||||||
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger.
|
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP.
|
||||||
- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default.
|
- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default.
|
||||||
- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response.
|
- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response.
|
||||||
|
- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy.
|
||||||
|
- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
custom := splinter.New(splinter.WithStream(...))
|
custom := splinter.New(splinter.WithStream(...))
|
||||||
mux.Use(middleware.Recoverer, middleware.LoggerWith(custom))
|
mux.Use(middleware.Recoverer, middleware.LoggerWith(custom))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address:
|
||||||
|
|
||||||
|
```go
|
||||||
|
mux.Use(middleware.RealIP, middleware.Logger)
|
||||||
|
```
|
||||||
|
|
||||||
## Path parameters
|
## Path parameters
|
||||||
|
|
||||||
lightmux is a thin wrapper, so path parameters work the stdlib way:
|
lightmux is a thin wrapper, so path parameters work the stdlib way:
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
"status", rec.status,
|
"status", rec.status,
|
||||||
"duration", time.Since(start),
|
"duration", time.Since(start),
|
||||||
|
"client", r.RemoteAddr,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -61,6 +62,7 @@ func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
"status", rec.status,
|
"status", rec.status,
|
||||||
"duration", time.Since(start),
|
"duration", time.Since(start),
|
||||||
|
"client", r.RemoteAddr,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ func TestLogger(t *testing.T) {
|
||||||
t.Errorf("status code = %d, want 418", rr.Code)
|
t.Errorf("status code = %d, want 418", rr.Code)
|
||||||
}
|
}
|
||||||
out := buf.String()
|
out := buf.String()
|
||||||
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`} {
|
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} {
|
||||||
if !strings.Contains(out, want) {
|
if !strings.Contains(out, want) {
|
||||||
t.Errorf("log output missing %s\nfull output: %s", want, out)
|
t.Errorf("log output missing %s\nfull output: %s", want, out)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
86
pkg/middleware/realip.go
Normal file
86
pkg/middleware/realip.go
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var realIPHeaders = []string{
|
||||||
|
"CF-Connecting-IP",
|
||||||
|
"True-Client-IP",
|
||||||
|
"X-Real-IP",
|
||||||
|
"X-Forwarded-For",
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealIP rewrites r.RemoteAddr with the originating client IP found in common
|
||||||
|
// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers —
|
||||||
|
// only register it when the service sits behind a trusted proxy. Use
|
||||||
|
// RealIPWith for an allowlist-gated variant.
|
||||||
|
func RealIP(next http.Handler) http.Handler {
|
||||||
|
return realIPHandler(nil, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealIPWith returns a RealIP middleware that only honors the proxy headers
|
||||||
|
// when the immediate peer (parsed from r.RemoteAddr) falls within one of the
|
||||||
|
// trusted prefixes. Requests from outside the allowlist are passed through
|
||||||
|
// untouched.
|
||||||
|
func RealIPWith(trusted ...netip.Prefix) Middleware {
|
||||||
|
if trusted == nil {
|
||||||
|
trusted = []netip.Prefix{}
|
||||||
|
}
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return realIPHandler(trusted, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ip := extractRealIP(r); ip != "" {
|
||||||
|
r2 := *r
|
||||||
|
r2.RemoteAddr = ip
|
||||||
|
next.ServeHTTP(w, &r2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractRealIP(r *http.Request) string {
|
||||||
|
for _, h := range realIPHeaders {
|
||||||
|
v := r.Header.Get(h)
|
||||||
|
if v == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||||
|
v = v[:i]
|
||||||
|
}
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
if net.ParseIP(v) != nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool {
|
||||||
|
var peer netip.Addr
|
||||||
|
if ap, err := netip.ParseAddrPort(remoteAddr); err == nil {
|
||||||
|
peer = ap.Addr()
|
||||||
|
} else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil {
|
||||||
|
peer = a
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, p := range trusted {
|
||||||
|
if p.Contains(peer) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
198
pkg/middleware/realip_test.go
Normal file
198
pkg/middleware/realip_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTestRemoteAddr = "192.0.2.1:1234"
|
||||||
|
|
||||||
|
func captureRemoteAddr(got *string) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
*got = r.RemoteAddr
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIP(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
headers map[string]string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no headers",
|
||||||
|
want: defaultTestRemoteAddr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CF-Connecting-IP",
|
||||||
|
headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"},
|
||||||
|
want: "203.0.113.5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "True-Client-IP",
|
||||||
|
headers: map[string]string{"True-Client-IP": "203.0.113.6"},
|
||||||
|
want: "203.0.113.6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Real-IP",
|
||||||
|
headers: map[string]string{"X-Real-IP": "203.0.113.7"},
|
||||||
|
want: "203.0.113.7",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For single",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "203.0.113.8"},
|
||||||
|
want: "203.0.113.8",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For list",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"},
|
||||||
|
want: "203.0.113.9",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For with spaces",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"},
|
||||||
|
want: "203.0.113.10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "precedence CF over XFF",
|
||||||
|
headers: map[string]string{
|
||||||
|
"CF-Connecting-IP": "203.0.113.11",
|
||||||
|
"X-Forwarded-For": "198.51.100.1",
|
||||||
|
},
|
||||||
|
want: "203.0.113.11",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid then valid",
|
||||||
|
headers: map[string]string{
|
||||||
|
"CF-Connecting-IP": "not-an-ip",
|
||||||
|
"X-Real-IP": "203.0.113.12",
|
||||||
|
},
|
||||||
|
want: "203.0.113.12",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6",
|
||||||
|
headers: map[string]string{"X-Real-IP": "2001:db8::1"},
|
||||||
|
want: "2001:db8::1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all invalid falls through",
|
||||||
|
headers: map[string]string{"CF-Connecting-IP": "garbage"},
|
||||||
|
want: defaultTestRemoteAddr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIP(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
for k, v := range tc.headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPDoesNotMutateCallerRequest(t *testing.T) {
|
||||||
|
var seen string
|
||||||
|
h := RealIP(captureRemoteAddr(&seen))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.13")
|
||||||
|
original := req.RemoteAddr
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if seen != "203.0.113.13" {
|
||||||
|
t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13")
|
||||||
|
}
|
||||||
|
if req.RemoteAddr != original {
|
||||||
|
t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPWithTrustedPeer(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "10.1.2.3:55555"
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.20")
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != "203.0.113.20" {
|
||||||
|
t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPWithUntrustedPeer(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.99:55555"
|
||||||
|
req.Header.Set("X-Real-IP", "198.51.100.1")
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != "203.0.113.99:55555" {
|
||||||
|
t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIPWith(
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("2001:db8::/32"),
|
||||||
|
)(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "[2001:db8::abcd]:55555"
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.30")
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != "203.0.113.30" {
|
||||||
|
t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPWithZeroPrefixesNoOp(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIPWith()(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("CF-Connecting-IP", "203.0.113.40")
|
||||||
|
original := req.RemoteAddr
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != original {
|
||||||
|
t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealIPWithUnparseableRemoteAddr(t *testing.T) {
|
||||||
|
var got string
|
||||||
|
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "not-an-addr"
|
||||||
|
req.Header.Set("CF-Connecting-IP", "203.0.113.50")
|
||||||
|
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
if got != "not-an-addr" {
|
||||||
|
t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -166,4 +166,3 @@ func TestConflictingPatternsPanic(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue