diff --git a/README.md b/README.md index 5a1370f..06b020d 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,10 @@ import ( "strconv" "git.juancwu.dev/juancwu/lightmux" - "git.juancwu.dev/juancwu/lightmux/pkg/middleware" ) func main() { mux := lightmux.New() - mux.Use(middleware.Recoverer, middleware.Logger) mux.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "hello from lightmux") @@ -71,34 +69,29 @@ There are two ways to attach middleware: chain-style with `Use`, or per-route as ```go mux := lightmux.New() -mux.Use(middleware.Recoverer, middleware.Logger) // applies to every route registered after +mux.Use(authMiddleware) // applies to every route registered after -mux.Get("/admin", adminHandler, requireAdmin) // requireAdmin only on this route +mux.Get("/admin", adminHandler, requireAdmin) // requireAdmin only on this route ``` Order is outer → inner: `Use`-order first, then per-route mws. The outermost middleware runs first on the request and last on the response. Middleware values are plain `func(http.Handler) http.Handler`, so any stdlib-compatible middleware works without an adapter. -## Built-in middleware +An opinionated set of middlewares (request logging, panic recovery, real-IP resolution) lives in a sibling module — see [lightmux-contrib](https://git.juancwu.dev/juancwu/lightmux-contrib): -The `pkg/middleware` package ships: - -- **`Logger`** — emits a structured `http.request` record (method, path, status, duration, client) via [splinter](https://git.juancwu.dev/juancwu/splinter)'s default logger. The `client` field is `r.RemoteAddr`, so pairing with `RealIP` makes it the resolved client IP. -- **`LoggerWith(*splinter.Logger)`** — same, but routes records through the supplied splinter logger instead of the default. -- **`Recoverer`** — catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `middleware.Recoverer`, logs it with the stack, and writes a 500 response. -- **`RealIP`** — replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order). Always trusts these headers — only register it when the service sits behind a trusted proxy. -- **`RealIPWith(trusted ...netip.Prefix)`** — same, but only honors the headers when the immediate peer's IP falls within one of the trusted prefixes. Requests from outside the allowlist pass through untouched. - -```go -custom := splinter.New(splinter.WithStream(...)) -mux.Use(middleware.Recoverer, middleware.LoggerWith(custom)) +```sh +go get git.juancwu.dev/juancwu/lightmux-contrib ``` -When using `RealIP` together with `Logger`, register `RealIP` first so the logged `client` field is the resolved client IP rather than the proxy's peer address: - ```go -mux.Use(middleware.RealIP, middleware.Logger) +import ( + "git.juancwu.dev/juancwu/lightmux-contrib/realip" + "git.juancwu.dev/juancwu/lightmux-contrib/recoverer" + "git.juancwu.dev/juancwu/lightmux-contrib/requestlog" +) + +mux.Use(recoverer.New(), realip.New(), requestlog.New(nil)) ``` ## Path parameters @@ -114,5 +107,5 @@ mux.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) { ## Notes - `mux.Get("/", h)` registers `"GET /"`, which is a stdlib **subtree** pattern — it matches every unmatched path. Use `"/{$}"` to match only the literal root. -- Bad route registrations (invalid prefix, conflicting wildcards) panic at startup, matching stdlib `http.ServeMux` behavior. `Recoverer` handles panics that occur *inside* request handlers. +- Bad route registrations (invalid prefix, conflicting wildcards) panic at startup, matching stdlib `http.ServeMux` behavior. The `recoverer` middleware in [lightmux-contrib](https://git.juancwu.dev/juancwu/lightmux-contrib) handles panics that occur *inside* request handlers. - Group prefixes apply to the path only — they never inject a host. Routes can still carry an explicit host via `Handle("GET host.com/path", h)`. diff --git a/go.mod b/go.mod index b5e1214..1a7ff78 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module git.juancwu.dev/juancwu/lightmux go 1.26.2 -require ( - git.juancwu.dev/juancwu/errx v0.1.0 - git.juancwu.dev/juancwu/splinter v0.1.0 -) +require git.juancwu.dev/juancwu/errx v0.1.0 diff --git a/go.sum b/go.sum index 207ed5b..871cea8 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w= git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk= -git.juancwu.dev/juancwu/splinter v0.1.0 h1:ZGvvzyi24hZw/yFAwpUsHtj+q+fh9I2KIGmOAILWD5Q= -git.juancwu.dev/juancwu/splinter v0.1.0/go.mod h1:dAYsRQfS6tqWynEGz8xMCtIJUN7+KIp3jLE7kgO3yKE= diff --git a/lightmux.go b/lightmux.go index 18fe1ed..9deb2bd 100644 --- a/lightmux.go +++ b/lightmux.go @@ -2,14 +2,11 @@ // adding method-named convenience methods, groups, and per-route middleware. package lightmux -import ( - "git.juancwu.dev/juancwu/lightmux/pkg/middleware" - "git.juancwu.dev/juancwu/lightmux/pkg/router" -) +import "git.juancwu.dev/juancwu/lightmux/pkg/router" type ( Mux = router.Mux - Middleware = middleware.Middleware + Middleware = router.Middleware ) func New() *Mux { return router.New() } diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go deleted file mode 100644 index 4d5e754..0000000 --- a/pkg/middleware/logger.go +++ /dev/null @@ -1,68 +0,0 @@ -package middleware - -import ( - "net/http" - "time" - - "git.juancwu.dev/juancwu/splinter" -) - -type statusRecorder struct { - http.ResponseWriter - status int - wrote bool -} - -func (s *statusRecorder) WriteHeader(code int) { - if !s.wrote { - s.status = code - s.wrote = true - } - s.ResponseWriter.WriteHeader(code) -} - -func (s *statusRecorder) Write(b []byte) (int, error) { - if !s.wrote { - s.status = http.StatusOK - s.wrote = true - } - return s.ResponseWriter.Write(b) -} - -// Logger uses splinter.Default() resolved at request time. -func Logger(next http.Handler) http.Handler { - return loggerHandler(nil, next) -} - -// LoggerWith returns a Logger middleware backed by the given splinter logger. -// Pass nil to fall back to splinter.Default() (equivalent to Logger). -func LoggerWith(l *splinter.Logger) Middleware { - return func(next http.Handler) http.Handler { - return loggerHandler(l, next) - } -} - -func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} - next.ServeHTTP(rec, r) - if l == nil { - splinter.Info("http.request", - "method", r.Method, - "path", r.URL.Path, - "status", rec.status, - "duration", time.Since(start), - "client", r.RemoteAddr, - ) - return - } - l.Info("http.request", - "method", r.Method, - "path", r.URL.Path, - "status", rec.status, - "duration", time.Since(start), - "client", r.RemoteAddr, - ) - }) -} diff --git a/pkg/middleware/logger_test.go b/pkg/middleware/logger_test.go deleted file mode 100644 index a35b3d8..0000000 --- a/pkg/middleware/logger_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package middleware - -import ( - "bytes" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "git.juancwu.dev/juancwu/splinter" -) - -func captureSplinter(t *testing.T) *bytes.Buffer { - t.Helper() - var buf bytes.Buffer - logger := splinter.New(splinter.WithStream(splinter.NewConsoleStream( - splinter.ConsoleJSON, - splinter.LevelDebug, - splinter.ConsoleWriter(&buf), - ))) - prev := splinter.SetDefault(logger) - t.Cleanup(func() { splinter.SetDefault(prev) }) - return &buf -} - -func TestLogger(t *testing.T) { - buf := captureSplinter(t) - - h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTeapot) - })) - - rr := httptest.NewRecorder() - h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil)) - - if rr.Code != http.StatusTeapot { - t.Errorf("status code = %d, want 418", rr.Code) - } - out := buf.String() - for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} { - if !strings.Contains(out, want) { - t.Errorf("log output missing %s\nfull output: %s", want, out) - } - } -} - -func TestLoggerDefaultStatusOK(t *testing.T) { - buf := captureSplinter(t) - - h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hi")) - })) - h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) - - if !strings.Contains(buf.String(), `"status":200`) { - t.Errorf("expected default status 200 in log, got %q", buf.String()) - } -} - -func TestLoggerWith(t *testing.T) { - defaultBuf := captureSplinter(t) - - var customBuf bytes.Buffer - custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream( - splinter.ConsoleJSON, - splinter.LevelDebug, - splinter.ConsoleWriter(&customBuf), - ))) - - h := LoggerWith(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusAccepted) - })) - h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil)) - - if !strings.Contains(customBuf.String(), `"path":"/x"`) { - t.Errorf("custom logger did not receive record: %q", customBuf.String()) - } - if defaultBuf.Len() != 0 { - t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String()) - } -} - -func TestLoggerWithNilFallsBackToDefault(t *testing.T) { - buf := captureSplinter(t) - - h := LoggerWith(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil)) - - if !strings.Contains(buf.String(), `"path":"/y"`) { - t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String()) - } -} diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go deleted file mode 100644 index f54cea9..0000000 --- a/pkg/middleware/middleware.go +++ /dev/null @@ -1,5 +0,0 @@ -package middleware - -import "net/http" - -type Middleware = func(http.Handler) http.Handler diff --git a/pkg/middleware/realip.go b/pkg/middleware/realip.go deleted file mode 100644 index a0d323c..0000000 --- a/pkg/middleware/realip.go +++ /dev/null @@ -1,86 +0,0 @@ -package middleware - -import ( - "net" - "net/http" - "net/netip" - "strings" -) - -var realIPHeaders = []string{ - "CF-Connecting-IP", - "True-Client-IP", - "X-Real-IP", - "X-Forwarded-For", -} - -// RealIP rewrites r.RemoteAddr with the originating client IP found in common -// reverse-proxy headers (Cloudflare, nginx). It always trusts these headers — -// only register it when the service sits behind a trusted proxy. Use -// RealIPWith for an allowlist-gated variant. -func RealIP(next http.Handler) http.Handler { - return realIPHandler(nil, next) -} - -// RealIPWith returns a RealIP middleware that only honors the proxy headers -// when the immediate peer (parsed from r.RemoteAddr) falls within one of the -// trusted prefixes. Requests from outside the allowlist are passed through -// untouched. -func RealIPWith(trusted ...netip.Prefix) Middleware { - if trusted == nil { - trusted = []netip.Prefix{} - } - return func(next http.Handler) http.Handler { - return realIPHandler(trusted, next) - } -} - -func realIPHandler(trusted []netip.Prefix, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if trusted != nil && !peerTrusted(r.RemoteAddr, trusted) { - next.ServeHTTP(w, r) - return - } - if ip := extractRealIP(r); ip != "" { - r2 := *r - r2.RemoteAddr = ip - next.ServeHTTP(w, &r2) - return - } - next.ServeHTTP(w, r) - }) -} - -func extractRealIP(r *http.Request) string { - for _, h := range realIPHeaders { - v := r.Header.Get(h) - if v == "" { - continue - } - if i := strings.IndexByte(v, ','); i >= 0 { - v = v[:i] - } - v = strings.TrimSpace(v) - if net.ParseIP(v) != nil { - return v - } - } - return "" -} - -func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool { - var peer netip.Addr - if ap, err := netip.ParseAddrPort(remoteAddr); err == nil { - peer = ap.Addr() - } else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil { - peer = a - } else { - return false - } - for _, p := range trusted { - if p.Contains(peer) { - return true - } - } - return false -} diff --git a/pkg/middleware/realip_test.go b/pkg/middleware/realip_test.go deleted file mode 100644 index 39897e0..0000000 --- a/pkg/middleware/realip_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "net/netip" - "testing" -) - -const defaultTestRemoteAddr = "192.0.2.1:1234" - -func captureRemoteAddr(got *string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - *got = r.RemoteAddr - }) -} - -func TestRealIP(t *testing.T) { - cases := []struct { - name string - headers map[string]string - want string - }{ - { - name: "no headers", - want: defaultTestRemoteAddr, - }, - { - name: "CF-Connecting-IP", - headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"}, - want: "203.0.113.5", - }, - { - name: "True-Client-IP", - headers: map[string]string{"True-Client-IP": "203.0.113.6"}, - want: "203.0.113.6", - }, - { - name: "X-Real-IP", - headers: map[string]string{"X-Real-IP": "203.0.113.7"}, - want: "203.0.113.7", - }, - { - name: "X-Forwarded-For single", - headers: map[string]string{"X-Forwarded-For": "203.0.113.8"}, - want: "203.0.113.8", - }, - { - name: "X-Forwarded-For list", - headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"}, - want: "203.0.113.9", - }, - { - name: "X-Forwarded-For with spaces", - headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"}, - want: "203.0.113.10", - }, - { - name: "precedence CF over XFF", - headers: map[string]string{ - "CF-Connecting-IP": "203.0.113.11", - "X-Forwarded-For": "198.51.100.1", - }, - want: "203.0.113.11", - }, - { - name: "invalid then valid", - headers: map[string]string{ - "CF-Connecting-IP": "not-an-ip", - "X-Real-IP": "203.0.113.12", - }, - want: "203.0.113.12", - }, - { - name: "IPv6", - headers: map[string]string{"X-Real-IP": "2001:db8::1"}, - want: "2001:db8::1", - }, - { - name: "all invalid falls through", - headers: map[string]string{"CF-Connecting-IP": "garbage"}, - want: defaultTestRemoteAddr, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - var got string - h := RealIP(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - for k, v := range tc.headers { - req.Header.Set(k, v) - } - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != tc.want { - t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want) - } - }) - } -} - -func TestRealIPDoesNotMutateCallerRequest(t *testing.T) { - var seen string - h := RealIP(captureRemoteAddr(&seen)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-Real-IP", "203.0.113.13") - original := req.RemoteAddr - - h.ServeHTTP(httptest.NewRecorder(), req) - - if seen != "203.0.113.13" { - t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13") - } - if req.RemoteAddr != original { - t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original) - } -} - -func TestRealIPWithTrustedPeer(t *testing.T) { - var got string - h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "10.1.2.3:55555" - req.Header.Set("X-Real-IP", "203.0.113.20") - - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != "203.0.113.20" { - t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20") - } -} - -func TestRealIPWithUntrustedPeer(t *testing.T) { - var got string - h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "203.0.113.99:55555" - req.Header.Set("X-Real-IP", "198.51.100.1") - - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != "203.0.113.99:55555" { - t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555") - } -} - -func TestRealIPWithMultiplePrefixesIPv6(t *testing.T) { - var got string - h := RealIPWith( - netip.MustParsePrefix("10.0.0.0/8"), - netip.MustParsePrefix("2001:db8::/32"), - )(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "[2001:db8::abcd]:55555" - req.Header.Set("X-Real-IP", "203.0.113.30") - - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != "203.0.113.30" { - t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30") - } -} - -func TestRealIPWithZeroPrefixesNoOp(t *testing.T) { - var got string - h := RealIPWith()(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("CF-Connecting-IP", "203.0.113.40") - original := req.RemoteAddr - - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != original { - t.Errorf("RealIPWith() with no prefixes should be a no-op: r.RemoteAddr = %q, want %q", got, original) - } -} - -func TestRealIPWithUnparseableRemoteAddr(t *testing.T) { - var got string - h := RealIPWith(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got)) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.RemoteAddr = "not-an-addr" - req.Header.Set("CF-Connecting-IP", "203.0.113.50") - - h.ServeHTTP(httptest.NewRecorder(), req) - - if got != "not-an-addr" { - t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got) - } -} diff --git a/pkg/middleware/recoverer.go b/pkg/middleware/recoverer.go deleted file mode 100644 index 37a3262..0000000 --- a/pkg/middleware/recoverer.go +++ /dev/null @@ -1,31 +0,0 @@ -package middleware - -import ( - "log" - "net/http" - "runtime/debug" - - "git.juancwu.dev/juancwu/errx" -) - -const recovererOp = "middleware.Recoverer" - -func Recoverer(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - rec := recover() - if rec == nil { - return - } - var err error - if e, ok := rec.(error); ok { - err = errx.Wrap(recovererOp, e) - } else { - err = errx.Newf(recovererOp, "panic: %v", rec) - } - log.Printf("%v\n%s", err, debug.Stack()) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - }() - next.ServeHTTP(w, r) - }) -} diff --git a/pkg/middleware/recoverer_test.go b/pkg/middleware/recoverer_test.go deleted file mode 100644 index 471bf5c..0000000 --- a/pkg/middleware/recoverer_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package middleware - -import ( - "bytes" - "errors" - "log" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func captureLog(t *testing.T) *bytes.Buffer { - t.Helper() - var buf bytes.Buffer - orig := log.Default().Writer() - log.Default().SetOutput(&buf) - t.Cleanup(func() { log.Default().SetOutput(orig) }) - return &buf -} - -func TestRecovererCatchesStringPanic(t *testing.T) { - buf := captureLog(t) - - h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic("boom") - })) - - rr := httptest.NewRecorder() - h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) - - if rr.Code != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", rr.Code) - } - out := buf.String() - for _, want := range []string{"middleware.Recoverer", "panic: boom"} { - if !strings.Contains(out, want) { - t.Errorf("log missing %q\nfull: %s", want, out) - } - } -} - -func TestRecovererWrapsErrorPanic(t *testing.T) { - buf := captureLog(t) - - cause := errors.New("db down") - h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - panic(cause) - })) - - h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) - - out := buf.String() - if !strings.Contains(out, "middleware.Recoverer: db down") { - t.Errorf("expected errx-wrapped breadcrumb, got: %s", out) - } -} - -func TestRecovererPassesThrough(t *testing.T) { - called := false - h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - })) - rr := httptest.NewRecorder() - h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) - if !called || rr.Code != http.StatusOK { - t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code) - } -} diff --git a/pkg/router/chain.go b/pkg/router/chain.go index 54eedea..052e61c 100644 --- a/pkg/router/chain.go +++ b/pkg/router/chain.go @@ -1,15 +1,11 @@ package router -import ( - "net/http" - - "git.juancwu.dev/juancwu/lightmux/pkg/middleware" -) +import "net/http" // chain wraps h with groupMws followed by routeMws so that groupMws[0] is the // outermost layer (runs first on request, last on response). -func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler { - all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws)) +func chain(h http.Handler, groupMws, routeMws []Middleware) http.Handler { + all := make([]Middleware, 0, len(groupMws)+len(routeMws)) all = append(all, groupMws...) all = append(all, routeMws...) for i := len(all) - 1; i >= 0; i-- { diff --git a/pkg/router/chain_test.go b/pkg/router/chain_test.go index 3ee5001..6259f39 100644 --- a/pkg/router/chain_test.go +++ b/pkg/router/chain_test.go @@ -5,11 +5,9 @@ import ( "net/http/httptest" "strings" "testing" - - "git.juancwu.dev/juancwu/lightmux/pkg/middleware" ) -func tagMW(log *[]string, tag string) middleware.Middleware { +func tagMW(log *[]string, tag string) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { *log = append(*log, tag+":before") @@ -26,8 +24,8 @@ func TestChainOrder(t *testing.T) { }) wrapped := chain(h, - []middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")}, - []middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")}, + []Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")}, + []Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")}, ) req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/pkg/router/middleware.go b/pkg/router/middleware.go new file mode 100644 index 0000000..b0d951a --- /dev/null +++ b/pkg/router/middleware.go @@ -0,0 +1,7 @@ +package router + +import "net/http" + +// Middleware is the standard func(http.Handler) http.Handler middleware shape, +// compatible with any stdlib-style middleware. +type Middleware = func(http.Handler) http.Handler diff --git a/pkg/router/mux.go b/pkg/router/mux.go index 172dbe4..9acde2b 100644 --- a/pkg/router/mux.go +++ b/pkg/router/mux.go @@ -4,16 +4,12 @@ // underlying mux while carrying their own prefix and middleware stack. package router -import ( - "net/http" - - "git.juancwu.dev/juancwu/lightmux/pkg/middleware" -) +import "net/http" type Mux struct { root *http.ServeMux prefix string - middlewares []middleware.Middleware + middlewares []Middleware } func New() *Mux { @@ -21,16 +17,16 @@ func New() *Mux { return &Mux{root: sm} } -func (m *Mux) Use(mws ...middleware.Middleware) { +func (m *Mux) Use(mws ...Middleware) { m.middlewares = append(m.middlewares, mws...) } // Group returns a child Mux that registers on the same underlying ServeMux but // with its prefix appended and the parent's current middlewares snapshotted. // Use() calls made on the parent after Group() do not propagate to the child. -func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux { +func (m *Mux) Group(prefix string, mws ...Middleware) *Mux { validateGroupPrefix(prefix) - mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws)) + mwsCopy := make([]Middleware, 0, len(m.middlewares)+len(mws)) mwsCopy = append(mwsCopy, m.middlewares...) mwsCopy = append(mwsCopy, mws...) return &Mux{ @@ -40,45 +36,45 @@ func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux { } } -func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) { +func (m *Mux) Handle(pattern string, h http.Handler, mws ...Middleware) { full := buildPattern("", m.prefix, pattern) m.root.Handle(full, chain(h, m.middlewares, mws)) } -func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...Middleware) { m.Handle(pattern, fn, mws...) } -func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) { +func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []Middleware) { full := buildPattern(method, m.prefix, path) m.root.Handle(full, chain(fn, m.middlewares, mws)) } -func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodGet, path, fn, mws) } -func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodPost, path, fn, mws) } -func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodPut, path, fn, mws) } -func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodPatch, path, fn, mws) } -func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodDelete, path, fn, mws) } -func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodOptions, path, fn, mws) } -func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { +func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...Middleware) { m.method(http.MethodHead, path, fn, mws) }