From 10c450c78db2b220a852df236190712c40e256f4 Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Wed, 14 Jun 2017 16:22:54 -0700 Subject: [PATCH] Add EDNS0 support (#3131) This is a refactor of GH-1980. Originally I tried to do a straight rebase, but the code has changed too much. --- agent/dns.go | 58 +++++++++++++++----- agent/dns_test.go | 136 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 171 insertions(+), 23 deletions(-) diff --git a/agent/dns.go b/agent/dns.go index 53bdb4158..e9dbdb878 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -25,6 +25,8 @@ const ( // Increment a counter when requests staler than this are served staleCounterThreshold = 5 * time.Second + + defaultMaxUDPSize = 512 ) // DNSServer is used to wrap an Agent and expose various @@ -163,6 +165,11 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { return } + // Enable EDNS if enabled + if edns := req.IsEdns0(); edns != nil { + m.SetEdns0(edns.UDPSize(), false) + } + // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) @@ -200,6 +207,11 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { // Dispatch the correct handler d.dispatch(network, req, m) + // Handle EDNS + if edns := req.IsEdns0(); edns != nil { + m.SetEdns0(edns.UDPSize(), false) + } + // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) @@ -401,16 +413,17 @@ RPC: // Add the node record n := out.NodeServices.Node + edns := req.IsEdns0() != nil addr := translateAddress(d.agent.config, datacenter, n.Address, n.TaggedAddresses) records := d.formatNodeRecord(out.NodeServices.Node, addr, - req.Question[0].Name, qType, d.config.NodeTTL) + req.Question[0].Name, qType, d.config.NodeTTL, edns) if records != nil { resp.Answer = append(resp.Answer, records...) } } // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record -func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) { +func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) { // Parse the IP ip := net.ParseIP(addr) var ipv4 net.IP @@ -463,7 +476,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: records = append(records, rr) extra++ - if extra == maxRecurseRecords { + if extra == maxRecurseRecords && !edns { break MORE_REC } } @@ -525,9 +538,17 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) { // 1035. Enforce an arbitrary limit that can be further ratcheted down by // config, and then make sure the response doesn't exceed 512 bytes. Any extra // records will be trimmed along with answers. -func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) { +func trimUDPResponse(config *DNSConfig, req, resp *dns.Msg) (trimmed bool) { numAnswers := len(resp.Answer) hasExtra := len(resp.Extra) > 0 + maxSize := defaultMaxUDPSize + + // Update to the maximum edns size + if edns := req.IsEdns0(); edns != nil { + if size := edns.UDPSize(); size > uint16(maxSize) { + maxSize = int(size) + } + } // We avoid some function calls and allocations by only handling the // extra data when necessary. @@ -539,21 +560,22 @@ func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) { // This cuts UDP responses to a useful but limited number of responses. maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit) - if numAnswers > maxAnswers { + if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers { resp.Answer = resp.Answer[:maxAnswers] if hasExtra { syncExtra(index, resp) } } - // This enforces the hard limit of 512 bytes per the RFC. Note that we - // temporarily switch to uncompressed so that we limit to a response - // that will not exceed 512 bytes uncompressed, which is more - // conservative and will allow our responses to be compliant even if - // some downstream server uncompresses them. + // This enforces the given limit on the number bytes. The default is 512 as + // per the RFC, but EDNS0 allows for the user to specify larger sizes. Note + // that we temporarily switch to uncompressed so that we limit to a response + // that will not exceed 512 bytes uncompressed, which is more conservative and + // will allow our responses to be compliant even if some downstream server + // uncompresses them. compress := resp.Compress resp.Compress = false - for len(resp.Answer) > 0 && resp.Len() > 512 { + for len(resp.Answer) > 0 && resp.Len() > maxSize { resp.Answer = resp.Answer[:len(resp.Answer)-1] if hasExtra { syncExtra(index, resp) @@ -629,7 +651,7 @@ RPC: // If the network is not TCP, restrict the number of responses if network != "tcp" { - wasTrimmed := trimUDPResponse(d.config, resp) + wasTrimmed := trimUDPResponse(d.config, req, resp) // Flag that there are more records to return in the UDP response if wasTrimmed && d.config.EnableTruncate { @@ -738,7 +760,7 @@ RPC: // If the network is not TCP, restrict the number of responses. if network != "tcp" { - wasTrimmed := trimUDPResponse(d.config, resp) + wasTrimmed := trimUDPResponse(d.config, req, resp) // Flag that there are more records to return in the UDP response if wasTrimmed && d.config.EnableTruncate { @@ -758,6 +780,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode qName := req.Question[0].Name qType := req.Question[0].Qtype handled := make(map[string]struct{}) + edns := req.IsEdns0() != nil for _, node := range nodes { // Start with the translated address but use the service address, @@ -781,7 +804,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode handled[addr] = struct{}{} // Add the node record - records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl) + records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns) if records != nil { resp.Answer = append(resp.Answer, records...) } @@ -791,6 +814,8 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode // serviceARecords is used to add the SRV records for a service lookup func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) { handled := make(map[string]struct{}) + edns := req.IsEdns0() != nil + for _, node := range nodes { // Avoid duplicate entries, possible if a node has // the same service the same port, etc. @@ -823,7 +848,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl) + records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns) if len(records) > 0 { // Use the node address if it doesn't differ from the service address if addr == node.Node.Address { @@ -903,6 +928,9 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { m.Compress = !d.config.DisableCompression m.RecursionAvailable = true m.SetRcode(req, dns.RcodeServerFailure) + if edns := req.IsEdns0(); edns != nil { + m.SetEdns0(edns.UDPSize(), false) + } resp.WriteMsg(m) } diff --git a/agent/dns_test.go b/agent/dns_test.go index 139b45ffd..fd69dcb5d 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -367,6 +367,46 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { } } +func TestDNS_EDNS0(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), nil) + defer a.Shutdown() + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.2", + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetEdns0(12345, true) + m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY) + + c := new(dns.Client) + addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) != 1 { + t.Fatalf("empty lookup: %#v", in) + } + edns := in.IsEdns0() + if edns == nil { + t.Fatalf("empty edns: %#v", in) + } + if edns.UDPSize() != 12345 { + t.Fatalf("bad edns size: %d", edns.UDPSize()) + } +} + func TestDNS_ReverseLookup(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), nil) @@ -4001,6 +4041,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) { func TestDNS_trimUDPResponse_NoTrim(t *testing.T) { t.Parallel() + req := &dns.Msg{} resp := &dns.Msg{ Answer: []dns.RR{ &dns.SRV{ @@ -4025,7 +4066,7 @@ func TestDNS_trimUDPResponse_NoTrim(t *testing.T) { } config := &DefaultConfig().DNSConfig - if trimmed := trimUDPResponse(config, resp); trimmed { + if trimmed := trimUDPResponse(config, req, resp); trimmed { t.Fatalf("Bad %#v", *resp) } @@ -4060,7 +4101,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) { t.Parallel() config := &DefaultConfig().DNSConfig - resp, expected := &dns.Msg{}, &dns.Msg{} + req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{} for i := 0; i < config.UDPAnswerLimit+1; i++ { target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) srv := &dns.SRV{ @@ -4088,7 +4129,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) { } } - if trimmed := trimUDPResponse(config, resp); !trimmed { + if trimmed := trimUDPResponse(config, req, resp); !trimmed { t.Fatalf("Bad %#v", *resp) } if !reflect.DeepEqual(resp, expected) { @@ -4100,7 +4141,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) { t.Parallel() config := &DefaultConfig().DNSConfig - resp := &dns.Msg{} + req, resp := &dns.Msg{}, &dns.Msg{} for i := 0; i < 100; i++ { target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) srv := &dns.SRV{ @@ -4126,7 +4167,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) { // We don't know the exact trim, but we know the resulting answer // data should match its extra data. - if trimmed := trimUDPResponse(config, resp); !trimmed { + if trimmed := trimUDPResponse(config, req, resp); !trimmed { t.Fatalf("Bad %#v", *resp) } if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) { @@ -4149,6 +4190,85 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) { } } +func TestDNS_trimUDPResponse_TrimSizeEDNS(t *testing.T) { + t.Parallel() + config := &DefaultConfig().DNSConfig + + req, resp := &dns.Msg{}, &dns.Msg{} + + for i := 0; i < 100; i++ { + target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 150+i) + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "redis-cache-redis.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + }, + Target: target, + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 150+i)), + } + + resp.Answer = append(resp.Answer, srv) + resp.Extra = append(resp.Extra, a) + } + + // Copy over to a new slice since we are trimming both. + reqEDNS, respEDNS := &dns.Msg{}, &dns.Msg{} + reqEDNS.SetEdns0(2048, true) + respEDNS.Answer = append(respEDNS.Answer, resp.Answer...) + respEDNS.Extra = append(respEDNS.Extra, resp.Extra...) + + // Trim each response + if trimmed := trimUDPResponse(config, req, resp); !trimmed { + t.Errorf("expected response to be trimmed: %#v", resp) + } + if trimmed := trimUDPResponse(config, reqEDNS, respEDNS); !trimmed { + t.Errorf("expected edns to be trimmed: %#v", resp) + } + + // Check answer lengths + if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) { + t.Errorf("bad response answer length: %#v", resp) + } + if len(respEDNS.Answer) == 0 || len(respEDNS.Answer) != len(respEDNS.Extra) { + t.Errorf("bad edns answer length: %#v", resp) + } + + // Due to the compression, we can't check exact equality of sizes, but we can + // make two requests and ensure that the edns one returns a larger payload + // than the non-edns0 one. + if len(resp.Answer) >= len(respEDNS.Answer) { + t.Errorf("expected edns have larger answer: %#v\n%#v", resp, respEDNS) + } + if len(resp.Extra) >= len(respEDNS.Extra) { + t.Errorf("expected edns have larger extra: %#v\n%#v", resp, respEDNS) + } + + // Verify that the things point where they should + for i := range resp.Answer { + srv, ok := resp.Answer[i].(*dns.SRV) + if !ok { + t.Errorf("%d should be an SRV", i) + } + + a, ok := resp.Extra[i].(*dns.A) + if !ok { + t.Errorf("%d should be an A", i) + } + + if srv.Target != a.Header().Name { + t.Errorf("%d: bad %#v vs. %#v", i, srv, a) + } + } +} + func TestDNS_syncExtra(t *testing.T) { t.Parallel() resp := &dns.Msg{ @@ -4377,8 +4497,8 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) { t.Parallel() config := &DefaultConfig().DNSConfig - m := dns.Msg{} - trimUDPResponse(config, &m) + req, m := dns.Msg{}, dns.Msg{} + trimUDPResponse(config, &req, &m) if m.Compress { t.Fatalf("compression should be off") } @@ -4386,7 +4506,7 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) { // The trim function temporarily turns off compression, so we need to // make sure the setting gets restored properly. m.Compress = true - trimUDPResponse(config, &m) + trimUDPResponse(config, &req, &m) if !m.Compress { t.Fatalf("compression should be on") }