Add support for DNS config hot-reload (#4875)

The DNS config parameters `recursors` and `dns_config.*` are now hot
reloaded on SIGHUP or `consul reload` and do not need an agent restart
to be modified.
Config is stored in an atomic.Value and loaded at the beginning of each
request. Reloading only affects requests that start _after_ the
reload. Ongoing requests are not affected. To match the current
behavior the recursor handler is loaded and unloaded as needed on config
reload.
This commit is contained in:
Aestek 2019-04-24 20:11:54 +02:00 committed by Matt Keeler
parent d51fd740bf
commit 6762eeb03c
4 changed files with 402 additions and 169 deletions

View File

@ -3579,6 +3579,12 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error {
a.loadLimits(newCfg)
for _, s := range a.dnsServers {
if err := s.ReloadConfig(newCfg); err != nil {
return fmt.Errorf("Failed reloading dns config : %v", err)
}
}
// create the config for the rpc server/client
consulCfg, err := a.consulConfig()
if err != nil {

View File

@ -46,7 +46,7 @@ type dnsSOAConfig struct {
Refresh uint32 // 3600 by default
Retry uint32 // 600
Expire uint32 // 86400
Minttl uint32 // 0,
Minttl uint32 // 0
}
type dnsConfig struct {
@ -60,12 +60,17 @@ type dnsConfig struct {
NodeTTL time.Duration
OnlyPassing bool
RecursorTimeout time.Duration
Recursors []string
SegmentName string
ServiceTTL map[string]time.Duration
UDPAnswerLimit int
ARecordLimit int
NodeMetaTXT bool
dnsSOAConfig dnsSOAConfig
SOAConfig dnsSOAConfig
// TTLRadix sets service TTLs by prefix, eg: "database-*"
TTLRadix *radix.Tree
// TTLStict sets TTLs to service by full name match. It Has higher priority than TTLRadix
TTLStrict map[string]time.Duration
DisableCompression bool
}
// DNSServer is used to wrap an Agent and expose various
@ -73,63 +78,39 @@ type dnsConfig struct {
type DNSServer struct {
*dns.Server
agent *Agent
config *dnsConfig
mux *dns.ServeMux
domain string
recursors []string
logger *log.Logger
// Those are handling prefix lookups
ttlRadix *radix.Tree
ttlStrict map[string]time.Duration
// disableCompression is the config.DisableCompression flag that can
// be safely changed at runtime. It always contains a bool and is
// initialized with the value from config.DisableCompression.
disableCompression atomic.Value
// config stores the config as an atomic value (for hot-reloading). It is always of type *dnsConfig
config atomic.Value
// recursorEnabled stores whever the recursor handler is enabled as an atomic flag.
// the recursor handler is only enabled if recursors are configured. This flag is used during config hot-reloading
recursorEnabled uint32
}
func NewDNSServer(a *Agent) (*DNSServer, error) {
var recursors []string
for _, r := range a.config.DNSRecursors {
ra, err := recursorAddr(r)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
recursors = append(recursors, ra)
}
// Make sure domain is FQDN, make it case insensitive for ServeMux
domain := dns.Fqdn(strings.ToLower(a.config.DNSDomain))
dnscfg := GetDNSConfig(a.config)
srv := &DNSServer{
agent: a,
config: dnscfg,
domain: domain,
logger: a.logger,
recursors: recursors,
ttlRadix: radix.New(),
ttlStrict: make(map[string]time.Duration),
}
if dnscfg.ServiceTTL != nil {
for key, ttl := range dnscfg.ServiceTTL {
// All suffix with '*' are put in radix
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
srv.ttlRadix.Insert(key[:len(key)-1], ttl)
} else {
srv.ttlStrict[key] = ttl
cfg, err := GetDNSConfig(a.config)
if err != nil {
return nil, err
}
}
}
srv.disableCompression.Store(a.config.DNSDisableCompression)
srv.config.Store(cfg)
return srv, nil
}
// GetDNSConfig takes global config and creates the config used by DNS server
func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig {
return &dnsConfig{
func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) {
cfg := &dnsConfig{
AllowStale: conf.DNSAllowStale,
ARecordLimit: conf.DNSARecordLimit,
Datacenter: conf.Datacenter,
@ -140,48 +121,73 @@ func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig {
OnlyPassing: conf.DNSOnlyPassing,
RecursorTimeout: conf.DNSRecursorTimeout,
SegmentName: conf.SegmentName,
ServiceTTL: conf.DNSServiceTTL,
UDPAnswerLimit: conf.DNSUDPAnswerLimit,
NodeMetaTXT: conf.DNSNodeMetaTXT,
DisableCompression: conf.DNSDisableCompression,
UseCache: conf.DNSUseCache,
CacheMaxAge: conf.DNSCacheMaxAge,
dnsSOAConfig: dnsSOAConfig{
SOAConfig: dnsSOAConfig{
Expire: conf.DNSSOA.Expire,
Minttl: conf.DNSSOA.Minttl,
Refresh: conf.DNSSOA.Refresh,
Retry: conf.DNSSOA.Retry,
},
}
if conf.DNSServiceTTL != nil {
cfg.TTLRadix = radix.New()
cfg.TTLStrict = make(map[string]time.Duration)
for key, ttl := range conf.DNSServiceTTL {
// All suffix with '*' are put in radix
// This include '*' that will match anything
if strings.HasSuffix(key, "*") {
cfg.TTLRadix.Insert(key[:len(key)-1], ttl)
} else {
cfg.TTLStrict[key] = ttl
}
}
}
for _, r := range conf.DNSRecursors {
ra, err := recursorAddr(r)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
cfg.Recursors = append(cfg.Recursors, ra)
}
return cfg, nil
}
// GetTTLForService Find the TTL for a given service.
// return ttl, true if found, 0, false otherwise
func (d *DNSServer) GetTTLForService(service string) (time.Duration, bool) {
if d.config.ServiceTTL != nil {
ttl, ok := d.ttlStrict[service]
func (cfg *dnsConfig) GetTTLForService(service string) (time.Duration, bool) {
if cfg.TTLStrict != nil {
ttl, ok := cfg.TTLStrict[service]
if ok {
return ttl, true
}
_, ttlRaw, ok := d.ttlRadix.LongestPrefix(service)
}
if cfg.TTLRadix != nil {
_, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service)
if ok {
return ttlRaw.(time.Duration), true
}
}
return time.Duration(0), false
return 0, false
}
func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error {
mux := dns.NewServeMux()
mux.HandleFunc("arpa.", d.handlePtr)
mux.HandleFunc(d.domain, d.handleQuery)
if len(d.recursors) > 0 {
mux.HandleFunc(".", d.handleRecurse)
}
cfg := d.config.Load().(*dnsConfig)
d.mux = dns.NewServeMux()
d.mux.HandleFunc("arpa.", d.handlePtr)
d.mux.HandleFunc(d.domain, d.handleQuery)
d.toggleRecursorHandlerFromConfig(cfg)
d.Server = &dns.Server{
Addr: addr,
Net: network,
Handler: mux,
Handler: d.mux,
NotifyStartedFunc: notif,
}
if network == "udp" {
@ -190,6 +196,34 @@ func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error {
return d.Server.ListenAndServe()
}
// toggleRecursorHandlerFromConfig enables or disables the recursor handler based on config idempotently
func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) {
shouldEnable := len(cfg.Recursors) > 0
if shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 0, 1) {
d.mux.HandleFunc(".", d.handleRecurse)
d.logger.Println("[DEBUG] dns: recursor enabled")
return
}
if !shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 1, 0) {
d.mux.HandleRemove(".")
d.logger.Println("[DEBUG] dns: recursor disabled")
return
}
}
// ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS*
func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error {
cfg, err := GetDNSConfig(newCfg)
if err != nil {
return err
}
d.config.Store(cfg)
d.toggleRecursorHandlerFromConfig(cfg)
return nil
}
// setEDNS is used to set the responses EDNS size headers and
// possibly the ECS headers as well if they were present in the
// original request
@ -258,16 +292,18 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
resp.RemoteAddr().Network())
}(time.Now())
cfg := d.config.Load().(*dnsConfig)
// Setup the message response
m := new(dns.Msg)
m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool)
m.Compress = !cfg.DisableCompression
m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0)
m.RecursionAvailable = (len(cfg.Recursors) > 0)
// Only add the SOA if requested
if req.Question[0].Qtype == dns.TypeSOA {
d.addSOA(m)
d.addSOA(cfg, m)
}
datacenter := d.agent.config.Datacenter
@ -279,7 +315,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale,
AllowStale: cfg.AllowStale,
},
}
var out structs.IndexedNodes
@ -308,7 +344,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale,
AllowStale: cfg.AllowStale,
},
ServiceAddress: serviceAddress,
}
@ -360,25 +396,27 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
network = "tcp"
}
cfg := d.config.Load().(*dnsConfig)
// Setup the message response
m := new(dns.Msg)
m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool)
m.Compress = !cfg.DisableCompression
m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0)
m.RecursionAvailable = (len(cfg.Recursors) > 0)
ecsGlobal := true
switch req.Question[0].Qtype {
case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa())
ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa(cfg))
m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess)
case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = ns
m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess)
@ -398,34 +436,34 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
}
}
func (d *DNSServer) soa() *dns.SOA {
func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{
Name: d.domain,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
// Has to be consistent with MinTTL to avoid invalidation
Ttl: d.config.dnsSOAConfig.Minttl,
Ttl: cfg.SOAConfig.Minttl,
},
Ns: "ns." + d.domain,
Serial: uint32(time.Now().Unix()),
Mbox: "hostmaster." + d.domain,
Refresh: d.config.dnsSOAConfig.Refresh,
Retry: d.config.dnsSOAConfig.Retry,
Expire: d.config.dnsSOAConfig.Expire,
Minttl: d.config.dnsSOAConfig.Minttl,
Refresh: cfg.SOAConfig.Refresh,
Retry: cfg.SOAConfig.Retry,
Expire: cfg.SOAConfig.Expire,
Minttl: cfg.SOAConfig.Minttl,
}
}
// addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(msg *dns.Msg) {
msg.Ns = append(msg.Ns, d.soa())
func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) {
msg.Ns = append(msg.Ns, d.soa(cfg))
}
// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(cfg, d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
@ -456,15 +494,15 @@ func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR,
Name: d.domain,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second),
Ttl: uint32(cfg.NodeTTL / time.Second),
},
Ns: fqdn,
}
ns = append(ns, nsrr)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel, d.config.NodeMetaTXT)
glue, meta := d.formatNodeRecord(cfg, nil, addr, fqdn, dns.TypeANY, cfg.NodeTTL, edns, maxRecursionLevel, cfg.NodeMetaTXT)
extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT {
if meta != nil && cfg.NodeMetaTXT {
extra = append(extra, meta...)
}
@ -499,6 +537,8 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
// Provide a flag for remembering whether the datacenter name was parsed already.
var dcParsed bool
cfg := d.config.Load().(*dnsConfig)
// The last label is either "node", "service", "query", "_<protocol>", or a datacenter name
PARSE:
n := len(labels)
@ -531,7 +571,7 @@ PARSE:
}
// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)
d.serviceLookup(cfg, network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)
// Consul 0.3 and prior format for SRV queries
} else {
@ -543,7 +583,7 @@ PARSE:
}
// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
d.serviceLookup(cfg, network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
}
case "connect":
@ -552,7 +592,7 @@ PARSE:
}
// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)
d.serviceLookup(cfg, network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)
case "node":
if n == 1 {
@ -561,7 +601,7 @@ PARSE:
// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel)
d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel)
case "query":
if n == 1 {
@ -571,7 +611,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
case "addr":
if n != 2 {
@ -591,7 +631,7 @@ PARSE:
Name: qName + d.domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second),
Ttl: uint32(cfg.NodeTTL / time.Second),
},
A: ip,
})
@ -607,7 +647,7 @@ PARSE:
Name: qName + d.domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(d.config.NodeTTL / time.Second),
Ttl: uint32(cfg.NodeTTL / time.Second),
},
AAAA: ip,
})
@ -638,13 +678,13 @@ PARSE:
return
INVALID:
d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
d.addSOA(resp)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
@ -657,10 +697,10 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
Node: node,
QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale,
AllowStale: cfg.AllowStale,
},
}
out, err := d.lookupNode(args)
out, err := d.lookupNode(cfg, args)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
@ -669,7 +709,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
// If we have no address, return not found!
if out.NodeServices == nil {
d.addSOA(resp)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
@ -679,7 +719,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
if qType == dns.TypeANY || qType == dns.TypeTXT {
generateMeta = true
metaInAnswer = true
} else if d.config.NodeMetaTXT {
} else if cfg.NodeMetaTXT {
generateMeta = true
}
@ -687,21 +727,21 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel, generateMeta)
records, meta := d.formatNodeRecord(cfg, out.NodeServices.Node, addr, req.Question[0].Name, qType, cfg.NodeTTL, edns, maxRecursionLevel, generateMeta)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
if meta != nil && metaInAnswer && generateMeta {
resp.Answer = append(resp.Answer, meta...)
} else if meta != nil && generateMeta {
} else if meta != nil && cfg.NodeMetaTXT {
resp.Extra = append(resp.Extra, meta...)
}
}
func (d *DNSServer) lookupNode(args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
var out structs.IndexedNodeServices
useCache := d.config.UseCache
useCache := cfg.UseCache
RPC:
if useCache {
raw, _, err := d.agent.cache.Get(cachetype.NodeServicesName, args)
@ -722,7 +762,7 @@ RPC:
// Verify that request is not too stale, redo the request
if args.AllowStale {
if out.LastContact > d.config.MaxStale {
if out.LastContact > cfg.MaxStale {
args.AllowStale = false
useCache = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
@ -761,7 +801,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) {
func (d *DNSServer) formatNodeRecord(cfg *dnsConfig, node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
@ -807,7 +847,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec)
// Recurse
more := d.resolveCNAME(cnRec.Target, maxRecursionLevel)
more := d.resolveCNAME(cfg, cnRec.Target, maxRecursionLevel)
extra := 0
MORE_REC:
for _, rr := range more {
@ -1036,21 +1076,21 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
}
// trimDNSResponse will trim the response for UDP and TCP
func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed bool) {
func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) {
if network != "tcp" {
trimmed = trimUDPResponse(req, resp, d.config.UDPAnswerLimit)
trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit)
} else {
trimmed = d.trimTCPResponse(req, resp)
}
// Flag that there are more records to return in the UDP response
if trimmed && d.config.EnableTruncate {
if trimmed && cfg.EnableTruncate {
resp.Truncated = true
}
return trimmed
}
// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
@ -1059,14 +1099,14 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
TagFilter: tag != "",
QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale,
MaxAge: d.config.CacheMaxAge,
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
},
}
var out structs.IndexedCheckServiceNodes
if d.config.UseCache {
if cfg.UseCache {
raw, m, err := d.agent.cache.Get(cachetype.HealthServicesName, &args)
if err != nil {
return out, err
@ -1090,7 +1130,7 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
}
// redo the request the response was too stale
if args.AllowStale && out.LastContact > d.config.MaxStale {
if args.AllowStale && out.LastContact > cfg.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
@ -1103,13 +1143,13 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
// We copy the slice to avoid modifying the result if it comes from the cache
nodes := make(structs.CheckServiceNodes, len(out.Nodes))
copy(nodes, out.Nodes)
out.Nodes = nodes.Filter(d.config.OnlyPassing)
out.Nodes = nodes.Filter(cfg.OnlyPassing)
return out, nil
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel)
func (d *DNSServer) serviceLookup(cfg *dnsConfig, network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(cfg, datacenter, service, tag, connect, maxRecursionLevel)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
@ -1118,7 +1158,7 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
d.addSOA(resp)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
@ -1127,21 +1167,21 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
out.Nodes.Shuffle()
// Determine the TTL
ttl, _ := d.GetTTLForService(service)
ttl, _ := cfg.GetTTLForService(service)
// Add various responses depending on the request
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
d.serviceSRVRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
d.serviceNodeRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}
d.trimDNSResponse(network, req, resp)
d.trimDNSResponse(cfg, network, req, resp)
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(resp)
d.addSOA(cfg, resp)
return
}
}
@ -1164,15 +1204,15 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}
// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
QueryIDOrName: query,
QueryOptions: structs.QueryOptions{
Token: d.agent.tokens.UserToken(),
AllowStale: d.config.AllowStale,
MaxAge: d.config.CacheMaxAge,
AllowStale: cfg.AllowStale,
MaxAge: cfg.CacheMaxAge,
},
// Always pass the local agent through. In the DNS interface, there
@ -1201,13 +1241,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
}
}
out, err := d.lookupPreparedQuery(args)
out, err := d.lookupPreparedQuery(cfg, args)
// If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information.
if err != nil && err.Error() == consul.ErrQueryNotFound.Error() {
d.addSOA(resp)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
} else if err != nil {
@ -1234,13 +1274,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
if err != nil {
d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query)
}
} else if d.config.ServiceTTL != nil {
ttl, _ = d.GetTTLForService(out.Service)
} else {
ttl, _ = cfg.GetTTLForService(out.Service)
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
d.addSOA(resp)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
@ -1248,25 +1288,25 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot
// Add various responses depending on the request.
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
d.serviceSRVRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}
d.trimDNSResponse(network, req, resp)
d.trimDNSResponse(cfg, network, req, resp)
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(resp)
d.addSOA(cfg, resp)
return
}
}
func (d *DNSServer) lookupPreparedQuery(args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
var out structs.PreparedQueryExecuteResponse
RPC:
if d.config.UseCache {
if cfg.UseCache {
raw, m, err := d.agent.cache.Get(cachetype.PreparedQueryName, &args)
if err != nil {
return nil, err
@ -1288,7 +1328,7 @@ RPC:
// Verify that request is not too stale, redo the request.
if args.AllowStale {
if out.LastContact > d.config.MaxStale {
if out.LastContact > cfg.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
@ -1301,7 +1341,7 @@ RPC:
}
// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
@ -1335,13 +1375,13 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
if qType == dns.TypeANY || qType == dns.TypeTXT {
generateMeta = true
metaInAnswer = true
} else if d.config.NodeMetaTXT {
} else if cfg.NodeMetaTXT {
generateMeta = true
}
// Add the node record
had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta)
records, meta := d.formatNodeRecord(cfg, node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta)
if records != nil {
switch records[0].(type) {
case *dns.CNAME:
@ -1365,7 +1405,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
if had_answer {
count++
if count == d.config.ARecordLimit {
if count == cfg.ARecordLimit {
// We stop only if greater than 0 or we reached the limit
return
}
@ -1423,7 +1463,7 @@ func findWeight(node structs.CheckServiceNode) int {
}
// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
func (d *DNSServer) serviceSRVRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil
@ -1460,7 +1500,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}
// Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, d.config.NodeMetaTXT)
records, meta := d.formatNodeRecord(cfg, node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, cfg.NodeMetaTXT)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
@ -1491,7 +1531,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}
}
if meta != nil && d.config.NodeMetaTXT {
if meta != nil && cfg.NodeMetaTXT {
resp.Extra = append(resp.Extra, meta...)
}
}
@ -1500,6 +1540,8 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
cfg := d.config.Load().(*dnsConfig)
q := req.Question[0]
network := "udp"
defer func(s time.Time) {
@ -1514,11 +1556,11 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}
// Recursively resolve
c := &dns.Client{Net: network, Timeout: d.config.RecursorTimeout}
c := &dns.Client{Net: network, Timeout: cfg.RecursorTimeout}
var r *dns.Msg
var rtt time.Duration
var err error
for _, recursor := range d.recursors {
for _, recursor := range cfg.Recursors {
r, rtt, err = c.Exchange(req, recursor)
// Check if the response is valid and has the desired Response code
if r != nil && (r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError) {
@ -1530,7 +1572,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
// Compress the response; we don't know if the incoming
// response was compressed or not, so by not compressing
// we might generate an invalid packet on the way out.
r.Compress = !d.disableCompression.Load().(bool)
r.Compress = !cfg.DisableCompression
// Forward the response
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v) Recursor queried: %v", q, rtt, recursor)
@ -1547,7 +1589,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
q, resp.RemoteAddr().String(), resp.RemoteAddr().Network())
m := &dns.Msg{}
m.SetReply(req)
m.Compress = !d.disableCompression.Load().(bool)
m.Compress = !cfg.DisableCompression
m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure)
if edns := req.IsEdns0(); edns != nil {
@ -1557,7 +1599,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}
// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain is
// already converted
@ -1577,7 +1619,7 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
}
// Do nothing if we don't have a recursor
if len(d.recursors) == 0 {
if len(cfg.Recursors) == 0 {
return nil
}
@ -1586,11 +1628,11 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
m.SetQuestion(name, dns.TypeA)
// Make a DNS lookup request
c := &dns.Client{Net: "udp", Timeout: d.config.RecursorTimeout}
c := &dns.Client{Net: "udp", Timeout: cfg.RecursorTimeout}
var r *dns.Msg
var rtt time.Duration
var err error
for _, recursor := range d.recursors {
for _, recursor := range cfg.Recursors {
r, rtt, err = c.Exchange(m, recursor)
if err == nil {
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt)

View File

@ -5,6 +5,7 @@ import (
"math/rand"
"net"
"reflect"
"sort"
"strings"
"testing"
"time"
@ -3740,6 +3741,24 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}
newCfg := *a.Config
newCfg.DNSOnlyPassing = false
err := a.ReloadConfig(&newCfg)
require.NoError(t, err)
// only_passing is now false. we should now get two nodes
m := new(dns.Msg)
m.SetQuestion("db.service.consul.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
require.NoError(t, err)
require.Equal(t, 2, len(in.Answer))
ips := []string{in.Answer[0].(*dns.A).A.String(), in.Answer[1].(*dns.A).A.String()}
sort.Strings(ips)
require.Equal(t, []string{"127.0.0.1", "127.0.0.2"}, ips)
}
func TestDNS_ServiceLookup_Randomize(t *testing.T) {
@ -5190,7 +5209,6 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) {
})
}
}
func TestDNS_ServiceLookup_MetaTXT(t *testing.T) {
a := NewTestAgent(t, t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`)
defer a.Shutdown()
@ -6341,11 +6359,177 @@ func TestDNS_formatNodeRecord(t *testing.T) {
},
}
records, meta := s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false)
records, meta := s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false)
require.Len(t, records, 1)
require.Len(t, meta, 0)
records, meta = s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true)
records, meta = s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true)
require.Len(t, records, 1)
require.Len(t, meta, 2)
}
func TestDNS_ConfigReload(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, t.Name(), `
recursors = ["8.8.8.8:53"]
dns_config = {
allow_stale = false
max_stale = "20s"
node_ttl = "10s"
service_ttl = {
"my_services*" = "5s"
"my_specific_service" = "30s"
}
enable_truncate = false
only_passing = false
recursor_timeout = "15s"
disable_compression = false
a_record_limit = 1
enable_additional_node_meta_txt = false
soa = {
refresh = 1
retry = 2
expire = 3
min_ttl = 4
}
}
`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
for _, s := range a.dnsServers {
cfg := s.config.Load().(*dnsConfig)
require.Equal(t, []string{"8.8.8.8:53"}, cfg.Recursors)
require.False(t, cfg.AllowStale)
require.Equal(t, 20*time.Second, cfg.MaxStale)
require.Equal(t, 10*time.Second, cfg.NodeTTL)
ttl, _ := cfg.GetTTLForService("my_services_1")
require.Equal(t, 5*time.Second, ttl)
ttl, _ = cfg.GetTTLForService("my_specific_service")
require.Equal(t, 30*time.Second, ttl)
require.False(t, cfg.EnableTruncate)
require.False(t, cfg.OnlyPassing)
require.Equal(t, 15*time.Second, cfg.RecursorTimeout)
require.False(t, cfg.DisableCompression)
require.Equal(t, 1, cfg.ARecordLimit)
require.False(t, cfg.NodeMetaTXT)
require.Equal(t, uint32(1), cfg.SOAConfig.Refresh)
require.Equal(t, uint32(2), cfg.SOAConfig.Retry)
require.Equal(t, uint32(3), cfg.SOAConfig.Expire)
require.Equal(t, uint32(4), cfg.SOAConfig.Minttl)
}
newCfg := *a.Config
newCfg.DNSRecursors = []string{"1.1.1.1:53"}
newCfg.DNSAllowStale = true
newCfg.DNSMaxStale = 21 * time.Second
newCfg.DNSNodeTTL = 11 * time.Second
newCfg.DNSServiceTTL = map[string]time.Duration{
"2_my_services*": 6 * time.Second,
"2_my_specific_service": 31 * time.Second,
}
newCfg.DNSEnableTruncate = true
newCfg.DNSOnlyPassing = true
newCfg.DNSRecursorTimeout = 16 * time.Second
newCfg.DNSDisableCompression = true
newCfg.DNSARecordLimit = 2
newCfg.DNSNodeMetaTXT = true
newCfg.DNSSOA.Refresh = 10
newCfg.DNSSOA.Retry = 20
newCfg.DNSSOA.Expire = 30
newCfg.DNSSOA.Minttl = 40
err := a.ReloadConfig(&newCfg)
require.NoError(t, err)
for _, s := range a.dnsServers {
cfg := s.config.Load().(*dnsConfig)
require.Equal(t, []string{"1.1.1.1:53"}, cfg.Recursors)
require.True(t, cfg.AllowStale)
require.Equal(t, 21*time.Second, cfg.MaxStale)
require.Equal(t, 11*time.Second, cfg.NodeTTL)
ttl, _ := cfg.GetTTLForService("my_services_1")
require.Equal(t, time.Duration(0), ttl)
ttl, _ = cfg.GetTTLForService("2_my_services_1")
require.Equal(t, 6*time.Second, ttl)
ttl, _ = cfg.GetTTLForService("my_specific_service")
require.Equal(t, time.Duration(0), ttl)
ttl, _ = cfg.GetTTLForService("2_my_specific_service")
require.Equal(t, 31*time.Second, ttl)
require.True(t, cfg.EnableTruncate)
require.True(t, cfg.OnlyPassing)
require.Equal(t, 16*time.Second, cfg.RecursorTimeout)
require.True(t, cfg.DisableCompression)
require.Equal(t, 2, cfg.ARecordLimit)
require.True(t, cfg.NodeMetaTXT)
require.Equal(t, uint32(10), cfg.SOAConfig.Refresh)
require.Equal(t, uint32(20), cfg.SOAConfig.Retry)
require.Equal(t, uint32(30), cfg.SOAConfig.Expire)
require.Equal(t, uint32(40), cfg.SOAConfig.Minttl)
}
}
func TestDNS_ReloadConfig_DuringQuery(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, t.Name(), "")
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
m := MockPreparedQuery{
executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
time.Sleep(100 * time.Millisecond)
reply.Nodes = structs.CheckServiceNodes{
{
Node: &structs.Node{
ID: "my_node",
Address: "127.0.0.1",
},
Service: &structs.NodeService{
Address: "127.0.0.1",
Port: 8080,
},
},
}
return nil
},
}
err := a.registerEndpoint("PreparedQuery", &m)
require.NoError(t, err)
{
m := new(dns.Msg)
m.SetQuestion("nope.query.consul.", dns.TypeA)
timeout := time.NewTimer(time.Second)
res := make(chan *dns.Msg)
errs := make(chan error)
go func() {
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
errs <- err
return
}
res <- in
}()
time.Sleep(50 * time.Millisecond)
// reload the config halfway through, that should not affect the ongoing query
newCfg := *a.Config
newCfg.DNSAllowStale = true
a.ReloadConfig(&newCfg)
select {
case in := <-res:
require.Equal(t, "127.0.0.1", in.Answer[0].(*dns.A).A.String())
case err := <-errs:
require.NoError(t, err)
case <-timeout.C:
require.FailNow(t, "timeout")
}
}
}

View File

@ -279,7 +279,8 @@ func (a *TestAgent) Client() *api.Client {
// DNSDisableCompression disables compression for all started DNS servers.
func (a *TestAgent) DNSDisableCompression(b bool) {
for _, srv := range a.dnsServers {
srv.disableCompression.Store(b)
cfg := srv.config.Load().(*dnsConfig)
cfg.DisableCompression = b
}
}