forward with original client ip under trusted networks only
This commit is contained in:
parent
c58703f735
commit
27e51646fc
3 changed files with 109 additions and 8 deletions
|
|
@ -2,3 +2,4 @@ DNS_PORT=53
|
||||||
API_PORT=8888
|
API_PORT=8888
|
||||||
DB_PATH=./dns_records.db
|
DB_PATH=./dns_records.db
|
||||||
DEFAULT_TTL=300
|
DEFAULT_TTL=300
|
||||||
|
TRUSTED_SUBNETS=127.0.0.1/32,192.168.1.0/24
|
||||||
60
main.go
60
main.go
|
|
@ -59,8 +59,9 @@ func initDB(dbPath string) *sql.DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSResolver struct {
|
type DNSResolver struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
defaultTTL int
|
defaultTTL int
|
||||||
|
trustedSubnets []*net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
|
@ -141,8 +142,19 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.
|
||||||
// Create a copy of the message to avoid side effects
|
// Create a copy of the message to avoid side effects
|
||||||
req := r.Copy()
|
req := r.Copy()
|
||||||
|
|
||||||
// Add EDNS0 Client Subnet option
|
// Check if clientIP is allowed
|
||||||
|
allowForwarding := false
|
||||||
if clientIP != nil {
|
if clientIP != nil {
|
||||||
|
for _, subnet := range resolver.trustedSubnets {
|
||||||
|
if subnet.Contains(clientIP) {
|
||||||
|
allowForwarding = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add EDNS0 Client Subnet option
|
||||||
|
if allowForwarding {
|
||||||
o := req.IsEdns0()
|
o := req.IsEdns0()
|
||||||
if o == nil {
|
if o == nil {
|
||||||
req.SetEdns0(4096, false)
|
req.SetEdns0(4096, false)
|
||||||
|
|
@ -259,7 +271,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
||||||
method = r.Method
|
method = r.Method
|
||||||
}
|
}
|
||||||
|
|
||||||
if method == http.MethodPost {
|
switch method {
|
||||||
|
case http.MethodPost:
|
||||||
// Normalize domain to ensure trailing dot
|
// Normalize domain to ensure trailing dot
|
||||||
if !strings.HasSuffix(domain, ".") {
|
if !strings.HasSuffix(domain, ".") {
|
||||||
domain += "."
|
domain += "."
|
||||||
|
|
@ -281,7 +294,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", "/")
|
w.Header().Set("Location", "/")
|
||||||
w.WriteHeader(http.StatusSeeOther)
|
w.WriteHeader(http.StatusSeeOther)
|
||||||
} else if method == http.MethodDelete {
|
case http.MethodDelete:
|
||||||
if !strings.HasSuffix(domain, ".") {
|
if !strings.HasSuffix(domain, ".") {
|
||||||
domain += "."
|
domain += "."
|
||||||
}
|
}
|
||||||
|
|
@ -309,7 +322,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
||||||
method = r.Method
|
method = r.Method
|
||||||
}
|
}
|
||||||
|
|
||||||
if method == http.MethodPost {
|
switch method {
|
||||||
|
case http.MethodPost:
|
||||||
_, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address)
|
_, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
|
@ -317,7 +331,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", "/")
|
w.Header().Set("Location", "/")
|
||||||
w.WriteHeader(http.StatusSeeOther)
|
w.WriteHeader(http.StatusSeeOther)
|
||||||
} else if method == http.MethodDelete {
|
case http.MethodDelete:
|
||||||
_, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address)
|
_, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
|
@ -332,6 +346,26 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
||||||
log.Fatal(http.ListenAndServe(":"+apiPort, nil))
|
log.Fatal(http.ListenAndServe(":"+apiPort, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCIDRs(envVar string) ([]*net.IPNet, error) {
|
||||||
|
var subnets []*net.IPNet
|
||||||
|
if envVar == "" {
|
||||||
|
return subnets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for p := range strings.SplitSeq(envVar, ",") {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
subnets = append(subnets, ipnet)
|
||||||
|
}
|
||||||
|
return subnets, nil
|
||||||
|
}
|
||||||
|
|
||||||
func envString(key, def string) string {
|
func envString(key, def string) string {
|
||||||
value := os.Getenv(key)
|
value := os.Getenv(key)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
|
|
@ -363,12 +397,22 @@ func main() {
|
||||||
defaultTTL = envInt("DEFAULT_TTL", 300)
|
defaultTTL = envInt("DEFAULT_TTL", 300)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trustedSubnetsStr := envString("TRUSTED_SUBNETS", "")
|
||||||
|
trustedSubnets, err := parseCIDRs(trustedSubnetsStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Invalid TRUSTED_SUBNETS: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
db := initDB(dbPath)
|
db := initDB(dbPath)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
go startAPIServer(db, apiPort)
|
go startAPIServer(db, apiPort)
|
||||||
|
|
||||||
resolver := &DNSResolver{db: db, defaultTTL: defaultTTL}
|
resolver := &DNSResolver{
|
||||||
|
db: db,
|
||||||
|
defaultTTL: defaultTTL,
|
||||||
|
trustedSubnets: trustedSubnets,
|
||||||
|
}
|
||||||
|
|
||||||
dns.HandleFunc(".", resolver.handleDNSRequest)
|
dns.HandleFunc(".", resolver.handleDNSRequest)
|
||||||
|
|
||||||
|
|
|
||||||
56
main_test.go
Normal file
56
main_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseCIDRs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"", 0, false},
|
||||||
|
{"192.168.1.0/24", 1, false},
|
||||||
|
{"192.168.1.0/24, 10.0.0.0/8", 2, false},
|
||||||
|
{"invalid", 0, true},
|
||||||
|
{"192.168.1.0/24, invalid", 0, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result, err := parseCIDRs(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("parseCIDRs(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(result) != tt.expected {
|
||||||
|
t.Errorf("parseCIDRs(%q) returned %d subnets, expected %d", tt.input, len(result), tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubnetCheck(t *testing.T) {
|
||||||
|
// Test the logic used in resolveUpstream
|
||||||
|
subnets, _ := parseCIDRs("192.168.1.0/24, 10.0.0.0/8")
|
||||||
|
|
||||||
|
allow := func(ipStr string) bool {
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
for _, s := range subnets {
|
||||||
|
if s.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allow("192.168.1.50") {
|
||||||
|
t.Error("Should allow 192.168.1.50")
|
||||||
|
}
|
||||||
|
if !allow("10.5.5.5") {
|
||||||
|
t.Error("Should allow 10.5.5.5")
|
||||||
|
}
|
||||||
|
if allow("8.8.8.8") {
|
||||||
|
t.Error("Should not allow 8.8.8.8")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue