initial implementation
This commit is contained in:
parent
f8b0abc517
commit
cb373e637b
16 changed files with 777 additions and 1 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -24,4 +24,5 @@ go.work.sum
|
||||||
|
|
||||||
# env file
|
# env file
|
||||||
.env
|
.env
|
||||||
|
.env.*
|
||||||
|
!.env.example
|
||||||
|
|
|
||||||
43
Taskfile.yml
Normal file
43
Taskfile.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
version: '3'
|
||||||
|
|
||||||
|
tasks:
|
||||||
|
default:
|
||||||
|
desc: Run vet and tests
|
||||||
|
cmds:
|
||||||
|
- task: check
|
||||||
|
|
||||||
|
install:tools:
|
||||||
|
desc: Install development tools (tparse for prettier test output)
|
||||||
|
cmds:
|
||||||
|
- go install github.com/mfridman/tparse@latest
|
||||||
|
|
||||||
|
test:
|
||||||
|
desc: Run tests with prettier output via tparse
|
||||||
|
cmds:
|
||||||
|
- set -o pipefail && go test ./... -json -cover | tparse -all
|
||||||
|
|
||||||
|
test:race:
|
||||||
|
desc: Run tests under the race detector
|
||||||
|
cmds:
|
||||||
|
- set -o pipefail && go test ./... -race -json | tparse -all
|
||||||
|
|
||||||
|
vet:
|
||||||
|
desc: Run go vet
|
||||||
|
cmds:
|
||||||
|
- go vet ./...
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
desc: Format Go source files
|
||||||
|
cmds:
|
||||||
|
- gofmt -w .
|
||||||
|
|
||||||
|
tidy:
|
||||||
|
desc: Tidy go.mod
|
||||||
|
cmds:
|
||||||
|
- go mod tidy
|
||||||
|
|
||||||
|
check:
|
||||||
|
desc: Run vet and tests
|
||||||
|
cmds:
|
||||||
|
- task: vet
|
||||||
|
- task: test
|
||||||
46
examples/basic/main.go
Normal file
46
examples/basic/main.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/lightmux"
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
mux := lightmux.New()
|
||||||
|
mux.Use(middleware.Recoverer, middleware.Logger)
|
||||||
|
|
||||||
|
mux.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "hello from lightmux")
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, err := strconv.Atoi(r.PathValue("id"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "bad id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "user %d\n", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("demonstrating Recoverer")
|
||||||
|
})
|
||||||
|
|
||||||
|
api := mux.Group("/api")
|
||||||
|
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "pong")
|
||||||
|
})
|
||||||
|
|
||||||
|
v1 := api.Group("/v1")
|
||||||
|
v1.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintf(w, "item: %s\n", r.PathValue("name"))
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Println("listening on :8080")
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", mux))
|
||||||
|
}
|
||||||
3
go.mod
Normal file
3
go.mod
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
module git.juancwu.dev/juancwu/lightmux
|
||||||
|
|
||||||
|
go 1.26.2
|
||||||
15
lightmux.go
Normal file
15
lightmux.go
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
// Package lightmux is a small wrapper around the Go 1.22+ net/http ServeMux
|
||||||
|
// adding method-named convenience methods, groups, and per-route middleware.
|
||||||
|
package lightmux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Mux = router.Mux
|
||||||
|
Middleware = middleware.Middleware
|
||||||
|
)
|
||||||
|
|
||||||
|
func New() *Mux { return router.New() }
|
||||||
38
pkg/middleware/logger.go
Normal file
38
pkg/middleware/logger.go
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type statusRecorder struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
wrote bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusRecorder) WriteHeader(code int) {
|
||||||
|
if !s.wrote {
|
||||||
|
s.status = code
|
||||||
|
s.wrote = true
|
||||||
|
}
|
||||||
|
s.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusRecorder) Write(b []byte) (int, error) {
|
||||||
|
if !s.wrote {
|
||||||
|
s.status = http.StatusOK
|
||||||
|
s.wrote = true
|
||||||
|
}
|
||||||
|
return s.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Logger(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
next.ServeHTTP(rec, r)
|
||||||
|
log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.status, time.Since(start))
|
||||||
|
})
|
||||||
|
}
|
||||||
48
pkg/middleware/logger_test.go
Normal file
48
pkg/middleware/logger_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogger(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
orig := log.Default().Writer()
|
||||||
|
log.Default().SetOutput(&buf)
|
||||||
|
defer log.Default().SetOutput(orig)
|
||||||
|
|
||||||
|
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusTeapot)
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
||||||
|
|
||||||
|
if rr.Code != http.StatusTeapot {
|
||||||
|
t.Errorf("status code = %d, want 418", rr.Code)
|
||||||
|
}
|
||||||
|
out := buf.String()
|
||||||
|
if !strings.Contains(out, "GET /foo 418") {
|
||||||
|
t.Errorf("log output missing expected fields: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerDefaultStatusOK(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
orig := log.Default().Writer()
|
||||||
|
log.Default().SetOutput(&buf)
|
||||||
|
defer log.Default().SetOutput(orig)
|
||||||
|
|
||||||
|
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("hi"))
|
||||||
|
}))
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "200") {
|
||||||
|
t.Errorf("expected default 200 in log, got %q", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
5
pkg/middleware/middleware.go
Normal file
5
pkg/middleware/middleware.go
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type Middleware = func(http.Handler) http.Handler
|
||||||
19
pkg/middleware/recoverer.go
Normal file
19
pkg/middleware/recoverer.go
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Recoverer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Printf("panic: %v\n%s", rec, debug.Stack())
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
44
pkg/middleware/recoverer_test.go
Normal file
44
pkg/middleware/recoverer_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecovererCatchesPanic(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
orig := log.Default().Writer()
|
||||||
|
log.Default().SetOutput(&buf)
|
||||||
|
defer log.Default().SetOutput(orig)
|
||||||
|
|
||||||
|
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("boom")
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
if rr.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("status = %d, want 500", rr.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "panic: boom") {
|
||||||
|
t.Errorf("expected panic log, got %q", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecovererPassesThrough(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
if !called || rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
19
pkg/router/chain.go
Normal file
19
pkg/router/chain.go
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
|
||||||
|
// outermost layer (runs first on request, last on response).
|
||||||
|
func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler {
|
||||||
|
all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws))
|
||||||
|
all = append(all, groupMws...)
|
||||||
|
all = append(all, routeMws...)
|
||||||
|
for i := len(all) - 1; i >= 0; i-- {
|
||||||
|
h = all[i](h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
50
pkg/router/chain_test.go
Normal file
50
pkg/router/chain_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tagMW(log *[]string, tag string) middleware.Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
*log = append(*log, tag+":before")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
*log = append(*log, tag+":after")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainOrder(t *testing.T) {
|
||||||
|
var log []string
|
||||||
|
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log = append(log, "handler")
|
||||||
|
})
|
||||||
|
|
||||||
|
wrapped := chain(h,
|
||||||
|
[]middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
|
||||||
|
[]middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
|
||||||
|
)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
wrapped.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
|
want := "g1:before,g2:before,r1:before,r2:before,handler,r2:after,r1:after,g2:after,g1:after"
|
||||||
|
if got := strings.Join(log, ","); got != want {
|
||||||
|
t.Errorf("order:\n got %s\nwant %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainNoMiddlewares(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
|
||||||
|
wrapped := chain(h, nil, nil)
|
||||||
|
wrapped.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
if !called {
|
||||||
|
t.Fatal("handler not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
87
pkg/router/mux.go
Normal file
87
pkg/router/mux.go
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
// Package router provides a thin, idiomatic wrapper around the Go 1.22+
|
||||||
|
// net/http ServeMux. It adds method-named convenience methods (Get, Post,
|
||||||
|
// ...), per-route and group middleware, and Group sub-routers that share the
|
||||||
|
// underlying mux while carrying their own prefix and middleware stack.
|
||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mux struct {
|
||||||
|
root *http.ServeMux
|
||||||
|
prefix string
|
||||||
|
middlewares []middleware.Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Mux {
|
||||||
|
sm := http.NewServeMux()
|
||||||
|
return &Mux{root: sm}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Use(mws ...middleware.Middleware) {
|
||||||
|
m.middlewares = append(m.middlewares, mws...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group returns a child Mux that registers on the same underlying ServeMux but
|
||||||
|
// with its prefix appended and the parent's current middlewares snapshotted.
|
||||||
|
// Use() calls made on the parent after Group() do not propagate to the child.
|
||||||
|
func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
|
||||||
|
validateGroupPrefix(prefix)
|
||||||
|
mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws))
|
||||||
|
mwsCopy = append(mwsCopy, m.middlewares...)
|
||||||
|
mwsCopy = append(mwsCopy, mws...)
|
||||||
|
return &Mux{
|
||||||
|
root: m.root,
|
||||||
|
prefix: m.prefix + normalizeGroupPrefix(prefix),
|
||||||
|
middlewares: mwsCopy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) {
|
||||||
|
full := buildPattern("", m.prefix, pattern)
|
||||||
|
m.root.Handle(full, chain(h, m.middlewares, mws))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.Handle(pattern, fn, mws...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) {
|
||||||
|
full := buildPattern(method, m.prefix, path)
|
||||||
|
m.root.Handle(full, chain(fn, m.middlewares, mws))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodGet, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodPost, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodPut, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodPatch, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodDelete, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodOptions, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||||
|
m.method(http.MethodHead, path, fn, mws)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
m.root.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
169
pkg/router/mux_test.go
Normal file
169
pkg/router/mux_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
m.ServeHTTP(rr, httptest.NewRequest(method, target, nil))
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodRouting(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") })
|
||||||
|
m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") })
|
||||||
|
|
||||||
|
if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" {
|
||||||
|
t.Errorf("GET /x = %q", got)
|
||||||
|
}
|
||||||
|
if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" {
|
||||||
|
t.Errorf("POST /x = %q", got)
|
||||||
|
}
|
||||||
|
if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("PUT /x got %d, want 405", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathValueAcrossMiddleware(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
m.Use(func(next http.Handler) http.Handler { return next })
|
||||||
|
var got string
|
||||||
|
m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
got = r.PathValue("id")
|
||||||
|
})
|
||||||
|
do(t, m, http.MethodGet, "/users/42")
|
||||||
|
if got != "42" {
|
||||||
|
t.Errorf("PathValue = %q, want 42", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupPrefix(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
api := m.Group("/api")
|
||||||
|
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||||
|
|
||||||
|
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||||
|
t.Errorf("/api/ping = %q", got)
|
||||||
|
}
|
||||||
|
if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("unprefixed /ping got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupNested(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
api := m.Group("/api")
|
||||||
|
v1 := api.Group("/v1")
|
||||||
|
v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||||
|
|
||||||
|
if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" {
|
||||||
|
t.Errorf("/api/v1/ping = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupTrailingSlashNormalized(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
api := m.Group("/api/")
|
||||||
|
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||||
|
|
||||||
|
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||||
|
t.Errorf("/api/ping = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupRootSubtree(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
api := m.Group("/api")
|
||||||
|
api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") })
|
||||||
|
|
||||||
|
if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" {
|
||||||
|
t.Errorf("/api/ = %q", got)
|
||||||
|
}
|
||||||
|
if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" {
|
||||||
|
t.Errorf("/api/anything = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseAfterGroupDoesNotPropagate(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
api := m.Group("/api")
|
||||||
|
hits := 0
|
||||||
|
m.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits++
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
api.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
m.Get("/y", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
do(t, m, http.MethodGet, "/api/x")
|
||||||
|
if hits != 0 {
|
||||||
|
t.Errorf("late Use propagated to api group, hits=%d", hits)
|
||||||
|
}
|
||||||
|
do(t, m, http.MethodGet, "/y")
|
||||||
|
if hits != 1 {
|
||||||
|
t.Errorf("late Use did not apply to root, hits=%d", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPerRouteAndGroupMiddlewareStack(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
var log []string
|
||||||
|
m.Use(tagMW(&log, "use"))
|
||||||
|
api := m.Group("/api", tagMW(&log, "group"))
|
||||||
|
api.Get("/x",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") },
|
||||||
|
tagMW(&log, "route"),
|
||||||
|
)
|
||||||
|
|
||||||
|
do(t, m, http.MethodGet, "/api/x")
|
||||||
|
want := "use:before,group:before,route:before,h,route:after,group:after,use:after"
|
||||||
|
if got := strings.Join(log, ","); got != want {
|
||||||
|
t.Errorf("\n got %s\nwant %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleAcceptsBarePath(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, "any")
|
||||||
|
}))
|
||||||
|
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
|
||||||
|
if got := do(t, m, method, "/x").Body.String(); got != "any" {
|
||||||
|
t.Errorf("%s /x = %q", method, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleAcceptsMethodPattern(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, "del")
|
||||||
|
}))
|
||||||
|
if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" {
|
||||||
|
t.Errorf("DELETE /x = %q", got)
|
||||||
|
}
|
||||||
|
if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("GET /x got %d, want 405", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConflictingPatternsPanic(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("expected panic on duplicate pattern")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
}
|
||||||
|
|
||||||
71
pkg/router/pattern.go
Normal file
71
pkg/router/pattern.go
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func splitPattern(pattern string) (method, host, path string) {
|
||||||
|
rest := pattern
|
||||||
|
if i := strings.Index(pattern, " "); i >= 0 {
|
||||||
|
method = pattern[:i]
|
||||||
|
rest = strings.TrimLeft(pattern[i+1:], " ")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(rest, "/") {
|
||||||
|
return method, "", rest
|
||||||
|
}
|
||||||
|
if i := strings.Index(rest, "/"); i >= 0 {
|
||||||
|
return method, rest[:i], rest[i:]
|
||||||
|
}
|
||||||
|
return method, rest, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateGroupPrefix(p string) {
|
||||||
|
if p == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(p, " \t") {
|
||||||
|
panic("lightmux: group prefix must not contain whitespace (no method or host allowed): " + p)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(p, "/") {
|
||||||
|
panic("lightmux: group prefix must start with '/': " + p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGroupPrefix(p string) string {
|
||||||
|
if p == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(p) > 1 && strings.HasSuffix(p, "/") {
|
||||||
|
return p[:len(p)-1]
|
||||||
|
}
|
||||||
|
if p == "/" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinPath(prefix, sub string) string {
|
||||||
|
if prefix == "" {
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
if sub == "" {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
if sub == "/" {
|
||||||
|
return prefix + "/"
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(sub, "/") {
|
||||||
|
panic("lightmux: route path must start with '/': " + sub)
|
||||||
|
}
|
||||||
|
return prefix + sub
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPattern(method, prefix, pattern string) string {
|
||||||
|
m, host, path := splitPattern(pattern)
|
||||||
|
if method != "" {
|
||||||
|
m = method
|
||||||
|
}
|
||||||
|
full := host + joinPath(prefix, path)
|
||||||
|
if m != "" {
|
||||||
|
return m + " " + full
|
||||||
|
}
|
||||||
|
return full
|
||||||
|
}
|
||||||
118
pkg/router/pattern_test.go
Normal file
118
pkg/router/pattern_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
package router
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSplitPattern(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
method, host, path string
|
||||||
|
}{
|
||||||
|
{"/foo", "", "", "/foo"},
|
||||||
|
{"GET /foo", "GET", "", "/foo"},
|
||||||
|
{"POST /users/{id}", "POST", "", "/users/{id}"},
|
||||||
|
{"GET example.com/foo", "GET", "example.com", "/foo"},
|
||||||
|
{"example.com/foo", "", "example.com", "/foo"},
|
||||||
|
{"GET /", "GET", "", "/"},
|
||||||
|
{"/", "", "", "/"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.in, func(t *testing.T) {
|
||||||
|
m, h, p := splitPattern(c.in)
|
||||||
|
if m != c.method || h != c.host || p != c.path {
|
||||||
|
t.Fatalf("splitPattern(%q) = (%q, %q, %q), want (%q, %q, %q)",
|
||||||
|
c.in, m, h, p, c.method, c.host, c.path)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeGroupPrefix(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"": "",
|
||||||
|
"/": "",
|
||||||
|
"/api": "/api",
|
||||||
|
"/api/": "/api",
|
||||||
|
"/api/v1": "/api/v1",
|
||||||
|
"/api/v1/": "/api/v1",
|
||||||
|
}
|
||||||
|
for in, want := range cases {
|
||||||
|
if got := normalizeGroupPrefix(in); got != want {
|
||||||
|
t.Errorf("normalizeGroupPrefix(%q) = %q, want %q", in, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinPath(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
prefix, sub, want string
|
||||||
|
}{
|
||||||
|
{"", "/foo", "/foo"},
|
||||||
|
{"/api", "/foo", "/api/foo"},
|
||||||
|
{"/api", "", "/api"},
|
||||||
|
{"/api", "/", "/api/"},
|
||||||
|
{"", "", ""},
|
||||||
|
{"/api/v1", "/users/{id}", "/api/v1/users/{id}"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
got := joinPath(c.prefix, c.sub)
|
||||||
|
if got != c.want {
|
||||||
|
t.Errorf("joinPath(%q, %q) = %q, want %q", c.prefix, c.sub, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinPathPanicsOnBadSub(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("expected panic on sub without leading /")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
joinPath("/api", "foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateGroupPrefix(t *testing.T) {
|
||||||
|
good := []string{"", "/api", "/api/v1"}
|
||||||
|
for _, p := range good {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("validateGroupPrefix(%q) unexpected panic: %v", p, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
validateGroupPrefix(p)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
bad := []string{"api", "GET /api", "/api with space", "host.com/api"}
|
||||||
|
for _, p := range bad {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("validateGroupPrefix(%q) expected panic", p)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
validateGroupPrefix(p)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPattern(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
method, prefix, pattern, want string
|
||||||
|
}{
|
||||||
|
{"", "", "/foo", "/foo"},
|
||||||
|
{"GET", "", "/foo", "GET /foo"},
|
||||||
|
{"GET", "/api", "/foo", "GET /api/foo"},
|
||||||
|
{"", "/api", "GET /foo", "GET /api/foo"},
|
||||||
|
{"", "/api", "GET example.com/foo", "GET example.com/api/foo"},
|
||||||
|
{"", "/api", "/", "/api/"},
|
||||||
|
{"GET", "/api", "/", "GET /api/"},
|
||||||
|
{"GET", "/api", "", "GET /api"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
got := buildPattern(c.method, c.prefix, c.pattern)
|
||||||
|
if got != c.want {
|
||||||
|
t.Errorf("buildPattern(%q, %q, %q) = %q, want %q",
|
||||||
|
c.method, c.prefix, c.pattern, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue