From 393a0eae93c6672cfa8d2d1e2fc8c414a3fffdcd Mon Sep 17 00:00:00 2001 From: Preetha Appan Date: Sun, 6 Aug 2017 18:18:30 -0500 Subject: [PATCH] Added test case with IPV6 bind address for NS records, rewrote tests to use verify library and other code review feedback --- agent/dns.go | 7 +++- agent/dns_test.go | 83 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/agent/dns.go b/agent/dns.go index f28edae4e..3245a4ea8 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -287,7 +287,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { // get server names and store them in a map to randomize the output servers := map[string]net.IP{} for name, addr := range d.agent.delegate.ServerAddrs() { - ip := net.ParseIP(strings.Split(addr, ":")[0]) + host, _, err := net.SplitHostPort(addr) + if err != nil { + d.logger.Println("[WARN] Unable to parse address %v, got error: %v", addr, err) + continue + } + ip := net.ParseIP(host) if ip == nil { continue } diff --git a/agent/dns_test.go b/agent/dns_test.go index 5d0929ef3..152f5f237 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -914,36 +914,75 @@ func TestDNS_NSRecords(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 1 { - t.Fatalf("Bad: %#v", in) + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: "consul.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x13}, + Ns: "server1.node.dc1.consul.", + }, + } + verify.Values(t, "answer", in.Answer, wantAnswer) + wantExtra := []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "server1.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4, Ttl: 0}, + A: net.ParseIP("127.0.0.1").To4(), + }, } - nsRec, ok := in.Answer[0].(*dns.NS) - if !ok { - t.Fatalf("Bad: %#v", in.Answer[0]) - } - if nsRec.Ns != "server1.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Answer[0]) - } - if nsRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Answer[0]) + verify.Values(t, "extra", in.Extra, wantExtra) + +} + +func TestDNS_NSRecords_IPV6(t *testing.T) { + t.Parallel() + cfg := TestConfig() + cfg.Domain = "CONSUL." + cfg.NodeName = "server1" + cfg.AdvertiseAddr = "::1" + cfg.AdvertiseAddrWan = "::1" + a := NewTestAgent(t.Name(), cfg) + defer a.Shutdown() + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, } - if len(in.Extra) != 1 { - t.Fatalf("Bad: %#v", in.Extra) + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) } - aRec, ok := in.Extra[0].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra) + m := new(dns.Msg) + m.SetQuestion("server1.node.dc1.consul.", dns.TypeNS) + + c := new(dns.Client) + addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) } - if aRec.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra) + + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: "consul.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x2}, + Ns: "server1.node.dc1.consul.", + }, } - if aRec.Hdr.Name != "server1.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra) + verify.Values(t, "answer", in.Answer, wantAnswer) + wantExtra := []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "server1.node.dc1.consul.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Rdlength: 0x10, Ttl: 0}, + AAAA: net.ParseIP("::1"), + }, } + verify.Values(t, "extra", in.Extra, wantExtra) + } func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { @@ -4773,7 +4812,9 @@ func TestDNSInvalidRegex(t *testing.T) { {"Valid Hostname", "testnode", false}, {"Valid Hostname", "test-node", false}, {"Invalid Hostname with special chars", "test#$$!node", true}, - {"Invalid Hostname with special chars in the end", "test-node%^", true}, + {"Invalid Hostname with special chars in the end", "testnode%^", true}, + {"Whitespace", " ", true}, + {"Only special chars", "./$", true}, } for _, test := range tests { t.Run(test.desc, func(t *testing.T) {