diff --git a/agent/dns.go b/agent/dns.go index 21e9e6714..2750ed6b0 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -718,7 +718,10 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) { func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { hasExtra := len(resp.Extra) > 0 // There is some overhead, 65535 does not work - maxSize := 64000 + maxSize := 65533 // 64k - 2 bytes + // In order to compute properly, we have to avoid compress first + compressed := resp.Compress + resp.Compress = false // We avoid some function calls and allocations by only handling the // extra data when necessary. @@ -745,6 +748,8 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { len(resp.Answer), originalNumRecords, resp.Len(), originalSize) } + // Restore compression if any + resp.Compress = compressed return truncated } diff --git a/agent/dns_test.go b/agent/dns_test.go index 18da89439..cf9571de0 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -2800,28 +2800,31 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} { for _, question := range questions { for _, protocol := range protocols { - t.Run(fmt.Sprintf("lookup %s %s (qType:=%d)", question, protocol, qType), func(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion(question, dns.TypeANY) - if protocol == "udp" { - m.SetEdns0(8192, true) - } - c := new(dns.Client) - c.Net = protocol - in, out, err := c.Exchange(m, a.DNSAddr()) - if err != nil && err != dns.ErrTruncated { - t.Fatalf("err: %v", err) - } + for _, compress := range []bool{true, false} { + t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%b", question, protocol, qType, compress), func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeANY) + if protocol == "udp" { + m.SetEdns0(8192, true) + } + c := new(dns.Client) + c.Net = protocol + m.Compress = compress + in, out, err := c.Exchange(m, a.DNSAddr()) + if err != nil && err != dns.ErrTruncated { + t.Fatalf("err: %v", err) + } - // Check for the truncate bit - shouldBeTruncated := numServices > 4095 + // Check for the truncate bit + shouldBeTruncated := numServices > 4095 - if shouldBeTruncated != in.Truncated { - info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v", - service, question, protocol, numServices, out) - t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info) - } - }) + if shouldBeTruncated != in.Truncated { + info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v", + service, question, protocol, numServices, out) + t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info) + } + }) + } } } }