diff --git a/command/agent/dns.go b/command/agent/dns.go index d2152f816..4f7ef33f2 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -289,7 +289,7 @@ func (d *DNSServer) dispatch(network string, req, resp *dns.Msg) { // Split into the label parts labels := dns.SplitDomainName(qName) - // The last label is either "node", "service" or a datacenter name + // The last label is either "node", "service", "query", or a datacenter name PARSE: n := len(labels) if n == 0 { @@ -336,6 +336,14 @@ PARSE: node := strings.Join(labels[:n-1], ".") d.nodeLookup(network, datacenter, node, req, resp) + case "query": + if len(labels) == 1 { + goto INVALID + } + // Allow a "." in the query name, just join all the parts. + query := strings.Join(labels[:n-1], ".") + d.preparedQueryLookup(network, datacenter, query, req, resp) + default: // Store the DC, and re-parse datacenter = labels[n-1] @@ -535,6 +543,84 @@ RPC: } } +// preparedQueryLookup is used to handle a prepared query. +func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, resp *dns.Msg) { + // Execute the prepared query. + args := structs.PreparedQueryExecuteRequest{ + Datacenter: datacenter, + QueryIDOrName: query, + QueryOptions: structs.QueryOptions{ + Token: d.agent.config.ACLToken, + AllowStale: d.config.AllowStale, + }, + } + + // If the network is not TCP then we just get enough responses to + // tell that things got truncated. This saves bandwidth since we + // will trim the list anyway. + if network != "tcp" { + args.Limit = maxServiceResponses + 1 + } + + endpoint := d.agent.getEndpoint(preparedQueryEndpoint) + var out structs.PreparedQueryExecuteResponse +RPC: + if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil { + d.logger.Printf("[ERR] dns: rpc error: %v", err) + resp.SetRcode(req, dns.RcodeServerFailure) + return + } + + // Verify that request is not too stale, redo the request. + if args.AllowStale && out.LastContact > d.config.MaxStale { + args.AllowStale = false + d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") + goto RPC + } + + // Determine the TTL. The parse should never fail since we vet it when + // the query is created, but we check anyway. + var ttl time.Duration + if out.DNS.TTL != "" { + var err error + ttl, err = time.ParseDuration(out.DNS.TTL) + if err != nil { + d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query) + } + } + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + d.addSOA(d.domain, resp) + resp.SetRcode(req, dns.RcodeNameError) + return + } + + // Add various responses depending on the request. + qType := req.Question[0].Qtype + d.serviceNodeRecords(out.Nodes, req, resp, ttl) + if qType == dns.TypeSRV { + d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl) + } + + // If the network is not TCP, restrict the number of responses. + if network != "tcp" && len(resp.Answer) > maxServiceResponses { + resp.Answer = resp.Answer[:maxServiceResponses] + + // Flag that there are more records to return in the UDP + // response. + if d.config.EnableTruncate { + resp.Truncated = true + } + } + + // If the answer is empty, return not found. + if len(resp.Answer) == 0 { + d.addSOA(d.domain, resp) + return + } +} + // serviceNodeRecords is used to add the node records for a service lookup func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) { qName := req.Question[0].Name