add realip, requestlog, recoverer middlewares
Initial implementation of lightmux-contrib, a sibling module to lightmux that hosts opinionated middlewares with one sub-package per middleware: - realip: resolves the originating client IP from CF-Connecting-IP, True-Client-IP, X-Real-IP, or X-Forwarded-For. Optional peer-CIDR allowlist via netip.Prefix. - requestlog: emits a structured http.request record (method, path, status, duration, client) per request via splinter. - recoverer: catches panics, wraps with errx under op "recoverer", logs with stack, and writes a 500 response. Each package exposes a single New(...) constructor returning func(http.Handler) http.Handler. The contrib module intentionally does not import lightmux — middlewares interoperate via the standard stdlib middleware shape. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
9dc0fc5d26
commit
b26ef7439e
10 changed files with 624 additions and 0 deletions
49
README.md
49
README.md
|
|
@ -1,2 +1,51 @@
|
|||
# lightmux-contrib
|
||||
|
||||
Opinionated middleware collection for [lightmux](https://git.juancwu.dev/juancwu/lightmux). Each middleware lives in its own sub-package so consumers only pull in the dependencies they actually use.
|
||||
|
||||
## Installation
|
||||
|
||||
```sh
|
||||
go get git.juancwu.dev/juancwu/lightmux-contrib
|
||||
```
|
||||
|
||||
## Packages
|
||||
|
||||
### `realip`
|
||||
|
||||
Replaces `r.RemoteAddr` with the originating client IP from `CF-Connecting-IP`, `True-Client-IP`, `X-Real-IP`, or `X-Forwarded-For` (in that order).
|
||||
|
||||
```go
|
||||
import "git.juancwu.dev/juancwu/lightmux-contrib/realip"
|
||||
|
||||
mux.Use(realip.New()) // always trust headers
|
||||
mux.Use(realip.New(netip.MustParsePrefix("10.0.0.0/8"))) // gated by peer CIDR
|
||||
```
|
||||
|
||||
With no arguments, `realip.New()` always honors the proxy headers — only register it when the service sits behind a trusted proxy. With one or more `netip.Prefix` arguments, the headers are honored only when the immediate peer's IP falls within one of them.
|
||||
|
||||
### `requestlog`
|
||||
|
||||
Emits a structured `http.request` record (method, path, status, duration, client) per request via [splinter](https://git.juancwu.dev/juancwu/splinter).
|
||||
|
||||
```go
|
||||
import "git.juancwu.dev/juancwu/lightmux-contrib/requestlog"
|
||||
|
||||
mux.Use(requestlog.New(nil)) // splinter.Default() resolved at request time
|
||||
mux.Use(requestlog.New(custom)) // custom *splinter.Logger
|
||||
```
|
||||
|
||||
When pairing with `realip`, register `realip` first so the `client` field is the resolved client IP rather than the proxy peer:
|
||||
|
||||
```go
|
||||
mux.Use(realip.New(), requestlog.New(nil))
|
||||
```
|
||||
|
||||
### `recoverer`
|
||||
|
||||
Catches panics inside handlers, wraps the value with [errx](https://git.juancwu.dev/juancwu/errx) under op `recoverer`, logs it with stack via the standard `log` package, and writes a 500 response.
|
||||
|
||||
```go
|
||||
import "git.juancwu.dev/juancwu/lightmux-contrib/recoverer"
|
||||
|
||||
mux.Use(recoverer.New())
|
||||
```
|
||||
|
|
|
|||
43
Taskfile.yml
Normal file
43
Taskfile.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
version: '3'
|
||||
|
||||
tasks:
|
||||
default:
|
||||
desc: Run vet and tests
|
||||
cmds:
|
||||
- task: check
|
||||
|
||||
install:tools:
|
||||
desc: Install development tools (tparse for prettier test output)
|
||||
cmds:
|
||||
- go install github.com/mfridman/tparse@latest
|
||||
|
||||
test:
|
||||
desc: Run tests with prettier output via tparse
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -json -cover | tparse -all
|
||||
|
||||
test:race:
|
||||
desc: Run tests under the race detector
|
||||
cmds:
|
||||
- set -o pipefail && go test ./... -race -json | tparse -all
|
||||
|
||||
vet:
|
||||
desc: Run go vet
|
||||
cmds:
|
||||
- go vet ./...
|
||||
|
||||
fmt:
|
||||
desc: Format Go source files
|
||||
cmds:
|
||||
- gofmt -w .
|
||||
|
||||
tidy:
|
||||
desc: Tidy go.mod
|
||||
cmds:
|
||||
- go mod tidy
|
||||
|
||||
check:
|
||||
desc: Run vet and tests
|
||||
cmds:
|
||||
- task: vet
|
||||
- task: test
|
||||
8
go.mod
Normal file
8
go.mod
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
module git.juancwu.dev/juancwu/lightmux-contrib
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
git.juancwu.dev/juancwu/errx v0.1.0
|
||||
git.juancwu.dev/juancwu/splinter v0.1.0
|
||||
)
|
||||
4
go.sum
Normal file
4
go.sum
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
git.juancwu.dev/juancwu/errx v0.1.0 h1:92yA0O1BkKGXcoEiWtxwH/ztXCjoV1KSTMtKpm3gd2w=
|
||||
git.juancwu.dev/juancwu/errx v0.1.0/go.mod h1:7jNhBOwcZ/q7zDD6mln3QCJBYZ8T6h+dAdxVfykprTk=
|
||||
git.juancwu.dev/juancwu/splinter v0.1.0 h1:ZGvvzyi24hZw/yFAwpUsHtj+q+fh9I2KIGmOAILWD5Q=
|
||||
git.juancwu.dev/juancwu/splinter v0.1.0/go.mod h1:dAYsRQfS6tqWynEGz8xMCtIJUN7+KIp3jLE7kgO3yKE=
|
||||
78
realip/realip.go
Normal file
78
realip/realip.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
// Package realip resolves the originating client IP from common reverse-proxy
|
||||
// and CDN headers (Cloudflare, nginx) and replaces r.RemoteAddr so downstream
|
||||
// handlers and middlewares see the real client.
|
||||
package realip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var headers = []string{
|
||||
"CF-Connecting-IP",
|
||||
"True-Client-IP",
|
||||
"X-Real-IP",
|
||||
"X-Forwarded-For",
|
||||
}
|
||||
|
||||
// New returns a real-IP middleware.
|
||||
//
|
||||
// With no trusted prefixes, it always honors the proxy headers — only register
|
||||
// it when the service sits behind a trusted proxy.
|
||||
//
|
||||
// With one or more prefixes, the headers are honored only when the immediate
|
||||
// peer (parsed from r.RemoteAddr) falls within one of them; requests from
|
||||
// outside the allowlist pass through untouched.
|
||||
func New(trusted ...netip.Prefix) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if len(trusted) > 0 && !peerTrusted(r.RemoteAddr, trusted) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if ip := extract(r); ip != "" {
|
||||
r2 := *r
|
||||
r2.RemoteAddr = ip
|
||||
next.ServeHTTP(w, &r2)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func extract(r *http.Request) string {
|
||||
for _, h := range headers {
|
||||
v := r.Header.Get(h)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||
v = v[:i]
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if net.ParseIP(v) != nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func peerTrusted(remoteAddr string, trusted []netip.Prefix) bool {
|
||||
var peer netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(remoteAddr); err == nil {
|
||||
peer = ap.Addr()
|
||||
} else if a, err2 := netip.ParseAddr(remoteAddr); err2 == nil {
|
||||
peer = a
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
for _, p := range trusted {
|
||||
if p.Contains(peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
183
realip/realip_test.go
Normal file
183
realip/realip_test.go
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
package realip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const defaultTestRemoteAddr = "192.0.2.1:1234"
|
||||
|
||||
func captureRemoteAddr(got *string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
*got = r.RemoteAddr
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no headers",
|
||||
want: defaultTestRemoteAddr,
|
||||
},
|
||||
{
|
||||
name: "CF-Connecting-IP",
|
||||
headers: map[string]string{"CF-Connecting-IP": "203.0.113.5"},
|
||||
want: "203.0.113.5",
|
||||
},
|
||||
{
|
||||
name: "True-Client-IP",
|
||||
headers: map[string]string{"True-Client-IP": "203.0.113.6"},
|
||||
want: "203.0.113.6",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
headers: map[string]string{"X-Real-IP": "203.0.113.7"},
|
||||
want: "203.0.113.7",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For single",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.8"},
|
||||
want: "203.0.113.8",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For list",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.9, 10.0.0.1, 10.0.0.2"},
|
||||
want: "203.0.113.9",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For with spaces",
|
||||
headers: map[string]string{"X-Forwarded-For": " 203.0.113.10 , 10.0.0.1"},
|
||||
want: "203.0.113.10",
|
||||
},
|
||||
{
|
||||
name: "precedence CF over XFF",
|
||||
headers: map[string]string{
|
||||
"CF-Connecting-IP": "203.0.113.11",
|
||||
"X-Forwarded-For": "198.51.100.1",
|
||||
},
|
||||
want: "203.0.113.11",
|
||||
},
|
||||
{
|
||||
name: "invalid then valid",
|
||||
headers: map[string]string{
|
||||
"CF-Connecting-IP": "not-an-ip",
|
||||
"X-Real-IP": "203.0.113.12",
|
||||
},
|
||||
want: "203.0.113.12",
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
headers: map[string]string{"X-Real-IP": "2001:db8::1"},
|
||||
want: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
name: "all invalid falls through",
|
||||
headers: map[string]string{"CF-Connecting-IP": "garbage"},
|
||||
want: defaultTestRemoteAddr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got string
|
||||
h := New()(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
for k, v := range tc.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("r.RemoteAddr = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDoesNotMutateCallerRequest(t *testing.T) {
|
||||
var seen string
|
||||
h := New()(captureRemoteAddr(&seen))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Real-IP", "203.0.113.13")
|
||||
original := req.RemoteAddr
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if seen != "203.0.113.13" {
|
||||
t.Errorf("handler saw r.RemoteAddr = %q, want %q", seen, "203.0.113.13")
|
||||
}
|
||||
if req.RemoteAddr != original {
|
||||
t.Errorf("caller's request was mutated: RemoteAddr = %q, want %q", req.RemoteAddr, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithTrustedPeer(t *testing.T) {
|
||||
var got string
|
||||
h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.1.2.3:55555"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.20")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.20" {
|
||||
t.Errorf("trusted peer header not honored: r.RemoteAddr = %q, want %q", got, "203.0.113.20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithUntrustedPeer(t *testing.T) {
|
||||
var got string
|
||||
h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "203.0.113.99:55555"
|
||||
req.Header.Set("X-Real-IP", "198.51.100.1")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.99:55555" {
|
||||
t.Errorf("untrusted peer should not have header honored: r.RemoteAddr = %q, want %q", got, "203.0.113.99:55555")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithMultiplePrefixesIPv6(t *testing.T) {
|
||||
var got string
|
||||
h := New(
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
)(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "[2001:db8::abcd]:55555"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.30")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "203.0.113.30" {
|
||||
t.Errorf("IPv6 peer match should honor header: r.RemoteAddr = %q, want %q", got, "203.0.113.30")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithUnparseableRemoteAddr(t *testing.T) {
|
||||
var got string
|
||||
h := New(netip.MustParsePrefix("10.0.0.0/8"))(captureRemoteAddr(&got))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "not-an-addr"
|
||||
req.Header.Set("CF-Connecting-IP", "203.0.113.50")
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
if got != "not-an-addr" {
|
||||
t.Errorf("unparseable RemoteAddr should pass through unchanged: got %q", got)
|
||||
}
|
||||
}
|
||||
36
recoverer/recoverer.go
Normal file
36
recoverer/recoverer.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Package recoverer catches panics inside HTTP handlers, logs them with stack
|
||||
// trace, and writes a 500 response.
|
||||
package recoverer
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"git.juancwu.dev/juancwu/errx"
|
||||
)
|
||||
|
||||
const op = "recoverer"
|
||||
|
||||
// New returns a panic-recovery middleware.
|
||||
func New() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
rec := recover()
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if e, ok := rec.(error); ok {
|
||||
err = errx.Wrap(op, e)
|
||||
} else {
|
||||
err = errx.Newf(op, "panic: %v", rec)
|
||||
}
|
||||
log.Printf("%v\n%s", err, debug.Stack())
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
70
recoverer/recoverer_test.go
Normal file
70
recoverer/recoverer_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package recoverer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func captureLog(t *testing.T) *bytes.Buffer {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
orig := log.Default().Writer()
|
||||
log.Default().SetOutput(&buf)
|
||||
t.Cleanup(func() { log.Default().SetOutput(orig) })
|
||||
return &buf
|
||||
}
|
||||
|
||||
func TestNewCatchesStringPanic(t *testing.T) {
|
||||
buf := captureLog(t)
|
||||
|
||||
h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("boom")
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want 500", rr.Code)
|
||||
}
|
||||
out := buf.String()
|
||||
for _, want := range []string{"recoverer", "panic: boom"} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Errorf("log missing %q\nfull: %s", want, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWrapsErrorPanic(t *testing.T) {
|
||||
buf := captureLog(t)
|
||||
|
||||
cause := errors.New("db down")
|
||||
h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(cause)
|
||||
}))
|
||||
|
||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "recoverer: db down") {
|
||||
t.Errorf("expected errx-wrapped breadcrumb, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPassesThrough(t *testing.T) {
|
||||
called := false
|
||||
h := New()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
if !called || rr.Code != http.StatusOK {
|
||||
t.Errorf("non-panic path broken: called=%v code=%d", called, rr.Code)
|
||||
}
|
||||
}
|
||||
61
requestlog/requestlog.go
Normal file
61
requestlog/requestlog.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
// Package requestlog emits a structured "http.request" record per request
|
||||
// (method, path, status, duration, client) via splinter.
|
||||
package requestlog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.juancwu.dev/juancwu/splinter"
|
||||
)
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
wrote bool
|
||||
}
|
||||
|
||||
func (s *statusRecorder) WriteHeader(code int) {
|
||||
if !s.wrote {
|
||||
s.status = code
|
||||
s.wrote = true
|
||||
}
|
||||
s.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (s *statusRecorder) Write(b []byte) (int, error) {
|
||||
if !s.wrote {
|
||||
s.status = http.StatusOK
|
||||
s.wrote = true
|
||||
}
|
||||
return s.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// New returns a request-logging middleware. Pass nil to use splinter.Default()
|
||||
// resolved at request time; otherwise records flow through the supplied logger.
|
||||
func New(l *splinter.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
if l == nil {
|
||||
splinter.Info("http.request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", rec.status,
|
||||
"duration", time.Since(start),
|
||||
"client", r.RemoteAddr,
|
||||
)
|
||||
return
|
||||
}
|
||||
l.Info("http.request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", rec.status,
|
||||
"duration", time.Since(start),
|
||||
"client", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
92
requestlog/requestlog_test.go
Normal file
92
requestlog/requestlog_test.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package requestlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.juancwu.dev/juancwu/splinter"
|
||||
)
|
||||
|
||||
func captureSplinter(t *testing.T) *bytes.Buffer {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
logger := splinter.New(splinter.WithStream(splinter.NewConsoleStream(
|
||||
splinter.ConsoleJSON,
|
||||
splinter.LevelDebug,
|
||||
splinter.ConsoleWriter(&buf),
|
||||
)))
|
||||
prev := splinter.SetDefault(logger)
|
||||
t.Cleanup(func() { splinter.SetDefault(prev) })
|
||||
return &buf
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
buf := captureSplinter(t)
|
||||
|
||||
h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
}))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/foo", nil))
|
||||
|
||||
if rr.Code != http.StatusTeapot {
|
||||
t.Errorf("status code = %d, want 418", rr.Code)
|
||||
}
|
||||
out := buf.String()
|
||||
for _, want := range []string{`"method":"GET"`, `"path":"/foo"`, `"status":418`, `"client":"192.0.2.1:1234"`} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Errorf("log output missing %s\nfull output: %s", want, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDefaultStatusOK(t *testing.T) {
|
||||
buf := captureSplinter(t)
|
||||
|
||||
h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hi"))
|
||||
}))
|
||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
if !strings.Contains(buf.String(), `"status":200`) {
|
||||
t.Errorf("expected default status 200 in log, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithCustomLogger(t *testing.T) {
|
||||
defaultBuf := captureSplinter(t)
|
||||
|
||||
var customBuf bytes.Buffer
|
||||
custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream(
|
||||
splinter.ConsoleJSON,
|
||||
splinter.LevelDebug,
|
||||
splinter.ConsoleWriter(&customBuf),
|
||||
)))
|
||||
|
||||
h := New(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}))
|
||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil))
|
||||
|
||||
if !strings.Contains(customBuf.String(), `"path":"/x"`) {
|
||||
t.Errorf("custom logger did not receive record: %q", customBuf.String())
|
||||
}
|
||||
if defaultBuf.Len() != 0 {
|
||||
t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNilFallsBackToDefault(t *testing.T) {
|
||||
buf := captureSplinter(t)
|
||||
|
||||
h := New(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil))
|
||||
|
||||
if !strings.Contains(buf.String(), `"path":"/y"`) {
|
||||
t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String())
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue