Avoid issue with compression of DNS messages causing overflow

This commit is contained in:
Pierre Souchay 2018-03-07 23:33:41 +01:00
parent b672707552
commit 1085d5a7b4
2 changed files with 29 additions and 21 deletions

View File

@ -718,7 +718,10 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
hasExtra := len(resp.Extra) > 0 hasExtra := len(resp.Extra) > 0
// There is some overhead, 65535 does not work // There is some overhead, 65535 does not work
maxSize := 64000 maxSize := 65533 // 64k - 2 bytes
// In order to compute properly, we have to avoid compress first
compressed := resp.Compress
resp.Compress = false
// We avoid some function calls and allocations by only handling the // We avoid some function calls and allocations by only handling the
// extra data when necessary. // extra data when necessary.
@ -745,6 +748,8 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
len(resp.Answer), originalNumRecords, resp.Len(), originalSize) len(resp.Answer), originalNumRecords, resp.Len(), originalSize)
} }
// Restore compression if any
resp.Compress = compressed
return truncated return truncated
} }

View File

@ -2800,28 +2800,31 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) {
for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} { for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} {
for _, question := range questions { for _, question := range questions {
for _, protocol := range protocols { for _, protocol := range protocols {
t.Run(fmt.Sprintf("lookup %s %s (qType:=%d)", question, protocol, qType), func(t *testing.T) { for _, compress := range []bool{true, false} {
m := new(dns.Msg) t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%b", question, protocol, qType, compress), func(t *testing.T) {
m.SetQuestion(question, dns.TypeANY) m := new(dns.Msg)
if protocol == "udp" { m.SetQuestion(question, dns.TypeANY)
m.SetEdns0(8192, true) if protocol == "udp" {
} m.SetEdns0(8192, true)
c := new(dns.Client) }
c.Net = protocol c := new(dns.Client)
in, out, err := c.Exchange(m, a.DNSAddr()) c.Net = protocol
if err != nil && err != dns.ErrTruncated { m.Compress = compress
t.Fatalf("err: %v", err) in, out, err := c.Exchange(m, a.DNSAddr())
} if err != nil && err != dns.ErrTruncated {
t.Fatalf("err: %v", err)
}
// Check for the truncate bit // Check for the truncate bit
shouldBeTruncated := numServices > 4095 shouldBeTruncated := numServices > 4095
if shouldBeTruncated != in.Truncated { if shouldBeTruncated != in.Truncated {
info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v", info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) in %v",
service, question, protocol, numServices, out) service, question, protocol, numServices, out)
t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info) t.Fatalf("Should have truncate:=%v for %s", shouldBeTruncated, info)
} }
}) })
}
} }
} }
} }