initial implementation
This commit is contained in:
parent
f8b0abc517
commit
cb373e637b
16 changed files with 777 additions and 1 deletions
19
pkg/router/chain.go
Normal file
19
pkg/router/chain.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
|
||||
// outermost layer (runs first on request, last on response).
|
||||
func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler {
|
||||
all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws))
|
||||
all = append(all, groupMws...)
|
||||
all = append(all, routeMws...)
|
||||
for i := len(all) - 1; i >= 0; i-- {
|
||||
h = all[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
50
pkg/router/chain_test.go
Normal file
50
pkg/router/chain_test.go
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
func tagMW(log *[]string, tag string) middleware.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*log = append(*log, tag+":before")
|
||||
next.ServeHTTP(w, r)
|
||||
*log = append(*log, tag+":after")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainOrder(t *testing.T) {
|
||||
var log []string
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log = append(log, "handler")
|
||||
})
|
||||
|
||||
wrapped := chain(h,
|
||||
[]middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
|
||||
[]middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
wrapped.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
want := "g1:before,g2:before,r1:before,r2:before,handler,r2:after,r1:after,g2:after,g1:after"
|
||||
if got := strings.Join(log, ","); got != want {
|
||||
t.Errorf("order:\n got %s\nwant %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainNoMiddlewares(t *testing.T) {
|
||||
called := false
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
|
||||
wrapped := chain(h, nil, nil)
|
||||
wrapped.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if !called {
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
}
|
||||
87
pkg/router/mux.go
Normal file
87
pkg/router/mux.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
// Package router provides a thin, idiomatic wrapper around the Go 1.22+
|
||||
// net/http ServeMux. It adds method-named convenience methods (Get, Post,
|
||||
// ...), per-route and group middleware, and Group sub-routers that share the
|
||||
// underlying mux while carrying their own prefix and middleware stack.
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
|
||||
)
|
||||
|
||||
type Mux struct {
|
||||
root *http.ServeMux
|
||||
prefix string
|
||||
middlewares []middleware.Middleware
|
||||
}
|
||||
|
||||
func New() *Mux {
|
||||
sm := http.NewServeMux()
|
||||
return &Mux{root: sm}
|
||||
}
|
||||
|
||||
func (m *Mux) Use(mws ...middleware.Middleware) {
|
||||
m.middlewares = append(m.middlewares, mws...)
|
||||
}
|
||||
|
||||
// Group returns a child Mux that registers on the same underlying ServeMux but
|
||||
// with its prefix appended and the parent's current middlewares snapshotted.
|
||||
// Use() calls made on the parent after Group() do not propagate to the child.
|
||||
func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
|
||||
validateGroupPrefix(prefix)
|
||||
mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws))
|
||||
mwsCopy = append(mwsCopy, m.middlewares...)
|
||||
mwsCopy = append(mwsCopy, mws...)
|
||||
return &Mux{
|
||||
root: m.root,
|
||||
prefix: m.prefix + normalizeGroupPrefix(prefix),
|
||||
middlewares: mwsCopy,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) {
|
||||
full := buildPattern("", m.prefix, pattern)
|
||||
m.root.Handle(full, chain(h, m.middlewares, mws))
|
||||
}
|
||||
|
||||
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.Handle(pattern, fn, mws...)
|
||||
}
|
||||
|
||||
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) {
|
||||
full := buildPattern(method, m.prefix, path)
|
||||
m.root.Handle(full, chain(fn, m.middlewares, mws))
|
||||
}
|
||||
|
||||
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodGet, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPost, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPut, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodPatch, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodDelete, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodOptions, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
|
||||
m.method(http.MethodHead, path, fn, mws)
|
||||
}
|
||||
|
||||
func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.root.ServeHTTP(w, r)
|
||||
}
|
||||
169
pkg/router/mux_test.go
Normal file
169
pkg/router/mux_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
rr := httptest.NewRecorder()
|
||||
m.ServeHTTP(rr, httptest.NewRequest(method, target, nil))
|
||||
return rr
|
||||
}
|
||||
|
||||
func TestMethodRouting(t *testing.T) {
|
||||
m := New()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") })
|
||||
m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" {
|
||||
t.Errorf("GET /x = %q", got)
|
||||
}
|
||||
if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" {
|
||||
t.Errorf("POST /x = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("PUT /x got %d, want 405", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathValueAcrossMiddleware(t *testing.T) {
|
||||
m := New()
|
||||
m.Use(func(next http.Handler) http.Handler { return next })
|
||||
var got string
|
||||
m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
got = r.PathValue("id")
|
||||
})
|
||||
do(t, m, http.MethodGet, "/users/42")
|
||||
if got != "42" {
|
||||
t.Errorf("PathValue = %q, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupPrefix(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/ping = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound {
|
||||
t.Errorf("unprefixed /ping got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupNested(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
v1 := api.Group("/v1")
|
||||
v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/v1/ping = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupTrailingSlashNormalized(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api/")
|
||||
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
|
||||
t.Errorf("/api/ping = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupRootSubtree(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") })
|
||||
|
||||
if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" {
|
||||
t.Errorf("/api/ = %q", got)
|
||||
}
|
||||
if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" {
|
||||
t.Errorf("/api/anything = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseAfterGroupDoesNotPropagate(t *testing.T) {
|
||||
m := New()
|
||||
api := m.Group("/api")
|
||||
hits := 0
|
||||
m.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits++
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
api.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
m.Get("/y", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
do(t, m, http.MethodGet, "/api/x")
|
||||
if hits != 0 {
|
||||
t.Errorf("late Use propagated to api group, hits=%d", hits)
|
||||
}
|
||||
do(t, m, http.MethodGet, "/y")
|
||||
if hits != 1 {
|
||||
t.Errorf("late Use did not apply to root, hits=%d", hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerRouteAndGroupMiddlewareStack(t *testing.T) {
|
||||
m := New()
|
||||
var log []string
|
||||
m.Use(tagMW(&log, "use"))
|
||||
api := m.Group("/api", tagMW(&log, "group"))
|
||||
api.Get("/x",
|
||||
func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") },
|
||||
tagMW(&log, "route"),
|
||||
)
|
||||
|
||||
do(t, m, http.MethodGet, "/api/x")
|
||||
want := "use:before,group:before,route:before,h,route:after,group:after,use:after"
|
||||
if got := strings.Join(log, ","); got != want {
|
||||
t.Errorf("\n got %s\nwant %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAcceptsBarePath(t *testing.T) {
|
||||
m := New()
|
||||
m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "any")
|
||||
}))
|
||||
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
|
||||
if got := do(t, m, method, "/x").Body.String(); got != "any" {
|
||||
t.Errorf("%s /x = %q", method, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAcceptsMethodPattern(t *testing.T) {
|
||||
m := New()
|
||||
m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, "del")
|
||||
}))
|
||||
if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" {
|
||||
t.Errorf("DELETE /x = %q", got)
|
||||
}
|
||||
if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("GET /x got %d, want 405", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConflictingPatternsPanic(t *testing.T) {
|
||||
m := New()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on duplicate pattern")
|
||||
}
|
||||
}()
|
||||
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
|
||||
71
pkg/router/pattern.go
Normal file
71
pkg/router/pattern.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package router
|
||||
|
||||
import "strings"
|
||||
|
||||
func splitPattern(pattern string) (method, host, path string) {
|
||||
rest := pattern
|
||||
if i := strings.Index(pattern, " "); i >= 0 {
|
||||
method = pattern[:i]
|
||||
rest = strings.TrimLeft(pattern[i+1:], " ")
|
||||
}
|
||||
if strings.HasPrefix(rest, "/") {
|
||||
return method, "", rest
|
||||
}
|
||||
if i := strings.Index(rest, "/"); i >= 0 {
|
||||
return method, rest[:i], rest[i:]
|
||||
}
|
||||
return method, rest, ""
|
||||
}
|
||||
|
||||
func validateGroupPrefix(p string) {
|
||||
if p == "" {
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(p, " \t") {
|
||||
panic("lightmux: group prefix must not contain whitespace (no method or host allowed): " + p)
|
||||
}
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
panic("lightmux: group prefix must start with '/': " + p)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeGroupPrefix(p string) string {
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
if len(p) > 1 && strings.HasSuffix(p, "/") {
|
||||
return p[:len(p)-1]
|
||||
}
|
||||
if p == "/" {
|
||||
return ""
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func joinPath(prefix, sub string) string {
|
||||
if prefix == "" {
|
||||
return sub
|
||||
}
|
||||
if sub == "" {
|
||||
return prefix
|
||||
}
|
||||
if sub == "/" {
|
||||
return prefix + "/"
|
||||
}
|
||||
if !strings.HasPrefix(sub, "/") {
|
||||
panic("lightmux: route path must start with '/': " + sub)
|
||||
}
|
||||
return prefix + sub
|
||||
}
|
||||
|
||||
func buildPattern(method, prefix, pattern string) string {
|
||||
m, host, path := splitPattern(pattern)
|
||||
if method != "" {
|
||||
m = method
|
||||
}
|
||||
full := host + joinPath(prefix, path)
|
||||
if m != "" {
|
||||
return m + " " + full
|
||||
}
|
||||
return full
|
||||
}
|
||||
118
pkg/router/pattern_test.go
Normal file
118
pkg/router/pattern_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package router
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSplitPattern(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
method, host, path string
|
||||
}{
|
||||
{"/foo", "", "", "/foo"},
|
||||
{"GET /foo", "GET", "", "/foo"},
|
||||
{"POST /users/{id}", "POST", "", "/users/{id}"},
|
||||
{"GET example.com/foo", "GET", "example.com", "/foo"},
|
||||
{"example.com/foo", "", "example.com", "/foo"},
|
||||
{"GET /", "GET", "", "/"},
|
||||
{"/", "", "", "/"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.in, func(t *testing.T) {
|
||||
m, h, p := splitPattern(c.in)
|
||||
if m != c.method || h != c.host || p != c.path {
|
||||
t.Fatalf("splitPattern(%q) = (%q, %q, %q), want (%q, %q, %q)",
|
||||
c.in, m, h, p, c.method, c.host, c.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGroupPrefix(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "",
|
||||
"/": "",
|
||||
"/api": "/api",
|
||||
"/api/": "/api",
|
||||
"/api/v1": "/api/v1",
|
||||
"/api/v1/": "/api/v1",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeGroupPrefix(in); got != want {
|
||||
t.Errorf("normalizeGroupPrefix(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
prefix, sub, want string
|
||||
}{
|
||||
{"", "/foo", "/foo"},
|
||||
{"/api", "/foo", "/api/foo"},
|
||||
{"/api", "", "/api"},
|
||||
{"/api", "/", "/api/"},
|
||||
{"", "", ""},
|
||||
{"/api/v1", "/users/{id}", "/api/v1/users/{id}"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := joinPath(c.prefix, c.sub)
|
||||
if got != c.want {
|
||||
t.Errorf("joinPath(%q, %q) = %q, want %q", c.prefix, c.sub, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinPathPanicsOnBadSub(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on sub without leading /")
|
||||
}
|
||||
}()
|
||||
joinPath("/api", "foo")
|
||||
}
|
||||
|
||||
func TestValidateGroupPrefix(t *testing.T) {
|
||||
good := []string{"", "/api", "/api/v1"}
|
||||
for _, p := range good {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("validateGroupPrefix(%q) unexpected panic: %v", p, r)
|
||||
}
|
||||
}()
|
||||
validateGroupPrefix(p)
|
||||
}()
|
||||
}
|
||||
bad := []string{"api", "GET /api", "/api with space", "host.com/api"}
|
||||
for _, p := range bad {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("validateGroupPrefix(%q) expected panic", p)
|
||||
}
|
||||
}()
|
||||
validateGroupPrefix(p)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPattern(t *testing.T) {
|
||||
cases := []struct {
|
||||
method, prefix, pattern, want string
|
||||
}{
|
||||
{"", "", "/foo", "/foo"},
|
||||
{"GET", "", "/foo", "GET /foo"},
|
||||
{"GET", "/api", "/foo", "GET /api/foo"},
|
||||
{"", "/api", "GET /foo", "GET /api/foo"},
|
||||
{"", "/api", "GET example.com/foo", "GET example.com/api/foo"},
|
||||
{"", "/api", "/", "/api/"},
|
||||
{"GET", "/api", "/", "GET /api/"},
|
||||
{"GET", "/api", "", "GET /api"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := buildPattern(c.method, c.prefix, c.pattern)
|
||||
if got != c.want {
|
||||
t.Errorf("buildPattern(%q, %q, %q) = %q, want %q",
|
||||
c.method, c.prefix, c.pattern, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue