diff --git a/agent/dns.go b/agent/dns.go index e499cc5ab..fd9df8ebf 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -135,6 +135,40 @@ func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error { return d.Server.ListenAndServe() } +// setEDNS is used to set the responses EDNS size headers and +// possibly the ECS headers as well if they were present in the +// original request +func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) { + // Enable EDNS if enabled + if edns := request.IsEdns0(); edns != nil { + // cannot just use the SetEdns0 function as we need to embed + // the ECS option as well + ednsResp := new(dns.OPT) + ednsResp.Hdr.Name = "." + ednsResp.Hdr.Rrtype = dns.TypeOPT + ednsResp.SetUDPSize(edns.UDPSize()) + + // Setup the ECS option if present + if subnet := ednsSubnetForRequest(request); subnet != nil { + subOp := new(dns.EDNS0_SUBNET) + subOp.Code = dns.EDNS0SUBNET + subOp.Family = subnet.Family + subOp.Address = subnet.Address + subOp.SourceNetmask = subnet.SourceNetmask + if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented { + // reply is globally valid and should be cached accordingly + subOp.SourceScope = 0 + } else { + // reply is only valid for the subnet it was queried with + subOp.SourceScope = subnet.SourceNetmask + } + ednsResp.Option = append(ednsResp.Option, subOp) + } + + response.Extra = append(response.Extra, ednsResp) + } +} + // recursorAddr is used to add a port to the recursor if omitted. func recursorAddr(recursor string) (string, error) { // Add the port if none @@ -245,10 +279,8 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { return } - // Enable EDNS if enabled - if edns := req.IsEdns0(); edns != nil { - m.SetEdns0(edns.UDPSize(), false) - } + // ptr record responses are globally valid + setEDNS(req, m, true) // Write out the complete response if err := resp.WriteMsg(m); err != nil { @@ -280,6 +312,8 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.Authoritative = true m.RecursionAvailable = (len(d.recursors) > 0) + ecsGlobal := true + switch req.Question[0].Qtype { case dns.TypeSOA: ns, glue := d.nameservers(req.IsEdns0() != nil) @@ -298,13 +332,10 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetRcode(req, dns.RcodeNotImplemented) default: - d.dispatch(network, resp.RemoteAddr(), req, m) + ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m) } - // Handle EDNS - if edns := req.IsEdns0(); edns != nil { - m.SetEdns0(edns.UDPSize(), false) - } + setEDNS(req, m, ecsGlobal) // Write out the complete response if err := resp.WriteMsg(m); err != nil { @@ -393,7 +424,8 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { } // dispatch is used to parse a request and invoke the correct handler -func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) { +func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) { + ecsGlobal = true // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter @@ -478,6 +510,7 @@ PARSE: // Allow a "." in the query name, just join all the parts. query := strings.Join(labels[:n-1], ".") + ecsGlobal = false d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp) case "addr": @@ -547,6 +580,7 @@ INVALID: d.logger.Printf("[WARN] dns: QName invalid: %s", qName) d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) + return } // nodeLookup is used to handle a node query @@ -1380,7 +1414,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { m.RecursionAvailable = true m.SetRcode(req, dns.RcodeServerFailure) if edns := req.IsEdns0(); edns != nil { - m.SetEdns0(edns.UDPSize(), false) + setEDNS(req, m, true) } resp.WriteMsg(m) } diff --git a/agent/dns_test.go b/agent/dns_test.go index 48bc3bf11..d52684803 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -755,6 +755,98 @@ func TestDNS_EDNS0(t *testing.T) { } } +func TestDNS_EDNS0_ECS(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), "") + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + // Register a node with a service. + { + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "db", + Tags: []string{"master"}, + Port: 12345, + }, + } + + var out struct{} + require.NoError(t, a.RPC("Catalog.Register", args, &out)) + } + + // Register an equivalent prepared query. + var id string + { + args := &structs.PreparedQueryRequest{ + Datacenter: "dc1", + Op: structs.PreparedQueryCreate, + Query: &structs.PreparedQuery{ + Name: "test", + Service: structs.ServiceQuery{ + Service: "db", + }, + }, + } + require.NoError(t, a.RPC("PreparedQuery.Apply", args, &id)) + } + + cases := []struct { + Name string + Question string + SubnetAddr string + SourceNetmask uint8 + ExpectedScope uint8 + }{ + {"global", "db.service.consul.", "198.18.0.1", 32, 0}, + {"query", "test.query.consul.", "198.18.0.1", 32, 32}, + {"query-subnet", "test.query.consul.", "198.18.0.0", 21, 21}, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + c := new(dns.Client) + // Query the service directly - should have a globally valid scope (0) + m := new(dns.Msg) + edns := new(dns.OPT) + edns.Hdr.Name = "." + edns.Hdr.Rrtype = dns.TypeOPT + edns.SetUDPSize(12345) + edns.SetDo(true) + subnetOp := new(dns.EDNS0_SUBNET) + subnetOp.Code = dns.EDNS0SUBNET + subnetOp.Family = 1 + subnetOp.SourceNetmask = tc.SourceNetmask + subnetOp.Address = net.ParseIP(tc.SubnetAddr) + edns.Option = append(edns.Option, subnetOp) + m.Extra = append(m.Extra, edns) + m.SetQuestion(tc.Question, dns.TypeA) + + in, _, err := c.Exchange(m, a.DNSAddr()) + require.NoError(t, err) + require.Len(t, in.Answer, 1) + aRec, ok := in.Answer[0].(*dns.A) + require.True(t, ok) + require.Equal(t, "127.0.0.1", aRec.A.String()) + + optRR := in.IsEdns0() + require.NotNil(t, optRR) + require.Len(t, optRR.Option, 1) + + subnet, ok := optRR.Option[0].(*dns.EDNS0_SUBNET) + require.True(t, ok) + require.Equal(t, uint16(1), subnet.Family) + require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask) + // scope set to 0 for a globally valid reply + require.Equal(t, tc.ExpectedScope, subnet.SourceScope) + require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address) + }) + } +} + func TestDNS_ReverseLookup(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), "")