diff --git a/main.go b/main.go index 6787a0a..83aa221 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "html/template" "log" "log/slog" + "net" "net/http" "os" "strconv" @@ -94,9 +95,17 @@ func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) } // Fallback to upstream for all other cases - resp, err := resolver.resolveUpstream(r) + var clientIP net.IP + switch addr := w.RemoteAddr().(type) { + case *net.UDPAddr: + clientIP = addr.IP + case *net.TCPAddr: + clientIP = addr.IP + } + + resp, err := resolver.resolveUpstream(r, clientIP) if err == nil && resp != nil { - log.Printf("[UPSTREAM] Forwarded %s", q.Name) + log.Printf("[UPSTREAM] Forwarded %s (ClientIP: %s)", q.Name, clientIP) m.Answer = resp.Answer m.Ns = resp.Ns m.Extra = resp.Extra @@ -122,13 +131,50 @@ func (resolver *DNSResolver) getRecordFromDB(domain string, recordType string) ( return ip, nil } -func (resolver *DNSResolver) resolveUpstream(r *dns.Msg) (*dns.Msg, error) { +func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.Msg, error) { rows, err := resolver.db.Query("SELECT address FROM upstreams") if err != nil { return nil, err } defer rows.Close() + // Create a copy of the message to avoid side effects + req := r.Copy() + + // Add EDNS0 Client Subnet option + if clientIP != nil { + o := req.IsEdns0() + if o == nil { + req.SetEdns0(4096, false) + o = req.IsEdns0() + } + + // Filter out existing SUBNET options to ensure we are authoritative about the client IP + var newOptions []dns.EDNS0 + for _, opt := range o.Option { + if opt.Option() != dns.EDNS0SUBNET { + newOptions = append(newOptions, opt) + } + } + + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + if ip4 := clientIP.To4(); ip4 != nil { + e.Family = 1 // IPv4 + e.SourceNetmask = 32 + e.SourceScope = 0 + e.Address = ip4 + } else { + e.Family = 2 // IPv6 + e.SourceNetmask = 128 + e.SourceScope = 0 + e.Address = clientIP + } + + newOptions = append(newOptions, e) + o.Option = newOptions + } + c := new(dns.Client) c.Net = "udp" @@ -140,7 +186,7 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg) (*dns.Msg, error) { continue } - resp, _, err := c.Exchange(r, upstreamAddr) + resp, _, err := c.Exchange(req, upstreamAddr) if err == nil && resp != nil && resp.Rcode != dns.RcodeServerFailure { return resp, nil }