agent: Shuffle DNS responses, limit records
This commit is contained in:
parent
0611db496a
commit
622aafb7c9
|
@ -6,6 +6,7 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
const (
|
||||
testQuery = "_test.consul."
|
||||
consulDomain = "consul."
|
||||
maxServiceResponses = 3 // TODO: Increase, currently a bug upstream in dns package
|
||||
)
|
||||
|
||||
// DNSServer is used to wrap an Agent and expose various
|
||||
|
@ -318,6 +320,14 @@ func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dn
|
|||
// Filter out any service nodes due to health checks
|
||||
out.Nodes = d.filterServiceNodes(out.Nodes)
|
||||
|
||||
// Perform a random shuffle
|
||||
shuffleServiceNodes(out.Nodes)
|
||||
|
||||
// Restrict the number of responses
|
||||
if len(out.Nodes) > maxServiceResponses {
|
||||
out.Nodes = out.Nodes[:maxServiceResponses]
|
||||
}
|
||||
|
||||
// Add various responses depending on the request
|
||||
qType := req.Question[0].Qtype
|
||||
if qType == dns.TypeANY || qType == dns.TypeA {
|
||||
|
@ -346,6 +356,14 @@ func (d *DNSServer) filterServiceNodes(nodes structs.CheckServiceNodes) structs.
|
|||
return nodes[:n]
|
||||
}
|
||||
|
||||
// shuffleServiceNodes does an in-place random shuffle using the Fisher-Yates algorithm
|
||||
func shuffleServiceNodes(nodes structs.CheckServiceNodes) {
|
||||
for i := len(nodes) - 1; i > 0; i-- {
|
||||
j := rand.Int31() % int32(i+1)
|
||||
nodes[i], nodes[j] = nodes[j], nodes[i]
|
||||
}
|
||||
}
|
||||
|
||||
// serviceARecords is used to add the A records for a service lookup
|
||||
func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
|
||||
handled := make(map[string]struct{})
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/miekg/dns"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -256,7 +258,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatalf("Bad: %#v", in.Answer[1])
|
||||
}
|
||||
if srvRec.Port != 12345 {
|
||||
if srvRec.Port != 12345 && srvRec.Port != 12346 {
|
||||
t.Fatalf("Bad: %#v", srvRec)
|
||||
}
|
||||
if srvRec.Target != "foo.node.dc1.consul." {
|
||||
|
@ -267,9 +269,12 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatalf("Bad: %#v", in.Answer[1])
|
||||
}
|
||||
if srvRec.Port != 12346 {
|
||||
if srvRec.Port != 12346 && srvRec.Port != 12345 {
|
||||
t.Fatalf("Bad: %#v", srvRec)
|
||||
}
|
||||
if srvRec.Port == in.Answer[1].(*dns.SRV).Port {
|
||||
t.Fatalf("should be a different port")
|
||||
}
|
||||
if srvRec.Target != "foo.node.dc1.consul." {
|
||||
t.Fatalf("Bad: %#v", srvRec)
|
||||
}
|
||||
|
@ -352,3 +357,66 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) {
|
|||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_ServiceLookup_Randomize(t *testing.T) {
|
||||
dir, srv := makeDNSServer(t)
|
||||
defer os.RemoveAll(dir)
|
||||
defer srv.agent.Shutdown()
|
||||
|
||||
// Wait for leader
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Register nodes
|
||||
for i := 0; i < 3*maxServiceResponses; i++ {
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: fmt.Sprintf("foo%d", i),
|
||||
Address: fmt.Sprintf("127.0.0.%d", i+1),
|
||||
Service: &structs.NodeService{
|
||||
Service: "web",
|
||||
Port: 8000,
|
||||
},
|
||||
}
|
||||
var out struct{}
|
||||
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the response is randomized each time
|
||||
uniques := map[string]struct{}{}
|
||||
for i := 0; i < 5; i++ {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("web.service.consul.", dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Response length should be truncated
|
||||
// We should get an SRV + A record for each response (hence 2x)
|
||||
if len(in.Answer) != 2*maxServiceResponses {
|
||||
t.Fatalf("Bad: %#v", len(in.Answer))
|
||||
}
|
||||
|
||||
// Collect all the names
|
||||
var names []string
|
||||
for _, rec := range in.Answer {
|
||||
switch v := rec.(type) {
|
||||
case *dns.SRV:
|
||||
names = append(names, v.Target)
|
||||
case *dns.A:
|
||||
names = append(names, v.A.String())
|
||||
}
|
||||
}
|
||||
nameS := strings.Join(names, "|")
|
||||
|
||||
// Check if unique
|
||||
if _, ok := uniques[nameS]; ok {
|
||||
t.Fatalf("non-unique response: %v", nameS)
|
||||
}
|
||||
uniques[nameS] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue