From b26ef7439ebfc5486dc794305d86f91dfc8ba2f8 Mon Sep 17 00:00:00 2001 From: juancwu Date: Sun, 26 Apr 2026 14:03:04 +0000 Subject: [PATCH] add realip, requestlog, recoverer middlewares MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Initial implementation of lightmux-contrib, a sibling module to lightmux that hosts opinionated middlewares with one sub-package per middleware: - realip: resolves the originating client IP from CF-Connecting-IP, True-Client-IP, X-Real-IP, or X-Forwarded-For. Optional peer-CIDR allowlist via netip.Prefix. - requestlog: emits a structured http.request record (method, path, status, duration, client) per request via splinter. - recoverer: catches panics, wraps with errx under op "recoverer", logs with stack, and writes a 500 response. Each package exposes a single New(...) constructor returning func(http.Handler) http.Handler. The contrib module intentionally does not import lightmux — middlewares interoperate via the standard stdlib middleware shape. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 49 +++++++++ Taskfile.yml | 43 ++++++++ go.mod | 8 ++ go.sum | 4 + realip/realip.go | 78 +++++++++++++++ realip/realip_test.go | 183 ++++++++++++++++++++++++++++++++++ recoverer/recoverer.go | 36 +++++++ recoverer/recoverer_test.go | 70 +++++++++++++ requestlog/requestlog.go | 61 ++++++++++++ requestlog/requestlog_test.go | 92 +++++++++++++++++ 10 files changed, 624 insertions(+) create mode 100644 Taskfile.yml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 realip/realip.go create mode 100644 realip/realip_test.go create mode 100644 recoverer/recoverer.go create mode 100644 recoverer/recoverer_test.go create mode 100644 requestlog/requestlog.go create mode 100644 requestlog/requestlog_test.go diff --git a/README.md b/README.md index 71c4fbc..9eee35d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,51 @@ # lightmux-contrib +Opinionated middleware collection for [lightmux](https://git.juancwu.dev/juancwu/lightmux). Each middleware lives in its own sub-package so consumers only pull in the dependencies they actually use. + +## Installation + +```sh +go get git.juancwu.dev/juancwu/lightmux-contrib +``` + +## Packages + +### `realip` + +Replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). + +```go +import "git.juancwu.dev/juancwu/lightmux-contrib/realip" + +mux.Use(realip.New()) // always trust headers +mux.Use(realip.New(netip.MustParsePrefix("10.0.0.0/8"))) // gated by peer CIDR +``` + +With no arguments, `realip.New()` always honors the proxy headers — only register it when the service sits behind a trusted proxy. With one or more `netip.Prefix` arguments, the headers are honored only when the immediate peer's IP falls within one of them. + +### `requestlog` + +Emits a structured `http.request` record (method, path, status, duration, client) per request via [splinter](https://git.juancwu.dev/juancwu/splinter). + +```go +import "git.juancwu.dev/juancwu/lightmux-contrib/requestlog" + +mux.Use(requestlog.New(nil)) // splinter.Default() resolved at request time +mux.Use(requestlog.New(custom)) // custom *splinter.Logger +``` + +When pairing with `realip`, register `realip` first so the `client` field is the resolved client IP rather than the proxy peer: + +```go +mux.Use(realip.New(), requestlog.New(nil)) +``` + +### `recoverer` + +Catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `recoverer`, logs it with stack via the standard `log` package, and writes a 500 response. + +```go +import "git.juancwu.dev/juancwu/lightmux-contrib/recoverer" + +mux.Use(recoverer.New()) +``` diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..89710eb --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,43 @@ +version: '3' + +tasks: + default: + desc: Run vet and tests + cmds: + - task: check + + install:tools: + desc: Install development tools (tparse for prettier test output) + cmds: + - go install github.com/mfridman/tparse@latest + + test: + desc: Run tests with prettier output via tparse + cmds: + - set -o pipefail && go test ./... -json -cover | tparse -all + + test:race: + desc: Run tests under the race detector + cmds: + - set -o pipefail && go test ./... -race -json | tparse -all + + vet: + desc: Run go vet + cmds: + - go vet ./... + + fmt: + desc: Format Go source files + cmds: + - gofmt -w . + + tidy: + desc: Tidy go.mod + cmds: + - go mod tidy + + check: + desc: Run vet and tests + cmds: + - task: vet + - task: test diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..60781c4 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module git.juancwu.dev/juancwu/lightmux-contrib + +go 1.26.2 + +require ( + git.juancwu.dev/juancwu/errx v0.1.0 + git.juancwu.dev/juancwu/splinter v0.1.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..207ed5b --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w= +git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk= +git.juancwu.dev/juancwu/splinter v0.1.0 h1:ZGvvzyi24hZw/yFAwpUsHtj+q+fh9I2KIGmOAILWD5Q= +git.juancwu.dev/juancwu/splinter v0.1.0/go.mod h1:dAYsRQfS6tqWynEGz8xMCtIJUN7+KIp3jLE7kgO3yKE= diff --git a/realip/realip.go b/realip/realip.go new file mode 100644 index 0000000..72ceaed --- /dev/null +++ b/realip/realip.go @@ -0,0 +1,78 @@ +// Package realip resolves the originating client IP from common reverse-proxy +// and CDN headers (Cloudflare, nginx) and replaces r.RemoteAddr so downstream +// handlers and middlewares see the real client. +package realip + +import ( + "net" + "net/http" + "net/netip" + "strings" +) + +var headers = []string{ + "CF-Connecting-IP", + "True-Client-IP", + "X-Real-IP", + "X-Forwarded-For", +} + +// New returns a real-IP middleware. +// +// With no trusted prefixes, it always honors the proxy headers — only register +// it when the service sits behind a trusted proxy. +// +// With one or more prefixes, the headers are honored only when the immediate +// peer (parsed from r.RemoteAddr) falls within one of them; requests from +// outside the allowlist pass through untouched. +func New(trusted ...netip.Prefix) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(trusted) > 0 && !peerTrusted(r.RemoteAddr, trusted) { + next.ServeHTTP(w, r) + return + } + if ip := extract(r); ip != "" { + r2 := *r + r2.RemoteAddr = ip + next.ServeHTTP(w, &r2) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func extract(r *http.Request) string { + for _, h := range headers { + v := r.Header.Get(h) + if v == "" { + continue + } + if i := strings.IndexByte(v, ','); i >= 0 { + v = v[:i] + } + v = strings.TrimSpace(v) + if net.ParseIP(v) != nil { + return v + } + } + return "" +} + +func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool { + var peer netip.Addr + if ap, err := netip.ParseAddrPort(remoteAddr); err == nil { + peer = ap.Addr() + } else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil { + peer = a + } else { + return false + } + for _, p := range trusted { + if p.Contains(peer) { + return true + } + } + return false +} diff --git a/realip/realip_test.go b/realip/realip_test.go new file mode 100644 index 0000000..34709a3 --- /dev/null +++ b/realip/realip_test.go @@ -0,0 +1,183 @@ +package realip + +import ( + "net/http" + "net/http/httptest" + "net/netip" + "testing" +) + +const defaultTestRemoteAddr = "192.0.2.1:1234" + +func captureRemoteAddr(got *string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *got = r.RemoteAddr + }) +} + +func TestNew(t *testing.T) { + cases := []struct { + name string + headers map[string]string + want string + }{ + { + name: "no headers", + want: defaultTestRemoteAddr, + }, + { + name: "CF-Connecting-IP", + headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"}, + want: "203.0.113.5", + }, + { + name: "True-Client-IP", + headers: map[string]string{"True-Client-IP": "203.0.113.6"}, + want: "203.0.113.6", + }, + { + name: "X-Real-IP", + headers: map[string]string{"X-Real-IP": "203.0.113.7"}, + want: "203.0.113.7", + }, + { + name: "X-Forwarded-For single", + headers: map[string]string{"X-Forwarded-For": "203.0.113.8"}, + want: "203.0.113.8", + }, + { + name: "X-Forwarded-For list", + headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"}, + want: "203.0.113.9", + }, + { + name: "X-Forwarded-For with spaces", + headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"}, + want: "203.0.113.10", + }, + { + name: "precedence CF over XFF", + headers: map[string]string{ + "CF-Connecting-IP": "203.0.113.11", + "X-Forwarded-For": "198.51.100.1", + }, + want: "203.0.113.11", + }, + { + name: "invalid then valid", + headers: map[string]string{ + "CF-Connecting-IP": "not-an-ip", + "X-Real-IP": "203.0.113.12", + }, + want: "203.0.113.12", + }, + { + name: "IPv6", + headers: map[string]string{"X-Real-IP": "2001:db8::1"}, + want: "2001:db8::1", + }, + { + name: "all invalid falls through", + headers: map[string]string{"CF-Connecting-IP": "garbage"}, + want: defaultTestRemoteAddr, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got string + h := New()(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + for k, v := range tc.headers { + req.Header.Set(k, v) + } + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != tc.want { + t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want) + } + }) + } +} + +func TestNewDoesNotMutateCallerRequest(t *testing.T) { + var seen string + h := New()(captureRemoteAddr(&seen)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Real-IP", "203.0.113.13") + original := req.RemoteAddr + + h.ServeHTTP(httptest.NewRecorder(), req) + + if seen != "203.0.113.13" { + t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13") + } + if req.RemoteAddr != original { + t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original) + } +} + +func TestNewWithTrustedPeer(t *testing.T) { + var got string + h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.1.2.3:55555" + req.Header.Set("X-Real-IP", "203.0.113.20") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.20" { + t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20") + } +} + +func TestNewWithUntrustedPeer(t *testing.T) { + var got string + h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.99:55555" + req.Header.Set("X-Real-IP", "198.51.100.1") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.99:55555" { + t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555") + } +} + +func TestNewWithMultiplePrefixesIPv6(t *testing.T) { + var got string + h := New( + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("2001:db8::/32"), + )(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "[2001:db8::abcd]:55555" + req.Header.Set("X-Real-IP", "203.0.113.30") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "203.0.113.30" { + t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30") + } +} + +func TestNewWithUnparseableRemoteAddr(t *testing.T) { + var got string + h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "not-an-addr" + req.Header.Set("CF-Connecting-IP", "203.0.113.50") + + h.ServeHTTP(httptest.NewRecorder(), req) + + if got != "not-an-addr" { + t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got) + } +} diff --git a/recoverer/recoverer.go b/recoverer/recoverer.go new file mode 100644 index 0000000..832e9ce --- /dev/null +++ b/recoverer/recoverer.go @@ -0,0 +1,36 @@ +// Package recoverer catches panics inside HTTP handlers, logs them with stack +// trace, and writes a 500 response. +package recoverer + +import ( + "log" + "net/http" + "runtime/debug" + + "git.juancwu.dev/juancwu/errx" +) + +const op = "recoverer" + +// New returns a panic-recovery middleware. +func New() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + rec := recover() + if rec == nil { + return + } + var err error + if e, ok := rec.(error); ok { + err = errx.Wrap(op, e) + } else { + err = errx.Newf(op, "panic: %v", rec) + } + log.Printf("%v\n%s", err, debug.Stack()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + }() + next.ServeHTTP(w, r) + }) + } +} diff --git a/recoverer/recoverer_test.go b/recoverer/recoverer_test.go new file mode 100644 index 0000000..8e8c3d4 --- /dev/null +++ b/recoverer/recoverer_test.go @@ -0,0 +1,70 @@ +package recoverer + +import ( + "bytes" + "errors" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func captureLog(t *testing.T) *bytes.Buffer { + t.Helper() + var buf bytes.Buffer + orig := log.Default().Writer() + log.Default().SetOutput(&buf) + t.Cleanup(func() { log.Default().SetOutput(orig) }) + return &buf +} + +func TestNewCatchesStringPanic(t *testing.T) { + buf := captureLog(t) + + h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom") + })) + + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", rr.Code) + } + out := buf.String() + for _, want := range []string{"recoverer", "panic: boom"} { + if !strings.Contains(out, want) { + t.Errorf("log missing %q\nfull: %s", want, out) + } + } +} + +func TestNewWrapsErrorPanic(t *testing.T) { + buf := captureLog(t) + + cause := errors.New("db down") + h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(cause) + })) + + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + + out := buf.String() + if !strings.Contains(out, "recoverer: db down") { + t.Errorf("expected errx-wrapped breadcrumb, got: %s", out) + } +} + +func TestNewPassesThrough(t *testing.T) { + called := false + h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + if !called || rr.Code != http.StatusOK { + t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code) + } +} diff --git a/requestlog/requestlog.go b/requestlog/requestlog.go new file mode 100644 index 0000000..5652d2a --- /dev/null +++ b/requestlog/requestlog.go @@ -0,0 +1,61 @@ +// Package requestlog emits a structured "http.request" record per request +// (method, path, status, duration, client) via splinter. +package requestlog + +import ( + "net/http" + "time" + + "git.juancwu.dev/juancwu/splinter" +) + +type statusRecorder struct { + http.ResponseWriter + status int + wrote bool +} + +func (s *statusRecorder) WriteHeader(code int) { + if !s.wrote { + s.status = code + s.wrote = true + } + s.ResponseWriter.WriteHeader(code) +} + +func (s *statusRecorder) Write(b []byte) (int, error) { + if !s.wrote { + s.status = http.StatusOK + s.wrote = true + } + return s.ResponseWriter.Write(b) +} + +// New returns a request-logging middleware. Pass nil to use splinter.Default() +// resolved at request time; otherwise records flow through the supplied logger. +func New(l *splinter.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rec, r) + if l == nil { + splinter.Info("http.request", + "method", r.Method, + "path", r.URL.Path, + "status", rec.status, + "duration", time.Since(start), + "client", r.RemoteAddr, + ) + return + } + l.Info("http.request", + "method", r.Method, + "path", r.URL.Path, + "status", rec.status, + "duration", time.Since(start), + "client", r.RemoteAddr, + ) + }) + } +} diff --git a/requestlog/requestlog_test.go b/requestlog/requestlog_test.go new file mode 100644 index 0000000..025aa71 --- /dev/null +++ b/requestlog/requestlog_test.go @@ -0,0 +1,92 @@ +package requestlog + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.juancwu.dev/juancwu/splinter" +) + +func captureSplinter(t *testing.T) *bytes.Buffer { + t.Helper() + var buf bytes.Buffer + logger := splinter.New(splinter.WithStream(splinter.NewConsoleStream( + splinter.ConsoleJSON, + splinter.LevelDebug, + splinter.ConsoleWriter(&buf), + ))) + prev := splinter.SetDefault(logger) + t.Cleanup(func() { splinter.SetDefault(prev) }) + return &buf +} + +func TestNew(t *testing.T) { + buf := captureSplinter(t) + + h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil)) + + if rr.Code != http.StatusTeapot { + t.Errorf("status code = %d, want 418", rr.Code) + } + out := buf.String() + for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} { + if !strings.Contains(out, want) { + t.Errorf("log output missing %s\nfull output: %s", want, out) + } + } +} + +func TestNewDefaultStatusOK(t *testing.T) { + buf := captureSplinter(t) + + h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + + if !strings.Contains(buf.String(), `"status":200`) { + t.Errorf("expected default status 200 in log, got %q", buf.String()) + } +} + +func TestNewWithCustomLogger(t *testing.T) { + defaultBuf := captureSplinter(t) + + var customBuf bytes.Buffer + custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream( + splinter.ConsoleJSON, + splinter.LevelDebug, + splinter.ConsoleWriter(&customBuf), + ))) + + h := New(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil)) + + if !strings.Contains(customBuf.String(), `"path":"/x"`) { + t.Errorf("custom logger did not receive record: %q", customBuf.String()) + } + if defaultBuf.Len() != 0 { + t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String()) + } +} + +func TestNewNilFallsBackToDefault(t *testing.T) { + buf := captureSplinter(t) + + h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil)) + + if !strings.Contains(buf.String(), `"path":"/y"`) { + t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String()) + } +}