diff --git a/.env.include.example b/.env.include.example new file mode 100644 index 0000000..629e044 --- /dev/null +++ b/.env.include.example @@ -0,0 +1,4 @@ +DNS_PORT=53 +API_PORT=8888 +DB_PATH=./dns_records.db +DEFAULT_TTL=300 diff --git a/.gitignore b/.gitignore index ab7d69d..21d70d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .env .env.* +!.env.include* *.db diff --git a/go.mod b/go.mod index e7744d9..1c4ab75 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.juancwu.dev/juancwu/ccretdns go 1.25.0 require ( + github.com/joho/godotenv v1.5.1 github.com/mattn/go-sqlite3 v1.14.32 github.com/miekg/dns v1.1.69 ) diff --git a/go.sum b/go.sum index 782c2c8..dad74db 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/miekg/dns v1.1.69 h1:Kb7Y/1Jo+SG+a2GtfoFUfDkG//csdRPwRLkCsxDG9Sc= diff --git a/main.go b/main.go index 7d07467..b52ba50 100644 --- a/main.go +++ b/main.go @@ -5,21 +5,18 @@ import ( "encoding/json" "fmt" "log" + "log/slog" "net/http" + "os" + "strconv" "strings" "time" + "github.com/joho/godotenv" _ "github.com/mattn/go-sqlite3" "github.com/miekg/dns" ) -const ( - DBPath = "./dns_records.db" - DNSPort = "53" - APIPort = ":8080" - DefaultTTL = 300 -) - type RecordRequest struct { Domain string `json:"domain"` IP string `json:"ip"` @@ -30,8 +27,9 @@ type UpstreamRequest struct { Address string `json:"address"` } -func initDB() *sql.DB { - db, err := sql.Open("sqlite3", DBPath) +func initDB(dbPath string) *sql.DB { + + db, err := sql.Open("sqlite3", dbPath) if err != nil { log.Fatal(err) } @@ -61,7 +59,8 @@ func initDB() *sql.DB { } type DNSResolver struct { - db *sql.DB + db *sql.DB + defaultTTL int } func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { @@ -75,7 +74,7 @@ func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) ip, err := resolver.getRecordFromDB(q.Name) if err == nil && ip != "" { log.Printf("[LOCAL] Resolved %s -> %s", q.Name, ip) - rr, err := dns.NewRR(fmt.Sprintf("%s %d A %s", q.Name, DefaultTTL, ip)) + rr, err := dns.NewRR(fmt.Sprintf("%s %d A %s", q.Name, resolver.defaultTTL, ip)) if err == nil { m.Answer = append(m.Answer, rr) } @@ -136,7 +135,7 @@ func (resolver *DNSResolver) resolveUpstream(r *dns.Msg) (*dns.Msg, error) { return nil, fmt.Errorf("all upstreams failed") } -func startAPIServer(db *sql.DB) { +func startAPIServer(db *sql.DB, apiPort string) { http.HandleFunc("/records", func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var req RecordRequest @@ -206,22 +205,52 @@ func startAPIServer(db *sql.DB) { } }) - log.Printf("API Listening on %s...", APIPort) - log.Fatal(http.ListenAndServe(APIPort, nil)) + log.Printf("API Listening on %s...", apiPort) + log.Fatal(http.ListenAndServe(":"+apiPort, nil)) +} + +func envString(key, def string) string { + value := os.Getenv(key) + if value == "" { + value = def + } + return value +} + +func envInt(key string, def int) int { + value, ok := os.LookupEnv(key) + if !ok || value == "" { + return def + } + val, err := strconv.ParseInt(value, 10, 32) + if err != nil { + slog.Warn("config invalid integer, using default", "key", key, "value", value, "default", def) + return def + } + return int(val) } func main() { - db := initDB() + godotenv.Load() + + var ( + dbPath = envString("DB_PATH", "./dns_records.db") + dnsPort = envString("DNS_PORT", "53") + apiPort = envString("API_PORT", "8888") + defaultTTL = envInt("DEFAULT_TTL", 300) + ) + + db := initDB(dbPath) defer db.Close() - go startAPIServer(db) + go startAPIServer(db, apiPort) - resolver := &DNSResolver{db: db} + resolver := &DNSResolver{db: db, defaultTTL: defaultTTL} dns.HandleFunc(".", resolver.handleDNSRequest) - server := &dns.Server{Addr: ":" + DNSPort, Net: "udp"} - log.Printf("DNS Server listening on port %s...", DNSPort) + server := &dns.Server{Addr: ":" + dnsPort, Net: "udp"} + log.Printf("DNS Server listening on port %s...", dnsPort) if err := server.ListenAndServe(); err != nil { log.Fatalf("Failed to set up DNS server: %v", err)