diff --git a/agent/dns.go b/agent/dns.go index 98cb5a9b5..6211e71b9 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -717,6 +717,33 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) { resp.Extra = extra } +// dnsBinaryTruncate find the optimal number of records using a fast binary search and return +// it in order to return a DNS answer lower than maxSize parameter. +func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasExtra bool) int { + originalAnswser := resp.Answer + startIndex := 0 + endIndex := len(resp.Answer) + 1 + for endIndex-startIndex > 1 { + median := startIndex + (endIndex-startIndex)/2 + + resp.Answer = originalAnswser[:median] + if hasExtra { + syncExtra(index, resp) + } + aLen := resp.Len() + if aLen <= maxSize { + if maxSize-aLen < 10 { + // We are good, increasing will go out of bounds + return median + } + startIndex = median + } else { + endIndex = median + } + } + return startIndex +} + // trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit // of DNS responses func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { @@ -752,7 +779,13 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // This enforces the given limit on 64k, the max limit for DNS messages for len(resp.Answer) > 0 && resp.Len() > maxSize { truncated = true - resp.Answer = resp.Answer[:len(resp.Answer)-1] + // More than 100 bytes, find with a binary search + if resp.Len()-maxSize > 100 { + bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) + resp.Answer = resp.Answer[:bestIndex] + } else { + resp.Answer = resp.Answer[:len(resp.Answer)-1] + } if hasExtra { syncExtra(index, resp) } @@ -809,7 +842,13 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { compress := resp.Compress resp.Compress = false for len(resp.Answer) > 0 && resp.Len() > maxSize { - resp.Answer = resp.Answer[:len(resp.Answer)-1] + // More than 100 bytes, find with a binary search + if resp.Len()-maxSize > 100 { + bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) + resp.Answer = resp.Answer[:bestIndex] + } else { + resp.Answer = resp.Answer[:len(resp.Answer)-1] + } if hasExtra { syncExtra(index, resp) } diff --git a/agent/dns_test.go b/agent/dns_test.go index 83bf2ca87..0770c92a0 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -2983,6 +2983,46 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } } +func TestBinarySearch(t *testing.T) { + t.Parallel() + msgSrc := new(dns.Msg) + msgSrc.Compress = true + msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV) + + for i := 0; i < 5000; i++ { + target := fmt.Sprintf("host-redis-%d-%d.test.acme.com.node.dc1.consul.", i/256, i%256) + msgSrc.Answer = append(msgSrc.Answer, &dns.SRV{Hdr: dns.RR_Header{Name: "redis.service.consul.", Class: 1, Rrtype: dns.TypeSRV, Ttl: 0x3c}, Port: 0x4c57, Target: target}) + msgSrc.Extra = append(msgSrc.Extra, &dns.CNAME{Hdr: dns.RR_Header{Name: target, Class: 1, Rrtype: dns.TypeCNAME, Ttl: 0x3c}, Target: fmt.Sprintf("fx.168.%d.%d.", i/256, i%256)}) + } + for _, compress := range []bool{true, false} { + for idx, maxSize := range []int{12, 256, 512, 8192, 65535} { + t.Run(fmt.Sprintf("binarySearch %d", maxSize), func(t *testing.T) { + msg := new(dns.Msg) + msgSrc.Compress = compress + msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV) + msg.Answer = msgSrc.Answer + msg.Extra = msgSrc.Extra + index := make(map[string]dns.RR, len(msg.Extra)) + indexRRs(msg.Extra, index) + blen := dnsBinaryTruncate(msg, maxSize, index, true) + msg.Answer = msg.Answer[:blen] + syncExtra(index, msg) + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + } + if predicted < len(buf) { + t.Fatalf("Bug in DNS library: %d != %d", predicted, len(buf)) + } + if len(buf) > maxSize || (idx != 0 && len(buf) < 16) { + t.Fatalf("bad[%d]: %d > %d", idx, len(buf), maxSize) + } + }) + } + } +} + func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), ` @@ -3057,9 +3097,8 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { if err != nil && err != dns.ErrTruncated { t.Fatalf("err: %v", err) } - // Check for the truncate bit - shouldBeTruncated := numServices > 4095 + shouldBeTruncated := numServices > 5000 if shouldBeTruncated != in.Truncated || len(in.Answer) > 2000 || len(in.Answer) < 1 || in.Len() > 65535 { info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) sz:= %d in %v", @@ -3311,17 +3350,17 @@ func testDNSServiceLookupResponseLimits(t *testing.T, answerLimit int, qType uin case 0: if (expectedService > 0 && len(in.Answer) != expectedService) || (expectedService < -1 && len(in.Answer) < lib.AbsInt(expectedService)) { - return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question) + return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len()) } case 1: if (expectedQuery > 0 && len(in.Answer) != expectedQuery) || (expectedQuery < -1 && len(in.Answer) < lib.AbsInt(expectedQuery)) { - return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question) + return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len()) } case 2: if (expectedQueryID > 0 && len(in.Answer) != expectedQueryID) || (expectedQueryID < -1 && len(in.Answer) < lib.AbsInt(expectedQueryID)) { - return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question) + return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len()) } default: panic("abort") @@ -3473,7 +3512,7 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) { t.Parallel() err := checkDNSService(t, test.numNodesTotal, test.aRecordLimit, qType, test.expectedAResults, test.udpSize, test.udpAnswerLimit) if err != nil { - t.Errorf("Expected lookup %s to pass: %v", test.name, err) + t.Fatalf("Expected lookup %s to pass: %v", test.name, err) } }) } @@ -3482,7 +3521,7 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) { t.Parallel() err := checkDNSService(t, test.expectedSRVResults, test.aRecordLimit, dns.TypeSRV, test.numNodesTotal, test.udpSize, test.udpAnswerLimit) if err != nil { - t.Errorf("Expected service SRV lookup %s to pass: %v", test.name, err) + t.Fatalf("Expected service SRV lookup %s to pass: %v", test.name, err) } }) } @@ -3528,27 +3567,27 @@ func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) { } for _, test := range tests { test := test // capture loop var - t.Run("A lookup", func(t *testing.T) { + t.Run(fmt.Sprintf("A lookup %v", test), func(t *testing.T) { t.Parallel() ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeA, test.expectedAService, test.expectedAQuery, test.expectedAQueryID) if !ok { - t.Errorf("Expected service A lookup %s to pass: %v", test.name, err) + t.Fatalf("Expected service A lookup %s to pass: %v", test.name, err) } }) - t.Run("AAAA lookup", func(t *testing.T) { + t.Run(fmt.Sprintf("AAAA lookup %v", test), func(t *testing.T) { t.Parallel() ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeAAAA, test.expectedAAAAService, test.expectedAAAAQuery, test.expectedAAAAQueryID) if !ok { - t.Errorf("Expected service AAAA lookup %s to pass: %v", test.name, err) + t.Fatalf("Expected service AAAA lookup %s to pass: %v", test.name, err) } }) - t.Run("ANY lookup", func(t *testing.T) { + t.Run(fmt.Sprintf("ANY lookup %v", test), func(t *testing.T) { t.Parallel() ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeANY, test.expectedANYService, test.expectedANYQuery, test.expectedANYQueryID) if !ok { - t.Errorf("Expected service ANY lookup %s to pass: %v", test.name, err) + t.Fatalf("Expected service ANY lookup %s to pass: %v", test.name, err) } }) }