diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go index ced4a20..4628a5e 100644 --- a/pkg/middleware/logger.go +++ b/pkg/middleware/logger.go @@ -29,12 +29,34 @@ func (s *statusRecorder) Write(b []byte) (int, error) { return s.ResponseWriter.Write(b) } +// Logger uses splinter.Default() resolved at request time. func Logger(next http.Handler) http.Handler { + return loggerHandler(nil, next) +} + +// LoggerWith returns a Logger middleware backed by the given splinter logger. +// Pass nil to fall back to splinter.Default() (equivalent to Logger). +func LoggerWith(l *splinter.Logger) Middleware { + return func(next http.Handler) http.Handler { + return loggerHandler(l, next) + } +} + +func loggerHandler(l *splinter.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(rec, r) - splinter.Info("http.request", + if l == nil { + splinter.Info("http.request", + "method", r.Method, + "path", r.URL.Path, + "status", rec.status, + "duration", time.Since(start), + ) + return + } + l.Info("http.request", "method", r.Method, "path", r.URL.Path, "status", rec.status, diff --git a/pkg/middleware/logger_test.go b/pkg/middleware/logger_test.go index 37f251b..3e5ffa5 100644 --- a/pkg/middleware/logger_test.go +++ b/pkg/middleware/logger_test.go @@ -56,3 +56,37 @@ func TestLoggerDefaultStatusOK(t *testing.T) { t.Errorf("expected default status 200 in log, got %q", buf.String()) } } + +func TestLoggerWith(t *testing.T) { + defaultBuf := captureSplinter(t) + + var customBuf bytes.Buffer + custom := splinter.New(splinter.WithStream(splinter.NewConsoleStream( + splinter.ConsoleJSON, + splinter.LevelDebug, + splinter.ConsoleWriter(&customBuf), + ))) + + h := LoggerWith(custom)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/x", nil)) + + if !strings.Contains(customBuf.String(), `"path":"/x"`) { + t.Errorf("custom logger did not receive record: %q", customBuf.String()) + } + if defaultBuf.Len() != 0 { + t.Errorf("default logger should not have been written to, got: %q", defaultBuf.String()) + } +} + +func TestLoggerWithNilFallsBackToDefault(t *testing.T) { + buf := captureSplinter(t) + + h := LoggerWith(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/y", nil)) + + if !strings.Contains(buf.String(), `"path":"/y"`) { + t.Errorf("nil logger should fall back to splinter.Default(): %q", buf.String()) + } +}