upstream with actual client ip to upstream

This commit is contained in:
juancwu 2026-01-17 20:27:28 +00:00
commit c58703f735

54
main.go
View file

@ -7,6 +7,7 @@ import (
"html/template"
"log"
"log/slog"
"net"
"net/http"
"os"
"strconv"
@ -94,9 +95,17 @@ func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg)
}
// Fallback to upstream for all other cases
resp, err := resolver.resolveUpstream(r)
var clientIP net.IP
switch addr := w.RemoteAddr().(type) {
case *net.UDPAddr:
clientIP = addr.IP
case *net.TCPAddr:
clientIP = addr.IP
}
resp, err := resolver.resolveUpstream(r, clientIP)
if err == nil && resp != nil {
log.Printf("[UPSTREAM] Forwarded %s", q.Name)
log.Printf("[UPSTREAM] Forwarded %s (ClientIP: %s)", q.Name, clientIP)
m.Answer = resp.Answer
m.Ns = resp.Ns
m.Extra = resp.Extra
@ -122,13 +131,50 @@ func (resolver *DNSResolver) getRecordFromDB(domain string, recordType string) (
return ip, nil
}
func (resolver *DNSResolver) resolveUpstream(r *dns.Msg) (*dns.Msg, error) {
func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.Msg, error) {
rows, err := resolver.db.Query("SELECT address FROM upstreams")
if err != nil {
return nil, err
}
defer rows.Close()
// Create a copy of the message to avoid side effects
req := r.Copy()
// Add EDNS0 Client Subnet option
if clientIP != nil {
o := req.IsEdns0()
if o == nil {
req.SetEdns0(4096, false)
o = req.IsEdns0()
}
// Filter out existing SUBNET options to ensure we are authoritative about the client IP
var newOptions []dns.EDNS0
for _, opt := range o.Option {
if opt.Option() != dns.EDNS0SUBNET {
newOptions = append(newOptions, opt)
}
}
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip4 := clientIP.To4(); ip4 != nil {
e.Family = 1 // IPv4
e.SourceNetmask = 32
e.SourceScope = 0
e.Address = ip4
} else {
e.Family = 2 // IPv6
e.SourceNetmask = 128
e.SourceScope = 0
e.Address = clientIP
}
newOptions = append(newOptions, e)
o.Option = newOptions
}
c := new(dns.Client)
c.Net = "udp"
@ -140,7 +186,7 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg) (*dns.Msg, error) {
continue
}
resp, _, err := c.Exchange(r, upstreamAddr)
resp, _, err := c.Exchange(req, upstreamAddr)
if err == nil && resp != nil && resp.Rcode != dns.RcodeServerFailure {
return resp, nil
}