diff --git a/command/agent/dns.go b/command/agent/dns.go index 78b3bf26a..7a0b918fd 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -175,7 +175,11 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Authoritative = true - d.addSOA(d.domain, m) + + // Only add the SOA if requested + if req.Question[0].Qtype == dns.TypeSOA { + d.addSOA(d.domain, m) + } // Dispatch the correct handler d.dispatch(network, req, m) @@ -310,6 +314,14 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. record := formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType) if record != nil { resp.Answer = append(resp.Answer, record) + + // Try to recursively resolve the CNAME + if cnRec, ok := record.(*dns.CNAME); ok { + aRecs := d.resolveCNAME(cnRec.Target) + if len(aRecs) > 0 { + resp.Extra = append(resp.Extra, aRecs[0]) + } + } } } @@ -399,6 +411,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, if qType == dns.TypeANY || qType == dns.TypeSRV { d.serviceSRVRecords(datacenter, out.Nodes, req, resp) } + + // Cleanup duplicate extra entries + resp.Extra = removeDuplicates(resp.Extra) } // filterServiceNodes is used to filter out nodes that are failing @@ -446,6 +461,14 @@ func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, res record := formatNodeRecord(&node.Node, qName, qType) if record != nil { resp.Answer = append(resp.Answer, record) + + // Try to recursively resolve the CNAME + if cnRec, ok := record.(*dns.CNAME); ok { + aRecs := d.resolveCNAME(cnRec.Target) + if len(aRecs) > 0 { + resp.Extra = append(resp.Extra, aRecs[0]) + } + } } } } @@ -489,6 +512,14 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes record := formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY) if record != nil { resp.Extra = append(resp.Extra, record) + + // Try to recursively resolve the CNAME + if cnRec, ok := record.(*dns.CNAME); ok { + aRecs := d.resolveCNAME(cnRec.Target) + if len(aRecs) > 0 { + resp.Extra = append(resp.Extra, aRecs[0]) + } + } } } } @@ -526,3 +557,47 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } } + +// resolveCNAME is used to recursively resolve CNAME records +func (d *DNSServer) resolveCNAME(name string) []dns.RR { + // Do nothing if we don't have a recursor + if d.recursor == "" { + return nil + } + + // Ask for any A records + m := new(dns.Msg) + m.SetQuestion(name, dns.TypeA) + + // Make a DNS lookup request + c := &dns.Client{Net: "udp"} + r, rtt, err := c.Exchange(m, d.recursor) + if err != nil { + d.logger.Printf("[ERR] dns: cname recurse failed: %v", err) + return nil + } + d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) + + // Return all the answers + return r.Answer +} + +// removeDuplicates is used to remove the duplicate entries. +// This only deduplicates on the QName and QType +func removeDuplicates(rr []dns.RR) []dns.RR { + handled := make(map[string]struct{}) + n := len(rr) + for i := 0; i < n; i++ { + rec := rr[i] + hdr := rec.Header() + key := fmt.Sprintf("%s:%d", hdr.Name, hdr.Rrtype) + if _, ok := handled[key]; ok { + // Remove duplicate + rr[i], rr[n-1] = rr[n-1], nil + n-- + i-- + } + handled[key] = struct{}{} + } + return rr[:n] +} diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index f1b3a4a9f..bc1b40068 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -587,14 +587,22 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { t.Fatalf("Bad: %#v", srvRec) } - cnRec, ok = in.Extra[0].(*dns.CNAME) + aRec, ok := in.Extra[0].(*dns.A) if !ok { t.Fatalf("Bad: %#v", in.Extra[0]) } - if cnRec.Hdr.Name != "google.node.dc1.consul." { + if aRec.Hdr.Name != "www.google.com." { t.Fatalf("Bad: %#v", in.Extra[0]) } + + cnRec, ok = in.Extra[1].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[1]) + } + if cnRec.Hdr.Name != "google.node.dc1.consul." { + t.Fatalf("Bad: %#v", in.Extra[1]) + } if cnRec.Target != "www.google.com." { - t.Fatalf("Bad: %#v", in.Extra[0]) + t.Fatalf("Bad: %#v", in.Extra[1]) } }