forward with original client ip under trusted networks only
This commit is contained in:
parent
c58703f735
commit
27e51646fc
3 changed files with 109 additions and 8 deletions
|
|
@ -2,3 +2,4 @@ DNS_PORT=53
|
|||
API_PORT=8888
|
||||
DB_PATH=./dns_records.db
|
||||
DEFAULT_TTL=300
|
||||
TRUSTED_SUBNETS=127.0.0.1/32,192.168.1.0/24
|
||||
60
main.go
60
main.go
|
|
@ -59,8 +59,9 @@ func initDB(dbPath string) *sql.DB {
|
|||
}
|
||||
|
||||
type DNSResolver struct {
|
||||
db *sql.DB
|
||||
defaultTTL int
|
||||
db *sql.DB
|
||||
defaultTTL int
|
||||
trustedSubnets []*net.IPNet
|
||||
}
|
||||
|
||||
func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
|
@ -141,8 +142,19 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.
|
|||
// Create a copy of the message to avoid side effects
|
||||
req := r.Copy()
|
||||
|
||||
// Add EDNS0 Client Subnet option
|
||||
// Check if clientIP is allowed
|
||||
allowForwarding := false
|
||||
if clientIP != nil {
|
||||
for _, subnet := range resolver.trustedSubnets {
|
||||
if subnet.Contains(clientIP) {
|
||||
allowForwarding = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add EDNS0 Client Subnet option
|
||||
if allowForwarding {
|
||||
o := req.IsEdns0()
|
||||
if o == nil {
|
||||
req.SetEdns0(4096, false)
|
||||
|
|
@ -259,7 +271,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
|||
method = r.Method
|
||||
}
|
||||
|
||||
if method == http.MethodPost {
|
||||
switch method {
|
||||
case http.MethodPost:
|
||||
// Normalize domain to ensure trailing dot
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain += "."
|
||||
|
|
@ -281,7 +294,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
|||
}
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusSeeOther)
|
||||
} else if method == http.MethodDelete {
|
||||
case http.MethodDelete:
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain += "."
|
||||
}
|
||||
|
|
@ -309,7 +322,8 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
|||
method = r.Method
|
||||
}
|
||||
|
||||
if method == http.MethodPost {
|
||||
switch method {
|
||||
case http.MethodPost:
|
||||
_, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -317,7 +331,7 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
|||
}
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusSeeOther)
|
||||
} else if method == http.MethodDelete {
|
||||
case http.MethodDelete:
|
||||
_, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -332,6 +346,26 @@ func startAPIServer(db *sql.DB, apiPort string) {
|
|||
log.Fatal(http.ListenAndServe(":"+apiPort, nil))
|
||||
}
|
||||
|
||||
func parseCIDRs(envVar string) ([]*net.IPNet, error) {
|
||||
var subnets []*net.IPNet
|
||||
if envVar == "" {
|
||||
return subnets, nil
|
||||
}
|
||||
|
||||
for p := range strings.SplitSeq(envVar, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subnets = append(subnets, ipnet)
|
||||
}
|
||||
return subnets, nil
|
||||
}
|
||||
|
||||
func envString(key, def string) string {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
|
|
@ -363,12 +397,22 @@ func main() {
|
|||
defaultTTL = envInt("DEFAULT_TTL", 300)
|
||||
)
|
||||
|
||||
trustedSubnetsStr := envString("TRUSTED_SUBNETS", "")
|
||||
trustedSubnets, err := parseCIDRs(trustedSubnetsStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid TRUSTED_SUBNETS: %v", err)
|
||||
}
|
||||
|
||||
db := initDB(dbPath)
|
||||
defer db.Close()
|
||||
|
||||
go startAPIServer(db, apiPort)
|
||||
|
||||
resolver := &DNSResolver{db: db, defaultTTL: defaultTTL}
|
||||
resolver := &DNSResolver{
|
||||
db: db,
|
||||
defaultTTL: defaultTTL,
|
||||
trustedSubnets: trustedSubnets,
|
||||
}
|
||||
|
||||
dns.HandleFunc(".", resolver.handleDNSRequest)
|
||||
|
||||
|
|
|
|||
56
main_test.go
Normal file
56
main_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseCIDRs(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected int
|
||||
wantErr bool
|
||||
}{
|
||||
{"", 0, false},
|
||||
{"192.168.1.0/24", 1, false},
|
||||
{"192.168.1.0/24, 10.0.0.0/8", 2, false},
|
||||
{"invalid", 0, true},
|
||||
{"192.168.1.0/24, invalid", 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result, err := parseCIDRs(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseCIDRs(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
continue
|
||||
}
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("parseCIDRs(%q) returned %d subnets, expected %d", tt.input, len(result), tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubnetCheck(t *testing.T) {
|
||||
// Test the logic used in resolveUpstream
|
||||
subnets, _ := parseCIDRs("192.168.1.0/24, 10.0.0.0/8")
|
||||
|
||||
allow := func(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, s := range subnets {
|
||||
if s.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !allow("192.168.1.50") {
|
||||
t.Error("Should allow 192.168.1.50")
|
||||
}
|
||||
if !allow("10.5.5.5") {
|
||||
t.Error("Should allow 10.5.5.5")
|
||||
}
|
||||
if allow("8.8.8.8") {
|
||||
t.Error("Should not allow 8.8.8.8")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue