ccretdns/main.go

425 lines
9.8 KiB
Go

package main
import (
"database/sql"
"embed"
"fmt"
"html/template"
"log"
"log/slog"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/joho/godotenv"
_ "github.com/mattn/go-sqlite3"
"github.com/miekg/dns"
"github.com/pressly/goose/v3"
)
//go:embed migrations/*.sql index.html
var embedMigrations embed.FS
var indexTmpl = template.Must(template.ParseFS(embedMigrations, "index.html"))
type Record struct {
Domain string
IP string
RecordType string
}
type Upstream struct {
Address string
}
type PageData struct {
Records []Record
Upstreams []Upstream
}
func initDB(dbPath string) *sql.DB {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
log.Fatal(err)
}
if err := goose.SetDialect("sqlite3"); err != nil {
log.Fatal(err)
}
goose.SetBaseFS(embedMigrations)
if err := goose.Up(db, "migrations"); err != nil {
log.Fatal(err)
}
return db
}
type DNSResolver struct {
db *sql.DB
defaultTTL int
trustedSubnets []*net.IPNet
}
func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
for _, q := range m.Question {
isLanOrInternal := strings.HasSuffix(q.Name, ".lan.") || strings.HasSuffix(q.Name, ".internal.")
isARecord := q.Qtype == dns.TypeA
isAaaaRecord := q.Qtype == dns.TypeAAAA
if (isARecord || isAaaaRecord) && isLanOrInternal {
var qTypeString string
if isARecord {
qTypeString = "A"
} else {
qTypeString = "AAAA"
}
ip, err := resolver.getRecordFromDB(q.Name, qTypeString)
if err == nil && ip != "" {
log.Printf("[LOCAL] Resolved %s -> %s", q.Name, ip)
rr, err := dns.NewRR(fmt.Sprintf("%s %d %s %s", q.Name, resolver.defaultTTL, qTypeString, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
continue
}
}
// Fallback to upstream for all other cases
var clientIP net.IP
switch addr := w.RemoteAddr().(type) {
case *net.UDPAddr:
clientIP = addr.IP
case *net.TCPAddr:
clientIP = addr.IP
}
resp, err := resolver.resolveUpstream(r, clientIP)
if err == nil && resp != nil {
log.Printf("[UPSTREAM] Forwarded %s (ClientIP: %s)", q.Name, clientIP)
m.Answer = resp.Answer
m.Ns = resp.Ns
m.Extra = resp.Extra
} else {
log.Printf("[ERROR] Could not resolve %s", q.Name)
m.Rcode = dns.RcodeServerFailure
}
}
}
w.WriteMsg(m)
}
func (resolver *DNSResolver) getRecordFromDB(domain string, recordType string) (string, error) {
// DNS queries usually come with a trailing dot (e.g., "google.com.")
// Ensure consistency by checking both with and without it if needed,
// or enforcing it in the DB. Here we assume DB stores with trailing dot.
var ip string
err := resolver.db.QueryRow("SELECT ip FROM records WHERE domain = ? AND record_type = ?", domain, recordType).Scan(&ip)
if err != nil {
return "", err
}
return ip, nil
}
func (resolver *DNSResolver) resolveUpstream(r *dns.Msg, clientIP net.IP) (*dns.Msg, error) {
rows, err := resolver.db.Query("SELECT address FROM upstreams")
if err != nil {
return nil, err
}
defer rows.Close()
// Create a copy of the message to avoid side effects
req := r.Copy()
// Check if clientIP is allowed
allowForwarding := false
if clientIP != nil {
for _, subnet := range resolver.trustedSubnets {
if subnet.Contains(clientIP) {
allowForwarding = true
break
}
}
}
// Add EDNS0 Client Subnet option
if allowForwarding {
o := req.IsEdns0()
if o == nil {
req.SetEdns0(4096, false)
o = req.IsEdns0()
}
// Filter out existing SUBNET options to ensure we are authoritative about the client IP
var newOptions []dns.EDNS0
for _, opt := range o.Option {
if opt.Option() != dns.EDNS0SUBNET {
newOptions = append(newOptions, opt)
}
}
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip4 := clientIP.To4(); ip4 != nil {
e.Family = 1 // IPv4
e.SourceNetmask = 32
e.SourceScope = 0
e.Address = ip4
} else {
e.Family = 2 // IPv6
e.SourceNetmask = 128
e.SourceScope = 0
e.Address = clientIP
}
newOptions = append(newOptions, e)
o.Option = newOptions
}
c := new(dns.Client)
c.Net = "udp"
c.Timeout = 5 * time.Second
for rows.Next() {
var upstreamAddr string
if err := rows.Scan(&upstreamAddr); err != nil {
continue
}
resp, _, err := c.Exchange(req, upstreamAddr)
if err == nil && resp != nil && resp.Rcode != dns.RcodeServerFailure {
return resp, nil
}
}
return nil, fmt.Errorf("all upstreams failed")
}
func startAPIServer(db *sql.DB, apiPort string) {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
recordsRows, err := db.Query("SELECT domain, ip, record_type FROM records")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer recordsRows.Close()
var records []Record
for recordsRows.Next() {
var rec Record
if err := recordsRows.Scan(&rec.Domain, &rec.IP, &rec.RecordType); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
records = append(records, rec)
}
upstreamsRows, err := db.Query("SELECT address FROM upstreams")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer upstreamsRows.Close()
var upstreams []Upstream
for upstreamsRows.Next() {
var ups Upstream
if err := upstreamsRows.Scan(&ups.Address); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
upstreams = append(upstreams, ups)
}
data := PageData{
Records: records,
Upstreams: upstreams,
}
if err := indexTmpl.Execute(w, data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
http.HandleFunc("/records", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
domain := r.FormValue("domain")
ip := r.FormValue("ip")
recordType := r.FormValue("type")
method := r.FormValue("_method")
if method == "" {
method = r.Method
}
switch method {
case http.MethodPost:
// Normalize domain to ensure trailing dot
if !strings.HasSuffix(domain, ".") {
domain += "."
}
if recordType == "" {
recordType = "A"
}
recordType = strings.ToUpper(recordType)
if recordType != "A" && recordType != "AAAA" {
http.Error(w, "Invalid record type. Must be A or AAAA", http.StatusBadRequest)
return
}
_, err := db.Exec("INSERT OR REPLACE INTO records (domain, ip, record_type) VALUES (?, ?, ?)", domain, ip, recordType)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
case http.MethodDelete:
if !strings.HasSuffix(domain, ".") {
domain += "."
}
_, err := db.Exec("DELETE FROM records WHERE domain = ? AND record_type = ?", domain, recordType)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
}
})
http.HandleFunc("/upstreams", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
address := r.FormValue("address")
method := r.FormValue("_method")
if method == "" {
method = r.Method
}
switch method {
case http.MethodPost:
_, err := db.Exec("INSERT INTO upstreams (address) VALUES (?)", address)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
case http.MethodDelete:
_, err := db.Exec("DELETE FROM upstreams WHERE address = ?", address)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusSeeOther)
}
})
log.Printf("API Listening on %s...", apiPort)
log.Fatal(http.ListenAndServe(":"+apiPort, nil))
}
func parseCIDRs(envVar string) ([]*net.IPNet, error) {
var subnets []*net.IPNet
if envVar == "" {
return subnets, nil
}
for p := range strings.SplitSeq(envVar, ",") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
_, ipnet, err := net.ParseCIDR(p)
if err != nil {
return nil, err
}
subnets = append(subnets, ipnet)
}
return subnets, nil
}
func envString(key, def string) string {
value := os.Getenv(key)
if value == "" {
value = def
}
return value
}
func envInt(key string, def int) int {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return def
}
val, err := strconv.ParseInt(value, 10, 32)
if err != nil {
slog.Warn("config invalid integer, using default", "key", key, "value", value, "default", def)
return def
}
return int(val)
}
func main() {
godotenv.Load()
var (
dbPath = envString("DB_PATH", "./dns_records.db")
dnsPort = envString("DNS_PORT", "53")
apiPort = envString("API_PORT", "8888")
defaultTTL = envInt("DEFAULT_TTL", 300)
)
trustedSubnetsStr := envString("TRUSTED_SUBNETS", "")
trustedSubnets, err := parseCIDRs(trustedSubnetsStr)
if err != nil {
log.Fatalf("Invalid TRUSTED_SUBNETS: %v", err)
}
db := initDB(dbPath)
defer db.Close()
go startAPIServer(db, apiPort)
resolver := &DNSResolver{
db: db,
defaultTTL: defaultTTL,
trustedSubnets: trustedSubnets,
}
dns.HandleFunc(".", resolver.handleDNSRequest)
server := &dns.Server{Addr: ":" + dnsPort, Net: "udp"}
log.Printf("DNS Server listening on port %s...", dnsPort)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to set up DNS server: %v", err)
}
}