initial implementation

This commit is contained in:
juancwu 2026-04-25 20:26:07 +00:00
commit cb373e637b
16 changed files with 777 additions and 1 deletions

3
.gitignore vendored
View file

@ -24,4 +24,5 @@ go.work.sum
# env file
.env
.env.*
!.env.example

43
Taskfile.yml Normal file
View file

@ -0,0 +1,43 @@
version: '3'
tasks:
default:
desc: Run vet and tests
cmds:
- task: check
install:tools:
desc: Install development tools (tparse for prettier test output)
cmds:
- go install github.com/mfridman/tparse@latest
test:
desc: Run tests with prettier output via tparse
cmds:
- set -o pipefail && go test ./... -json -cover | tparse -all
test:race:
desc: Run tests under the race detector
cmds:
- set -o pipefail && go test ./... -race -json | tparse -all
vet:
desc: Run go vet
cmds:
- go vet ./...
fmt:
desc: Format Go source files
cmds:
- gofmt -w .
tidy:
desc: Tidy go.mod
cmds:
- go mod tidy
check:
desc: Run vet and tests
cmds:
- task: vet
- task: test

46
examples/basic/main.go Normal file
View file

@ -0,0 +1,46 @@
package main
import (
"fmt"
"log"
"net/http"
"strconv"
"git.juancwu.dev/juancwu/lightmux"
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
)
func main() {
mux := lightmux.New()
mux.Use(middleware.Recoverer, middleware.Logger)
mux.Get("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello from lightmux")
})
mux.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
id, err := strconv.Atoi(r.PathValue("id"))
if err != nil {
http.Error(w, "bad id", http.StatusBadRequest)
return
}
fmt.Fprintf(w, "user %d\n", id)
})
mux.Get("/panic", func(w http.ResponseWriter, r *http.Request) {
panic("demonstrating Recoverer")
})
api := mux.Group("/api")
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "pong")
})
v1 := api.Group("/v1")
v1.Get("/items/{name}", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "item: %s\n", r.PathValue("name"))
})
log.Println("listening on :8080")
log.Fatal(http.ListenAndServe(":8080", mux))
}

3
go.mod Normal file
View file

@ -0,0 +1,3 @@
module git.juancwu.dev/juancwu/lightmux
go 1.26.2

15
lightmux.go Normal file
View file

@ -0,0 +1,15 @@
// Package lightmux is a small wrapper around the Go 1.22+ net/http ServeMux
// adding method-named convenience methods, groups, and per-route middleware.
package lightmux
import (
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
"git.juancwu.dev/juancwu/lightmux/pkg/router"
)
type (
Mux = router.Mux
Middleware = middleware.Middleware
)
func New() *Mux { return router.New() }

38
pkg/middleware/logger.go Normal file
View file

@ -0,0 +1,38 @@
package middleware
import (
"log"
"net/http"
"time"
)
type statusRecorder struct {
http.ResponseWriter
status int
wrote bool
}
func (s *statusRecorder) WriteHeader(code int) {
if !s.wrote {
s.status = code
s.wrote = true
}
s.ResponseWriter.WriteHeader(code)
}
func (s *statusRecorder) Write(b []byte) (int, error) {
if !s.wrote {
s.status = http.StatusOK
s.wrote = true
}
return s.ResponseWriter.Write(b)
}
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rec, r)
log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.status, time.Since(start))
})
}

View file

@ -0,0 +1,48 @@
package middleware
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestLogger(t *testing.T) {
var buf bytes.Buffer
orig := log.Default().Writer()
log.Default().SetOutput(&buf)
defer log.Default().SetOutput(orig)
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}))
rr := httptest.NewRecorder()
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil))
if rr.Code != http.StatusTeapot {
t.Errorf("status code = %d, want 418", rr.Code)
}
out := buf.String()
if !strings.Contains(out, "GET /foo 418") {
t.Errorf("log output missing expected fields: %q", out)
}
}
func TestLoggerDefaultStatusOK(t *testing.T) {
var buf bytes.Buffer
orig := log.Default().Writer()
log.Default().SetOutput(&buf)
defer log.Default().SetOutput(orig)
h := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hi"))
}))
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
if !strings.Contains(buf.String(), "200") {
t.Errorf("expected default 200 in log, got %q", buf.String())
}
}

View file

@ -0,0 +1,5 @@
package middleware
import "net/http"
type Middleware = func(http.Handler) http.Handler

View file

@ -0,0 +1,19 @@
package middleware
import (
"log"
"net/http"
"runtime/debug"
)
func Recoverer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
log.Printf("panic: %v\n%s", rec, debug.Stack())
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,44 @@
package middleware
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRecovererCatchesPanic(t *testing.T) {
var buf bytes.Buffer
orig := log.Default().Writer()
log.Default().SetOutput(&buf)
defer log.Default().SetOutput(orig)
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("boom")
}))
rr := httptest.NewRecorder()
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if rr.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", rr.Code)
}
if !strings.Contains(buf.String(), "panic: boom") {
t.Errorf("expected panic log, got %q", buf.String())
}
}
func TestRecovererPassesThrough(t *testing.T) {
called := false
h := Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
if !called || rr.Code != http.StatusOK {
t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code)
}
}

19
pkg/router/chain.go Normal file
View file

@ -0,0 +1,19 @@
package router
import (
"net/http"
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
)
// chain wraps h with groupMws followed by routeMws so that groupMws[0] is the
// outermost layer (runs first on request, last on response).
func chain(h http.Handler, groupMws, routeMws []middleware.Middleware) http.Handler {
all := make([]middleware.Middleware, 0, len(groupMws)+len(routeMws))
all = append(all, groupMws...)
all = append(all, routeMws...)
for i := len(all) - 1; i >= 0; i-- {
h = all[i](h)
}
return h
}

50
pkg/router/chain_test.go Normal file
View file

@ -0,0 +1,50 @@
package router
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
)
func tagMW(log *[]string, tag string) middleware.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
*log = append(*log, tag+":before")
next.ServeHTTP(w, r)
*log = append(*log, tag+":after")
})
}
}
func TestChainOrder(t *testing.T) {
var log []string
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log = append(log, "handler")
})
wrapped := chain(h,
[]middleware.Middleware{tagMW(&log, "g1"), tagMW(&log, "g2")},
[]middleware.Middleware{tagMW(&log, "r1"), tagMW(&log, "r2")},
)
req := httptest.NewRequest(http.MethodGet, "/", nil)
wrapped.ServeHTTP(httptest.NewRecorder(), req)
want := "g1:before,g2:before,r1:before,r2:before,handler,r2:after,r1:after,g2:after,g1:after"
if got := strings.Join(log, ","); got != want {
t.Errorf("order:\n got %s\nwant %s", got, want)
}
}
func TestChainNoMiddlewares(t *testing.T) {
called := false
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true })
wrapped := chain(h, nil, nil)
wrapped.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
if !called {
t.Fatal("handler not called")
}
}

87
pkg/router/mux.go Normal file
View file

@ -0,0 +1,87 @@
// Package router provides a thin, idiomatic wrapper around the Go 1.22+
// net/http ServeMux. It adds method-named convenience methods (Get, Post,
// ...), per-route and group middleware, and Group sub-routers that share the
// underlying mux while carrying their own prefix and middleware stack.
package router
import (
"net/http"
"git.juancwu.dev/juancwu/lightmux/pkg/middleware"
)
type Mux struct {
root *http.ServeMux
prefix string
middlewares []middleware.Middleware
}
func New() *Mux {
sm := http.NewServeMux()
return &Mux{root: sm}
}
func (m *Mux) Use(mws ...middleware.Middleware) {
m.middlewares = append(m.middlewares, mws...)
}
// Group returns a child Mux that registers on the same underlying ServeMux but
// with its prefix appended and the parent's current middlewares snapshotted.
// Use() calls made on the parent after Group() do not propagate to the child.
func (m *Mux) Group(prefix string, mws ...middleware.Middleware) *Mux {
validateGroupPrefix(prefix)
mwsCopy := make([]middleware.Middleware, 0, len(m.middlewares)+len(mws))
mwsCopy = append(mwsCopy, m.middlewares...)
mwsCopy = append(mwsCopy, mws...)
return &Mux{
root: m.root,
prefix: m.prefix + normalizeGroupPrefix(prefix),
middlewares: mwsCopy,
}
}
func (m *Mux) Handle(pattern string, h http.Handler, mws ...middleware.Middleware) {
full := buildPattern("", m.prefix, pattern)
m.root.Handle(full, chain(h, m.middlewares, mws))
}
func (m *Mux) HandleFunc(pattern string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.Handle(pattern, fn, mws...)
}
func (m *Mux) method(method, path string, fn http.HandlerFunc, mws []middleware.Middleware) {
full := buildPattern(method, m.prefix, path)
m.root.Handle(full, chain(fn, m.middlewares, mws))
}
func (m *Mux) Get(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodGet, path, fn, mws)
}
func (m *Mux) Post(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodPost, path, fn, mws)
}
func (m *Mux) Put(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodPut, path, fn, mws)
}
func (m *Mux) Patch(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodPatch, path, fn, mws)
}
func (m *Mux) Delete(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodDelete, path, fn, mws)
}
func (m *Mux) Options(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodOptions, path, fn, mws)
}
func (m *Mux) Head(path string, fn http.HandlerFunc, mws ...middleware.Middleware) {
m.method(http.MethodHead, path, fn, mws)
}
func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.root.ServeHTTP(w, r)
}

169
pkg/router/mux_test.go Normal file
View file

@ -0,0 +1,169 @@
package router
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func do(t *testing.T, m *Mux, method, target string) *httptest.ResponseRecorder {
t.Helper()
rr := httptest.NewRecorder()
m.ServeHTTP(rr, httptest.NewRequest(method, target, nil))
return rr
}
func TestMethodRouting(t *testing.T) {
m := New()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "get") })
m.Post("/x", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "post") })
if got := do(t, m, http.MethodGet, "/x").Body.String(); got != "get" {
t.Errorf("GET /x = %q", got)
}
if got := do(t, m, http.MethodPost, "/x").Body.String(); got != "post" {
t.Errorf("POST /x = %q", got)
}
if rr := do(t, m, http.MethodPut, "/x"); rr.Code != http.StatusMethodNotAllowed {
t.Errorf("PUT /x got %d, want 405", rr.Code)
}
}
func TestPathValueAcrossMiddleware(t *testing.T) {
m := New()
m.Use(func(next http.Handler) http.Handler { return next })
var got string
m.Get("/users/{id}", func(w http.ResponseWriter, r *http.Request) {
got = r.PathValue("id")
})
do(t, m, http.MethodGet, "/users/42")
if got != "42" {
t.Errorf("PathValue = %q, want 42", got)
}
}
func TestGroupPrefix(t *testing.T) {
m := New()
api := m.Group("/api")
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
t.Errorf("/api/ping = %q", got)
}
if rr := do(t, m, http.MethodGet, "/ping"); rr.Code != http.StatusNotFound {
t.Errorf("unprefixed /ping got %d, want 404", rr.Code)
}
}
func TestGroupNested(t *testing.T) {
m := New()
api := m.Group("/api")
v1 := api.Group("/v1")
v1.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/v1/ping").Body.String(); got != "pong" {
t.Errorf("/api/v1/ping = %q", got)
}
}
func TestGroupTrailingSlashNormalized(t *testing.T) {
m := New()
api := m.Group("/api/")
api.Get("/ping", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "pong") })
if got := do(t, m, http.MethodGet, "/api/ping").Body.String(); got != "pong" {
t.Errorf("/api/ping = %q", got)
}
}
func TestGroupRootSubtree(t *testing.T) {
m := New()
api := m.Group("/api")
api.Get("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "subtree") })
if got := do(t, m, http.MethodGet, "/api/").Body.String(); got != "subtree" {
t.Errorf("/api/ = %q", got)
}
if got := do(t, m, http.MethodGet, "/api/anything").Body.String(); got != "subtree" {
t.Errorf("/api/anything = %q", got)
}
}
func TestUseAfterGroupDoesNotPropagate(t *testing.T) {
m := New()
api := m.Group("/api")
hits := 0
m.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
next.ServeHTTP(w, r)
})
})
api.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
m.Get("/y", func(w http.ResponseWriter, r *http.Request) {})
do(t, m, http.MethodGet, "/api/x")
if hits != 0 {
t.Errorf("late Use propagated to api group, hits=%d", hits)
}
do(t, m, http.MethodGet, "/y")
if hits != 1 {
t.Errorf("late Use did not apply to root, hits=%d", hits)
}
}
func TestPerRouteAndGroupMiddlewareStack(t *testing.T) {
m := New()
var log []string
m.Use(tagMW(&log, "use"))
api := m.Group("/api", tagMW(&log, "group"))
api.Get("/x",
func(w http.ResponseWriter, r *http.Request) { log = append(log, "h") },
tagMW(&log, "route"),
)
do(t, m, http.MethodGet, "/api/x")
want := "use:before,group:before,route:before,h,route:after,group:after,use:after"
if got := strings.Join(log, ","); got != want {
t.Errorf("\n got %s\nwant %s", got, want)
}
}
func TestHandleAcceptsBarePath(t *testing.T) {
m := New()
m.Handle("/x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "any")
}))
for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodDelete} {
if got := do(t, m, method, "/x").Body.String(); got != "any" {
t.Errorf("%s /x = %q", method, got)
}
}
}
func TestHandleAcceptsMethodPattern(t *testing.T) {
m := New()
m.Handle("DELETE /x", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "del")
}))
if got := do(t, m, http.MethodDelete, "/x").Body.String(); got != "del" {
t.Errorf("DELETE /x = %q", got)
}
if rr := do(t, m, http.MethodGet, "/x"); rr.Code != http.StatusMethodNotAllowed {
t.Errorf("GET /x got %d, want 405", rr.Code)
}
}
func TestConflictingPatternsPanic(t *testing.T) {
m := New()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on duplicate pattern")
}
}()
m.Get("/x", func(w http.ResponseWriter, r *http.Request) {})
}

71
pkg/router/pattern.go Normal file
View file

@ -0,0 +1,71 @@
package router
import "strings"
func splitPattern(pattern string) (method, host, path string) {
rest := pattern
if i := strings.Index(pattern, " "); i >= 0 {
method = pattern[:i]
rest = strings.TrimLeft(pattern[i+1:], " ")
}
if strings.HasPrefix(rest, "/") {
return method, "", rest
}
if i := strings.Index(rest, "/"); i >= 0 {
return method, rest[:i], rest[i:]
}
return method, rest, ""
}
func validateGroupPrefix(p string) {
if p == "" {
return
}
if strings.ContainsAny(p, " \t") {
panic("lightmux: group prefix must not contain whitespace (no method or host allowed): " + p)
}
if !strings.HasPrefix(p, "/") {
panic("lightmux: group prefix must start with '/': " + p)
}
}
func normalizeGroupPrefix(p string) string {
if p == "" {
return ""
}
if len(p) > 1 && strings.HasSuffix(p, "/") {
return p[:len(p)-1]
}
if p == "/" {
return ""
}
return p
}
func joinPath(prefix, sub string) string {
if prefix == "" {
return sub
}
if sub == "" {
return prefix
}
if sub == "/" {
return prefix + "/"
}
if !strings.HasPrefix(sub, "/") {
panic("lightmux: route path must start with '/': " + sub)
}
return prefix + sub
}
func buildPattern(method, prefix, pattern string) string {
m, host, path := splitPattern(pattern)
if method != "" {
m = method
}
full := host + joinPath(prefix, path)
if m != "" {
return m + " " + full
}
return full
}

