initial implementation
This commit is contained in:
parent
f8b0abc517
commit
cb373e637b
16 changed files with 777 additions and 1 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -24,4 +24,5 @@ go.work.sum
|
|||
|
||||
# env file
|
||||
.env
|
||||
|
||||
.env.*
|
||||
!.env.example
|
||||
|
|
|
|||
43
Taskfile.yml
Normal file
43
Taskfile.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
version: '3'
|
||||
|
||||
tasks:
|
||||
default:
|
||||
desc: Run vet and tests
|
||||
cmds:
|
||||
- task: check
|
||||
|
||||
install:tools:
|
||||
desc: Install development tools (tparse for prettier test output)
|
||||
cmds:
|
||||
- go install github.com/mfridman/tparse@latest
|
||||
|
||||
test:
|
||||
desc: Run tests with prettier output via tparse
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -json -cover | tparse -all
|
||||
|
||||
test:race:
|
||||
desc: Run tests under the race detector
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -race -json | tparse -all
|
||||
|
||||
vet:
|
||||
desc: Run go vet
|
||||
cmds:
|
||||
- go vet ./...
|
||||
|
||||
fmt:
|
||||
desc: Format Go source files
|
||||
cmds:
|
||||
- gofmt -w .
|
||||
|
||||
tidy:
|
||||
desc: Tidy go.mod
|
||||
cmds:
|
||||
- go mod tidy
|
||||
|
||||
check:
|
||||
desc: Run vet and tests
|
||||
cmds:
|
||||
- task: vet
|
||||
- task: test
|
||||
46
examples/basic/main.go
Normal file
46
examples/basic/main.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux"
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mux := lightmux.New()
|
||||
mux.Use(middleware.Recoverer, middleware.Logger)
|
||||
|
||||
mux.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "hello from lightmux")
|
||||
})
|
||||
|
||||
mux.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.Atoi(r.PathValue("id"))
|
||||
if err != nil {
|
||||
http.Error(w, "bad id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "user %d\n", id)
|
||||
})
|
||||
|
||||
mux.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("demonstrating Recoverer")
|
||||
})
|
||||
|
||||
api := mux.Group("/api")
|
||||
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "pong")
|
||||
})
|
||||
|
||||
v1 := api.Group("/v1")
|
||||
v1.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "item: %s\n", r.PathValue("name"))
|
||||
})
|
||||
|
||||
log.Println("listening on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", mux))
|
||||
}
|
||||
3
go.mod
Normal file
3
go.mod
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
module git.juancwu.dev/juancwu/lightmux
|
||||
|
||||
go 1.26.2
|
||||
15
lightmux.go
Normal file
15
lightmux.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Package lightmux is a small wrapper around the Go 1.22+ net/http ServeMux
|
||||
// adding method-named convenience methods, groups, and per-route middleware.
|
||||
package lightmux
|
||||
|
||||
import (
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/router"
|
||||
)
|
||||
|
||||
type (
|
||||
Mux = router.Mux
|
||||
Middleware = middleware.Middleware
|
||||
)
|
||||
|
||||
func New() *Mux { return router.New() }
|
||||
38
pkg/middleware/logger.go
Normal file
38
pkg/middleware/logger.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
wrote bool
|
||||
}
|
||||
|
||||
func (s *statusRecorder) WriteHeader(code int) {
|
||||
if !s.wrote {
|
||||
s.status = code
|
||||
s.wrote = true
|
||||
}
|
||||
s.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (s *statusRecorder) Write(b []byte) (int, error) {
|
||||
if !s.wrote {
|
||||
s.status = http.StatusOK
|
||||
s.wrote = true
|
||||
}
|
||||
return s.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.status, time.Since(start))
|
||||
})
|
||||
}
|
||||
48
pkg/middleware/logger_test.go
Normal file
48
pkg/middleware/logger_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
orig := log.Default().Writer()
|
||||
log.Default().SetOutput(&buf)
|
||||
defer log.Default().SetOutput(orig)
|
||||
|
||||
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
||||
|
||||
if rr.Code != http.StatusTeapot {
|
||||
t.Errorf("status code = %d, want 418", rr.Code)
|
||||
}
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "GET /foo 418") {
|
||||
t.Errorf("log output missing expected fields: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerDefaultStatusOK(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
orig := log.Default().Writer()
|
||||
log.Default().SetOutput(&buf)
|
||||
defer log.Default().SetOutput(orig)
|
||||
|
||||
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hi"))
|
||||
}))
|
||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
if !strings.Contains(buf.String(), "200") {
|
||||
t.Errorf("expected default 200 in log, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
5
pkg/middleware/middleware.go
Normal file
5
pkg/middleware/middleware.go
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Middleware = func(http.Handler) http.Handler
|
||||
19
pkg/middleware/recoverer.go
Normal file
19
pkg/middleware/recoverer.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func Recoverer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("panic: %v\n%s", rec, debug.Stack())
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
44
pkg/middleware/recoverer_test.go
Normal file
44
pkg/middleware/recoverer_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecovererCatchesPanic(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
orig := log.Default().Writer()
|
||||
log.Default().SetOutput(&buf)
|
||||
defer log.Default().SetOutput(orig)
|
||||
|
||||
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("boom")
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want 500", rr.Code)
|
||||
}
|
||||
if !strings.Contains(buf.String(), "panic: boom") {
|
||||
t.Errorf("expected panic log, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecovererPassesThrough(t *testing.T) {
|
||||
called := false
|
||||
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if !called || rr.Code != http.StatusOK {
|
||||
t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code)
|
||||
}
|
||||
}
|
||||
19
pkg/router/chain.go
Normal file
19
pkg/router/chain.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
|
||||
// outermost layer (runs first on request, last on response).
|
||||
func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler {
|
||||
all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws))
|
||||
all = append(all, groupMws...)
|
||||
all = append(all, routeMws...)
|
||||
for i := len(all) - 1; i >= 0; i-- {
|
||||
h = all[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
50
pkg/router/chain_test.go
Normal file
50
pkg/router/chain_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
func tagMW(log *[]string, tag string) middleware.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*log = append(*log, tag+":before")
|
||||
next.ServeHTTP(w, r)
|
||||
*log = append(*log, tag+":after")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainOrder(t *testing.T) {
|
||||
var log []string
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler")
|
||||
})
|
||||
|
||||
wrapped := chain(h,
|
||||
[]middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
|
||||
[]middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
wrapped.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
want := "g1:before,g2:before,r1:before,r2:before,handler,r2:after,r1:after,g2:after,g1:after"
|
||||
if got := strings.Join(log, ","); got != want {
|
||||
t.Errorf("order:\n got %s\nwant %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainNoMiddlewares(t *testing.T) {
|
||||
called := false
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
|
||||
wrapped := chain(h, nil, nil)
|
||||
wrapped.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if !called {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
}
|
||||
87
pkg/router/mux.go
Normal file
87
pkg/router/mux.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Package router provides a thin, idiomatic wrapper around the Go 1.22+
|
||||
// net/http ServeMux. It adds method-named convenience methods (Get, Post,
|
||||
// ...), per-route and group middleware, and Group sub-routers that share the
|
||||
// underlying mux while carrying their own prefix and middleware stack.
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
type Mux struct {
|
||||
root *http.ServeMux
|
||||
prefix string
|
||||
middlewares []middleware.Middleware
|
||||
}
|
||||
|
||||
func New() *Mux {
|
||||
sm := http.NewServeMux()
|
||||
return &Mux{root: sm}
|
||||
}
|
||||
|
||||
func (m *Mux) Use(mws ...middleware.Middleware) {
|
||||
m.middlewares = append(m.middlewares, mws...)
|
||||
}
|
||||
|
||||
// Group returns a child Mux that registers on the same underlying ServeMux but
|
||||
// with its prefix appended and the parent's current middlewares snapshotted.
|
||||
// Use() calls made on the parent after Group() do not propagate to the child.
|
||||
func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
|
||||
validateGroupPrefix(prefix)
|
||||
mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws))
|
||||
mwsCopy = append(mwsCopy, m.middlewares...)
|
||||
mwsCopy = append(mwsCopy, mws...)
|
||||
return &Mux{
|
||||
root: m.root,
|
||||
prefix: m.prefix + normalizeGroupPrefix(prefix),
|
||||
middlewares: mwsCopy,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) {
|
||||
full := buildPattern("", m.prefix, pattern)
|
||||
m.root.Handle(full, chain(h, m.middlewares, mws))
|
||||
}
|
||||
|
||||
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.Handle(pattern, fn, mws...)
|
||||
}
|
||||
|
||||
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) {
|
||||
full := buildPattern(method, m.prefix, path)
|
||||
m.root.Handle(full, chain(fn, m.middlewares, mws))
|
||||
}
|
||||
|
||||
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodGet, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPost, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPut, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPatch, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodDelete, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodOptions, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodHead, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.root.ServeHTTP(w, r)
|
||||
}
|
||||
169
pkg/router/mux_test.go
Normal file
169
pkg/router/mux_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
rr := httptest.NewRecorder()
|
||||
m.ServeHTTP(rr, httptest.NewRequest(method, target, nil))
|
||||
return rr
|
||||
}
|
||||
|
||||
func TestMethodRouting(t *testing.T) {
|
||||
m := New()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") })
|
||||
m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" {
|
||||
t.Errorf("GET /x = %q", got)
|
||||
}
|
||||
if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" {
|
||||
t.Errorf("POST /x = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("PUT /x got %d, want 405", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathValueAcrossMiddleware(t *testing.T) {
|
||||
m := New()
|
||||
m.Use(func(next http.Handler) http.Handler { return next })
|
||||
var got string
|
||||
m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
got = r.PathValue("id")
|
||||
})
|
||||
do(t, m, http.MethodGet, "/users/42")
|
||||
if got != "42" {
|
||||
t.Errorf("PathValue = %q, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupPrefix(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/ping = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound {
|
||||
t.Errorf("unprefixed /ping got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupNested(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
v1 := api.Group("/v1")
|
||||
v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/v1/ping = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupTrailingSlashNormalized(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api/")
|
||||
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/ping = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRootSubtree(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" {
|
||||
t.Errorf("/api/ = %q", got)
|
||||
}
|
||||
if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" {
|
||||
t.Errorf("/api/anything = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseAfterGroupDoesNotPropagate(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
hits := 0
|
||||
m.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits++
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
api.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
m.Get("/y", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
do(t, m, http.MethodGet, "/api/x")
|
||||
if hits != 0 {
|
||||
t.Errorf("late Use propagated to api group, hits=%d", hits)
|
||||
}
|
||||
do(t, m, http.MethodGet, "/y")
|
||||
if hits != 1 {
|
||||
t.Errorf("late Use did not apply to root, hits=%d", hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerRouteAndGroupMiddlewareStack(t *testing.T) {
|
||||
m := New()
|
||||
var log []string
|
||||
m.Use(tagMW(&log, "use"))
|
||||
api := m.Group("/api", tagMW(&log, "group"))
|
||||
api.Get("/x",
|
||||
func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") },
|
||||
tagMW(&log, "route"),
|
||||
)
|
||||
|
||||
do(t, m, http.MethodGet, "/api/x")
|
||||
want := "use:before,group:before,route:before,h,route:after,group:after,use:after"
|
||||
if got := strings.Join(log, ","); got != want {
|
||||
t.Errorf("\n got %s\nwant %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAcceptsBarePath(t *testing.T) {
|
||||
m := New()
|
||||
m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "any")
|
||||
}))
|
||||
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
|
||||
if got := do(t, m, method, "/x").Body.String(); got != "any" {
|
||||
t.Errorf("%s /x = %q", method, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAcceptsMethodPattern(t *testing.T) {
|
||||
m := New()
|
||||
m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "del")
|
||||
}))
|
||||
if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" {
|
||||
t.Errorf("DELETE /x = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("GET /x got %d, want 405", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConflictingPatternsPanic(t *testing.T) {
|
||||
m := New()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on duplicate pattern")
|
||||
}
|
||||
}()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
|
||||
71
pkg/router/pattern.go
Normal file
71
pkg/router/pattern.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package router
|
||||
|
||||
import "strings"
|
||||
|
||||
func splitPattern(pattern string) (method, host, path string) {
|
||||
rest := pattern
|
||||
if i := strings.Index(pattern, " "); i >= 0 {
|
||||
method = pattern[:i]
|
||||
rest = strings.TrimLeft(pattern[i+1:], " ")
|
||||
}
|
||||
if strings.HasPrefix(rest, "/") {
|
||||
return method, "", rest
|
||||
}
|
||||
if i := strings.Index(rest, "/"); i >= 0 {
|
||||
return method, rest[:i], rest[i:]
|
||||
}
|
||||
return method, rest, ""
|
||||
}
|
||||
|
||||
func validateGroupPrefix(p string) {
|
||||
if p == "" {
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(p, " \t") {
|
||||
panic("lightmux: group prefix must not contain whitespace (no method or host allowed): " + p)
|
||||
}
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
panic("lightmux: group prefix must start with '/': " + p)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeGroupPrefix(p string) string {
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
if len(p) > 1 && strings.HasSuffix(p, "/") {
|
||||
return p[:len(p)-1]
|
||||
}
|
||||
if p == "/" {
|
||||
return ""
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func joinPath(prefix, sub string) string {
|
||||
if prefix == "" {
|
||||
return sub
|
||||
}
|
||||
if sub == "" {
|
||||
return prefix
|
||||
}
|
||||
if sub == "/" {
|
||||
return prefix + "/"
|
||||
}
|
||||
if !strings.HasPrefix(sub, "/") {
|
||||
panic("lightmux: route path must start with '/': " + sub)
|
||||
}
|
||||
return prefix + sub
|
||||
}
|
||||
|
||||
func buildPattern(method, prefix, pattern string) string {
|
||||
m, host, path := splitPattern(pattern)
|
||||
if method != "" {
|
||||
m = method
|
||||
}
|
||||
full := host + joinPath(prefix, path)
|
||||
if m != "" {
|
||||
return m + " " + full
|
||||
}
|
||||
return full
|
||||
}
|
||||
118
pkg/router/pattern_test.go
Normal file
118
pkg/router/pattern_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package router
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSplitPattern(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
method, host, path string
|
||||
}{
|
||||
{"/foo", "", "", "/foo"},
|
||||
{"GET /foo", "GET", "", "/foo"},
|
||||
{"POST /users/{id}", "POST", "", "/users/{id}"},
|
||||
{"GET example.com/foo", "GET", "example.com", "/foo"},
|
||||
{"example.com/foo", "", "example.com", "/foo"},
|
||||
{"GET /", "GET", "", "/"},
|
||||
{"/", "", "", "/"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.in, func(t *testing.T) {
|
||||
m, h, p := splitPattern(c.in)
|
||||
if m != c.method || h != c.host || p != c.path {
|
||||
t.Fatalf("splitPattern(%q) = (%q, %q, %q), want (%q, %q, %q)",
|
||||
c.in, m, h, p, c.method, c.host, c.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGroupPrefix(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "",
|
||||
"/": "",
|
||||
"/api": "/api",
|
||||
"/api/": "/api",
|
||||
"/api/v1": "/api/v1",
|
||||
"/api/v1/": "/api/v1",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeGroupPrefix(in); got != want {
|
||||
t.Errorf("normalizeGroupPrefix(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
prefix, sub, want string
|
||||
}{
|
||||
{"", "/foo", "/foo"},
|
||||
{"/api", "/foo", "/api/foo"},
|
||||
{"/api", "", "/api"},
|
||||
{"/api", "/", "/api/"},
|
||||
{"", "", ""},
|
||||
{"/api/v1", "/users/{id}", "/api/v1/users/{id}"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := joinPath(c.prefix, c.sub)
|
||||
if got != c.want {
|
||||
t.Errorf("joinPath(%q, %q) = %q, want %q", c.prefix, c.sub, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinPathPanicsOnBadSub(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on sub without leading /")
|
||||
}
|
||||
}()
|
||||
joinPath("/api", "foo")
|
||||
}
|
||||
|
||||
func TestValidateGroupPrefix(t *testing.T) {
|
||||
good := []string{"", "/api", "/api/v1"}
|
||||
for _, p := range good {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("validateGroupPrefix(%q) unexpected panic: %v", p, r)
|
||||
}
|
||||
}()
|
||||
validateGroupPrefix(p)
|
||||
}()
|
||||
}
|
||||
bad := []string{"api", "GET /api", "/api with space", "host.com/api"}
|
||||
for _, p := range bad {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("validateGroupPrefix(%q) expected panic", p)
|
||||
}
|
||||
}()
|
||||
validateGroupPrefix(p)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPattern(t *testing.T) {
|
||||
cases := []struct {
|
||||
method, prefix, pattern, want string
|
||||
}{
|
||||
{"", "", "/foo", "/foo"},
|
||||
{"GET", "", "/foo", "GET /foo"},
|
||||
{"GET", "/api", "/foo", "GET /api/foo"},
|
||||
{"", "/api", "GET /foo", "GET /api/foo"},
|
||||
{"", "/api", "GET example.com/foo", "GET example.com/api/foo"},
|
||||
{"", "/api", "/", "/api/"},
|
||||
{"GET", "/api", "/", "GET /api/"},
|
||||
{"GET", "/api", "", "GET /api"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := buildPattern(c.method, c.prefix, c.pattern)
|
||||
if got != c.want {
|
||||
t.Errorf("buildPattern(%q, %q, %q) = %q, want %q",
|
||||
c.method, c.prefix, c.pattern, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue