add ipinfo middleware

This commit is contained in:
juancwu 2026-04-26 22:14:42 +00:00
commit 6fd78737dc
5 changed files with 278 additions and 0 deletions

85
ipinfo/ipinfo.go Normal file
View file

@ -0,0 +1,85 @@
// Package ipinfo enriches incoming requests with geolocation data from the
// ipinfo.io API and attaches the result to the request context for downstream
// handlers to consume via From.
package ipinfo
import (
"context"
"net"
"net/http"
"git.juancwu.dev/juancwu/errx"
"git.juancwu.dev/juancwu/splinter"
"github.com/ipinfo/go/v2/ipinfo"
)
const op = "ipinfo"
type ctxKey struct{}
// New returns a middleware that looks up the client IP (parsed from
// r.RemoteAddr) against the ipinfo.io API and attaches the *ipinfo.Core
// result to the request context. Pair with realip.New() upstream so the
// lookup uses the originating client IP rather than the proxy peer.
//
// Pass nil for the logger to use splinter.Default() resolved at request time.
//
// Loopback, private, link-local, and unspecified addresses are skipped so the
// upstream API quota is preserved. Lookup errors are logged at warn level but
// do not abort the request — downstream handlers should treat the context
// value as optional via From.
func New(client *ipinfo.Client, l *splinter.Logger) func(http.Handler) http.Handler {
if client == nil {
panic(errx.New(op, "client must not be nil"))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := parseClientIP(r.RemoteAddr)
if ip == nil || isLocal(ip) {
next.ServeHTTP(w, r)
return
}
info, err := client.GetIPInfo(ip)
if err != nil {
warn(l, "ipinfo.lookup_failed",
"client", ip.String(),
"err", errx.Wrap(op, err),
)
next.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{}, info)))
})
}
}
// From returns the *ipinfo.Core attached to ctx by the middleware. The second
// return value reports whether a lookup ran and produced a result.
func From(ctx context.Context) (*ipinfo.Core, bool) {
info, ok := ctx.Value(ctxKey{}).(*ipinfo.Core)
return info, ok
}
func parseClientIP(remoteAddr string) net.IP {
host := remoteAddr
if h, _, err := net.SplitHostPort(remoteAddr); err == nil {
host = h
}
return net.ParseIP(host)
}
func isLocal(ip net.IP) bool {
return ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsUnspecified()
}
func warn(l *splinter.Logger, msg string, args ...any) {
if l == nil {
splinter.Warn(msg, args...)
return
}
l.Warn(msg, args...)
}

158
ipinfo/ipinfo_test.go Normal file
View file

@ -0,0 +1,158 @@
package ipinfo
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
sdk "github.com/ipinfo/go/v2/ipinfo"
)
func newClient(t *testing.T, handler http.HandlerFunc) *sdk.Client {
t.Helper()
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
c := sdk.NewClient(srv.Client(), nil, "")
u, err := url.Parse(srv.URL + "/")
if err != nil {
t.Fatalf("parse server URL: %v", err)
}
c.BaseURL = u
return c
}
func TestNewAttachesCoreOnSuccess(t *testing.T) {
const ip = "8.8.8.8"
hits := 0
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
hits++
if got := r.URL.Path; got != "/"+ip {
t.Errorf("ipinfo path = %q, want /%s", got, ip)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ip":"` + ip + `","city":"San Francisco","country":"US"}`))
})
var seen *sdk.Core
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info, ok := From(r.Context())
if !ok {
t.Fatal("From: expected info on context")
}
seen = info
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = ip + ":54321"
h.ServeHTTP(httptest.NewRecorder(), req)
if hits != 1 {
t.Fatalf("expected 1 ipinfo call, got %d", hits)
}
if seen == nil || seen.City != "San Francisco" || seen.Country != "US" {
t.Errorf("unexpected core: %+v", seen)
}
}
func TestNewAcceptsBareIPRemoteAddr(t *testing.T) {
const ip = "8.8.8.8"
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ip":"` + ip + `","country":"US"}`))
})
var ok bool
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok = From(r.Context())
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = ip // no port — realip leaves it bare
h.ServeHTTP(httptest.NewRecorder(), req)
if !ok {
t.Error("From: expected info on context for bare IP RemoteAddr")
}
}
func TestNewSkipsLocalAddresses(t *testing.T) {
cases := []string{
"127.0.0.1:1234", // loopback
"10.0.0.1:1234", // private
"192.168.1.1:1234", // private
"169.254.0.1:1234", // link-local
"[::1]:1234", // IPv6 loopback
"[fe80::1]:1234", // IPv6 link-local
"not-an-addr", // unparseable
"", // empty
}
called := false
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
for _, ra := range cases {
t.Run(ra, func(t *testing.T) {
var ok bool
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok = From(r.Context())
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = ra
h.ServeHTTP(httptest.NewRecorder(), req)
if ok {
t.Errorf("expected no context value for %q", ra)
}
})
}
if called {
t.Error("ipinfo API should not be called for local/unparseable addresses")
}
}
func TestNewPassesThroughOnLookupError(t *testing.T) {
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
served := false
var ok bool
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
served = true
_, ok = From(r.Context())
w.WriteHeader(http.StatusOK)
}))
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "8.8.8.8:54321"
h.ServeHTTP(rr, req)
if !served {
t.Fatal("downstream handler not invoked after lookup error")
}
if ok {
t.Error("From should report no info when lookup failed")
}
if rr.Code != http.StatusOK {
t.Errorf("response status = %d, want 200", rr.Code)
}
}
func TestNewPanicsOnNilClient(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic on nil client")
}
}()
_ = New(nil, nil)
}
func TestFromEmptyContext(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if _, ok := From(req.Context()); ok {
t.Error("From: expected ok=false on empty context")
}
}