diff --git a/.env.include.example b/.env.include.example index 629e044..084a501 100644 --- a/.env.include.example +++ b/.env.include.example @@ -2,3 +2,4 @@ DNS_PORT=53 API_PORT=8888 DB_PATH=./dns_records.db DEFAULT_TTL=300 +TRUSTED_SUBNETS=127.0.0.1/32,192.168.1.0/24 \ No newline at end of file diff --git a/main.go b/main.go index 83aa221..7e49b52 100644 --- a/main.go +++ b/main.go @@ -59,8 +59,9 @@ func initDB(dbPath string) *sql.DB { } type DNSResolver struct { - db *sql.DB - defaultTTL int + db *sql.DB + defaultTTL int + trustedSubnets []*net.IPNet } func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { @@ -141,8 +142,19 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns. // Create a copy of the message to avoid side effects req := r.Copy() - // Add EDNS0 Client Subnet option + // Check if clientIP is allowed + allowForwarding := false if clientIP != nil { + for _, subnet := range resolver.trustedSubnets { + if subnet.Contains(clientIP) { + allowForwarding = true + break + } + } + } + + // Add EDNS0 Client Subnet option + if allowForwarding { o := req.IsEdns0() if o == nil { req.SetEdns0(4096, false) @@ -259,7 +271,8 @@ func startAPIServer(db *sql.DB, apiPort string) { method = r.Method } - if method == http.MethodPost { + switch method { + case http.MethodPost: // Normalize domain to ensure trailing dot if !strings.HasSuffix(domain, ".") { domain += "." @@ -281,7 +294,7 @@ func startAPIServer(db *sql.DB, apiPort string) { } w.Header().Set("Location", "/") w.WriteHeader(http.StatusSeeOther) - } else if method == http.MethodDelete { + case http.MethodDelete: if !strings.HasSuffix(domain, ".") { domain += "." } @@ -309,7 +322,8 @@ func startAPIServer(db *sql.DB, apiPort string) { method = r.Method } - if method == http.MethodPost { + switch method { + case http.MethodPost: _, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -317,7 +331,7 @@ func startAPIServer(db *sql.DB, apiPort string) { } w.Header().Set("Location", "/") w.WriteHeader(http.StatusSeeOther) - } else if method == http.MethodDelete { + case http.MethodDelete: _, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -332,6 +346,26 @@ func startAPIServer(db *sql.DB, apiPort string) { log.Fatal(http.ListenAndServe(":"+apiPort, nil)) } +func parseCIDRs(envVar string) ([]*net.IPNet, error) { + var subnets []*net.IPNet + if envVar == "" { + return subnets, nil + } + + for p := range strings.SplitSeq(envVar, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + _, ipnet, err := net.ParseCIDR(p) + if err != nil { + return nil, err + } + subnets = append(subnets, ipnet) + } + return subnets, nil +} + func envString(key, def string) string { value := os.Getenv(key) if value == "" { @@ -363,12 +397,22 @@ func main() { defaultTTL = envInt("DEFAULT_TTL", 300) ) + trustedSubnetsStr := envString("TRUSTED_SUBNETS", "") + trustedSubnets, err := parseCIDRs(trustedSubnetsStr) + if err != nil { + log.Fatalf("Invalid TRUSTED_SUBNETS: %v", err) + } + db := initDB(dbPath) defer db.Close() go startAPIServer(db, apiPort) - resolver := &DNSResolver{db: db, defaultTTL: defaultTTL} + resolver := &DNSResolver{ + db: db, + defaultTTL: defaultTTL, + trustedSubnets: trustedSubnets, + } dns.HandleFunc(".", resolver.handleDNSRequest) diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..6495fb1 --- /dev/null +++ b/main_test.go @@ -0,0 +1,56 @@ +package main + +import ( + "net" + "testing" +) + +func TestParseCIDRs(t *testing.T) { + tests := []struct { + input string + expected int + wantErr bool + }{ + {"", 0, false}, + {"192.168.1.0/24", 1, false}, + {"192.168.1.0/24, 10.0.0.0/8", 2, false}, + {"invalid", 0, true}, + {"192.168.1.0/24, invalid", 0, true}, + } + + for _, tt := range tests { + result, err := parseCIDRs(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseCIDRs(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + continue + } + if len(result) != tt.expected { + t.Errorf("parseCIDRs(%q) returned %d subnets, expected %d", tt.input, len(result), tt.expected) + } + } +} + +func TestSubnetCheck(t *testing.T) { + // Test the logic used in resolveUpstream + subnets, _ := parseCIDRs("192.168.1.0/24, 10.0.0.0/8") + + allow := func(ipStr string) bool { + ip := net.ParseIP(ipStr) + for _, s := range subnets { + if s.Contains(ip) { + return true + } + } + return false + } + + if !allow("192.168.1.50") { + t.Error("Should allow 192.168.1.50") + } + if !allow("10.5.5.5") { + t.Error("Should allow 10.5.5.5") + } + if allow("8.8.8.8") { + t.Error("Should not allow 8.8.8.8") + } +}