diff --git a/agent/dns.go b/agent/dns.go index e5f0e5085..937c32d08 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -501,14 +501,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { switch req.Question[0].Qtype { case dns.TypeSOA: - ns, glue := d.nameservers(cfg, maxRecursionLevelDefault) + ns, glue := d.nameservers(req.Question[0].Name, cfg, maxRecursionLevelDefault) m.Answer = append(m.Answer, d.soa(cfg, q.Name)) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, glue := d.nameservers(cfg, maxRecursionLevelDefault) + ns, glue := d.nameservers(req.Question[0].Name, cfg, maxRecursionLevelDefault) m.Answer = ns m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) @@ -566,7 +566,7 @@ func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg, questionName string) { // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. -func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { +func (d *DNSServer) nameservers(questionName string, cfg *dnsConfig, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { out, err := d.lookupServiceNodes(cfg, serviceLookup{ Datacenter: d.agent.config.Datacenter, Service: structs.ConsulServiceName, @@ -594,14 +594,14 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns d.logger.Warn("Skipping invalid node for NS records", "node", name) continue } - - fqdn := name + ".node." + dc + "." + d.domain + respDomain := d.getResponseDomain(questionName) + fqdn := name + ".node." + dc + "." + respDomain fqdn = dns.Fqdn(strings.ToLower(fqdn)) // NS record nsrr := &dns.NS{ Hdr: dns.RR_Header{ - Name: d.domain, + Name: respDomain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(cfg.NodeTTL / time.Second), diff --git a/agent/dns_test.go b/agent/dns_test.go index e6a3f5be6..7bcab1cea 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -2128,6 +2128,58 @@ func TestDNS_NSRecords(t *testing.T) { require.Equal(t, wantExtra, in.Extra, "extra") } +func TestDNS_AltDomain_NSRecords(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + a := NewTestAgent(t, ` + domain = "CONSUL." + node_name = "server1" + alt_domain = "test-domain." + `) + defer a.Shutdown() + testrpc.WaitForTestAgent(t, a.RPC, "dc1") + + questions := []struct { + ask string + domain string + wantDomain string + }{ + {"something.node.consul.", "consul.", "server1.node.dc1.consul."}, + {"something.node.test-domain.", "test-domain.", "server1.node.dc1.test-domain."}, + } + + for _, question := range questions { + m := new(dns.Msg) + m.SetQuestion(question.ask, dns.TypeNS) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: question.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x13}, + Ns: question.wantDomain, + }, + } + require.Equal(t, wantAnswer, in.Answer, "answer") + wantExtra := []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: question.wantDomain, Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4, Ttl: 0}, + A: net.ParseIP("127.0.0.1").To4(), + }, + } + + require.Equal(t, wantExtra, in.Extra, "extra") + } + +} + func TestDNS_NSRecords_IPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2169,6 +2221,59 @@ func TestDNS_NSRecords_IPV6(t *testing.T) { } +func TestDNS_AltDomain_NSRecords_IPV6(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + a := NewTestAgent(t, ` + domain = "CONSUL." + node_name = "server1" + advertise_addr = "::1" + alt_domain = "test-domain." + `) + defer a.Shutdown() + testrpc.WaitForTestAgent(t, a.RPC, "dc1") + + questions := []struct { + ask string + domain string + wantDomain string + }{ + {"server1.node.dc1.consul.", "consul.", "server1.node.dc1.consul."}, + {"server1.node.dc1.test-domain.", "test-domain.", "server1.node.dc1.test-domain."}, + } + + for _, question := range questions { + m := new(dns.Msg) + m.SetQuestion(question.ask, dns.TypeNS) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: question.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x2}, + Ns: question.wantDomain, + }, + } + require.Equal(t, wantAnswer, in.Answer, "answer") + wantExtra := []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: question.wantDomain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Rdlength: 0x10, Ttl: 0}, + AAAA: net.ParseIP("::1"), + }, + } + + require.Equal(t, wantExtra, in.Extra, "extra") + } + +} + func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short")