lightmux-contrib/ipinfo/ipinfo.go
2026-04-26 22:14:42 +00:00

85 lines
2.4 KiB
Go

// Package ipinfo enriches incoming requests with geolocation data from the
// ipinfo.io API and attaches the result to the request context for downstream
// handlers to consume via From.
package ipinfo
import (
"context"
"net"
"net/http"
"git.juancwu.dev/juancwu/errx"
"git.juancwu.dev/juancwu/splinter"
"github.com/ipinfo/go/v2/ipinfo"
)
const op = "ipinfo"
type ctxKey struct{}
// New returns a middleware that looks up the client IP (parsed from
// r.RemoteAddr) against the ipinfo.io API and attaches the *ipinfo.Core
// result to the request context. Pair with realip.New() upstream so the
// lookup uses the originating client IP rather than the proxy peer.
//
// Pass nil for the logger to use splinter.Default() resolved at request time.
//
// Loopback, private, link-local, and unspecified addresses are skipped so the
// upstream API quota is preserved. Lookup errors are logged at warn level but
// do not abort the request — downstream handlers should treat the context
// value as optional via From.
func New(client *ipinfo.Client, l *splinter.Logger) func(http.Handler) http.Handler {
if client == nil {
panic(errx.New(op, "client must not be nil"))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := parseClientIP(r.RemoteAddr)
if ip == nil || isLocal(ip) {
next.ServeHTTP(w, r)
return
}
info, err := client.GetIPInfo(ip)
if err != nil {
warn(l, "ipinfo.lookup_failed",
"client", ip.String(),
"err", errx.Wrap(op, err),
)
next.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{}, info)))
})
}
}
// From returns the *ipinfo.Core attached to ctx by the middleware. The second
// return value reports whether a lookup ran and produced a result.
func From(ctx context.Context) (*ipinfo.Core, bool) {
info, ok := ctx.Value(ctxKey{}).(*ipinfo.Core)
return info, ok
}
func parseClientIP(remoteAddr string) net.IP {
host := remoteAddr
if h, _, err := net.SplitHostPort(remoteAddr); err == nil {
host = h
}
return net.ParseIP(host)
}
func isLocal(ip net.IP) bool {
return ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsUnspecified()
}
func warn(l *splinter.Logger, msg string, args ...any) {
if l == nil {
splinter.Warn(msg, args...)
return
}
l.Warn(msg, args...)
}