Add EDNS0 support (#3131)
This is a refactor of GH-1980. Originally I tried to do a straight rebase, but the code has changed too much.
This commit is contained in:
parent
345531deaa
commit
10c450c78d
58
agent/dns.go
58
agent/dns.go
|
@ -25,6 +25,8 @@ const (
|
|||
|
||||
// Increment a counter when requests staler than this are served
|
||||
staleCounterThreshold = 5 * time.Second
|
||||
|
||||
defaultMaxUDPSize = 512
|
||||
)
|
||||
|
||||
// DNSServer is used to wrap an Agent and expose various
|
||||
|
@ -163,6 +165,11 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
return
|
||||
}
|
||||
|
||||
// Enable EDNS if enabled
|
||||
if edns := req.IsEdns0(); edns != nil {
|
||||
m.SetEdns0(edns.UDPSize(), false)
|
||||
}
|
||||
|
||||
// Write out the complete response
|
||||
if err := resp.WriteMsg(m); err != nil {
|
||||
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||
|
@ -200,6 +207,11 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
// Dispatch the correct handler
|
||||
d.dispatch(network, req, m)
|
||||
|
||||
// Handle EDNS
|
||||
if edns := req.IsEdns0(); edns != nil {
|
||||
m.SetEdns0(edns.UDPSize(), false)
|
||||
}
|
||||
|
||||
// Write out the complete response
|
||||
if err := resp.WriteMsg(m); err != nil {
|
||||
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||
|
@ -401,16 +413,17 @@ RPC:
|
|||
|
||||
// Add the node record
|
||||
n := out.NodeServices.Node
|
||||
edns := req.IsEdns0() != nil
|
||||
addr := translateAddress(d.agent.config, datacenter, n.Address, n.TaggedAddresses)
|
||||
records := d.formatNodeRecord(out.NodeServices.Node, addr,
|
||||
req.Question[0].Name, qType, d.config.NodeTTL)
|
||||
req.Question[0].Name, qType, d.config.NodeTTL, edns)
|
||||
if records != nil {
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
}
|
||||
}
|
||||
|
||||
// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
|
||||
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
|
||||
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) {
|
||||
// Parse the IP
|
||||
ip := net.ParseIP(addr)
|
||||
var ipv4 net.IP
|
||||
|
@ -463,7 +476,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
|
|||
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
|
||||
records = append(records, rr)
|
||||
extra++
|
||||
if extra == maxRecurseRecords {
|
||||
if extra == maxRecurseRecords && !edns {
|
||||
break MORE_REC
|
||||
}
|
||||
}
|
||||
|
@ -525,9 +538,17 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
|
|||
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
|
||||
// config, and then make sure the response doesn't exceed 512 bytes. Any extra
|
||||
// records will be trimmed along with answers.
|
||||
func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
|
||||
func trimUDPResponse(config *DNSConfig, req, resp *dns.Msg) (trimmed bool) {
|
||||
numAnswers := len(resp.Answer)
|
||||
hasExtra := len(resp.Extra) > 0
|
||||
maxSize := defaultMaxUDPSize
|
||||
|
||||
// Update to the maximum edns size
|
||||
if edns := req.IsEdns0(); edns != nil {
|
||||
if size := edns.UDPSize(); size > uint16(maxSize) {
|
||||
maxSize = int(size)
|
||||
}
|
||||
}
|
||||
|
||||
// We avoid some function calls and allocations by only handling the
|
||||
// extra data when necessary.
|
||||
|
@ -539,21 +560,22 @@ func trimUDPResponse(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
|
|||
|
||||
// This cuts UDP responses to a useful but limited number of responses.
|
||||
maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
|
||||
if numAnswers > maxAnswers {
|
||||
if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers {
|
||||
resp.Answer = resp.Answer[:maxAnswers]
|
||||
if hasExtra {
|
||||
syncExtra(index, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// This enforces the hard limit of 512 bytes per the RFC. Note that we
|
||||
// temporarily switch to uncompressed so that we limit to a response
|
||||
// that will not exceed 512 bytes uncompressed, which is more
|
||||
// conservative and will allow our responses to be compliant even if
|
||||
// some downstream server uncompresses them.
|
||||
// This enforces the given limit on the number bytes. The default is 512 as
|
||||
// per the RFC, but EDNS0 allows for the user to specify larger sizes. Note
|
||||
// that we temporarily switch to uncompressed so that we limit to a response
|
||||
// that will not exceed 512 bytes uncompressed, which is more conservative and
|
||||
// will allow our responses to be compliant even if some downstream server
|
||||
// uncompresses them.
|
||||
compress := resp.Compress
|
||||
resp.Compress = false
|
||||
for len(resp.Answer) > 0 && resp.Len() > 512 {
|
||||
for len(resp.Answer) > 0 && resp.Len() > maxSize {
|
||||
resp.Answer = resp.Answer[:len(resp.Answer)-1]
|
||||
if hasExtra {
|
||||
syncExtra(index, resp)
|
||||
|
@ -629,7 +651,7 @@ RPC:
|
|||
|
||||
// If the network is not TCP, restrict the number of responses
|
||||
if network != "tcp" {
|
||||
wasTrimmed := trimUDPResponse(d.config, resp)
|
||||
wasTrimmed := trimUDPResponse(d.config, req, resp)
|
||||
|
||||
// Flag that there are more records to return in the UDP response
|
||||
if wasTrimmed && d.config.EnableTruncate {
|
||||
|
@ -738,7 +760,7 @@ RPC:
|
|||
|
||||
// If the network is not TCP, restrict the number of responses.
|
||||
if network != "tcp" {
|
||||
wasTrimmed := trimUDPResponse(d.config, resp)
|
||||
wasTrimmed := trimUDPResponse(d.config, req, resp)
|
||||
|
||||
// Flag that there are more records to return in the UDP response
|
||||
if wasTrimmed && d.config.EnableTruncate {
|
||||
|
@ -758,6 +780,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
|
|||
qName := req.Question[0].Name
|
||||
qType := req.Question[0].Qtype
|
||||
handled := make(map[string]struct{})
|
||||
edns := req.IsEdns0() != nil
|
||||
|
||||
for _, node := range nodes {
|
||||
// Start with the translated address but use the service address,
|
||||
|
@ -781,7 +804,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
|
|||
handled[addr] = struct{}{}
|
||||
|
||||
// Add the node record
|
||||
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl)
|
||||
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
|
||||
if records != nil {
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
}
|
||||
|
@ -791,6 +814,8 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
|
|||
// serviceARecords is used to add the SRV records for a service lookup
|
||||
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
|
||||
handled := make(map[string]struct{})
|
||||
edns := req.IsEdns0() != nil
|
||||
|
||||
for _, node := range nodes {
|
||||
// Avoid duplicate entries, possible if a node has
|
||||
// the same service the same port, etc.
|
||||
|
@ -823,7 +848,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
|
|||
}
|
||||
|
||||
// Add the extra record
|
||||
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl)
|
||||
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
|
||||
if len(records) > 0 {
|
||||
// Use the node address if it doesn't differ from the service address
|
||||
if addr == node.Node.Address {
|
||||
|
@ -903,6 +928,9 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
m.Compress = !d.config.DisableCompression
|
||||
m.RecursionAvailable = true
|
||||
m.SetRcode(req, dns.RcodeServerFailure)
|
||||
if edns := req.IsEdns0(); edns != nil {
|
||||
m.SetEdns0(edns.UDPSize(), false)
|
||||
}
|
||||
resp.WriteMsg(m)
|
||||
}
|
||||
|
||||
|
|
|
@ -367,6 +367,46 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDNS_EDNS0(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), nil)
|
||||
defer a.Shutdown()
|
||||
|
||||
// Register node
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "foo",
|
||||
Address: "127.0.0.2",
|
||||
}
|
||||
|
||||
var out struct{}
|
||||
if err := a.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetEdns0(12345, true)
|
||||
m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS)
|
||||
in, _, err := c.Exchange(m, addr.String())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 1 {
|
||||
t.Fatalf("empty lookup: %#v", in)
|
||||
}
|
||||
edns := in.IsEdns0()
|
||||
if edns == nil {
|
||||
t.Fatalf("empty edns: %#v", in)
|
||||
}
|
||||
if edns.UDPSize() != 12345 {
|
||||
t.Fatalf("bad edns size: %d", edns.UDPSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_ReverseLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewTestAgent(t.Name(), nil)
|
||||
|
@ -4001,6 +4041,7 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
|
|||
|
||||
func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := &dns.Msg{}
|
||||
resp := &dns.Msg{
|
||||
Answer: []dns.RR{
|
||||
&dns.SRV{
|
||||
|
@ -4025,7 +4066,7 @@ func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
|
|||
}
|
||||
|
||||
config := &DefaultConfig().DNSConfig
|
||||
if trimmed := trimUDPResponse(config, resp); trimmed {
|
||||
if trimmed := trimUDPResponse(config, req, resp); trimmed {
|
||||
t.Fatalf("Bad %#v", *resp)
|
||||
}
|
||||
|
||||
|
@ -4060,7 +4101,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
|
|||
t.Parallel()
|
||||
config := &DefaultConfig().DNSConfig
|
||||
|
||||
resp, expected := &dns.Msg{}, &dns.Msg{}
|
||||
req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
|
||||
for i := 0; i < config.UDPAnswerLimit+1; i++ {
|
||||
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
|
||||
srv := &dns.SRV{
|
||||
|
@ -4088,7 +4129,7 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
if trimmed := trimUDPResponse(config, resp); !trimmed {
|
||||
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
|
||||
t.Fatalf("Bad %#v", *resp)
|
||||
}
|
||||
if !reflect.DeepEqual(resp, expected) {
|
||||
|
@ -4100,7 +4141,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
|
|||
t.Parallel()
|
||||
config := &DefaultConfig().DNSConfig
|
||||
|
||||
resp := &dns.Msg{}
|
||||
req, resp := &dns.Msg{}, &dns.Msg{}
|
||||
for i := 0; i < 100; i++ {
|
||||
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
|
||||
srv := &dns.SRV{
|
||||
|
@ -4126,7 +4167,7 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
|
|||
|
||||
// We don't know the exact trim, but we know the resulting answer
|
||||
// data should match its extra data.
|
||||
if trimmed := trimUDPResponse(config, resp); !trimmed {
|
||||
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
|
||||
t.Fatalf("Bad %#v", *resp)
|
||||
}
|
||||
if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) {
|
||||
|
@ -4149,6 +4190,85 @@ func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDNS_trimUDPResponse_TrimSizeEDNS(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &DefaultConfig().DNSConfig
|
||||
|
||||
req, resp := &dns.Msg{}, &dns.Msg{}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 150+i)
|
||||
srv := &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "redis-cache-redis.service.consul.",
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Target: target,
|
||||
}
|
||||
a := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: target,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 150+i)),
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, srv)
|
||||
resp.Extra = append(resp.Extra, a)
|
||||
}
|
||||
|
||||
// Copy over to a new slice since we are trimming both.
|
||||
reqEDNS, respEDNS := &dns.Msg{}, &dns.Msg{}
|
||||
reqEDNS.SetEdns0(2048, true)
|
||||
respEDNS.Answer = append(respEDNS.Answer, resp.Answer...)
|
||||
respEDNS.Extra = append(respEDNS.Extra, resp.Extra...)
|
||||
|
||||
// Trim each response
|
||||
if trimmed := trimUDPResponse(config, req, resp); !trimmed {
|
||||
t.Errorf("expected response to be trimmed: %#v", resp)
|
||||
}
|
||||
if trimmed := trimUDPResponse(config, reqEDNS, respEDNS); !trimmed {
|
||||
t.Errorf("expected edns to be trimmed: %#v", resp)
|
||||
}
|
||||
|
||||
// Check answer lengths
|
||||
if len(resp.Answer) == 0 || len(resp.Answer) != len(resp.Extra) {
|
||||
t.Errorf("bad response answer length: %#v", resp)
|
||||
}
|
||||
if len(respEDNS.Answer) == 0 || len(respEDNS.Answer) != len(respEDNS.Extra) {
|
||||
t.Errorf("bad edns answer length: %#v", resp)
|
||||
}
|
||||
|
||||
// Due to the compression, we can't check exact equality of sizes, but we can
|
||||
// make two requests and ensure that the edns one returns a larger payload
|
||||
// than the non-edns0 one.
|
||||
if len(resp.Answer) >= len(respEDNS.Answer) {
|
||||
t.Errorf("expected edns have larger answer: %#v\n%#v", resp, respEDNS)
|
||||
}
|
||||
if len(resp.Extra) >= len(respEDNS.Extra) {
|
||||
t.Errorf("expected edns have larger extra: %#v\n%#v", resp, respEDNS)
|
||||
}
|
||||
|
||||
// Verify that the things point where they should
|
||||
for i := range resp.Answer {
|
||||
srv, ok := resp.Answer[i].(*dns.SRV)
|
||||
if !ok {
|
||||
t.Errorf("%d should be an SRV", i)
|
||||
}
|
||||
|
||||
a, ok := resp.Extra[i].(*dns.A)
|
||||
if !ok {
|
||||
t.Errorf("%d should be an A", i)
|
||||
}
|
||||
|
||||
if srv.Target != a.Header().Name {
|
||||
t.Errorf("%d: bad %#v vs. %#v", i, srv, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_syncExtra(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := &dns.Msg{
|
||||
|
@ -4377,8 +4497,8 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) {
|
|||
t.Parallel()
|
||||
config := &DefaultConfig().DNSConfig
|
||||
|
||||
m := dns.Msg{}
|
||||
trimUDPResponse(config, &m)
|
||||
req, m := dns.Msg{}, dns.Msg{}
|
||||
trimUDPResponse(config, &req, &m)
|
||||
if m.Compress {
|
||||
t.Fatalf("compression should be off")
|
||||
}
|
||||
|
@ -4386,7 +4506,7 @@ func TestDNS_Compression_trimUDPResponse(t *testing.T) {
|
|||
// The trim function temporarily turns off compression, so we need to
|
||||
// make sure the setting gets restored properly.
|
||||
m.Compress = true
|
||||
trimUDPResponse(config, &m)
|
||||
trimUDPResponse(config, &req, &m)
|
||||
if !m.Compress {
|
||||
t.Fatalf("compression should be on")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue