diff --git a/.gitignore b/.gitignore index 5b90e79..0666eb8 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ go.work.sum # env file .env - +.env.* +!.env.example diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..89710eb --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,43 @@ +version: '3' + +tasks: + default: + desc: Run vet and tests + cmds: + - task: check + + install:tools: + desc: Install development tools (tparse for prettier test output) + cmds: + - go install github.com/mfridman/tparse@latest + + test: + desc: Run tests with prettier output via tparse + cmds: + - set -o pipefail && go test ./... -json -cover | tparse -all + + test:race: + desc: Run tests under the race detector + cmds: + - set -o pipefail && go test ./... -race -json | tparse -all + + vet: + desc: Run go vet + cmds: + - go vet ./... + + fmt: + desc: Format Go source files + cmds: + - gofmt -w . + + tidy: + desc: Tidy go.mod + cmds: + - go mod tidy + + check: + desc: Run vet and tests + cmds: + - task: vet + - task: test diff --git a/examples/basic/main.go b/examples/basic/main.go new file mode 100644 index 0000000..4991fec --- /dev/null +++ b/examples/basic/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "strconv" + + "git.juancwu.dev/juancwu/lightmux" + "git.juancwu.dev/juancwu/lightmux/pkg/middleware" +) + +func main() { + mux := lightmux.New() + mux.Use(middleware.Recoverer, middleware.Logger) + + mux.Get("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello from lightmux") + }) + + mux.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) { + id, err := strconv.Atoi(r.PathValue("id")) + if err != nil { + http.Error(w, "bad id", http.StatusBadRequest) + return + } + fmt.Fprintf(w, "user %d\n", id) + }) + + mux.Get("/panic", func(w http.ResponseWriter, r *http.Request) { + panic("demonstrating Recoverer") + }) + + api := mux.Group("/api") + api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "pong") + }) + + v1 := api.Group("/v1") + v1.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "item: %s\n", r.PathValue("name")) + }) + + log.Println("listening on :8080") + log.Fatal(http.ListenAndServe(":8080", mux)) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..329bb20 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.juancwu.dev/juancwu/lightmux + +go 1.26.2 diff --git a/lightmux.go b/lightmux.go new file mode 100644 index 0000000..18fe1ed --- /dev/null +++ b/lightmux.go @@ -0,0 +1,15 @@ +// Package lightmux is a small wrapper around the Go 1.22+ net/http ServeMux +// adding method-named convenience methods, groups, and per-route middleware. +package lightmux + +import ( + "git.juancwu.dev/juancwu/lightmux/pkg/middleware" + "git.juancwu.dev/juancwu/lightmux/pkg/router" +) + +type ( + Mux = router.Mux + Middleware = middleware.Middleware +) + +func New() *Mux { return router.New() } diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go new file mode 100644 index 0000000..a535b43 --- /dev/null +++ b/pkg/middleware/logger.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "log" + "net/http" + "time" +) + +type statusRecorder struct { + http.ResponseWriter + status int + wrote bool +} + +func (s *statusRecorder) WriteHeader(code int) { + if !s.wrote { + s.status = code + s.wrote = true + } + s.ResponseWriter.WriteHeader(code) +} + +func (s *statusRecorder) Write(b []byte) (int, error) { + if !s.wrote { + s.status = http.StatusOK + s.wrote = true + } + return s.ResponseWriter.Write(b) +} + +func Logger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rec, r) + log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.status, time.Since(start)) + }) +} diff --git a/pkg/middleware/logger_test.go b/pkg/middleware/logger_test.go new file mode 100644 index 0000000..3c886bf --- /dev/null +++ b/pkg/middleware/logger_test.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestLogger(t *testing.T) { + var buf bytes.Buffer + orig := log.Default().Writer() + log.Default().SetOutput(&buf) + defer log.Default().SetOutput(orig) + + h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil)) + + if rr.Code != http.StatusTeapot { + t.Errorf("status code = %d, want 418", rr.Code) + } + out := buf.String() + if !strings.Contains(out, "GET /foo 418") { + t.Errorf("log output missing expected fields: %q", out) + } +} + +func TestLoggerDefaultStatusOK(t *testing.T) { + var buf bytes.Buffer + orig := log.Default().Writer() + log.Default().SetOutput(&buf) + defer log.Default().SetOutput(orig) + + h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + + if !strings.Contains(buf.String(), "200") { + t.Errorf("expected default 200 in log, got %q", buf.String()) + } +} diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go new file mode 100644 index 0000000..f54cea9 --- /dev/null +++ b/pkg/middleware/middleware.go @@ -0,0 +1,5 @@ +package middleware + +import "net/http" + +type Middleware = func(http.Handler) http.Handler diff --git a/pkg/middleware/recoverer.go b/pkg/middleware/recoverer.go new file mode 100644 index 0000000..b7a8258 --- /dev/null +++ b/pkg/middleware/recoverer.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "log" + "net/http" + "runtime/debug" +) + +func Recoverer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + log.Printf("panic: %v\n%s", rec, debug.Stack()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} diff --git a/pkg/middleware/recoverer_test.go b/pkg/middleware/recoverer_test.go new file mode 100644 index 0000000..a5a07a9 --- /dev/null +++ b/pkg/middleware/recoverer_test.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRecovererCatchesPanic(t *testing.T) { + var buf bytes.Buffer + orig := log.Default().Writer() + log.Default().SetOutput(&buf) + defer log.Default().SetOutput(orig) + + h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom") + })) + + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", rr.Code) + } + if !strings.Contains(buf.String(), "panic: boom") { + t.Errorf("expected panic log, got %q", buf.String()) + } +} + +func TestRecovererPassesThrough(t *testing.T) { + called := false + h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + if !called || rr.Code != http.StatusOK { + t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code) + } +} diff --git a/pkg/router/chain.go b/pkg/router/chain.go new file mode 100644 index 0000000..54eedea --- /dev/null +++ b/pkg/router/chain.go @@ -0,0 +1,19 @@ +package router + +import ( + "net/http" + + "git.juancwu.dev/juancwu/lightmux/pkg/middleware" +) + +// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the +// outermost layer (runs first on request, last on response). +func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler { + all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws)) + all = append(all, groupMws...) + all = append(all, routeMws...) + for i := len(all) - 1; i >= 0; i-- { + h = all[i](h) + } + return h +} diff --git a/pkg/router/chain_test.go b/pkg/router/chain_test.go new file mode 100644 index 0000000..3ee5001 --- /dev/null +++ b/pkg/router/chain_test.go @@ -0,0 +1,50 @@ +package router + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.juancwu.dev/juancwu/lightmux/pkg/middleware" +) + +func tagMW(log *[]string, tag string) middleware.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *log = append(*log, tag+":before") + next.ServeHTTP(w, r) + *log = append(*log, tag+":after") + }) + } +} + +func TestChainOrder(t *testing.T) { + var log []string + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log = append(log, "handler") + }) + + wrapped := chain(h, + []middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")}, + []middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")}, + ) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + wrapped.ServeHTTP(httptest.NewRecorder(), req) + + want := "g1:before,g2:before,r1:before,r2:before,handler,r2:after,r1:after,g2:after,g1:after" + if got := strings.Join(log, ","); got != want { + t.Errorf("order:\n got %s\nwant %s", got, want) + } +} + +func TestChainNoMiddlewares(t *testing.T) { + called := false + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true }) + wrapped := chain(h, nil, nil) + wrapped.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + if !called { + t.Fatal("handler not called") + } +} diff --git a/pkg/router/mux.go b/pkg/router/mux.go new file mode 100644 index 0000000..172dbe4 --- /dev/null +++ b/pkg/router/mux.go @@ -0,0 +1,87 @@ +// Package router provides a thin, idiomatic wrapper around the Go 1.22+ +// net/http ServeMux. It adds method-named convenience methods (Get, Post, +// ...), per-route and group middleware, and Group sub-routers that share the +// underlying mux while carrying their own prefix and middleware stack. +package router + +import ( + "net/http" + + "git.juancwu.dev/juancwu/lightmux/pkg/middleware" +) + +type Mux struct { + root *http.ServeMux + prefix string + middlewares []middleware.Middleware +} + +func New() *Mux { + sm := http.NewServeMux() + return &Mux{root: sm} +} + +func (m *Mux) Use(mws ...middleware.Middleware) { + m.middlewares = append(m.middlewares, mws...) +} + +// Group returns a child Mux that registers on the same underlying ServeMux but +// with its prefix appended and the parent's current middlewares snapshotted. +// Use() calls made on the parent after Group() do not propagate to the child. +func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux { + validateGroupPrefix(prefix) + mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws)) + mwsCopy = append(mwsCopy, m.middlewares...) + mwsCopy = append(mwsCopy, mws...) + return &Mux{ + root: m.root, + prefix: m.prefix + normalizeGroupPrefix(prefix), + middlewares: mwsCopy, + } +} + +func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) { + full := buildPattern("", m.prefix, pattern) + m.root.Handle(full, chain(h, m.middlewares, mws)) +} + +func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.Handle(pattern, fn, mws...) +} + +func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) { + full := buildPattern(method, m.prefix, path) + m.root.Handle(full, chain(fn, m.middlewares, mws)) +} + +func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodGet, path, fn, mws) +} + +func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodPost, path, fn, mws) +} + +func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodPut, path, fn, mws) +} + +func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodPatch, path, fn, mws) +} + +func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodDelete, path, fn, mws) +} + +func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodOptions, path, fn, mws) +} + +func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) { + m.method(http.MethodHead, path, fn, mws) +} + +func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.root.ServeHTTP(w, r) +} diff --git a/pkg/router/mux_test.go b/pkg/router/mux_test.go new file mode 100644 index 0000000..1af57d1 --- /dev/null +++ b/pkg/router/mux_test.go @@ -0,0 +1,169 @@ +package router + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder { + t.Helper() + rr := httptest.NewRecorder() + m.ServeHTTP(rr, httptest.NewRequest(method, target, nil)) + return rr +} + +func TestMethodRouting(t *testing.T) { + m := New() + m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") }) + m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") }) + + if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" { + t.Errorf("GET /x = %q", got) + } + if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" { + t.Errorf("POST /x = %q", got) + } + if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed { + t.Errorf("PUT /x got %d, want 405", rr.Code) + } +} + +func TestPathValueAcrossMiddleware(t *testing.T) { + m := New() + m.Use(func(next http.Handler) http.Handler { return next }) + var got string + m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) { + got = r.PathValue("id") + }) + do(t, m, http.MethodGet, "/users/42") + if got != "42" { + t.Errorf("PathValue = %q, want 42", got) + } +} + +func TestGroupPrefix(t *testing.T) { + m := New() + api := m.Group("/api") + api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") }) + + if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" { + t.Errorf("/api/ping = %q", got) + } + if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound { + t.Errorf("unprefixed /ping got %d, want 404", rr.Code) + } +} + +func TestGroupNested(t *testing.T) { + m := New() + api := m.Group("/api") + v1 := api.Group("/v1") + v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") }) + + if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" { + t.Errorf("/api/v1/ping = %q", got) + } +} + +func TestGroupTrailingSlashNormalized(t *testing.T) { + m := New() + api := m.Group("/api/") + api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") }) + + if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" { + t.Errorf("/api/ping = %q", got) + } +} + +func TestGroupRootSubtree(t *testing.T) { + m := New() + api := m.Group("/api") + api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") }) + + if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" { + t.Errorf("/api/ = %q", got) + } + if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" { + t.Errorf("/api/anything = %q", got) + } +} + +func TestUseAfterGroupDoesNotPropagate(t *testing.T) { + m := New() + api := m.Group("/api") + hits := 0 + m.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + next.ServeHTTP(w, r) + }) + }) + api.Get("/x", func(w http.ResponseWriter, r *http.Request) {}) + m.Get("/y", func(w http.ResponseWriter, r *http.Request) {}) + + do(t, m, http.MethodGet, "/api/x") + if hits != 0 { + t.Errorf("late Use propagated to api group, hits=%d", hits) + } + do(t, m, http.MethodGet, "/y") + if hits != 1 { + t.Errorf("late Use did not apply to root, hits=%d", hits) + } +} + +func TestPerRouteAndGroupMiddlewareStack(t *testing.T) { + m := New() + var log []string + m.Use(tagMW(&log, "use")) + api := m.Group("/api", tagMW(&log, "group")) + api.Get("/x", + func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") }, + tagMW(&log, "route"), + ) + + do(t, m, http.MethodGet, "/api/x") + want := "use:before,group:before,route:before,h,route:after,group:after,use:after" + if got := strings.Join(log, ","); got != want { + t.Errorf("\n got %s\nwant %s", got, want) + } +} + +func TestHandleAcceptsBarePath(t *testing.T) { + m := New() + m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "any") + })) + for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} { + if got := do(t, m, method, "/x").Body.String(); got != "any" { + t.Errorf("%s /x = %q", method, got) + } + } +} + +func TestHandleAcceptsMethodPattern(t *testing.T) { + m := New() + m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "del") + })) + if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" { + t.Errorf("DELETE /x = %q", got) + } + if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed { + t.Errorf("GET /x got %d, want 405", rr.Code) + } +} + +func TestConflictingPatternsPanic(t *testing.T) { + m := New() + m.Get("/x", func(w http.ResponseWriter, r *http.Request) {}) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on duplicate pattern") + } + }() + m.Get("/x", func(w http.ResponseWriter, r *http.Request) {}) +} + diff --git a/pkg/router/pattern.go b/pkg/router/pattern.go new file mode 100644 index 0000000..3d79c19 --- /dev/null +++ b/pkg/router/pattern.go @@ -0,0 +1,71 @@ +package router + +import "strings" + +func splitPattern(pattern string) (method, host, path string) { + rest := pattern + if i := strings.Index(pattern, " "); i >= 0 { + method = pattern[:i] + rest = strings.TrimLeft(pattern[i+1:], " ") + } + if strings.HasPrefix(rest, "/") { + return method, "", rest + } + if i := strings.Index(rest, "/"); i >= 0 { + return method, rest[:i], rest[i:] + } + return method, rest, "" +} + +func validateGroupPrefix(p string) { + if p == "" { + return + } + if strings.ContainsAny(p, " \t") { + panic("lightmux: group prefix must not contain whitespace (no method or host allowed): " + p) + } + if !strings.HasPrefix(p, "/") { + panic("lightmux: group prefix must start with '/': " + p) + } +} + +func normalizeGroupPrefix(p string) string { + if p == "" { + return "" + } + if len(p) > 1 && strings.HasSuffix(p, "/") { + return p[:len(p)-1] + } + if p == "/" { + return "" + } + return p +} + +func joinPath(prefix, sub string) string { + if prefix == "" { + return sub + } + if sub == "" { + return prefix + } + if sub == "/" { + return prefix + "/" + } + if !strings.HasPrefix(sub, "/") { + panic("lightmux: route path must start with '/': " + sub) + } + return prefix + sub +} + +func buildPattern(method, prefix, pattern string) string { + m, host, path := splitPattern(pattern) + if method != "" { + m = method + } + full := host + joinPath(prefix, path) + if m != "" { + return m + " " + full + } + return full +} diff --git a/pkg/router/pattern_test.go b/pkg/router/pattern_test.go new file mode 100644 index 0000000..1ef240a --- /dev/null +++ b/pkg/router/pattern_test.go @@ -0,0 +1,118 @@ +package router + +import "testing" + +func TestSplitPattern(t *testing.T) { + cases := []struct { + in string + method, host, path string + }{ + {"/foo", "", "", "/foo"}, + {"GET /foo", "GET", "", "/foo"}, + {"POST /users/{id}", "POST", "", "/users/{id}"}, + {"GET example.com/foo", "GET", "example.com", "/foo"}, + {"example.com/foo", "", "example.com", "/foo"}, + {"GET /", "GET", "", "/"}, + {"/", "", "", "/"}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + m, h, p := splitPattern(c.in) + if m != c.method || h != c.host || p != c.path { + t.Fatalf("splitPattern(%q) = (%q, %q, %q), want (%q, %q, %q)", + c.in, m, h, p, c.method, c.host, c.path) + } + }) + } +} + +func TestNormalizeGroupPrefix(t *testing.T) { + cases := map[string]string{ + "": "", + "/": "", + "/api": "/api", + "/api/": "/api", + "/api/v1": "/api/v1", + "/api/v1/": "/api/v1", + } + for in, want := range cases { + if got := normalizeGroupPrefix(in); got != want { + t.Errorf("normalizeGroupPrefix(%q) = %q, want %q", in, got, want) + } + } +} + +func TestJoinPath(t *testing.T) { + cases := []struct { + prefix, sub, want string + }{ + {"", "/foo", "/foo"}, + {"/api", "/foo", "/api/foo"}, + {"/api", "", "/api"}, + {"/api", "/", "/api/"}, + {"", "", ""}, + {"/api/v1", "/users/{id}", "/api/v1/users/{id}"}, + } + for _, c := range cases { + got := joinPath(c.prefix, c.sub) + if got != c.want { + t.Errorf("joinPath(%q, %q) = %q, want %q", c.prefix, c.sub, got, c.want) + } + } +} + +func TestJoinPathPanicsOnBadSub(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on sub without leading /") + } + }() + joinPath("/api", "foo") +} + +func TestValidateGroupPrefix(t *testing.T) { + good := []string{"", "/api", "/api/v1"} + for _, p := range good { + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("validateGroupPrefix(%q) unexpected panic: %v", p, r) + } + }() + validateGroupPrefix(p) + }() + } + bad := []string{"api", "GET /api", "/api with space", "host.com/api"} + for _, p := range bad { + func() { + defer func() { + if r := recover(); r == nil { + t.Errorf("validateGroupPrefix(%q) expected panic", p) + } + }() + validateGroupPrefix(p) + }() + } +} + +func TestBuildPattern(t *testing.T) { + cases := []struct { + method, prefix, pattern, want string + }{ + {"", "", "/foo", "/foo"}, + {"GET", "", "/foo", "GET /foo"}, + {"GET", "/api", "/foo", "GET /api/foo"}, + {"", "/api", "GET /foo", "GET /api/foo"}, + {"", "/api", "GET example.com/foo", "GET example.com/api/foo"}, + {"", "/api", "/", "/api/"}, + {"GET", "/api", "/", "GET /api/"}, + {"GET", "/api", "", "GET /api"}, + } + for _, c := range cases { + got := buildPattern(c.method, c.prefix, c.pattern) + if got != c.want { + t.Errorf("buildPattern(%q, %q, %q) = %q, want %q", + c.method, c.prefix, c.pattern, got, c.want) + } + } +}