Support DNS recursion and TCP queries
This commit is contained in:
parent
43bf345d32
commit
e2e844a70b
|
@ -52,6 +52,7 @@ func (c *Command) readConfig() *Config {
|
|||
"address to bind RPC listener to")
|
||||
cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory")
|
||||
cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter")
|
||||
cmdFlags.StringVar(&cmdConfig.DNSRecursor, "recursor", "", "address of dns recursor")
|
||||
cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server")
|
||||
cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode")
|
||||
if err := cmdFlags.Parse(c.args); err != nil {
|
||||
|
@ -148,7 +149,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
|
|||
}
|
||||
|
||||
if config.DNSAddr != "" {
|
||||
server, err := NewDNSServer(agent, logOutput, config.Domain, config.DNSAddr)
|
||||
server, err := NewDNSServer(agent, logOutput, config.Domain,
|
||||
config.DNSAddr, config.DNSRecursor)
|
||||
if err != nil {
|
||||
agent.Shutdown()
|
||||
c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err))
|
||||
|
|
|
@ -30,6 +30,10 @@ type Config struct {
|
|||
// DNSAddr is the address of the DNS server for the agent
|
||||
DNSAddr string
|
||||
|
||||
// DNSRecursor can be set to allow the DNS server to recursively
|
||||
// resolve non-consul domains
|
||||
DNSRecursor string
|
||||
|
||||
// Domain is the DNS domain for the records. Defaults to "consul."
|
||||
Domain string
|
||||
|
||||
|
@ -154,6 +158,9 @@ func MergeConfig(a, b *Config) *Config {
|
|||
if b.DNSAddr != "" {
|
||||
result.DNSAddr = b.DNSAddr
|
||||
}
|
||||
if b.DNSRecursor != "" {
|
||||
result.DNSRecursor = b.DNSRecursor
|
||||
}
|
||||
if b.Domain != "" {
|
||||
result.Domain = b.Domain
|
||||
}
|
||||
|
|
|
@ -19,36 +19,45 @@ const (
|
|||
// DNSServer is used to wrap an Agent and expose various
|
||||
// service discovery endpoints using a DNS interface.
|
||||
type DNSServer struct {
|
||||
agent *Agent
|
||||
dnsHandler *dns.ServeMux
|
||||
dnsServer *dns.Server
|
||||
domain string
|
||||
logger *log.Logger
|
||||
agent *Agent
|
||||
dnsHandler *dns.ServeMux
|
||||
dnsServer *dns.Server
|
||||
dnsServerTCP *dns.Server
|
||||
domain string
|
||||
recursor string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewDNSServer starts a new DNS server to provide an agent interface
|
||||
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) {
|
||||
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) {
|
||||
// Make sure domain is FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
// Construct the DNS components
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
// Setup the server
|
||||
// Setup the servers
|
||||
server := &dns.Server{
|
||||
Addr: bind,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
serverTCP := &dns.Server{
|
||||
Addr: bind,
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Create the server
|
||||
srv := &DNSServer{
|
||||
agent: agent,
|
||||
dnsHandler: mux,
|
||||
dnsServer: server,
|
||||
domain: domain,
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
agent: agent,
|
||||
dnsHandler: mux,
|
||||
dnsServer: server,
|
||||
dnsServerTCP: serverTCP,
|
||||
domain: domain,
|
||||
recursor: recursor,
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
}
|
||||
|
||||
// Register mux handlers, always handle "consul."
|
||||
|
@ -56,15 +65,25 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
|||
if domain != consulDomain {
|
||||
mux.HandleFunc(consulDomain, srv.handleTest)
|
||||
}
|
||||
if recursor != "" {
|
||||
mux.HandleFunc(".", srv.handleRecurse)
|
||||
}
|
||||
|
||||
// Async start the DNS Server, handle a potential error
|
||||
// Async start the DNS Servers, handle a potential error
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
srv.logger.Printf("[ERR] dns: error starting server: %v", err)
|
||||
srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
errChTCP := make(chan error, 1)
|
||||
go func() {
|
||||
err := serverTCP.ListenAndServe()
|
||||
srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
|
||||
errChTCP <- err
|
||||
}()
|
||||
|
||||
// Check the server is running, do a test lookup
|
||||
checkCh := make(chan error, 1)
|
||||
go func() {
|
||||
|
@ -93,6 +112,8 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
|||
select {
|
||||
case e := <-errCh:
|
||||
return srv, e
|
||||
case e := <-errChTCP:
|
||||
return srv, e
|
||||
case e := <-checkCh:
|
||||
return srv, e
|
||||
case <-time.After(time.Second):
|
||||
|
@ -119,10 +140,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
m.SetReply(req)
|
||||
m.Authoritative = true
|
||||
d.addSOA(d.domain, m)
|
||||
defer resp.WriteMsg(m)
|
||||
|
||||
// Dispatch the correct handler
|
||||
d.dispatch(req, m)
|
||||
|
||||
// Write out the complete response
|
||||
if err := resp.WriteMsg(m); err != nil {
|
||||
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTest is used to handle DNS queries in the ".consul." domain
|
||||
|
@ -147,7 +172,9 @@ func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
txt := &dns.TXT{header, []string{"ok"}}
|
||||
m.Answer = append(m.Answer, txt)
|
||||
d.addSOA(consulDomain, m)
|
||||
resp.WriteMsg(m)
|
||||
if err := resp.WriteMsg(m); err != nil {
|
||||
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// addSOA is used to add an SOA record to a message for the given domain
|
||||
|
@ -353,3 +380,40 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.ServiceNodes, req
|
|||
resp.Extra = append(resp.Extra, aRec)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRecurse is used to handle recursive DNS queries
|
||||
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
network := "udp"
|
||||
defer func(s time.Time) {
|
||||
d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v)", q, network, time.Now().Sub(s))
|
||||
}(time.Now())
|
||||
|
||||
// Switch to TCP if the client is
|
||||
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
|
||||
network = "tcp"
|
||||
}
|
||||
|
||||
// Recursively resolve
|
||||
c := &dns.Client{Net: network}
|
||||
r, rtt, err := c.Exchange(req, d.recursor)
|
||||
|
||||
// On failure, return a SERVFAIL message
|
||||
if err != nil {
|
||||
d.logger.Printf("[ERR] dns: recurse failed: %v", err)
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(req)
|
||||
m.SetRcode(req, dns.RcodeServerFailure)
|
||||
resp.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
|
||||
|
||||
// Seems to be a bug that forcing compression fixes...
|
||||
r.Compress = true
|
||||
|
||||
// Forward the response
|
||||
if err := resp.WriteMsg(r); err != nil {
|
||||
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ import (
|
|||
func makeDNSServer(t *testing.T) (string, *DNSServer) {
|
||||
conf := nextConfig()
|
||||
dir, agent := makeAgent(t, conf)
|
||||
server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, conf.DNSAddr)
|
||||
server, err := NewDNSServer(agent, agent.logOutput, conf.Domain,
|
||||
conf.DNSAddr, "8.8.8.8:53")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -173,3 +174,26 @@ func TestDNS_ServiceLookup(t *testing.T) {
|
|||
t.Fatalf("Bad: %#v", in.Extra[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Recurse(t *testing.T) {
|
||||
dir, srv := makeDNSServer(t)
|
||||
defer os.RemoveAll(dir)
|
||||
defer srv.agent.Shutdown()
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("apple.com.", dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
c.Net = "tcp"
|
||||
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(in.Answer) == 0 {
|
||||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
if in.Rcode != dns.RcodeSuccess {
|
||||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue