diff --git a/CHANGELOG.md b/CHANGELOG.md index 60011bf4c..4da8e90f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## UNRELEASED +BUG FIXES: + +* Fixing SOA record to return proper domain when alt domain in use. [[GH-10431]](https://github.com/hashicorp/consul/pull/10431) + ## 1.11.0-alpha (September 16, 2021) SECURITY: diff --git a/agent/dns.go b/agent/dns.go index 1e25305fd..b4f7b0c6d 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -373,7 +373,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { - d.addSOA(cfg, m) + d.addSOA(cfg, m, q.Name) } datacenter := d.agent.config.Datacenter @@ -486,7 +486,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { switch req.Question[0].Qtype { case dns.TypeSOA: ns, glue := d.nameservers(cfg, maxRecursionLevelDefault) - m.Answer = append(m.Answer, d.soa(cfg)) + m.Answer = append(m.Answer, d.soa(cfg, q.Name)) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) @@ -504,7 +504,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault) rCode := rCodeFromError(err) if rCode == dns.RcodeNameError || errors.Is(err, errNoData) { - d.addSOA(cfg, m) + d.addSOA(cfg, m, q.Name) } m.SetRcode(req, rCode) } @@ -518,18 +518,23 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { } } -func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA { +func (d *DNSServer) soa(cfg *dnsConfig, questionName string) *dns.SOA { + domain := d.domain + if d.altDomain != "" && strings.HasSuffix(questionName, "."+d.altDomain) { + domain = d.altDomain + } + return &dns.SOA{ Hdr: dns.RR_Header{ - Name: d.domain, + Name: domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, // Has to be consistent with MinTTL to avoid invalidation Ttl: cfg.SOAConfig.Minttl, }, - Ns: "ns." + d.domain, + Ns: "ns." + domain, Serial: uint32(time.Now().Unix()), - Mbox: "hostmaster." + d.domain, + Mbox: "hostmaster." + domain, Refresh: cfg.SOAConfig.Refresh, Retry: cfg.SOAConfig.Retry, Expire: cfg.SOAConfig.Expire, @@ -538,8 +543,8 @@ func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA { } // addSOA is used to add an SOA record to a message for the given domain -func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) { - msg.Ns = append(msg.Ns, d.soa(cfg)) +func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg, questionName string) { + msg.Ns = append(msg.Ns, d.soa(cfg, questionName)) } // nameservers returns the names and ip addresses of up to three random servers @@ -600,6 +605,12 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns return } +func (d *DNSServer) invalidQuery(req, resp *dns.Msg, cfg *dnsConfig, qName string) { + d.logger.Warn("QName invalid", "qname", qName) + d.addSOA(cfg, resp, qName) + resp.SetRcode(req, dns.RcodeNameError) +} + func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { switch len(labels) { case 1: diff --git a/agent/dns_test.go b/agent/dns_test.go index 60e5e1db2..ccbf5b012 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -6528,14 +6528,17 @@ func TestDNS_AltDomains_SOA(t *testing.T) { defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") - questions := []string{ - "test-node.node.consul.", - "test-node.node.test-domain.", + questions := []struct { + ask string + want_domain string + }{ + {"test-node.node.consul.", "consul."}, + {"test-node.node.test-domain.", "test-domain."}, } for _, question := range questions { m := new(dns.Msg) - m.SetQuestion(question, dns.TypeSOA) + m.SetQuestion(question.ask, dns.TypeSOA) c := new(dns.Client) in, _, err := c.Exchange(m, a.DNSAddr()) @@ -6552,10 +6555,10 @@ func TestDNS_AltDomains_SOA(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - if got, want := soaRec.Hdr.Name, "consul."; got != want { + if got, want := soaRec.Hdr.Name, question.want_domain; got != want { t.Fatalf("SOA name invalid, got %q want %q", got, want) } - if got, want := soaRec.Ns, "ns.consul."; got != want { + if got, want := soaRec.Ns, ("ns." + question.want_domain); got != want { t.Fatalf("SOA ns invalid, got %q want %q", got, want) } }