Add EDNS0 support (#3131)

This is a refactor of GH-1980. Originally I tried to do a straight
rebase, but the code has changed too much.
This commit is contained in:
Seth Vargo 2017-06-14 16:22:54 -07:00 committed by James Phillips
parent 345531deaa
commit 10c450c78d
2 changed files with 171 additions and 23 deletions

View File

@ -25,6 +25,8 @@ const (
// Increment a counter when requests staler than this are served
staleCounterThreshold = 5 * time.Second
defaultMaxUDPSize = 512
)
// DNSServer is used to wrap an Agent and expose various
@ -163,6 +165,11 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
return
}
// Enable EDNS if enabled
if edns := req.IsEdns0(); edns != nil {
m.SetEdns0(edns.UDPSize(), false)
}
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
@ -200,6 +207,11 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
// Dispatch the correct handler
d.dispatch(network, req, m)
// Handle EDNS
if edns := req.IsEdns0(); edns != nil {
m.SetEdns0(edns.UDPSize(), false)
}
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
@ -401,16 +413,17 @@ RPC:
// Add the node record
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := translateAddress(d.agent.config, datacenter, n.Address, n.TaggedAddresses)
records := d.formatNodeRecord(out.NodeServices.Node, addr,
req.Question[0].Name, qType, d.config.NodeTTL)
req.Question[0].Name, qType, d.config.NodeTTL, edns)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}
// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
@ -463,7 +476,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
records = append(records, rr)
extra++
if extra == maxRecurseRecords {
if extra == maxRecurseRecords && !edns {
break MORE_REC
}
}
@ -525,9 +538,17 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
// records will be trimmed along with answers.
func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
func trimUDPResponse(config *DNSConfig, req, resp *dns.Msg) (trimmed bool) {
numAnswers := len(resp.Answer)
hasExtra := len(resp.Extra) > 0
maxSize := defaultMaxUDPSize
// Update to the maximum edns size
if edns := req.IsEdns0(); edns != nil {
if size := edns.UDPSize(); size > uint16(maxSize) {
maxSize = int(size)
}
}
// We avoid some function calls and allocations by only handling the
// extra data when necessary.
@ -539,21 +560,22 @@ func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
// This cuts UDP responses to a useful but limited number of responses.
maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
if numAnswers > maxAnswers {
if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers {
resp.Answer = resp.Answer[:maxAnswers]
if hasExtra {
syncExtra(index, resp)
}
}
// This enforces the hard limit of 512 bytes per the RFC. Note that we
// temporarily switch to uncompressed so that we limit to a response
// that will not exceed 512 bytes uncompressed, which is more
// conservative and will allow our responses to be compliant even if
// some downstream server uncompresses them.
// This enforces the given limit on the number bytes. The default is 512 as
// per the RFC, but EDNS0 allows for the user to specify larger sizes. Note
// that we temporarily switch to uncompressed so that we limit to a response
// that will not exceed 512 bytes uncompressed, which is more conservative and
// will allow our responses to be compliant even if some downstream server
// uncompresses them.
compress := resp.Compress
resp.Compress = false
for len(resp.Answer) > 0 && resp.Len() > 512 {
for len(resp.Answer) > 0 && resp.Len() > maxSize {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
if hasExtra {
syncExtra(index, resp)
@ -629,7 +651,7 @@ RPC:
// If the network is not TCP, restrict the number of responses
if network != "tcp" {
wasTrimmed := trimUDPResponse(d.config, resp)
wasTrimmed := trimUDPResponse(d.config, req, resp)
// Flag that there are more records to return in the UDP response
if wasTrimmed && d.config.EnableTruncate {
@ -738,7 +760,7 @@ RPC:
// If the network is not TCP, restrict the number of responses.
if network != "tcp" {
wasTrimmed := trimUDPResponse(d.config, resp)
wasTrimmed := trimUDPResponse(d.config, req, resp)
// Flag that there are more records to return in the UDP response
if wasTrimmed && d.config.EnableTruncate {
@ -758,6 +780,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil
for _, node := range nodes {
// Start with the translated address but use the service address,
@ -781,7 +804,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
handled[addr] = struct{}{}
// Add the node record
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl)
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
@ -791,6 +814,8 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil
for _, node := range nodes {
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.
@ -823,7 +848,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}
// Add the extra record
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl)
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
@ -903,6 +928,9 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
m.Compress = !d.config.DisableCompression
m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure)
if edns := req.IsEdns0(); edns != nil {
m.SetEdns0(edns.UDPSize(), false)
}
resp.WriteMsg(m)
}

View File

@ -367,6 +367,46 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) {
}
}
func TestDNS_EDNS0(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), nil)
defer a.Shutdown()
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.2",
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetEdns0(12345, true)
m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY)
c := new(dns.Client)
addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS)
in, _, err := c.Exchange(m, addr.String())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("empty lookup: %#v", in)
}
edns := in.IsEdns0()
if edns == nil {
t.Fatalf("empty edns: %#v", in)
}
if edns.UDPSize() != 12345 {
t.Fatalf("bad edns size: %d", edns.UDPSize())
}
}
func TestDNS_ReverseLookup(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), nil)
@ -4001,6 +4041,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
t.Parallel()
req := &dns.Msg{}
resp := &dns.Msg{
Answer: []dns.RR{
&dns.SRV{
@ -4025,7 +4066,7 @@ func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
}
config := &DefaultConfig().DNSConfig
if trimmed := trimUDPResponse(config, resp); trimmed {
if trimmed := trimUDPResponse(config, req, resp); trimmed {
t.Fatalf("Bad %#v", *resp)
}
@ -4060,7 +4101,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
t.Parallel()
config := &DefaultConfig().DNSConfig
resp, expected := &dns.Msg{}, &dns.Msg{}
req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
for i := 0; i < config.UDPAnswerLimit+1; i++ {
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
srv := &dns.SRV{
@ -4088,7 +4129,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
}
}
if trimmed := trimUDPResponse(config, resp); !trimmed {
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
t.Fatalf("Bad %#v", *resp)
}
if !reflect.DeepEqual(resp, expected) {
@ -4100,7 +4141,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
t.Parallel()
config := &DefaultConfig().DNSConfig
resp := &dns.Msg{}
req, resp := &dns.Msg{}, &dns.Msg{}
for i := 0; i < 100; i++ {
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
srv := &dns.SRV{
@ -4126,7 +4167,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
// We don't know the exact trim, but we know the resulting answer
// data should match its extra data.
if trimmed := trimUDPResponse(config, resp); !trimmed {
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
t.Fatalf("Bad %#v", *resp)
}
if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) {
@ -4149,6 +4190,85 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
}
}
func TestDNS_trimUDPResponse_TrimSizeEDNS(t *testing.T) {
t.Parallel()
config := &DefaultConfig().DNSConfig
req, resp := &dns.Msg{}, &dns.Msg{}
for i := 0; i < 100; i++ {
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 150+i)
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Target: target,
}
a := &dns.A{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 150+i)),
}
resp.Answer = append(resp.Answer, srv)
resp.Extra = append(resp.Extra, a)
}
// Copy over to a new slice since we are trimming both.
reqEDNS, respEDNS := &dns.Msg{}, &dns.Msg{}
reqEDNS.SetEdns0(2048, true)
respEDNS.Answer = append(respEDNS.Answer, resp.Answer...)
respEDNS.Extra = append(respEDNS.Extra, resp.Extra...)
// Trim each response
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
t.Errorf("expected response to be trimmed: %#v", resp)
}
if trimmed := trimUDPResponse(config, reqEDNS, respEDNS); !trimmed {
t.Errorf("expected edns to be trimmed: %#v", resp)
}
// Check answer lengths
if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) {
t.Errorf("bad response answer length: %#v", resp)
}
if len(respEDNS.Answer) == 0 || len(respEDNS.Answer) != len(respEDNS.Extra) {
t.Errorf("bad edns answer length: %#v", resp)
}
// Due to the compression, we can't check exact equality of sizes, but we can
// make two requests and ensure that the edns one returns a larger payload
// than the non-edns0 one.
if len(resp.Answer) >= len(respEDNS.Answer) {
t.Errorf("expected edns have larger answer: %#v\n%#v", resp, respEDNS)
}
if len(resp.Extra) >= len(respEDNS.Extra) {
t.Errorf("expected edns have larger extra: %#v\n%#v", resp, respEDNS)
}
// Verify that the things point where they should
for i := range resp.Answer {
srv, ok := resp.Answer[i].(*dns.SRV)
if !ok {
t.Errorf("%d should be an SRV", i)
}
a, ok := resp.Extra[i].(*dns.A)
if !ok {
t.Errorf("%d should be an A", i)
}
if srv.Target != a.Header().Name {
t.Errorf("%d: bad %#v vs. %#v", i, srv, a)
}
}
}
func TestDNS_syncExtra(t *testing.T) {
t.Parallel()
resp := &dns.Msg{
@ -4377,8 +4497,8 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) {
t.Parallel()
config := &DefaultConfig().DNSConfig
m := dns.Msg{}
trimUDPResponse(config, &m)
req, m := dns.Msg{}, dns.Msg{}
trimUDPResponse(config, &req, &m)
if m.Compress {
t.Fatalf("compression should be off")
}
@ -4386,7 +4506,7 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) {
// The trim function temporarily turns off compression, so we need to
// make sure the setting gets restored properly.
m.Compress = true
trimUDPResponse(config, &m)
trimUDPResponse(config, &req, &m)
if !m.Compress {
t.Fatalf("compression should be on")
}