diff --git a/command/agent/command.go b/command/agent/command.go index 09ba1577f..89e67d446 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -52,6 +52,7 @@ func (c *Command) readConfig() *Config { "address to bind RPC listener to") cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory") cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter") + cmdFlags.StringVar(&cmdConfig.DNSRecursor, "recursor", "", "address of dns recursor") cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server") cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode") if err := cmdFlags.Parse(c.args); err != nil { @@ -148,7 +149,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log } if config.DNSAddr != "" { - server, err := NewDNSServer(agent, logOutput, config.Domain, config.DNSAddr) + server, err := NewDNSServer(agent, logOutput, config.Domain, + config.DNSAddr, config.DNSRecursor) if err != nil { agent.Shutdown() c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err)) diff --git a/command/agent/config.go b/command/agent/config.go index 08c16e27f..1b981c54e 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -30,6 +30,10 @@ type Config struct { // DNSAddr is the address of the DNS server for the agent DNSAddr string + // DNSRecursor can be set to allow the DNS server to recursively + // resolve non-consul domains + DNSRecursor string + // Domain is the DNS domain for the records. Defaults to "consul." Domain string @@ -154,6 +158,9 @@ func MergeConfig(a, b *Config) *Config { if b.DNSAddr != "" { result.DNSAddr = b.DNSAddr } + if b.DNSRecursor != "" { + result.DNSRecursor = b.DNSRecursor + } if b.Domain != "" { result.Domain = b.Domain } diff --git a/command/agent/dns.go b/command/agent/dns.go index 9a8eddbf0..2828588b6 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -19,36 +19,45 @@ const ( // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { - agent *Agent - dnsHandler *dns.ServeMux - dnsServer *dns.Server - domain string - logger *log.Logger + agent *Agent + dnsHandler *dns.ServeMux + dnsServer *dns.Server + dnsServerTCP *dns.Server + domain string + recursor string + logger *log.Logger } // NewDNSServer starts a new DNS server to provide an agent interface -func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) { +func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) { // Make sure domain is FQDN domain = dns.Fqdn(domain) // Construct the DNS components mux := dns.NewServeMux() - // Setup the server + // Setup the servers server := &dns.Server{ Addr: bind, Net: "udp", Handler: mux, UDPSize: 65535, } + serverTCP := &dns.Server{ + Addr: bind, + Net: "tcp", + Handler: mux, + } // Create the server srv := &DNSServer{ - agent: agent, - dnsHandler: mux, - dnsServer: server, - domain: domain, - logger: log.New(logOutput, "", log.LstdFlags), + agent: agent, + dnsHandler: mux, + dnsServer: server, + dnsServerTCP: serverTCP, + domain: domain, + recursor: recursor, + logger: log.New(logOutput, "", log.LstdFlags), } // Register mux handlers, always handle "consul." @@ -56,15 +65,25 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS if domain != consulDomain { mux.HandleFunc(consulDomain, srv.handleTest) } + if recursor != "" { + mux.HandleFunc(".", srv.handleRecurse) + } - // Async start the DNS Server, handle a potential error + // Async start the DNS Servers, handle a potential error errCh := make(chan error, 1) go func() { err := server.ListenAndServe() - srv.logger.Printf("[ERR] dns: error starting server: %v", err) + srv.logger.Printf("[ERR] dns: error starting udp server: %v", err) errCh <- err }() + errChTCP := make(chan error, 1) + go func() { + err := serverTCP.ListenAndServe() + srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err) + errChTCP <- err + }() + // Check the server is running, do a test lookup checkCh := make(chan error, 1) go func() { @@ -93,6 +112,8 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS select { case e := <-errCh: return srv, e + case e := <-errChTCP: + return srv, e case e := <-checkCh: return srv, e case <-time.After(time.Second): @@ -119,10 +140,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetReply(req) m.Authoritative = true d.addSOA(d.domain, m) - defer resp.WriteMsg(m) // Dispatch the correct handler d.dispatch(req, m) + + // Write out the complete response + if err := resp.WriteMsg(m); err != nil { + d.logger.Printf("[WARN] dns: failed to respond: %v", err) + } } // handleTest is used to handle DNS queries in the ".consul." domain @@ -147,7 +172,9 @@ func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) { txt := &dns.TXT{header, []string{"ok"}} m.Answer = append(m.Answer, txt) d.addSOA(consulDomain, m) - resp.WriteMsg(m) + if err := resp.WriteMsg(m); err != nil { + d.logger.Printf("[WARN] dns: failed to respond: %v", err) + } } // addSOA is used to add an SOA record to a message for the given domain @@ -353,3 +380,40 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.ServiceNodes, req resp.Extra = append(resp.Extra, aRec) } } + +// handleRecurse is used to handle recursive DNS queries +func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { + q := req.Question[0] + network := "udp" + defer func(s time.Time) { + d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v)", q, network, time.Now().Sub(s)) + }(time.Now()) + + // Switch to TCP if the client is + if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { + network = "tcp" + } + + // Recursively resolve + c := &dns.Client{Net: network} + r, rtt, err := c.Exchange(req, d.recursor) + + // On failure, return a SERVFAIL message + if err != nil { + d.logger.Printf("[ERR] dns: recurse failed: %v", err) + m := &dns.Msg{} + m.SetReply(req) + m.SetRcode(req, dns.RcodeServerFailure) + resp.WriteMsg(m) + return + } + d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) + + // Seems to be a bug that forcing compression fixes... + r.Compress = true + + // Forward the response + if err := resp.WriteMsg(r); err != nil { + d.logger.Printf("[WARN] dns: failed to respond: %v", err) + } +} diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index 771f8dd52..82585bc02 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -11,7 +11,8 @@ import ( func makeDNSServer(t *testing.T) (string, *DNSServer) { conf := nextConfig() dir, agent := makeAgent(t, conf) - server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, conf.DNSAddr) + server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, + conf.DNSAddr, "8.8.8.8:53") if err != nil { t.Fatalf("err: %v", err) } @@ -173,3 +174,26 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Extra[0]) } } + +func TestDNS_Recurse(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + m := new(dns.Msg) + m.SetQuestion("apple.com.", dns.TypeANY) + + c := new(dns.Client) + c.Net = "tcp" + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) == 0 { + t.Fatalf("Bad: %#v", in) + } + if in.Rcode != dns.RcodeSuccess { + t.Fatalf("Bad: %#v", in) + } +}