lightmux/pkg/router/mux_test.go
2026-04-26 13:14:30 +00:00

168 lines
4.7 KiB
Go

package router
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder {
t.Helper()
rr := httptest.NewRecorder()
m.ServeHTTP(rr, httptest.NewRequest(method, target, nil))
return rr
}
func TestMethodRouting(t *testing.T) {
m := New()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") })
m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") })
if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" {
t.Errorf("GET /x = %q", got)
}
if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" {
t.Errorf("POST /x = %q", got)
}
if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed {
t.Errorf("PUT /x got %d, want 405", rr.Code)
}
}
func TestPathValueAcrossMiddleware(t *testing.T) {
m := New()
m.Use(func(next http.Handler) http.Handler { return next })
var got string
m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
got = r.PathValue("id")
})
do(t, m, http.MethodGet, "/users/42")
if got != "42" {
t.Errorf("PathValue = %q, want 42", got)
}
}
func TestGroupPrefix(t *testing.T) {
m := New()
api := m.Group("/api")
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
t.Errorf("/api/ping = %q", got)
}
if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound {
t.Errorf("unprefixed /ping got %d, want 404", rr.Code)
}
}
func TestGroupNested(t *testing.T) {
m := New()
api := m.Group("/api")
v1 := api.Group("/v1")
v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" {
t.Errorf("/api/v1/ping = %q", got)
}
}
func TestGroupTrailingSlashNormalized(t *testing.T) {
m := New()
api := m.Group("/api/")
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
t.Errorf("/api/ping = %q", got)
}
}
func TestGroupRootSubtree(t *testing.T) {
m := New()
api := m.Group("/api")
api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") })
if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" {
t.Errorf("/api/ = %q", got)
}
if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" {
t.Errorf("/api/anything = %q", got)
}
}
func TestUseAfterGroupDoesNotPropagate(t *testing.T) {
m := New()
api := m.Group("/api")
hits := 0
m.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
next.ServeHTTP(w, r)
})
})
api.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
m.Get("/y", func(w http.ResponseWriter, r *http.Request) {})
do(t, m, http.MethodGet, "/api/x")
if hits != 0 {
t.Errorf("late Use propagated to api group, hits=%d", hits)
}
do(t, m, http.MethodGet, "/y")
if hits != 1 {
t.Errorf("late Use did not apply to root, hits=%d", hits)
}
}
func TestPerRouteAndGroupMiddlewareStack(t *testing.T) {
m := New()
var log []string
m.Use(tagMW(&log, "use"))
api := m.Group("/api", tagMW(&log, "group"))
api.Get("/x",
func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") },
tagMW(&log, "route"),
)
do(t, m, http.MethodGet, "/api/x")
want := "use:before,group:before,route:before,h,route:after,group:after,use:after"
if got := strings.Join(log, ","); got != want {
t.Errorf("\n got %s\nwant %s", got, want)
}
}
func TestHandleAcceptsBarePath(t *testing.T) {
m := New()
m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "any")
}))
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
if got := do(t, m, method, "/x").Body.String(); got != "any" {
t.Errorf("%s /x = %q", method, got)
}
}
}
func TestHandleAcceptsMethodPattern(t *testing.T) {
m := New()
m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "del")
}))
if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" {
t.Errorf("DELETE /x = %q", got)
}
if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed {
t.Errorf("GET /x got %d, want 405", rr.Code)
}
}
func TestConflictingPatternsPanic(t *testing.T) {
m := New()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on duplicate pattern")
}
}()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
}