diff --git a/agent/dns.go b/agent/dns.go index ba65ce75d..f28edae4e 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -212,7 +212,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { switch req.Question[0].Qtype { case dns.TypeSOA: - ns, glue := d.nameservers() + ns, glue := d.nameservers(req.IsEdns0() != nil) m.Answer = append(m.Answer, d.soa()) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) @@ -236,9 +236,9 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, _ := d.nameservers() + ns, glue := d.nameservers(req.IsEdns0() != nil) m.Answer = ns - // no need to send A records with the IP address, since the ns record is a node name that resolves correctly + m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) default: @@ -283,7 +283,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) { // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. -func (d *DNSServer) nameservers() (ns []dns.RR, extra []dns.RR) { +func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { // get server names and store them in a map to randomize the output servers := map[string]net.IP{} for name, addr := range d.agent.delegate.ServerAddrs() { @@ -324,18 +324,8 @@ func (d *DNSServer) nameservers() (ns []dns.RR, extra []dns.RR) { Ns: name, } ns = append(ns, nsrr) - // the glue record providing the ip address - a := &dns.A{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: uint32(d.config.NodeTTL / time.Second), - }, - A: ip, - } - extra = append(extra, a) + extra = append(extra, d.formatNodeRecord(ip.String(), name, dns.TypeANY, d.config.NodeTTL, edns)...) // don't provide more than 3 servers if len(ns) >= 3 { @@ -523,7 +513,7 @@ RPC: n := out.NodeServices.Node edns := req.IsEdns0() != nil addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) - records := d.formatNodeRecord(out.NodeServices.Node, addr, + records := d.formatNodeRecord(addr, req.Question[0].Name, qType, d.config.NodeTTL, edns) if records != nil { resp.Answer = append(resp.Answer, records...) @@ -531,7 +521,7 @@ RPC: } // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record -func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) { +func (d *DNSServer) formatNodeRecord(addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) { // Parse the IP ip := net.ParseIP(addr) var ipv4 net.IP @@ -911,7 +901,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode handled[addr] = struct{}{} // Add the node record - records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns) + records := d.formatNodeRecord(addr, qName, qType, ttl, edns) if records != nil { resp.Answer = append(resp.Answer, records...) } @@ -955,7 +945,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns) + records := d.formatNodeRecord(addr, srvRec.Target, dns.TypeANY, ttl, edns) if len(records) > 0 { // Use the node address if it doesn't differ from the service address if addr == node.Node.Address { diff --git a/agent/dns_test.go b/agent/dns_test.go index 4bf560055..5d0929ef3 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -885,7 +885,7 @@ func TestDNS_NSRecords(t *testing.T) { t.Parallel() cfg := TestConfig() cfg.Domain = "CONSUL." - cfg.NodeName = "foo" + cfg.NodeName = "server1" a := NewTestAgent(t.Name(), cfg) defer a.Shutdown() @@ -922,13 +922,28 @@ func TestDNS_NSRecords(t *testing.T) { if !ok { t.Fatalf("Bad: %#v", in.Answer[0]) } - if nsRec.Ns != "foo.node.dc1.consul." { + if nsRec.Ns != "server1.node.dc1.consul." { t.Fatalf("Bad: %#v", in.Answer[0]) } if nsRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Answer[0]) } + if len(in.Extra) != 1 { + t.Fatalf("Bad: %#v", in.Extra) + } + + aRec, ok := in.Extra[0].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Extra) + } + if aRec.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Extra) + } + if aRec.Hdr.Name != "server1.node.dc1.consul." { + t.Fatalf("Bad: %#v", in.Extra) + } + } func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) {