package agent import ( "context" "encoding/hex" "errors" "fmt" "net" "regexp" "strings" "sync/atomic" "time" "github.com/armon/go-metrics/prometheus" metrics "github.com/armon/go-metrics" radix "github.com/armon/go-radix" "github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/hashicorp/go-hclog" "github.com/miekg/dns" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/logging" ) var DNSCounters = []prometheus.CounterDefinition{ { Name: []string{"dns", "stale_queries"}, Help: "Increments when an agent serves a query within the allowed stale threshold.", }, } var DNSSummaries = []prometheus.SummaryDefinition{ { Name: []string{"dns", "ptr_query"}, Help: "Measures the time spent handling a reverse DNS query for the given node.", }, { Name: []string{"dns", "domain_query"}, Help: "Measures the time spent handling a domain query for the given node.", }, } const ( // UDP can fit ~25 A records in a 512B response, and ~14 AAAA // records. Limit further to prevent unintentional configuration // abuse that would have a negative effect on application response // times. maxUDPAnswerLimit = 8 maxRecurseRecords = 5 maxRecursionLevelDefault = 3 // Increment a counter when requests staler than this are served staleCounterThreshold = 5 * time.Second defaultMaxUDPSize = 512 ) type dnsSOAConfig struct { Refresh uint32 // 3600 by default Retry uint32 // 600 Expire uint32 // 86400 Minttl uint32 // 0 } type dnsConfig struct { AllowStale bool Datacenter string EnableTruncate bool MaxStale time.Duration UseCache bool CacheMaxAge time.Duration NodeName string NodeTTL time.Duration OnlyPassing bool RecursorTimeout time.Duration Recursors []string SegmentName string UDPAnswerLimit int ARecordLimit int NodeMetaTXT bool SOAConfig dnsSOAConfig // TTLRadix sets service TTLs by prefix, eg: "database-*" TTLRadix *radix.Tree // TTLStict sets TTLs to service by full name match. It Has higher priority than TTLRadix TTLStrict map[string]time.Duration DisableCompression bool enterpriseDNSConfig } type serviceLookup struct { Datacenter string Service string Tag string MaxRecursionLevel int Connect bool Ingress bool structs.EnterpriseMeta } // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { *dns.Server agent *Agent mux *dns.ServeMux domain string altDomain string logger hclog.Logger // config stores the config as an atomic value (for hot-reloading). It is always of type *dnsConfig config atomic.Value // recursorEnabled stores whever the recursor handler is enabled as an atomic flag. // the recursor handler is only enabled if recursors are configured. This flag is used during config hot-reloading recursorEnabled uint32 } func NewDNSServer(a *Agent) (*DNSServer, error) { // Make sure domains are FQDN, make them case insensitive for ServeMux domain := dns.Fqdn(strings.ToLower(a.config.DNSDomain)) altDomain := dns.Fqdn(strings.ToLower(a.config.DNSAltDomain)) srv := &DNSServer{ agent: a, domain: domain, altDomain: altDomain, logger: a.logger.Named(logging.DNS), } cfg, err := GetDNSConfig(a.config) if err != nil { return nil, err } srv.config.Store(cfg) return srv, nil } // GetDNSConfig takes global config and creates the config used by DNS server func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) { cfg := &dnsConfig{ AllowStale: conf.DNSAllowStale, ARecordLimit: conf.DNSARecordLimit, Datacenter: conf.Datacenter, EnableTruncate: conf.DNSEnableTruncate, MaxStale: conf.DNSMaxStale, NodeName: conf.NodeName, NodeTTL: conf.DNSNodeTTL, OnlyPassing: conf.DNSOnlyPassing, RecursorTimeout: conf.DNSRecursorTimeout, SegmentName: conf.SegmentName, UDPAnswerLimit: conf.DNSUDPAnswerLimit, NodeMetaTXT: conf.DNSNodeMetaTXT, DisableCompression: conf.DNSDisableCompression, UseCache: conf.DNSUseCache, CacheMaxAge: conf.DNSCacheMaxAge, SOAConfig: dnsSOAConfig{ Expire: conf.DNSSOA.Expire, Minttl: conf.DNSSOA.Minttl, Refresh: conf.DNSSOA.Refresh, Retry: conf.DNSSOA.Retry, }, enterpriseDNSConfig: getEnterpriseDNSConfig(conf), } if conf.DNSServiceTTL != nil { cfg.TTLRadix = radix.New() cfg.TTLStrict = make(map[string]time.Duration) for key, ttl := range conf.DNSServiceTTL { // All suffix with '*' are put in radix // This include '*' that will match anything if strings.HasSuffix(key, "*") { cfg.TTLRadix.Insert(key[:len(key)-1], ttl) } else { cfg.TTLStrict[key] = ttl } } } for _, r := range conf.DNSRecursors { ra, err := recursorAddr(r) if err != nil { return nil, fmt.Errorf("Invalid recursor address: %v", err) } cfg.Recursors = append(cfg.Recursors, ra) } return cfg, nil } // GetTTLForService Find the TTL for a given service. // return ttl, true if found, 0, false otherwise func (cfg *dnsConfig) GetTTLForService(service string) (time.Duration, bool) { if cfg.TTLStrict != nil { ttl, ok := cfg.TTLStrict[service] if ok { return ttl, true } } if cfg.TTLRadix != nil { _, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service) if ok { return ttlRaw.(time.Duration), true } } return 0, false } func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error { cfg := d.config.Load().(*dnsConfig) d.mux = dns.NewServeMux() d.mux.HandleFunc("arpa.", d.handlePtr) d.mux.HandleFunc(d.domain, d.handleQuery) // this is not an empty string check because NewDNSServer will have // converted the configured alt domain into an FQDN which will ensure that // the value ends with a ".". Therefore "." is the empty string equivalent // for originally having no alternate domain set. If there is a reason // why consul should be configured to handle the root zone I have yet // to think of it. if d.altDomain != "." { d.mux.HandleFunc(d.altDomain, d.handleQuery) } d.toggleRecursorHandlerFromConfig(cfg) d.Server = &dns.Server{ Addr: addr, Net: network, Handler: d.mux, NotifyStartedFunc: notif, } if network == "udp" { d.UDPSize = 65535 } return d.Server.ListenAndServe() } // toggleRecursorHandlerFromConfig enables or disables the recursor handler based on config idempotently func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) { shouldEnable := len(cfg.Recursors) > 0 if shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 0, 1) { d.mux.HandleFunc(".", d.handleRecurse) d.logger.Debug("recursor enabled") return } if !shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 1, 0) { d.mux.HandleRemove(".") d.logger.Debug("recursor disabled") return } } // ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS* func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error { cfg, err := GetDNSConfig(newCfg) if err != nil { return err } d.config.Store(cfg) d.toggleRecursorHandlerFromConfig(cfg) return nil } // setEDNS is used to set the responses EDNS size headers and // possibly the ECS headers as well if they were present in the // original request func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) { edns := request.IsEdns0() if edns == nil { return } // cannot just use the SetEdns0 function as we need to embed // the ECS option as well ednsResp := new(dns.OPT) ednsResp.Hdr.Name = "." ednsResp.Hdr.Rrtype = dns.TypeOPT ednsResp.SetUDPSize(edns.UDPSize()) // Setup the ECS option if present if subnet := ednsSubnetForRequest(request); subnet != nil { subOp := new(dns.EDNS0_SUBNET) subOp.Code = dns.EDNS0SUBNET subOp.Family = subnet.Family subOp.Address = subnet.Address subOp.SourceNetmask = subnet.SourceNetmask if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented { // reply is globally valid and should be cached accordingly subOp.SourceScope = 0 } else { // reply is only valid for the subnet it was queried with subOp.SourceScope = subnet.SourceNetmask } ednsResp.Option = append(ednsResp.Option, subOp) } response.Extra = append(response.Extra, ednsResp) } // recursorAddr is used to add a port to the recursor if omitted. func recursorAddr(recursor string) (string, error) { // Add the port if none START: _, _, err := net.SplitHostPort(recursor) if ae, ok := err.(*net.AddrError); ok { if ae.Err == "missing port in address" { recursor = ipaddr.FormatAddressPort(recursor, 53) goto START } else if ae.Err == "too many colons in address" { if ip := net.ParseIP(recursor); ip != nil && ip.To4() == nil { recursor = ipaddr.FormatAddressPort(recursor, 53) goto START } } } if err != nil { return "", err } // Get the address addr, err := net.ResolveTCPAddr("tcp", recursor) if err != nil { return "", err } // Return string return addr.String(), nil } func serviceNodeCanonicalDNSName(sn *structs.ServiceNode, domain string) string { return serviceCanonicalDNSName(sn.ServiceName, "service", sn.Datacenter, domain, &sn.EnterpriseMeta) } func serviceIngressDNSName(service, datacenter, domain string, entMeta *structs.EnterpriseMeta) string { return serviceCanonicalDNSName(service, "ingress", datacenter, domain, entMeta) } // handlePtr is used to handle "reverse" DNS queries func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] defer func(s time.Time) { metrics.MeasureSinceWithLabels([]string{"dns", "ptr_query"}, s, []metrics.Label{{Name: "node", Value: d.agent.config.NodeName}}) d.logger.Debug("request served from client", "question", q, "latency", time.Since(s).String(), "client", resp.RemoteAddr().String(), "client_network", resp.RemoteAddr().Network(), ) }(time.Now()) cfg := d.config.Load().(*dnsConfig) // Setup the message response m := new(dns.Msg) m.SetReply(req) m.Compress = !cfg.DisableCompression m.Authoritative = true m.RecursionAvailable = (len(cfg.Recursors) > 0) // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { d.addSOA(cfg, m) } datacenter := d.agent.config.Datacenter // Get the QName without the domain suffix qName := strings.ToLower(dns.Fqdn(req.Question[0].Name)) args := structs.DCSpecificRequest{ Datacenter: datacenter, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), AllowStale: cfg.AllowStale, }, } var out structs.IndexedNodes // TODO: Replace ListNodes with an internal RPC that can do the filter // server side to avoid transferring the entire node list. if err := d.agent.RPC("Catalog.ListNodes", &args, &out); err == nil { for _, n := range out.Nodes { arpa, _ := dns.ReverseAddr(n.Address) if arpa == qName { ptr := &dns.PTR{ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, Ptr: fmt.Sprintf("%s.node.%s.%s", n.Node, datacenter, d.domain), } m.Answer = append(m.Answer, ptr) break } } } // only look into the services if we didn't find a node if len(m.Answer) == 0 { // lookup the service address serviceAddress := dnsutil.ExtractAddressFromReverse(qName) sargs := structs.ServiceSpecificRequest{ Datacenter: datacenter, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), AllowStale: cfg.AllowStale, }, ServiceAddress: serviceAddress, EnterpriseMeta: *structs.WildcardEnterpriseMeta(), } var sout structs.IndexedServiceNodes if err := d.agent.RPC("Catalog.ServiceNodes", &sargs, &sout); err == nil { for _, n := range sout.ServiceNodes { if n.ServiceAddress == serviceAddress { ptr := &dns.PTR{ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, Ptr: serviceNodeCanonicalDNSName(n, d.domain), } m.Answer = append(m.Answer, ptr) break } } } } // nothing found locally, recurse if len(m.Answer) == 0 { d.handleRecurse(resp, req) return } // ptr record responses are globally valid setEDNS(req, m, true) // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Warn("failed to respond", "error", err) } } // handleQuery is used to handle DNS queries in the configured domain func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] defer func(s time.Time) { metrics.MeasureSinceWithLabels([]string{"dns", "domain_query"}, s, []metrics.Label{{Name: "node", Value: d.agent.config.NodeName}}) d.logger.Debug("request served from client", "name", q.Name, "type", dns.Type(q.Qtype), "class", dns.Class(q.Qclass), "latency", time.Since(s).String(), "client", resp.RemoteAddr().String(), "client_network", resp.RemoteAddr().Network(), ) }(time.Now()) // Switch to TCP if the client is network := "udp" if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { network = "tcp" } cfg := d.config.Load().(*dnsConfig) // Setup the message response m := new(dns.Msg) m.SetReply(req) m.Compress = !cfg.DisableCompression m.Authoritative = true m.RecursionAvailable = (len(cfg.Recursors) > 0) var err error switch req.Question[0].Qtype { case dns.TypeSOA: ns, glue := d.nameservers(cfg, maxRecursionLevelDefault) m.Answer = append(m.Answer, d.soa(cfg)) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: ns, glue := d.nameservers(cfg, maxRecursionLevelDefault) m.Answer = ns m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) case dns.TypeAXFR: m.SetRcode(req, dns.RcodeNotImplemented) default: err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault) rCode := rCodeFromError(err) if rCode == dns.RcodeNameError || errors.Is(err, errNoData) { d.addSOA(cfg, m) } m.SetRcode(req, rCode) } setEDNS(req, m, !errors.Is(err, errECSNotGlobal)) d.trimDNSResponse(cfg, network, req, m) if err := resp.WriteMsg(m); err != nil { d.logger.Warn("failed to respond", "error", err) } } func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA { return &dns.SOA{ Hdr: dns.RR_Header{ Name: d.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, // Has to be consistent with MinTTL to avoid invalidation Ttl: cfg.SOAConfig.Minttl, }, Ns: "ns." + d.domain, Serial: uint32(time.Now().Unix()), Mbox: "hostmaster." + d.domain, Refresh: cfg.SOAConfig.Refresh, Retry: cfg.SOAConfig.Retry, Expire: cfg.SOAConfig.Expire, Minttl: cfg.SOAConfig.Minttl, } } // addSOA is used to add an SOA record to a message for the given domain func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) { msg.Ns = append(msg.Ns, d.soa(cfg)) } // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { out, err := d.lookupServiceNodes(cfg, serviceLookup{ Datacenter: d.agent.config.Datacenter, Service: structs.ConsulServiceName, Connect: false, Ingress: false, EnterpriseMeta: *structs.DefaultEnterpriseMeta(), }) if err != nil { d.logger.Warn("Unable to get list of servers", "error", err) return nil, nil } if len(out.Nodes) == 0 { d.logger.Warn("no servers found") return } // shuffle the nodes to randomize the output out.Nodes.Shuffle() for _, o := range out.Nodes { name, dc := o.Node.Node, o.Node.Datacenter if agentdns.InvalidNameRe.MatchString(name) { d.logger.Warn("Skipping invalid node for NS records", "node", name) continue } fqdn := name + ".node." + dc + "." + d.domain fqdn = dns.Fqdn(strings.ToLower(fqdn)) // NS record nsrr := &dns.NS{ Hdr: dns.RR_Header{ Name: d.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(cfg.NodeTTL / time.Second), }, Ns: fqdn, } ns = append(ns, nsrr) extra = append(extra, d.makeRecordFromNode(o.Node, dns.TypeANY, fqdn, cfg.NodeTTL, maxRecursionLevel)...) // don't provide more than 3 servers if len(ns) >= 3 { return } } return } func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { switch len(labels) { case 1: *datacenter = labels[0] return true case 0: return true default: return false } } var errECSNotGlobal = fmt.Errorf("ECS response is not global") var errNameNotFound = fmt.Errorf("DNS name not found") // errNoData is used to indicate no resource records exist for the specified query type. // Per the recommendation from Section 2.2 of RFC 2308, the server will return a TYPE 2 // NODATA response in which the RCODE is set to NOERROR (RcodeSuccess), the Answer // section is empty, and the Authority section contains the SOA record. var errNoData = fmt.Errorf("no DNS Answer") // ecsNotGlobalError may be used to wrap an error or nil, to indicate that the // EDNS client subnet source scope is not global. type ecsNotGlobalError struct { error } func (e ecsNotGlobalError) Error() string { if e.error == nil { return "" } return e.error.Error() } func (e ecsNotGlobalError) Is(other error) bool { return other == errECSNotGlobal } func (e ecsNotGlobalError) Unwrap() error { return e.error } // dispatch is used to parse a request and invoke the correct handler. // parameter maxRecursionLevel will handle whether recursive call can be performed func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter // have to deref to clone it so we don't modify var entMeta structs.EnterpriseMeta // Get the QName without the domain suffix qName := strings.ToLower(dns.Fqdn(req.Question[0].Name)) qName = d.trimDomain(qName) // Split into the label parts labels := dns.SplitDomainName(qName) cfg := d.config.Load().(*dnsConfig) var queryKind string var queryParts []string var querySuffixes []string done := false for i := len(labels) - 1; i >= 0 && !done; i-- { switch labels[i] { case "service", "connect", "ingress", "node", "query", "addr": queryParts = labels[:i] querySuffixes = labels[i+1:] queryKind = labels[i] done = true default: // If this is a SRV query the "service" label is optional, we add it back to use the // existing code-path. if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") { queryKind = "service" queryParts = labels[:i+1] querySuffixes = labels[i+1:] done = true } } } invalid := func() error { d.logger.Warn("QName invalid", "qname", qName) return errNameNotFound } switch queryKind { case "service": n := len(queryParts) if n < 1 { return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { return invalid() } lookup := serviceLookup{ Datacenter: datacenter, Connect: false, Ingress: false, MaxRecursionLevel: maxRecursionLevel, EnterpriseMeta: entMeta, } // Support RFC 2782 style syntax if n == 2 && strings.HasPrefix(queryParts[1], "_") && strings.HasPrefix(queryParts[0], "_") { // Grab the tag since we make nuke it if it's tcp tag := queryParts[1][1:] // Treat _name._tcp.service.consul as a default, no need to filter on that tag if tag == "tcp" { tag = "" } lookup.Tag = tag lookup.Service = queryParts[0][1:] // _name._tag.service.consul return d.serviceLookup(cfg, lookup, req, resp) } // Consul 0.3 and prior format for SRV queries // Support "." in the label, re-join all the parts tag := "" if n >= 2 { tag = strings.Join(queryParts[:n-1], ".") } lookup.Tag = tag lookup.Service = queryParts[n-1] // tag[.tag].name.service.consul return d.serviceLookup(cfg, lookup, req, resp) case "connect": if len(queryParts) < 1 { return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { return invalid() } lookup := serviceLookup{ Datacenter: datacenter, Service: queryParts[len(queryParts)-1], Connect: true, Ingress: false, MaxRecursionLevel: maxRecursionLevel, EnterpriseMeta: entMeta, } // name.connect.consul return d.serviceLookup(cfg, lookup, req, resp) case "ingress": if len(queryParts) < 1 { return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { return invalid() } lookup := serviceLookup{ Datacenter: datacenter, Service: queryParts[len(queryParts)-1], Connect: false, Ingress: true, MaxRecursionLevel: maxRecursionLevel, EnterpriseMeta: entMeta, } // name.ingress.consul return d.serviceLookup(cfg, lookup, req, resp) case "node": if len(queryParts) < 1 { return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { return invalid() } // Allow a "." in the node name, just join all the parts node := strings.Join(queryParts, ".") return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) case "query": // ensure we have a query name if len(queryParts) < 1 { return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { return invalid() } // Allow a "." in the query name, just join all the parts. query := strings.Join(queryParts, ".") err := d.preparedQueryLookup(cfg, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) return ecsNotGlobalError{error: err} case "addr": //
.addr.. - addr must be the second label, datacenter is optional if len(queryParts) != 1 { return invalid() } switch len(queryParts[0]) / 2 { // IPv4 case 4: ip, err := hex.DecodeString(queryParts[0]) if err != nil { return invalid() } //check if the query type is A for IPv4 or ANY aRecord := &dns.A{ Hdr: dns.RR_Header{ Name: qName + d.domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(cfg.NodeTTL / time.Second), }, A: ip, } if req.Question[0].Qtype != dns.TypeA && req.Question[0].Qtype != dns.TypeANY { resp.Extra = append(resp.Answer, aRecord) } else { resp.Answer = append(resp.Answer, aRecord) } // IPv6 case 16: ip, err := hex.DecodeString(queryParts[0]) if err != nil { return invalid() } //check if the query type is AAAA for IPv6 or ANY aaaaRecord := &dns.AAAA{ Hdr: dns.RR_Header{ Name: qName + d.domain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: uint32(cfg.NodeTTL / time.Second), }, AAAA: ip, } if req.Question[0].Qtype != dns.TypeAAAA && req.Question[0].Qtype != dns.TypeANY { resp.Extra = append(resp.Extra, aaaaRecord) } else { resp.Answer = append(resp.Answer, aaaaRecord) } } return nil default: return invalid() } } func (d *DNSServer) trimDomain(query string) string { longer := d.domain shorter := d.altDomain if len(shorter) > len(longer) { longer, shorter = shorter, longer } if strings.HasSuffix(query, longer) { return strings.TrimSuffix(query, longer) } return strings.TrimSuffix(query, shorter) } // rCodeFromError return the appropriate DNS response code for a given error func rCodeFromError(err error) int { switch { case err == nil: return dns.RcodeSuccess case errors.Is(err, errNoData): return dns.RcodeSuccess case errors.Is(err, errECSNotGlobal): return rCodeFromError(errors.Unwrap(err)) case errors.Is(err, errNameNotFound): return dns.RcodeNameError case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err): return dns.RcodeNameError default: return dns.RcodeServerFailure } } // nodeLookup is used to handle a node query func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error { // Only handle ANY, A, AAAA, and TXT type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { return nil } // Make an RPC request args := &structs.NodeSpecificRequest{ Datacenter: datacenter, Node: node, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), AllowStale: cfg.AllowStale, }, } out, err := d.lookupNode(cfg, args) if err != nil { return fmt.Errorf("failed rpc request: %w", err) } // If we have no out.NodeServices.Nodeaddress, return not found! if out.NodeServices == nil { return errNameNotFound } // Add the node record n := out.NodeServices.Node metaTarget := &resp.Extra if qType == dns.TypeTXT || qType == dns.TypeANY { metaTarget = &resp.Answer } q := req.Question[0] // Only compute A and CNAME record if query is not TXT type if qType != dns.TypeTXT { records := d.makeRecordFromNode(n, q.Qtype, q.Name, cfg.NodeTTL, maxRecursionLevel) resp.Answer = append(resp.Answer, records...) } if cfg.NodeMetaTXT || qType == dns.TypeTXT || qType == dns.TypeANY { metas := d.generateMeta(q.Name, n, cfg.NodeTTL) *metaTarget = append(*metaTarget, metas...) } return nil } func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { var out structs.IndexedNodeServices useCache := cfg.UseCache RPC: if useCache { raw, _, err := d.agent.cache.Get(context.TODO(), cachetype.NodeServicesName, args) if err != nil { return nil, err } reply, ok := raw.(*structs.IndexedNodeServices) if !ok { // This should never happen, but we want to protect against panics return nil, fmt.Errorf("internal error: response type not correct") } out = *reply } else { if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { return nil, err } } // Verify that request is not too stale, redo the request if args.AllowStale { if out.LastContact > cfg.MaxStale { args.AllowStale = false useCache = false d.logger.Warn("Query results too stale, re-requesting") goto RPC } else if out.LastContact > staleCounterThreshold { metrics.IncrCounter([]string{"dns", "stale_queries"}, 1) } } return &out, nil } // encodeKVasRFC1464 encodes a key-value pair according to RFC1464 func encodeKVasRFC1464(key, value string) (txt string) { // For details on these replacements c.f. https://www.ietf.org/rfc/rfc1464.txt key = strings.Replace(key, "`", "``", -1) key = strings.Replace(key, "=", "`=", -1) // Backquote the leading spaces leadingSpacesRE := regexp.MustCompile("^ +") numLeadingSpaces := len(leadingSpacesRE.FindString(key)) key = leadingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numLeadingSpaces)) // Backquote the trailing spaces trailingSpacesRE := regexp.MustCompile(" +$") numTrailingSpaces := len(trailingSpacesRE.FindString(key)) key = trailingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numTrailingSpaces)) value = strings.Replace(value, "`", "``", -1) return key + "=" + value } // indexRRs populates a map which indexes a given list of RRs by name. NOTE that // the names are all squashed to lower case so we can perform case-insensitive // lookups; the RRs are not modified. func indexRRs(rrs []dns.RR, index map[string]dns.RR) { for _, rr := range rrs { name := strings.ToLower(rr.Header().Name) if _, ok := index[name]; !ok { index[name] = rr } } } // syncExtra takes a DNS response message and sets the extra data to the most // minimal set needed to cover the answer data. A pre-made index of RRs is given // so that can be re-used between calls. This assumes that the extra data is // only used to provide info for SRV records. If that's not the case, then this // will wipe out any additional data. func syncExtra(index map[string]dns.RR, resp *dns.Msg) { extra := make([]dns.RR, 0, len(resp.Answer)) resolved := make(map[string]struct{}, len(resp.Answer)) for _, ansRR := range resp.Answer { srv, ok := ansRR.(*dns.SRV) if !ok { continue } // Note that we always use lower case when using the index so // that compares are not case-sensitive. We don't alter the actual // RRs we add into the extra section, however. target := strings.ToLower(srv.Target) RESOLVE: if _, ok := resolved[target]; ok { continue } resolved[target] = struct{}{} extraRR, ok := index[target] if ok { extra = append(extra, extraRR) if cname, ok := extraRR.(*dns.CNAME); ok { target = strings.ToLower(cname.Target) goto RESOLVE } } } resp.Extra = extra } // dnsBinaryTruncate find the optimal number of records using a fast binary search and return // it in order to return a DNS answer lower than maxSize parameter. func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasExtra bool) int { originalAnswser := resp.Answer startIndex := 0 endIndex := len(resp.Answer) + 1 for endIndex-startIndex > 1 { median := startIndex + (endIndex-startIndex)/2 resp.Answer = originalAnswser[:median] if hasExtra { syncExtra(index, resp) } aLen := resp.Len() if aLen <= maxSize { if maxSize-aLen < 10 { // We are good, increasing will go out of bounds return median } startIndex = median } else { endIndex = median } } return startIndex } // trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit // of DNS responses func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { hasExtra := len(resp.Extra) > 0 // There is some overhead, 65535 does not work maxSize := 65523 // 64k - 12 bytes DNS raw overhead // We avoid some function calls and allocations by only handling the // extra data when necessary. var index map[string]dns.RR // It is not possible to return more than 4k records even with compression // Since we are performing binary search it is not a big deal, but it // improves a bit performance, even with binary search truncateAt := 4096 if req.Question[0].Qtype == dns.TypeSRV { // More than 1024 SRV records do not fit in 64k truncateAt = 1024 } if len(resp.Answer) > truncateAt { resp.Answer = resp.Answer[:truncateAt] } if hasExtra { index = make(map[string]dns.RR, len(resp.Extra)) indexRRs(resp.Extra, index) } truncated := false // This enforces the given limit on 64k, the max limit for DNS messages for len(resp.Answer) > 1 && resp.Len() > maxSize { truncated = true // first try to remove the NS section may be it will truncate enough if len(resp.Ns) != 0 { resp.Ns = []dns.RR{} } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) resp.Answer = resp.Answer[:bestIndex] } else { resp.Answer = resp.Answer[:len(resp.Answer)-1] } if hasExtra { syncExtra(index, resp) } } return truncated } // trimUDPResponse makes sure a UDP response is not longer than allowed by RFC // 1035. Enforce an arbitrary limit that can be further ratcheted down by // config, and then make sure the response doesn't exceed 512 bytes. Any extra // records will be trimmed along with answers. func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { numAnswers := len(resp.Answer) hasExtra := len(resp.Extra) > 0 maxSize := defaultMaxUDPSize // Update to the maximum edns size if edns := req.IsEdns0(); edns != nil { if size := edns.UDPSize(); size > uint16(maxSize) { maxSize = int(size) } } // We avoid some function calls and allocations by only handling the // extra data when necessary. var index map[string]dns.RR if hasExtra { index = make(map[string]dns.RR, len(resp.Extra)) indexRRs(resp.Extra, index) } // This cuts UDP responses to a useful but limited number of responses. maxAnswers := lib.MinInt(maxUDPAnswerLimit, udpAnswerLimit) compress := resp.Compress if maxSize == defaultMaxUDPSize && numAnswers > maxAnswers { // We disable computation of Len ONLY for non-eDNS request (512 bytes) resp.Compress = false resp.Answer = resp.Answer[:maxAnswers] if hasExtra { syncExtra(index, resp) } } // This enforces the given limit on the number bytes. The default is 512 as // per the RFC, but EDNS0 allows for the user to specify larger sizes. Note // that we temporarily switch to uncompressed so that we limit to a response // that will not exceed 512 bytes uncompressed, which is more conservative and // will allow our responses to be compliant even if some downstream server // uncompresses them. // Even when size is too big for one single record, try to send it anyway // (useful for 512 bytes messages) for len(resp.Answer) > 1 && resp.Len() > maxSize-7 { // first try to remove the NS section may be it will truncate enough if len(resp.Ns) != 0 { resp.Ns = []dns.RR{} } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) resp.Answer = resp.Answer[:bestIndex] } else { resp.Answer = resp.Answer[:len(resp.Answer)-1] } if hasExtra { syncExtra(index, resp) } } // For 512 non-eDNS responses, while we compute size non-compressed, // we send result compressed resp.Compress = compress return len(resp.Answer) < numAnswers } // trimDNSResponse will trim the response for UDP and TCP func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) { var trimmed bool originalSize := resp.Len() originalNumRecords := len(resp.Answer) if network != "tcp" { trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) } else { trimmed = trimTCPResponse(req, resp) } // Flag that there are more records to return in the UDP response if trimmed { if cfg.EnableTruncate { resp.Truncated = true } d.logger.Debug("DNS response too large, truncated", "protocol", network, "question", req.Question, "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), ) } } // lookupServiceNodes returns nodes with a given service. func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (structs.IndexedCheckServiceNodes, error) { serviceTags := []string{} if lookup.Tag != "" { serviceTags = []string{lookup.Tag} } args := structs.ServiceSpecificRequest{ Connect: lookup.Connect, Ingress: lookup.Ingress, Datacenter: lookup.Datacenter, ServiceName: lookup.Service, ServiceTags: serviceTags, TagFilter: lookup.Tag != "", QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), AllowStale: cfg.AllowStale, MaxAge: cfg.CacheMaxAge, UseCache: cfg.UseCache, MaxStaleDuration: cfg.MaxStale, }, EnterpriseMeta: lookup.EnterpriseMeta, } out, _, err := d.agent.rpcClientHealth.ServiceNodes(context.TODO(), args) if err != nil { return out, err } // Filter out any service nodes due to health checks // We copy the slice to avoid modifying the result if it comes from the cache nodes := make(structs.CheckServiceNodes, len(out.Nodes)) copy(nodes, out.Nodes) out.Nodes = nodes.Filter(cfg.OnlyPassing) return out, nil } // serviceLookup is used to handle a service query func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error { out, err := d.lookupServiceNodes(cfg, lookup) if err != nil { return fmt.Errorf("rpc request failed: %w", err) } // If we have no nodes, return not found! if len(out.Nodes) == 0 { return errNameNotFound } // Perform a random shuffle out.Nodes.Shuffle() // Determine the TTL ttl, _ := cfg.GetTTLForService(lookup.Service) // Add various responses depending on the request qType := req.Question[0].Qtype if qType == dns.TypeSRV { d.serviceSRVRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel) } else { d.serviceNodeRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel) } if len(resp.Answer) == 0 { return errNoData } return nil } func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { // IsEdns0 returns the EDNS RR if present or nil otherwise edns := req.IsEdns0() if edns == nil { return nil } for _, o := range edns.Option { if subnet, ok := o.(*dns.EDNS0_SUBNET); ok { return subnet } } return nil } // preparedQueryLookup is used to handle a prepared query. func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ Datacenter: datacenter, QueryIDOrName: query, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), AllowStale: cfg.AllowStale, MaxAge: cfg.CacheMaxAge, }, // Always pass the local agent through. In the DNS interface, there // is no provision for passing additional query parameters, so we // send the local agent's data through to allow distance sorting // relative to ourself on the server side. Agent: structs.QuerySource{ Datacenter: d.agent.config.Datacenter, Segment: d.agent.config.SegmentName, Node: d.agent.config.NodeName, }, } subnet := ednsSubnetForRequest(req) if subnet != nil { args.Source.Ip = subnet.Address.String() } else { switch v := remoteAddr.(type) { case *net.UDPAddr: args.Source.Ip = v.IP.String() case *net.TCPAddr: args.Source.Ip = v.IP.String() case *net.IPAddr: args.Source.Ip = v.IP.String() } } out, err := d.lookupPreparedQuery(cfg, args) if err != nil { return err } // TODO (slackpad) - What's a safe limit we can set here? It seems like // with dup filtering done at this level we need to get everything to // match the previous behavior. We can optimize by pushing more filtering // into the query execution, but for now I think we need to get the full // response. We could also choose a large arbitrary number that will // likely work in practice, like 10*maxUDPAnswerLimit which should help // reduce bandwidth if there are thousands of nodes available. // Determine the TTL. The parse should never fail since we vet it when // the query is created, but we check anyway. If the query didn't // specify a TTL then we will try to use the agent's service-specific // TTL configs. var ttl time.Duration if out.DNS.TTL != "" { var err error ttl, err = time.ParseDuration(out.DNS.TTL) if err != nil { d.logger.Warn("Failed to parse TTL for prepared query , ignoring", "ttl", out.DNS.TTL, "prepared_query", query, ) } } else { ttl, _ = cfg.GetTTLForService(out.Service) } // If we have no nodes, return not found! if len(out.Nodes) == 0 { return errNameNotFound } // Add various responses depending on the request. qType := req.Question[0].Qtype if qType == dns.TypeSRV { d.serviceSRVRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } else { d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } if len(resp.Answer) == 0 { return errNoData } return nil } func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { var out structs.PreparedQueryExecuteResponse RPC: if cfg.UseCache { raw, m, err := d.agent.cache.Get(context.TODO(), cachetype.PreparedQueryName, &args) if err != nil { return nil, err } reply, ok := raw.(*structs.PreparedQueryExecuteResponse) if !ok { // This should never happen, but we want to protect against panics return nil, err } d.logger.Trace("cache results for prepared query", "cache_hit", m.Hit, "prepared_query", args.QueryIDOrName, ) out = *reply } else { if err := d.agent.RPC("PreparedQuery.Execute", &args, &out); err != nil { return nil, err } } // Verify that request is not too stale, redo the request. if args.AllowStale { if out.LastContact > cfg.MaxStale { args.AllowStale = false d.logger.Warn("Query results too stale, re-requesting") goto RPC } else if out.LastContact > staleCounterThreshold { metrics.IncrCounter([]string{"dns", "stale_queries"}, 1) } } return &out, nil } // serviceNodeRecords is used to add the node records for a service lookup func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { handled := make(map[string]struct{}) var answerCNAME []dns.RR = nil count := 0 for _, node := range nodes { // Add the node record had_answer := false records, _ := d.nodeServiceRecords(dc, node, req, ttl, cfg, maxRecursionLevel) if len(records) == 0 { continue } // Avoid duplicate entries, possible if a node has // the same service on multiple ports, etc. if _, ok := handled[records[0].String()]; ok { continue } handled[records[0].String()] = struct{}{} switch records[0].(type) { case *dns.CNAME: // keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet // this will only be added if no non-CNAME RRs are found if len(answerCNAME) == 0 { answerCNAME = records } default: resp.Answer = append(resp.Answer, records...) had_answer = true } if had_answer { count++ if count == cfg.ARecordLimit { // We stop only if greater than 0 or we reached the limit return } } } if len(resp.Answer) == 0 && len(answerCNAME) > 0 { resp.Answer = answerCNAME } } func findWeight(node structs.CheckServiceNode) int { // By default, when only_passing is false, warning and passing nodes are returned // Those values will be used if using a client with support while server has no // support for weights weightPassing := 1 weightWarning := 1 if node.Service.Weights != nil { weightPassing = node.Service.Weights.Passing weightWarning = node.Service.Weights.Warning } serviceChecks := make(api.HealthChecks, 0) for _, c := range node.Checks { if c.ServiceName == node.Service.Service || c.ServiceName == "" { healthCheck := &api.HealthCheck{ Node: c.Node, CheckID: string(c.CheckID), Name: c.Name, Status: c.Status, Notes: c.Notes, Output: c.Output, ServiceID: c.ServiceID, ServiceName: c.ServiceName, ServiceTags: c.ServiceTags, } serviceChecks = append(serviceChecks, healthCheck) } } status := serviceChecks.AggregatedStatus() switch status { case api.HealthWarning: return weightWarning case api.HealthPassing: return weightPassing case api.HealthMaint: // Not used in theory return 0 case api.HealthCritical: // Should not happen since already filtered return 0 default: // When non-standard status, return 1 return 1 } } func (d *DNSServer) encodeIPAsFqdn(dc string, ip net.IP) string { ipv4 := ip.To4() if ipv4 != nil { ipStr := hex.EncodeToString(ip) return fmt.Sprintf("%s.addr.%s.%s", ipStr[len(ipStr)-(net.IPv4len*2):], dc, d.domain) } else { return fmt.Sprintf("%s.addr.%s.%s", hex.EncodeToString(ip), dc, d.domain) } } func makeARecord(qType uint16, ip net.IP, ttl time.Duration) dns.RR { var ipRecord dns.RR ipv4 := ip.To4() if ipv4 != nil { if qType == dns.TypeSRV || qType == dns.TypeA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT { ipRecord = &dns.A{ Hdr: dns.RR_Header{ Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, A: ipv4, } } } else if qType == dns.TypeSRV || qType == dns.TypeAAAA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT { ipRecord = &dns.AAAA{ Hdr: dns.RR_Header{ Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, AAAA: ip, } } return ipRecord } // Craft dns records for a node // In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP // Otherwise it will return a IN A record func (d *DNSServer) makeRecordFromNode(node *structs.Node, qType uint16, qName string, ttl time.Duration, maxRecursionLevel int) []dns.RR { addrTranslate := TranslateAddressAcceptDomain if qType == dns.TypeA { addrTranslate |= TranslateAddressAcceptIPv4 } else if qType == dns.TypeAAAA { addrTranslate |= TranslateAddressAcceptIPv6 } else { addrTranslate |= TranslateAddressAcceptAny } addr := d.agent.TranslateAddress(node.Datacenter, node.Address, node.TaggedAddresses, addrTranslate) ip := net.ParseIP(addr) var res []dns.RR if ip == nil { res = append(res, &dns.CNAME{ Hdr: dns.RR_Header{ Name: qName, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Target: dns.Fqdn(node.Address), }) res = append(res, d.resolveCNAME(d.config.Load().(*dnsConfig), dns.Fqdn(node.Address), maxRecursionLevel)..., ) return res } ipRecord := makeARecord(qType, ip, ttl) if ipRecord == nil { return nil } ipRecord.Header().Name = qName return []dns.RR{ipRecord} } // Craft dns records for a service // In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP // Otherwise it will return a IN A record func (d *DNSServer) makeRecordFromServiceNode(dc string, serviceNode structs.CheckServiceNode, addr net.IP, req *dns.Msg, ttl time.Duration) ([]dns.RR, []dns.RR) { q := req.Question[0] ipRecord := makeARecord(q.Qtype, addr, ttl) if ipRecord == nil { return nil, nil } if q.Qtype == dns.TypeSRV { nodeFQDN := fmt.Sprintf("%s.node.%s.%s", serviceNode.Node.Node, dc, d.domain) answers := []dns.RR{ &dns.SRV{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Priority: 1, Weight: uint16(findWeight(serviceNode)), Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), Target: nodeFQDN, }, } ipRecord.Header().Name = nodeFQDN return answers, []dns.RR{ipRecord} } ipRecord.Header().Name = q.Name return []dns.RR{ipRecord}, nil } // Craft dns records for an IP // In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the IP // Otherwise it will return a IN A record func (d *DNSServer) makeRecordFromIP(dc string, addr net.IP, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration) ([]dns.RR, []dns.RR) { q := req.Question[0] ipRecord := makeARecord(q.Qtype, addr, ttl) if ipRecord == nil { return nil, nil } if q.Qtype == dns.TypeSRV { ipFQDN := d.encodeIPAsFqdn(dc, addr) answers := []dns.RR{ &dns.SRV{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Priority: 1, Weight: uint16(findWeight(serviceNode)), Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), Target: ipFQDN, }, } ipRecord.Header().Name = ipFQDN return answers, []dns.RR{ipRecord} } ipRecord.Header().Name = q.Name return []dns.RR{ipRecord}, nil } // Craft dns records for an FQDN // In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the IP // Otherwise it will return a CNAME and a IN A record func (d *DNSServer) makeRecordFromFQDN(dc string, fqdn string, serviceNode structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) { edns := req.IsEdns0() != nil q := req.Question[0] more := d.resolveCNAME(cfg, dns.Fqdn(fqdn), maxRecursionLevel) var additional []dns.RR extra := 0 MORE_REC: for _, rr := range more { switch rr.Header().Rrtype { case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: // set the TTL manually rr.Header().Ttl = uint32(ttl / time.Second) additional = append(additional, rr) extra++ if extra == maxRecurseRecords && !edns { break MORE_REC } } } if q.Qtype == dns.TypeSRV { answers := []dns.RR{ &dns.SRV{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Priority: 1, Weight: uint16(findWeight(serviceNode)), Port: uint16(d.agent.TranslateServicePort(dc, serviceNode.Service.Port, serviceNode.Service.TaggedAddresses)), Target: dns.Fqdn(fqdn), }, } return answers, additional } answers := []dns.RR{ &dns.CNAME{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Target: dns.Fqdn(fqdn), }} answers = append(answers, additional...) return answers, nil } func (d *DNSServer) nodeServiceRecords(dc string, node structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) { addrTranslate := TranslateAddressAcceptDomain if req.Question[0].Qtype == dns.TypeA { addrTranslate |= TranslateAddressAcceptIPv4 } else if req.Question[0].Qtype == dns.TypeAAAA { addrTranslate |= TranslateAddressAcceptIPv6 } else { addrTranslate |= TranslateAddressAcceptAny } serviceAddr := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses, addrTranslate) nodeAddr := d.agent.TranslateAddress(node.Node.Datacenter, node.Node.Address, node.Node.TaggedAddresses, addrTranslate) if serviceAddr == "" && nodeAddr == "" { return nil, nil } nodeIPAddr := net.ParseIP(nodeAddr) serviceIPAddr := net.ParseIP(serviceAddr) // There is no service address and the node address is an IP if serviceAddr == "" && nodeIPAddr != nil { if node.Node.Address != nodeAddr { // Do not CNAME node address in case of WAN address return d.makeRecordFromIP(dc, nodeIPAddr, node, req, ttl) } return d.makeRecordFromServiceNode(dc, node, nodeIPAddr, req, ttl) } // There is no service address and the node address is a FQDN (external service) if serviceAddr == "" { return d.makeRecordFromFQDN(dc, nodeAddr, node, req, ttl, cfg, maxRecursionLevel) } // The service address is an IP if serviceIPAddr != nil { return d.makeRecordFromIP(dc, serviceIPAddr, node, req, ttl) } // If the service address is a CNAME for the service we are looking // for then use the node address. if dns.Fqdn(serviceAddr) == req.Question[0].Name && nodeIPAddr != nil { return d.makeRecordFromServiceNode(dc, node, nodeIPAddr, req, ttl) } // The service address is a FQDN (external service) return d.makeRecordFromFQDN(dc, serviceAddr, node, req, ttl, cfg, maxRecursionLevel) } func (d *DNSServer) generateMeta(qName string, node *structs.Node, ttl time.Duration) []dns.RR { extra := make([]dns.RR, 0, len(node.Meta)) for key, value := range node.Meta { txt := value if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") { txt = encodeKVasRFC1464(key, value) } extra = append(extra, &dns.TXT{ Hdr: dns.RR_Header{ Name: qName, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl / time.Second), }, Txt: []string{txt}, }) } return extra } // serviceARecords is used to add the SRV records for a service lookup func (d *DNSServer) serviceSRVRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { handled := make(map[string]struct{}) for _, node := range nodes { // Avoid duplicate entries, possible if a node has // the same service the same port, etc. serviceAddress := d.agent.TranslateServiceAddress(dc, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny) servicePort := d.agent.TranslateServicePort(dc, node.Service.Port, node.Service.TaggedAddresses) tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort) if _, ok := handled[tuple]; ok { continue } handled[tuple] = struct{}{} answers, extra := d.nodeServiceRecords(dc, node, req, ttl, cfg, maxRecursionLevel) resp.Answer = append(resp.Answer, answers...) resp.Extra = append(resp.Extra, extra...) if cfg.NodeMetaTXT { resp.Extra = append(resp.Extra, d.generateMeta(fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain), node.Node, ttl)...) } } } // handleRecurse is used to handle recursive DNS queries func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { cfg := d.config.Load().(*dnsConfig) q := req.Question[0] network := "udp" defer func(s time.Time) { d.logger.Debug("request served from client", "question", q, "network", network, "latency", time.Since(s).String(), "client", resp.RemoteAddr().String(), "client_network", resp.RemoteAddr().Network(), ) }(time.Now()) // Switch to TCP if the client is if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { network = "tcp" } // Recursively resolve c := &dns.Client{Net: network, Timeout: cfg.RecursorTimeout} var r *dns.Msg var rtt time.Duration var err error for _, recursor := range cfg.Recursors { r, rtt, err = c.Exchange(req, recursor) // Check if the response is valid and has the desired Response code if r != nil && (r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError) { d.logger.Debug("recurse failed for question", "question", q, "rtt", rtt, "recursor", recursor, "rcode", dns.RcodeToString[r.Rcode], ) // If we still have recursors to forward the query to, // we move forward onto the next one else the loop ends continue } else if err == nil || (r != nil && r.Truncated) { // Compress the response; we don't know if the incoming // response was compressed or not, so by not compressing // we might generate an invalid packet on the way out. r.Compress = !cfg.DisableCompression // Forward the response d.logger.Debug("recurse succeeded for question", "question", q, "rtt", rtt, "recursor", recursor, ) if err := resp.WriteMsg(r); err != nil { d.logger.Warn("failed to respond", "error", err) } return } d.logger.Error("recurse failed", "error", err) } // If all resolvers fail, return a SERVFAIL message d.logger.Error("all resolvers failed for question from client", "question", q, "client", resp.RemoteAddr().String(), "client_network", resp.RemoteAddr().Network(), ) m := &dns.Msg{} m.SetReply(req) m.Compress = !cfg.DisableCompression m.RecursionAvailable = true m.SetRcode(req, dns.RcodeServerFailure) if edns := req.IsEdns0(); edns != nil { setEDNS(req, m, true) } resp.WriteMsg(m) } // resolveCNAME is used to recursively resolve CNAME records func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel int) []dns.RR { // If the CNAME record points to a Consul address, resolve it internally // Convert query to lowercase because DNS is case insensitive; d.domain and // d.altDomain are already converted if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+d.domain) || strings.HasSuffix(ln, "."+d.altDomain) { if maxRecursionLevel < 1 { d.logger.Error("Infinite recursion detected for name, won't perform any CNAME resolution.", "name", name) return nil } req := &dns.Msg{} resp := &dns.Msg{} req.SetQuestion(name, dns.TypeANY) // TODO: handle error response d.dispatch(nil, req, resp, maxRecursionLevel-1) return resp.Answer } // Do nothing if we don't have a recursor if len(cfg.Recursors) == 0 { return nil } // Ask for any A records m := new(dns.Msg) m.SetQuestion(name, dns.TypeA) // Make a DNS lookup request c := &dns.Client{Net: "udp", Timeout: cfg.RecursorTimeout} var r *dns.Msg var rtt time.Duration var err error for _, recursor := range cfg.Recursors { r, rtt, err = c.Exchange(m, recursor) if err == nil { d.logger.Debug("cname recurse RTT for name", "name", name, "rtt", rtt, ) return r.Answer } d.logger.Error("cname recurse failed for name", "name", name, "error", err, ) } d.logger.Error("all resolvers failed for name", "name", name) return nil }