forward with original client ip under trusted networks only

This commit is contained in:
juancwu 2026-01-17 20:38:05 +00:00
commit 27e51646fc
3 changed files with 109 additions and 8 deletions

60
main.go
View file

@ -59,8 +59,9 @@ func initDB(dbPath string) *sql.DB {
}
type DNSResolver struct {
db *sql.DB
defaultTTL int
db *sql.DB
defaultTTL int
trustedSubnets []*net.IPNet
}
func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
@ -141,8 +142,19 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.
// Create a copy of the message to avoid side effects
req := r.Copy()
// Add EDNS0 Client Subnet option
// Check if clientIP is allowed
allowForwarding := false
if clientIP != nil {
for _, subnet := range resolver.trustedSubnets {
if subnet.Contains(clientIP) {
allowForwarding = true
break
}
}
}
// Add EDNS0 Client Subnet option
if allowForwarding {
o := req.IsEdns0()
if o == nil {
req.SetEdns0(4096, false)
@ -259,7 +271,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
method = r.Method
}
if method == http.MethodPost {
switch method {
case http.MethodPost:
// Normalize domain to ensure trailing dot
if !strings.HasSuffix(domain, ".") {
domain += "."
@ -281,7 +294,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
} else if method == http.MethodDelete {
case http.MethodDelete:
if !strings.HasSuffix(domain, ".") {
domain += "."
}
@ -309,7 +322,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
method = r.Method
}
if method == http.MethodPost {
switch method {
case http.MethodPost:
_, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -317,7 +331,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
} else if method == http.MethodDelete {
case http.MethodDelete:
_, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -332,6 +346,26 @@ func startAPIServer(db *sql.DB, apiPort string) {
log.Fatal(http.ListenAndServe(":"+apiPort, nil))
}
func parseCIDRs(envVar string) ([]*net.IPNet, error) {
var subnets []*net.IPNet
if envVar == "" {
return subnets, nil
}
for p := range strings.SplitSeq(envVar, ",") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
_, ipnet, err := net.ParseCIDR(p)
if err != nil {
return nil, err
}
subnets = append(subnets, ipnet)
}
return subnets, nil
}
func envString(key, def string) string {
value := os.Getenv(key)
if value == "" {
@ -363,12 +397,22 @@ func main() {
defaultTTL = envInt("DEFAULT_TTL", 300)
)
trustedSubnetsStr := envString("TRUSTED_SUBNETS", "")
trustedSubnets, err := parseCIDRs(trustedSubnetsStr)
if err != nil {
log.Fatalf("Invalid TRUSTED_SUBNETS: %v", err)
}
db := initDB(dbPath)
defer db.Close()
go startAPIServer(db, apiPort)
resolver := &DNSResolver{db: db, defaultTTL: defaultTTL}
resolver := &DNSResolver{
db: db,
defaultTTL: defaultTTL,
trustedSubnets: trustedSubnets,
}
dns.HandleFunc(".", resolver.handleDNSRequest)