diff --git a/command/agent/dns.go b/command/agent/dns.go index 3150683de..2356866c4 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -511,16 +511,6 @@ RPC: // Perform a random shuffle shuffleServiceNodes(out.Nodes) - // If the network is not TCP, restrict the number of responses - if network != "tcp" && len(out.Nodes) > maxServiceResponses { - out.Nodes = out.Nodes[:maxServiceResponses] - - // Flag that there are more records to return in the UDP response - if d.config.EnableTruncate { - resp.Truncated = true - } - } - // Add various responses depending on the request qType := req.Question[0].Qtype d.serviceNodeRecords(out.Nodes, req, resp, ttl) @@ -529,13 +519,22 @@ RPC: d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl) } + // If the network is not TCP, restrict the number of responses + if network != "tcp" && len(resp.Answer) > maxServiceResponses { + resp.Answer = resp.Answer[:maxServiceResponses] + + // Flag that there are more records to return in the UDP response + if d.config.EnableTruncate { + resp.Truncated = true + } + } + // If the answer is empty, return not found if len(resp.Answer) == 0 { d.addSOA(d.domain, resp) resp.SetRcode(req, dns.RcodeNameError) return } - } // filterServiceNodes is used to filter out nodes that are failing diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index e4fb25643..cdd5c3609 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -1416,6 +1416,72 @@ func TestDNS_ServiceLookup_Truncate(t *testing.T) { } } +func TestDNS_ServiceLookup_MaxResponses(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + testutil.WaitForLeader(t, srv.agent.RPC, "dc1") + + // Register nodes + for i := 0; i < 6*maxServiceResponses; i++ { + nodeAddress := fmt.Sprintf("127.0.0.%d", i+1) + if i > 3 { + nodeAddress = fmt.Sprintf("fe80::%d", i+1) + } + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: fmt.Sprintf("foo%d", i), + Address: nodeAddress, + Service: &structs.NodeService{ + Service: "web", + Port: 8000, + }, + } + + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Ensure the response is randomized each time. + m := new(dns.Msg) + m.SetQuestion("web.service.consul.", dns.TypeANY) + + addr, _ := srv.agent.config.ClientListener("", srv.agent.config.Ports.DNS) + c := new(dns.Client) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 3 { + t.Fatalf("should receive 3 answers for ANY") + } + + m.SetQuestion("web.service.consul.", dns.TypeA) + in, _, err = c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 3 { + t.Fatalf("should receive 3 answers for A") + } + + m.SetQuestion("web.service.consul.", dns.TypeAAAA) + in, _, err = c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 3 { + t.Fatalf("should receive 3 answers for AAAA") + } + +} + func TestDNS_ServiceLookup_CNAME(t *testing.T) { recursor := makeRecursor(t, []dns.RR{ dnsCNAME("www.google.com", "google.com"),