Adds DNS support for prepared queries (needs tests).

This commit is contained in:
James Phillips 2015-11-12 09:28:05 -08:00
parent 38daaea503
commit f9c91479ef

View file

@ -289,7 +289,7 @@ func (d *DNSServer) dispatch(network string, req, resp *dns.Msg) {
// Split into the label parts
labels := dns.SplitDomainName(qName)
// The last label is either "node", "service" or a datacenter name
// The last label is either "node", "service", "query", or a datacenter name
PARSE:
n := len(labels)
if n == 0 {
@ -336,6 +336,14 @@ PARSE:
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
case "query":
if len(labels) == 1 {
goto INVALID
}
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
d.preparedQueryLookup(network, datacenter, query, req, resp)
default:
// Store the DC, and re-parse
datacenter = labels[n-1]
@ -535,6 +543,84 @@ RPC:
}
}
// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, resp *dns.Msg) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
QueryIDOrName: query,
QueryOptions: structs.QueryOptions{
Token: d.agent.config.ACLToken,
AllowStale: d.config.AllowStale,
},
}
// If the network is not TCP then we just get enough responses to
// tell that things got truncated. This saves bandwidth since we
// will trim the list anyway.
if network != "tcp" {
args.Limit = maxServiceResponses + 1
}
endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
var out structs.PreparedQueryExecuteResponse
RPC:
if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// Verify that request is not too stale, redo the request.
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}
// Determine the TTL. The parse should never fail since we vet it when
// the query is created, but we check anyway.
var ttl time.Duration
if out.DNS.TTL != "" {
var err error
ttl, err = time.ParseDuration(out.DNS.TTL)
if err != nil {
d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query)
}
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
// Add various responses depending on the request.
qType := req.Question[0].Qtype
d.serviceNodeRecords(out.Nodes, req, resp, ttl)
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
}
// If the network is not TCP, restrict the number of responses.
if network != "tcp" && len(resp.Answer) > maxServiceResponses {
resp.Answer = resp.Answer[:maxServiceResponses]
// Flag that there are more records to return in the UDP
// response.
if d.config.EnableTruncate {
resp.Truncated = true
}
}
// If the answer is empty, return not found.
if len(resp.Answer) == 0 {
d.addSOA(d.domain, resp)
return
}
}
// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
qName := req.Question[0].Name