extract middlewares to lightmux-contrib
Removes pkg/middleware. The Logger, Recoverer, and RealIP middlewares now live in the sibling lightmux-contrib module as the realip, requestlog, and recoverer packages, each exposing a single New(...) constructor. The Middleware type alias moves to pkg/router. The splinter dependency is dropped from go.mod; only errx remains. BREAKING CHANGE: consumers must replace pkg/middleware imports with the corresponding lightmux-contrib sub-packages. See README for the new usage. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
d983cca25e
commit
22277186ae
15 changed files with 44 additions and 612 deletions
31
README.md
31
README.md
|
|
@ -22,12 +22,10 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/lightmux"
|
"git.juancwu.dev/juancwu/lightmux"
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
mux := lightmux.New()
|
mux := lightmux.New()
|
||||||
mux.Use(middleware.Recoverer, middleware.Logger)
|
|
||||||
|
|
||||||
mux.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
mux.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintln(w, "hello from lightmux")
|
fmt.Fprintln(w, "hello from lightmux")
|
||||||
|
|
@ -71,7 +69,7 @@ There are two ways to attach middleware: chain-style with `Use`, or per-route as
|
||||||
|
|
||||||
```go
|
```go
|
||||||
mux := lightmux.New()
|
mux := lightmux.New()
|
||||||
mux.Use(middleware.Recoverer, middleware.Logger) // applies to every route registered after
|
mux.Use(authMiddleware) // applies to every route registered after
|
||||||
|
|
||||||
mux.Get("/admin", adminHandler, requireAdmin) // requireAdmin only on this route
|
mux.Get("/admin", adminHandler, requireAdmin) // requireAdmin only on this route
|
||||||
```
|
```
|
||||||
|
|
@ -80,25 +78,20 @@ Order is outer → inner: `Use`-order first, then per-route mws. The outermost m
|
||||||
|
|
||||||
Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-compatible middleware works without an adapter.
|
Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-compatible middleware works without an adapter.
|
||||||
|
|
||||||
## Built-in middleware
|
An opinionated set of middlewares (request logging, panic recovery, real-IP resolution) lives in a sibling module — see [lightmux-contrib](https://git.juancwu.dev/juancwu/lightmux-contrib):
|
||||||
|
|
||||||
The `pkg/middleware` package ships:
|
```sh
|
||||||
|
go get git.juancwu.dev/juancwu/lightmux-contrib
|
||||||
- **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP.
|
|
||||||
- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default.
|
|
||||||
- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response.
|
|
||||||
- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy.
|
|
||||||
- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched.
|
|
||||||
|
|
||||||
```go
|
|
||||||
custom := splinter.New(splinter.WithStream(...))
|
|
||||||
mux.Use(middleware.Recoverer, middleware.LoggerWith(custom))
|
|
||||||
```
|
```
|
||||||
|
|
||||||
When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address:
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
mux.Use(middleware.RealIP, middleware.Logger)
|
import (
|
||||||
|
"git.juancwu.dev/juancwu/lightmux-contrib/realip"
|
||||||
|
"git.juancwu.dev/juancwu/lightmux-contrib/recoverer"
|
||||||
|
"git.juancwu.dev/juancwu/lightmux-contrib/requestlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
mux.Use(recoverer.New(), realip.New(), requestlog.New(nil))
|
||||||
```
|
```
|
||||||
|
|
||||||
## Path parameters
|
## Path parameters
|
||||||
|
|
@ -114,5 +107,5 @@ mux.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- `mux.Get("/", h)` registers `"GET /"`, which is a stdlib **subtree** pattern — it matches every unmatched path. Use `"/{$}"` to match only the literal root.
|
- `mux.Get("/", h)` registers `"GET /"`, which is a stdlib **subtree** pattern — it matches every unmatched path. Use `"/{$}"` to match only the literal root.
|
||||||
- Bad route registrations (invalid prefix, conflicting wildcards) panic at startup, matching stdlib `http.ServeMux` behavior. `Recoverer` handles panics that occur *inside* request handlers.
|
- Bad route registrations (invalid prefix, conflicting wildcards) panic at startup, matching stdlib `http.ServeMux` behavior. The `recoverer` middleware in [lightmux-contrib](https://git.juancwu.dev/juancwu/lightmux-contrib) handles panics that occur *inside* request handlers.
|
||||||
- Group prefixes apply to the path only — they never inject a host. Routes can still carry an explicit host via `Handle("GET host.com/path", h)`.
|
- Group prefixes apply to the path only — they never inject a host. Routes can still carry an explicit host via `Handle("GET host.com/path", h)`.
|
||||||
|
|
|
||||||
5
go.mod
5
go.mod
|
|
@ -2,7 +2,4 @@ module git.juancwu.dev/juancwu/lightmux
|
||||||
|
|
||||||
go 1.26.2
|
go 1.26.2
|
||||||
|
|
||||||
require (
|
require git.juancwu.dev/juancwu/errx v0.1.0
|
||||||
git.juancwu.dev/juancwu/errx v0.1.0
|
|
||||||
git.juancwu.dev/juancwu/splinter v0.1.0
|
|
||||||
)
|
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -1,4 +1,2 @@
|
||||||
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
||||||
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
||||||
git.juancwu.dev/juancwu/splinter v0.1.0 h1:ZGvvzyi24hZw/yFAwpUsHtj+q+fh9I2KIGmOAILWD5Q=
|
|
||||||
git.juancwu.dev/juancwu/splinter v0.1.0/go.mod h1:dAYsRQfS6tqWynEGz8xMCtIJUN7+KIp3jLE7kgO3yKE=
|
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,11 @@
|
||||||
// adding method-named convenience methods, groups, and per-route middleware.
|
// adding method-named convenience methods, groups, and per-route middleware.
|
||||||
package lightmux
|
package lightmux
|
||||||
|
|
||||||
import (
|
import "git.juancwu.dev/juancwu/lightmux/pkg/router"
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/router"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Mux = router.Mux
|
Mux = router.Mux
|
||||||
Middleware = middleware.Middleware
|
Middleware = router.Middleware
|
||||||
)
|
)
|
||||||
|
|
||||||
func New() *Mux { return router.New() }
|
func New() *Mux { return router.New() }
|
||||||
|
|
|
||||||
|
|
@ -1,68 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/splinter"
|
|
||||||
)
|
|
||||||
|
|
||||||
type statusRecorder struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
status int
|
|
||||||
wrote bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *statusRecorder) WriteHeader(code int) {
|
|
||||||
if !s.wrote {
|
|
||||||
s.status = code
|
|
||||||
s.wrote = true
|
|
||||||
}
|
|
||||||
s.ResponseWriter.WriteHeader(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *statusRecorder) Write(b []byte) (int, error) {
|
|
||||||
if !s.wrote {
|
|
||||||
s.status = http.StatusOK
|
|
||||||
s.wrote = true
|
|
||||||
}
|
|
||||||
return s.ResponseWriter.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger uses splinter.Default() resolved at request time.
|
|
||||||
func Logger(next http.Handler) http.Handler {
|
|
||||||
return loggerHandler(nil, next)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoggerWith returns a Logger middleware backed by the given splinter logger.
|
|
||||||
// Pass nil to fall back to splinter.Default() (equivalent to Logger).
|
|
||||||
func LoggerWith(l *splinter.Logger) Middleware {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return loggerHandler(l, next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
start := time.Now()
|
|
||||||
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
|
||||||
next.ServeHTTP(rec, r)
|
|
||||||
if l == nil {
|
|
||||||
splinter.Info("http.request",
|
|
||||||
"method", r.Method,
|
|
||||||
"path", r.URL.Path,
|
|
||||||
"status", rec.status,
|
|
||||||
"duration", time.Since(start),
|
|
||||||
"client", r.RemoteAddr,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Info("http.request",
|
|
||||||
"method", r.Method,
|
|
||||||
"path", r.URL.Path,
|
|
||||||
"status", rec.status,
|
|
||||||
"duration", time.Since(start),
|
|
||||||
"client", r.RemoteAddr,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/splinter"
|
|
||||||
)
|
|
||||||
|
|
||||||
func captureSplinter(t *testing.T) *bytes.Buffer {
|
|
||||||
t.Helper()
|
|
||||||
var buf bytes.Buffer
|
|
||||||
logger := splinter.New(splinter.WithStream(splinter.NewConsoleStream(
|
|
||||||
splinter.ConsoleJSON,
|
|
||||||
splinter.LevelDebug,
|
|
||||||
splinter.ConsoleWriter(&buf),
|
|
||||||
)))
|
|
||||||
prev := splinter.SetDefault(logger)
|
|
||||||
t.Cleanup(func() { splinter.SetDefault(prev) })
|
|
||||||
return &buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger(t *testing.T) {
|
|
||||||
buf := captureSplinter(t)
|
|
||||||
|
|
||||||
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
}))
|
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
|
||||||
|
|
||||||
if rr.Code != http.StatusTeapot {
|
|
||||||
t.Errorf("status code = %d, want 418", rr.Code)
|
|
||||||
}
|
|
||||||
out := buf.String()
|
|
||||||
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} {
|
|
||||||
if !strings.Contains(out, want) {
|
|
||||||
t.Errorf("log output missing %s\nfull output: %s", want, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerDefaultStatusOK(t *testing.T) {
|
|
||||||
buf := captureSplinter(t)
|
|
||||||
|
|
||||||
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte("hi"))
|
|
||||||
}))
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), `"status":200`) {
|
|
||||||
t.Errorf("expected default status 200 in log, got %q", buf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerWith(t *testing.T) {
|
|
||||||
defaultBuf := captureSplinter(t)
|
|
||||||
|
|
||||||
var customBuf bytes.Buffer
|
|
||||||
custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream(
|
|
||||||
splinter.ConsoleJSON,
|
|
||||||
splinter.LevelDebug,
|
|
||||||
splinter.ConsoleWriter(&customBuf),
|
|
||||||
)))
|
|
||||||
|
|
||||||
h := LoggerWith(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
|
||||||
}))
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil))
|
|
||||||
|
|
||||||
if !strings.Contains(customBuf.String(), `"path":"/x"`) {
|
|
||||||
t.Errorf("custom logger did not receive record: %q", customBuf.String())
|
|
||||||
}
|
|
||||||
if defaultBuf.Len() != 0 {
|
|
||||||
t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerWithNilFallsBackToDefault(t *testing.T) {
|
|
||||||
buf := captureSplinter(t)
|
|
||||||
|
|
||||||
h := LoggerWith(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil))
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), `"path":"/y"`) {
|
|
||||||
t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
type Middleware = func(http.Handler) http.Handler
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/netip"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
var realIPHeaders = []string{
|
|
||||||
"CF-Connecting-IP",
|
|
||||||
"True-Client-IP",
|
|
||||||
"X-Real-IP",
|
|
||||||
"X-Forwarded-For",
|
|
||||||
}
|
|
||||||
|
|
||||||
// RealIP rewrites r.RemoteAddr with the originating client IP found in common
|
|
||||||
// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers —
|
|
||||||
// only register it when the service sits behind a trusted proxy. Use
|
|
||||||
// RealIPWith for an allowlist-gated variant.
|
|
||||||
func RealIP(next http.Handler) http.Handler {
|
|
||||||
return realIPHandler(nil, next)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RealIPWith returns a RealIP middleware that only honors the proxy headers
|
|
||||||
// when the immediate peer (parsed from r.RemoteAddr) falls within one of the
|
|
||||||
// trusted prefixes. Requests from outside the allowlist are passed through
|
|
||||||
// untouched.
|
|
||||||
func RealIPWith(trusted ...netip.Prefix) Middleware {
|
|
||||||
if trusted == nil {
|
|
||||||
trusted = []netip.Prefix{}
|
|
||||||
}
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return realIPHandler(trusted, next)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ip := extractRealIP(r); ip != "" {
|
|
||||||
r2 := *r
|
|
||||||
r2.RemoteAddr = ip
|
|
||||||
next.ServeHTTP(w, &r2)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractRealIP(r *http.Request) string {
|
|
||||||
for _, h := range realIPHeaders {
|
|
||||||
v := r.Header.Get(h)
|
|
||||||
if v == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
|
||||||
v = v[:i]
|
|
||||||
}
|
|
||||||
v = strings.TrimSpace(v)
|
|
||||||
if net.ParseIP(v) != nil {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool {
|
|
||||||
var peer netip.Addr
|
|
||||||
if ap, err := netip.ParseAddrPort(remoteAddr); err == nil {
|
|
||||||
peer = ap.Addr()
|
|
||||||
} else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil {
|
|
||||||
peer = a
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, p := range trusted {
|
|
||||||
if p.Contains(peer) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
@ -1,198 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultTestRemoteAddr = "192.0.2.1:1234"
|
|
||||||
|
|
||||||
func captureRemoteAddr(got *string) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
*got = r.RemoteAddr
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIP(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
headers map[string]string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no headers",
|
|
||||||
want: defaultTestRemoteAddr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "CF-Connecting-IP",
|
|
||||||
headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"},
|
|
||||||
want: "203.0.113.5",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "True-Client-IP",
|
|
||||||
headers: map[string]string{"True-Client-IP": "203.0.113.6"},
|
|
||||||
want: "203.0.113.6",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Real-IP",
|
|
||||||
headers: map[string]string{"X-Real-IP": "203.0.113.7"},
|
|
||||||
want: "203.0.113.7",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Forwarded-For single",
|
|
||||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.8"},
|
|
||||||
want: "203.0.113.8",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Forwarded-For list",
|
|
||||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"},
|
|
||||||
want: "203.0.113.9",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "X-Forwarded-For with spaces",
|
|
||||||
headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"},
|
|
||||||
want: "203.0.113.10",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "precedence CF over XFF",
|
|
||||||
headers: map[string]string{
|
|
||||||
"CF-Connecting-IP": "203.0.113.11",
|
|
||||||
"X-Forwarded-For": "198.51.100.1",
|
|
||||||
},
|
|
||||||
want: "203.0.113.11",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid then valid",
|
|
||||||
headers: map[string]string{
|
|
||||||
"CF-Connecting-IP": "not-an-ip",
|
|
||||||
"X-Real-IP": "203.0.113.12",
|
|
||||||
},
|
|
||||||
want: "203.0.113.12",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6",
|
|
||||||
headers: map[string]string{"X-Real-IP": "2001:db8::1"},
|
|
||||||
want: "2001:db8::1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "all invalid falls through",
|
|
||||||
headers: map[string]string{"CF-Connecting-IP": "garbage"},
|
|
||||||
want: defaultTestRemoteAddr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIP(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
for k, v := range tc.headers {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPDoesNotMutateCallerRequest(t *testing.T) {
|
|
||||||
var seen string
|
|
||||||
h := RealIP(captureRemoteAddr(&seen))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.Header.Set("X-Real-IP", "203.0.113.13")
|
|
||||||
original := req.RemoteAddr
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if seen != "203.0.113.13" {
|
|
||||||
t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13")
|
|
||||||
}
|
|
||||||
if req.RemoteAddr != original {
|
|
||||||
t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPWithTrustedPeer(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "10.1.2.3:55555"
|
|
||||||
req.Header.Set("X-Real-IP", "203.0.113.20")
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != "203.0.113.20" {
|
|
||||||
t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPWithUntrustedPeer(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "203.0.113.99:55555"
|
|
||||||
req.Header.Set("X-Real-IP", "198.51.100.1")
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != "203.0.113.99:55555" {
|
|
||||||
t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIPWith(
|
|
||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
netip.MustParsePrefix("2001:db8::/32"),
|
|
||||||
)(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "[2001:db8::abcd]:55555"
|
|
||||||
req.Header.Set("X-Real-IP", "203.0.113.30")
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != "203.0.113.30" {
|
|
||||||
t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPWithZeroPrefixesNoOp(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIPWith()(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.Header.Set("CF-Connecting-IP", "203.0.113.40")
|
|
||||||
original := req.RemoteAddr
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != original {
|
|
||||||
t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRealIPWithUnparseableRemoteAddr(t *testing.T) {
|
|
||||||
var got string
|
|
||||||
h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.RemoteAddr = "not-an-addr"
|
|
||||||
req.Header.Set("CF-Connecting-IP", "203.0.113.50")
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
||||||
|
|
||||||
if got != "not-an-addr" {
|
|
||||||
t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"runtime/debug"
|
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/errx"
|
|
||||||
)
|
|
||||||
|
|
||||||
const recovererOp = "middleware.Recoverer"
|
|
||||||
|
|
||||||
func Recoverer(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer func() {
|
|
||||||
rec := recover()
|
|
||||||
if rec == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if e, ok := rec.(error); ok {
|
|
||||||
err = errx.Wrap(recovererOp, e)
|
|
||||||
} else {
|
|
||||||
err = errx.Newf(recovererOp, "panic: %v", rec)
|
|
||||||
}
|
|
||||||
log.Printf("%v\n%s", err, debug.Stack())
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
}()
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func captureLog(t *testing.T) *bytes.Buffer {
|
|
||||||
t.Helper()
|
|
||||||
var buf bytes.Buffer
|
|
||||||
orig := log.Default().Writer()
|
|
||||||
log.Default().SetOutput(&buf)
|
|
||||||
t.Cleanup(func() { log.Default().SetOutput(orig) })
|
|
||||||
return &buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecovererCatchesStringPanic(t *testing.T) {
|
|
||||||
buf := captureLog(t)
|
|
||||||
|
|
||||||
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
panic("boom")
|
|
||||||
}))
|
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
||||||
|
|
||||||
if rr.Code != http.StatusInternalServerError {
|
|
||||||
t.Errorf("status = %d, want 500", rr.Code)
|
|
||||||
}
|
|
||||||
out := buf.String()
|
|
||||||
for _, want := range []string{"middleware.Recoverer", "panic: boom"} {
|
|
||||||
if !strings.Contains(out, want) {
|
|
||||||
t.Errorf("log missing %q\nfull: %s", want, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecovererWrapsErrorPanic(t *testing.T) {
|
|
||||||
buf := captureLog(t)
|
|
||||||
|
|
||||||
cause := errors.New("db down")
|
|
||||||
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
panic(cause)
|
|
||||||
}))
|
|
||||||
|
|
||||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
|
||||||
|
|
||||||
out := buf.String()
|
|
||||||
if !strings.Contains(out, "middleware.Recoverer: db down") {
|
|
||||||
t.Errorf("expected errx-wrapped breadcrumb, got: %s", out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecovererPassesThrough(t *testing.T) {
|
|
||||||
called := false
|
|
||||||
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
called = true
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
|
||||||
if !called || rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,15 +1,11 @@
|
||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import "net/http"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
|
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
|
||||||
// outermost layer (runs first on request, last on response).
|
// outermost layer (runs first on request, last on response).
|
||||||
func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler {
|
func chain(h http.Handler, groupMws, routeMws []Middleware) http.Handler {
|
||||||
all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws))
|
all := make([]Middleware, 0, len(groupMws)+len(routeMws))
|
||||||
all = append(all, groupMws...)
|
all = append(all, groupMws...)
|
||||||
all = append(all, routeMws...)
|
all = append(all, routeMws...)
|
||||||
for i := len(all) - 1; i >= 0; i-- {
|
for i := len(all) - 1; i >= 0; i-- {
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,9 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func tagMW(log *[]string, tag string) middleware.Middleware {
|
func tagMW(log *[]string, tag string) Middleware {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
*log = append(*log, tag+":before")
|
*log = append(*log, tag+":before")
|
||||||
|
|
@ -26,8 +24,8 @@ func TestChainOrder(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
wrapped := chain(h,
|
wrapped := chain(h,
|
||||||
[]middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
|
[]Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
|
||||||
[]middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
|
[]Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
|
||||||
)
|
)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
|
||||||
7
pkg/router/middleware.go
Normal file
7
pkg/router/middleware.go
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Middleware is the standard func(http.Handler) http.Handler middleware shape,
|
||||||
|
// compatible with any stdlib-style middleware.
|
||||||
|
type Middleware = func(http.Handler) http.Handler
|
||||||
|
|
@ -4,16 +4,12 @@
|
||||||
// underlying mux while carrying their own prefix and middleware stack.
|
// underlying mux while carrying their own prefix and middleware stack.
|
||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import "net/http"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Mux struct {
|
type Mux struct {
|
||||||
root *http.ServeMux
|
root *http.ServeMux
|
||||||
prefix string
|
prefix string
|
||||||
middlewares []middleware.Middleware
|
middlewares []Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() *Mux {
|
func New() *Mux {
|
||||||
|
|
@ -21,16 +17,16 @@ func New() *Mux {
|
||||||
return &Mux{root: sm}
|
return &Mux{root: sm}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Use(mws ...middleware.Middleware) {
|
func (m *Mux) Use(mws ...Middleware) {
|
||||||
m.middlewares = append(m.middlewares, mws...)
|
m.middlewares = append(m.middlewares, mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group returns a child Mux that registers on the same underlying ServeMux but
|
// Group returns a child Mux that registers on the same underlying ServeMux but
|
||||||
// with its prefix appended and the parent's current middlewares snapshotted.
|
// with its prefix appended and the parent's current middlewares snapshotted.
|
||||||
// Use() calls made on the parent after Group() do not propagate to the child.
|
// Use() calls made on the parent after Group() do not propagate to the child.
|
||||||
func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
|
func (m *Mux) Group(prefix string, mws ...Middleware) *Mux {
|
||||||
validateGroupPrefix(prefix)
|
validateGroupPrefix(prefix)
|
||||||
mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws))
|
mwsCopy := make([]Middleware, 0, len(m.middlewares)+len(mws))
|
||||||
mwsCopy = append(mwsCopy, m.middlewares...)
|
mwsCopy = append(mwsCopy, m.middlewares...)
|
||||||
mwsCopy = append(mwsCopy, mws...)
|
mwsCopy = append(mwsCopy, mws...)
|
||||||
return &Mux{
|
return &Mux{
|
||||||
|
|
@ -40,45 +36,45 @@ func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) {
|
func (m *Mux) Handle(pattern string, h http.Handler, mws ...Middleware) {
|
||||||
full := buildPattern("", m.prefix, pattern)
|
full := buildPattern("", m.prefix, pattern)
|
||||||
m.root.Handle(full, chain(h, m.middlewares, mws))
|
m.root.Handle(full, chain(h, m.middlewares, mws))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.Handle(pattern, fn, mws...)
|
m.Handle(pattern, fn, mws...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) {
|
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []Middleware) {
|
||||||
full := buildPattern(method, m.prefix, path)
|
full := buildPattern(method, m.prefix, path)
|
||||||
m.root.Handle(full, chain(fn, m.middlewares, mws))
|
m.root.Handle(full, chain(fn, m.middlewares, mws))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodGet, path, fn, mws)
|
m.method(http.MethodGet, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodPost, path, fn, mws)
|
m.method(http.MethodPost, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodPut, path, fn, mws)
|
m.method(http.MethodPut, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodPatch, path, fn, mws)
|
m.method(http.MethodPatch, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodDelete, path, fn, mws)
|
m.method(http.MethodDelete, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodOptions, path, fn, mws)
|
m.method(http.MethodOptions, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...Middleware) {
|
||||||
m.method(http.MethodHead, path, fn, mws)
|
m.method(http.MethodHead, path, fn, mws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue