diff --git a/main.go b/main.go index 0beca5e..a8a16cc 100644 --- a/main.go +++ b/main.go @@ -63,7 +63,15 @@ func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) switch r.Opcode { case dns.OpcodeQuery: for _, q := range m.Question { - ip, err := resolver.getRecordFromDB(q.Name) + var qTypeString string + switch q.Qtype { + case dns.TypeA: + qTypeString = "A" + case dns.TypeAAAA: + qTypeString = "AAAA" + } + + ip, err := resolver.getRecordFromDB(q.Name, qTypeString) if err == nil && ip != "" { log.Printf("[LOCAL] Resolved %s -> %s", q.Name, ip) rr, err := dns.NewRR(fmt.Sprintf("%s %d A %s", q.Name, resolver.defaultTTL, ip)) @@ -89,12 +97,12 @@ func (resolver *DNSResolver) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) w.WriteMsg(m) } -func (resolver *DNSResolver) getRecordFromDB(domain string) (string, error) { +func (resolver *DNSResolver) getRecordFromDB(domain string, recordType string) (string, error) { // DNS queries usually come with a trailing dot (e.g., "google.com.") // Ensure consistency by checking both with and without it if needed, // or enforcing it in the DB. Here we assume DB stores with trailing dot. var ip string - err := resolver.db.QueryRow("SELECT ip FROM records WHERE domain = ?", domain).Scan(&ip) + err := resolver.db.QueryRow("SELECT ip FROM records WHERE domain = ? AND record_type = ?", domain, recordType).Scan(&ip) if err != nil { return "", err } @@ -141,13 +149,23 @@ func startAPIServer(db *sql.DB, apiPort string) { req.Domain += "." } - _, err := db.Exec("INSERT OR REPLACE INTO records (domain, ip, record_type) VALUES (?, ?, ?)", req.Domain, req.IP, "A") + if req.Type == "" { + req.Type = "A" + } + req.Type = strings.ToUpper(req.Type) + + if req.Type != "A" && req.Type != "AAAA" { + http.Error(w, "Invalid record type. Must be A or AAAA", http.StatusBadRequest) + return + } + + _, err := db.Exec("INSERT OR REPLACE INTO records (domain, ip, record_type) VALUES (?, ?, ?)", req.Domain, req.IP, req.Type) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.WriteHeader(http.StatusCreated) - fmt.Fprintf(w, "Added %s -> %s", req.Domain, req.IP) + fmt.Fprintf(w, "Added %s (%s) -> %s", req.Domain, req.Type, req.IP) } else if r.Method == http.MethodDelete { var req RecordRequest