allow custom splinter logger for middleware

This commit is contained in:
juancwu 2026-04-25 20:45:20 +00:00
commit f46916fdf5
2 changed files with 57 additions and 1 deletions

View file

@ -29,16 +29,38 @@ func (s *statusRecorder) Write(b []byte) (int, error) {
return s.ResponseWriter.Write(b) return s.ResponseWriter.Write(b)
} }
// Logger uses splinter.Default() resolved at request time.
func Logger(next http.Handler) http.Handler { func Logger(next http.Handler) http.Handler {
return loggerHandler(nil, next)
}
// LoggerWith returns a Logger middleware backed by the given splinter logger.
// Pass nil to fall back to splinter.Default() (equivalent to Logger).
func LoggerWith(l *splinter.Logger) Middleware {
return func(next http.Handler) http.Handler {
return loggerHandler(l, next)
}
}
func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rec, r) next.ServeHTTP(rec, r)
if l == nil {
splinter.Info("http.request", splinter.Info("http.request",
"method", r.Method, "method", r.Method,
"path", r.URL.Path, "path", r.URL.Path,
"status", rec.status, "status", rec.status,
"duration", time.Since(start), "duration", time.Since(start),
) )
return
}
l.Info("http.request",
"method", r.Method,
"path", r.URL.Path,
"status", rec.status,
"duration", time.Since(start),
)
}) })
} }

View file

@ -56,3 +56,37 @@ func TestLoggerDefaultStatusOK(t *testing.T) {
t.Errorf("expected default status 200 in log, got %q", buf.String()) t.Errorf("expected default status 200 in log, got %q", buf.String())
} }
} }
func TestLoggerWith(t *testing.T) {
defaultBuf := captureSplinter(t)
var customBuf bytes.Buffer
custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream(
splinter.ConsoleJSON,
splinter.LevelDebug,
splinter.ConsoleWriter(&customBuf),
)))
h := LoggerWith(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
}))
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil))
if !strings.Contains(customBuf.String(), `"path":"/x"`) {
t.Errorf("custom logger did not receive record: %q", customBuf.String())
}
if defaultBuf.Len() != 0 {
t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String())
}
}
func TestLoggerWithNilFallsBackToDefault(t *testing.T) {
buf := captureSplinter(t)
h := LoggerWith(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil))
if !strings.Contains(buf.String(), `"path":"/y"`) {
t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String())
}
}