158 lines
4 KiB
Go
158 lines
4 KiB
Go
package ipinfo
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
sdk "github.com/ipinfo/go/v2/ipinfo"
|
|
)
|
|
|
|
func newClient(t *testing.T, handler http.HandlerFunc) *sdk.Client {
|
|
t.Helper()
|
|
srv := httptest.NewServer(handler)
|
|
t.Cleanup(srv.Close)
|
|
c := sdk.NewClient(srv.Client(), nil, "")
|
|
u, err := url.Parse(srv.URL + "/")
|
|
if err != nil {
|
|
t.Fatalf("parse server URL: %v", err)
|
|
}
|
|
c.BaseURL = u
|
|
return c
|
|
}
|
|
|
|
func TestNewAttachesCoreOnSuccess(t *testing.T) {
|
|
const ip = "8.8.8.8"
|
|
hits := 0
|
|
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
hits++
|
|
if got := r.URL.Path; got != "/"+ip {
|
|
t.Errorf("ipinfo path = %q, want /%s", got, ip)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ip":"` + ip + `","city":"San Francisco","country":"US"}`))
|
|
})
|
|
|
|
var seen *sdk.Core
|
|
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
info, ok := From(r.Context())
|
|
if !ok {
|
|
t.Fatal("From: expected info on context")
|
|
}
|
|
seen = info
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = ip + ":54321"
|
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
|
|
if hits != 1 {
|
|
t.Fatalf("expected 1 ipinfo call, got %d", hits)
|
|
}
|
|
if seen == nil || seen.City != "San Francisco" || seen.Country != "US" {
|
|
t.Errorf("unexpected core: %+v", seen)
|
|
}
|
|
}
|
|
|
|
func TestNewAcceptsBareIPRemoteAddr(t *testing.T) {
|
|
const ip = "8.8.8.8"
|
|
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ip":"` + ip + `","country":"US"}`))
|
|
})
|
|
|
|
var ok bool
|
|
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, ok = From(r.Context())
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = ip // no port — realip leaves it bare
|
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
|
|
if !ok {
|
|
t.Error("From: expected info on context for bare IP RemoteAddr")
|
|
}
|
|
}
|
|
|
|
func TestNewSkipsLocalAddresses(t *testing.T) {
|
|
cases := []string{
|
|
"127.0.0.1:1234", // loopback
|
|
"10.0.0.1:1234", // private
|
|
"192.168.1.1:1234", // private
|
|
"169.254.0.1:1234", // link-local
|
|
"[::1]:1234", // IPv6 loopback
|
|
"[fe80::1]:1234", // IPv6 link-local
|
|
"not-an-addr", // unparseable
|
|
"", // empty
|
|
}
|
|
called := false
|
|
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
for _, ra := range cases {
|
|
t.Run(ra, func(t *testing.T) {
|
|
var ok bool
|
|
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, ok = From(r.Context())
|
|
}))
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = ra
|
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
|
if ok {
|
|
t.Errorf("expected no context value for %q", ra)
|
|
}
|
|
})
|
|
}
|
|
if called {
|
|
t.Error("ipinfo API should not be called for local/unparseable addresses")
|
|
}
|
|
}
|
|
|
|
func TestNewPassesThroughOnLookupError(t *testing.T) {
|
|
client := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
})
|
|
|
|
served := false
|
|
var ok bool
|
|
h := New(client, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
served = true
|
|
_, ok = From(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
rr := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "8.8.8.8:54321"
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if !served {
|
|
t.Fatal("downstream handler not invoked after lookup error")
|
|
}
|
|
if ok {
|
|
t.Error("From should report no info when lookup failed")
|
|
}
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("response status = %d, want 200", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestNewPanicsOnNilClient(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Error("expected panic on nil client")
|
|
}
|
|
}()
|
|
_ = New(nil, nil)
|
|
}
|
|
|
|
func TestFromEmptyContext(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if _, ok := From(req.Context()); ok {
|
|
t.Error("From: expected ok=false on empty context")
|
|
}
|
|
}
|