118
pkg/router/pattern_test.go Normal file
View file

@ -0,0 +1,118 @@
package router
import "testing"
func TestSplitPattern(t *testing.T) {
cases := []struct {
in string
method, host, path string
}{
{"/foo", "", "", "/foo"},
{"GET /foo", "GET", "", "/foo"},
{"POST /users/{id}", "POST", "", "/users/{id}"},
{"GET example.com/foo", "GET", "example.com", "/foo"},
{"example.com/foo", "", "example.com", "/foo"},
{"GET /", "GET", "", "/"},
{"/", "", "", "/"},
}
for _, c := range cases {
t.Run(c.in, func(t *testing.T) {
m, h, p := splitPattern(c.in)
if m != c.method || h != c.host || p != c.path {
t.Fatalf("splitPattern(%q) = (%q, %q, %q), want (%q, %q, %q)",
c.in, m, h, p, c.method, c.host, c.path)
}
})
}
}
func TestNormalizeGroupPrefix(t *testing.T) {
cases := map[string]string{
"": "",
"/": "",
"/api": "/api",
"/api/": "/api",
"/api/v1": "/api/v1",
"/api/v1/": "/api/v1",
}
for in, want := range cases {
if got := normalizeGroupPrefix(in); got != want {
t.Errorf("normalizeGroupPrefix(%q) = %q, want %q", in, got, want)
}
}
}
func TestJoinPath(t *testing.T) {
cases := []struct {
prefix, sub, want string
}{
{"", "/foo", "/foo"},
{"/api", "/foo", "/api/foo"},
{"/api", "", "/api"},
{"/api", "/", "/api/"},
{"", "", ""},
{"/api/v1", "/users/{id}", "/api/v1/users/{id}"},
}
for _, c := range cases {
got := joinPath(c.prefix, c.sub)
if got != c.want {
t.Errorf("joinPath(%q, %q) = %q, want %q", c.prefix, c.sub, got, c.want)
}
}
}
func TestJoinPathPanicsOnBadSub(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on sub without leading /")
}
}()
joinPath("/api", "foo")
}
func TestValidateGroupPrefix(t *testing.T) {
good := []string{"", "/api", "/api/v1"}
for _, p := range good {
func() {
defer func() {
if r := recover(); r != nil {
t.Errorf("validateGroupPrefix(%q) unexpected panic: %v", p, r)
}
}()
validateGroupPrefix(p)
}()
}
bad := []string{"api", "GET /api", "/api with space", "host.com/api"}
for _, p := range bad {
func() {
defer func() {
if r := recover(); r == nil {
t.Errorf("validateGroupPrefix(%q) expected panic", p)
}
}()
validateGroupPrefix(p)
}()
}
}
func TestBuildPattern(t *testing.T) {
cases := []struct {
method, prefix, pattern, want string
}{
{"", "", "/foo", "/foo"},
{"GET", "", "/foo", "GET /foo"},
{"GET", "/api", "/foo", "GET /api/foo"},
{"", "/api", "GET /foo", "GET /api/foo"},
{"", "/api", "GET example.com/foo", "GET example.com/api/foo"},
{"", "/api", "/", "/api/"},
{"GET", "/api", "/", "GET /api/"},
{"GET", "/api", "", "GET /api"},
}
for _, c := range cases {
got := buildPattern(c.method, c.prefix, c.pattern)
if got != c.want {
t.Errorf("buildPattern(%q, %q, %q) = %q, want %q",
c.method, c.prefix, c.pattern, got, c.want)
}
}
}