diff --git a/command/agent/dns.go b/command/agent/dns.go index e36d6a49d..cc4929178 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -16,7 +16,7 @@ import ( const ( maxServiceResponses = 3 // For UDP only - maxRecurseRecords = 3 + maxRecurseRecords = 5 ) // DNSServer is used to wrap an Agent and expose various @@ -426,8 +426,9 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy // Recurse more := d.resolveCNAME(cnRec.Target) + extra := 0 MORE_REC: - for extra, rr := range more { + for _, rr := range more { switch rr.Header().Rrtype { case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: records = append(records, rr) diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index d6d5e3f0f..90f7f5826 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "os" - "reflect" "strings" "testing" "time" @@ -14,23 +13,72 @@ import ( "github.com/miekg/dns" ) -func makeDNSServer(t *testing.T) (string, *DNSServer) { - config := &DNSConfig{} - return makeDNSServerConfig(t, config) -} - -func makeDNSServerConfig(t *testing.T, config *DNSConfig) (string, *DNSServer) { +func makeDNSServer(t *testing.T, config *DNSConfig, recursor *dns.Server) (string, *DNSServer) { + if config == nil { + config = &DNSConfig{} + } + recursors := []string{} + if recursor != nil { + recursors = append(recursors, recursor.Addr) + } conf := nextConfig() addr, _ := conf.ClientListener(conf.Addresses.DNS, conf.Ports.DNS) dir, agent := makeAgent(t, conf) server, err := NewDNSServer(agent, config, agent.logOutput, - conf.Domain, addr.String(), []string{"8.8.8.8:53"}) + conf.Domain, addr.String(), recursors) if err != nil { t.Fatalf("err: %v", err) } return dir, server } +// makeRecursor creates a generic DNS server which always returns +// the provided reply. This is useful for mocking a DNS recursor with +// an expected result. +func makeRecursor(t *testing.T, answer []dns.RR) *dns.Server { + dnsConf := nextConfig() + dnsAddr := fmt.Sprintf("%s:%d", dnsConf.Addresses.DNS, dnsConf.Ports.DNS) + mux := dns.NewServeMux() + mux.HandleFunc(".", func(resp dns.ResponseWriter, msg *dns.Msg) { + ans := &dns.Msg{Answer: answer[:]} + ans.SetReply(msg) + if err := resp.WriteMsg(ans); err != nil { + t.Fatalf("err: %s", err) + } + }) + server := &dns.Server{ + Addr: dnsAddr, + Net: "udp", + Handler: mux, + } + go server.ListenAndServe() + return server +} + +// dnsCNAME returns a DNS CNAME record struct +func dnsCNAME(src, dest string) *dns.CNAME { + return &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(src), + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + }, + Target: dns.Fqdn(dest), + } +} + +// dnsA returns a DNS A record struct +func dnsA(src, dest string) *dns.A { + return &dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(src), + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(dest), + } +} + func TestRecursorAddr(t *testing.T) { addr, err := recursorAddr("8.8.8.8") if err != nil { @@ -42,7 +90,7 @@ func TestRecursorAddr(t *testing.T) { } func TestDNS_NodeLookup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -112,7 +160,7 @@ func TestDNS_NodeLookup(t *testing.T) { } func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -146,7 +194,7 @@ func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { } func TestDNS_NodeLookup_PeriodName(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -188,7 +236,7 @@ func TestDNS_NodeLookup_PeriodName(t *testing.T) { } func TestDNS_NodeLookup_AAAA(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -233,7 +281,13 @@ func TestDNS_NodeLookup_AAAA(t *testing.T) { } func TestDNS_NodeLookup_CNAME(t *testing.T) { - dir, srv := makeDNSServer(t) + recursor := makeRecursor(t, []dns.RR{ + dnsCNAME("www.google.com", "google.com"), + dnsA("google.com", "1.2.3.4"), + }) + defer recursor.Shutdown() + + dir, srv := makeDNSServer(t, nil, recursor) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -261,8 +315,8 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { t.Fatalf("err: %v", err) } - // Should have the CNAME record + a few A records - if len(in.Answer) < 2 { + // Should have the service record, CNAME record + A record + if len(in.Answer) != 3 { t.Fatalf("Bad: %#v", in) } @@ -279,7 +333,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { } func TestDNS_ReverseLookup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -321,7 +375,7 @@ func TestDNS_ReverseLookup(t *testing.T) { } func TestDNS_ReverseLookup_CustomDomain(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() srv.domain = dns.Fqdn("custom") @@ -364,7 +418,7 @@ func TestDNS_ReverseLookup_CustomDomain(t *testing.T) { } func TestDNS_ReverseLookup_IPV6(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -406,7 +460,7 @@ func TestDNS_ReverseLookup_IPV6(t *testing.T) { } func TestDNS_ServiceLookup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -473,7 +527,7 @@ func TestDNS_ServiceLookup(t *testing.T) { } func TestDNS_ServiceLookup_ServiceAddress(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -541,7 +595,7 @@ func TestDNS_ServiceLookup_ServiceAddress(t *testing.T) { } func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -580,7 +634,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { } func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -641,7 +695,7 @@ func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { } func TestDNS_ServiceLookup_Dedup(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -718,7 +772,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { } func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -823,7 +877,10 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { } func TestDNS_Recurse(t *testing.T) { - dir, srv := makeDNSServer(t) + recursor := makeRecursor(t, []dns.RR{dnsA("apple.com", "1.2.3.4")}) + defer recursor.Shutdown() + + dir, srv := makeDNSServer(t, nil, recursor) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -831,7 +888,6 @@ func TestDNS_Recurse(t *testing.T) { m.SetQuestion("apple.com.", dns.TypeANY) c := new(dns.Client) - c.Net = "tcp" addr, _ := srv.agent.config.ClientListener("", srv.agent.config.Ports.DNS) in, _, err := c.Exchange(m, addr.String()) if err != nil { @@ -847,7 +903,7 @@ func TestDNS_Recurse(t *testing.T) { } func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -978,7 +1034,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { } func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { - dir, srv := makeDNSServerConfig(t, &DNSConfig{OnlyPassing: true}) + dir, srv := makeDNSServer(t, &DNSConfig{OnlyPassing: true}, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1094,7 +1150,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { } func TestDNS_ServiceLookup_Randomize(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1162,7 +1218,13 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } func TestDNS_ServiceLookup_CNAME(t *testing.T) { - dir, srv := makeDNSServer(t) + recursor := makeRecursor(t, []dns.RR{ + dnsCNAME("www.google.com", "google.com"), + dnsA("google.com", "1.2.3.4"), + }) + defer recursor.Shutdown() + + dir, srv := makeDNSServer(t, nil, recursor) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1194,11 +1256,12 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) < 2 { + // Service CNAME, google CNAME, google A record + if len(in.Answer) != 3 { t.Fatalf("Bad: %#v", in) } - // Should have google CNAME + // Should have service CNAME cnRec, ok := in.Answer[0].(*dns.CNAME) if !ok { t.Fatalf("Bad: %#v", in.Answer[0]) @@ -1207,22 +1270,35 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } + // Should have google CNAME + cnRec, ok = in.Answer[1].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[1]) + } + if cnRec.Target != "google.com." { + t.Fatalf("Bad: %#v", in.Answer[1]) + } + // Check we recursively resolve - for _, ans := range in.Answer[1:] { - if _, ok := ans.(*dns.A); !ok { - t.Fatalf("Bad: %#v", ans) - } + if _, ok := in.Answer[2].(*dns.A); !ok { + t.Fatalf("Bad: %#v", in.Answer[2]) } } func TestDNS_NodeLookup_TTL(t *testing.T) { + recursor := makeRecursor(t, []dns.RR{ + dnsCNAME("www.google.com", "google.com"), + dnsA("google.com", "1.2.3.4"), + }) + defer recursor.Shutdown() + config := &DNSConfig{ NodeTTL: 10 * time.Second, AllowStale: true, MaxStale: time.Second, } - dir, srv := makeDNSServerConfig(t, config) + dir, srv := makeDNSServer(t, config, recursor) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1344,7 +1420,7 @@ func TestDNS_ServiceLookup_TTL(t *testing.T) { MaxStale: time.Second, } - dir, srv := makeDNSServerConfig(t, config) + dir, srv := makeDNSServer(t, config, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1439,7 +1515,7 @@ func TestDNS_ServiceLookup_TTL(t *testing.T) { } func TestDNS_ServiceLookup_SRV_RFC(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1506,7 +1582,7 @@ func TestDNS_ServiceLookup_SRV_RFC(t *testing.T) { } func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) { - dir, srv := makeDNSServer(t) + dir, srv := makeDNSServer(t, nil, nil) defer os.RemoveAll(dir) defer srv.agent.Shutdown() @@ -1571,133 +1647,3 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) { t.Fatalf("Bad: %#v", in.Extra[0]) } } - -func TestDNS_CNAME_recurse(t *testing.T) { - // Create our recursor - Consul will recurse to this - dnsConf := nextConfig() - dnsAddr := fmt.Sprintf("%s:%d", dnsConf.Addresses.DNS, dnsConf.Ports.DNS) - mux := dns.NewServeMux() - mux.HandleFunc(".", func(resp dns.ResponseWriter, msg *dns.Msg) { - - cnResp := func(src, target string) *dns.CNAME { - return &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: src, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - }, - Target: target, - } - } - - // Create the answer - ans := &dns.Msg{} - ans.SetReply(msg) - ans.Answer = append(ans.Answer, - cnResp("a.example.com.", "b.example.com."), - cnResp("b.example.com.", "c.example.com."), - cnResp("c.example.com.", "d.example.com."), - &dns.A{ - Hdr: dns.RR_Header{ - Name: "d.example.com.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: net.ParseIP("1.2.3.4"), - }) - - // Write the answer back to the client - if err := resp.WriteMsg(ans); err != nil { - t.Fatalf("err: %s", err) - } - }) - server := &dns.Server{ - Addr: dnsAddr, - Net: "udp", - Handler: mux, - } - go server.ListenAndServe() - defer server.Shutdown() - - // Create the Consul server - dconf := &DNSConfig{} - config := nextConfig() - addr, _ := config.ClientListener(config.Addresses.DNS, config.Ports.DNS) - dir, agent := makeAgent(t, config) - defer os.RemoveAll(dir) - defer agent.Shutdown() - - srv, err := NewDNSServer(agent, dconf, agent.logOutput, - config.Domain, addr.String(), []string{dnsAddr}) - if err != nil { - t.Fatalf("err: %v", err) - } - - testutil.WaitForLeader(t, srv.agent.RPC, "dc1") - - // Register a service with a recursing CNAME as the address - args := &structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "a.example.com", - Service: &structs.NodeService{ - Service: "db", - Tags: []string{"master"}, - Address: "a.example.com", - Port: 12345, - }, - } - - var out struct{} - if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { - t.Fatalf("err: %v", err) - } - - // Create the DNS query against the Consul server - m := new(dns.Msg) - m.SetQuestion("db.service.consul.", dns.TypeA) - - c := new(dns.Client) - c.Net = "tcp" - in, _, err := c.Exchange(m, addr.String()) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Should have all 3 CNAMES and the A record - if len(in.Answer) != 4 { - t.Fatalf("Bad: %#v", in) - } - - // Check all the records - expected := []dns.RR{ - &dns.CNAME{ - Hdr: dns.RR_Header{ - Rrtype: dns.TypeCNAME, - }, - Target: "abc", - }, - &dns.CNAME{ - Hdr: dns.RR_Header{ - Rrtype: dns.TypeCNAME, - }, - Target: "abc", - }, - &dns.CNAME{ - Hdr: dns.RR_Header{ - Rrtype: dns.TypeCNAME, - }, - Target: "abc", - }, - &dns.A{ - Hdr: dns.RR_Header{ - Rrtype: dns.TypeCNAME, - }, - A: net.ParseIP("1.2.3.4"), - }, - } - - if !reflect.DeepEqual(expected, in.Answer) { - t.Fatalf("Bad: %v %v", expected, in.Answer) - } -}