Support DNS recursion and TCP queries
This commit is contained in:
parent
43bf345d32
commit
e2e844a70b
|
@ -52,6 +52,7 @@ func (c *Command) readConfig() *Config {
|
||||||
"address to bind RPC listener to")
|
"address to bind RPC listener to")
|
||||||
cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory")
|
cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory")
|
||||||
cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter")
|
cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter")
|
||||||
|
cmdFlags.StringVar(&cmdConfig.DNSRecursor, "recursor", "", "address of dns recursor")
|
||||||
cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server")
|
cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server")
|
||||||
cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode")
|
cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode")
|
||||||
if err := cmdFlags.Parse(c.args); err != nil {
|
if err := cmdFlags.Parse(c.args); err != nil {
|
||||||
|
@ -148,7 +149,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DNSAddr != "" {
|
if config.DNSAddr != "" {
|
||||||
server, err := NewDNSServer(agent, logOutput, config.Domain, config.DNSAddr)
|
server, err := NewDNSServer(agent, logOutput, config.Domain,
|
||||||
|
config.DNSAddr, config.DNSRecursor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
agent.Shutdown()
|
agent.Shutdown()
|
||||||
c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err))
|
c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err))
|
||||||
|
|
|
@ -30,6 +30,10 @@ type Config struct {
|
||||||
// DNSAddr is the address of the DNS server for the agent
|
// DNSAddr is the address of the DNS server for the agent
|
||||||
DNSAddr string
|
DNSAddr string
|
||||||
|
|
||||||
|
// DNSRecursor can be set to allow the DNS server to recursively
|
||||||
|
// resolve non-consul domains
|
||||||
|
DNSRecursor string
|
||||||
|
|
||||||
// Domain is the DNS domain for the records. Defaults to "consul."
|
// Domain is the DNS domain for the records. Defaults to "consul."
|
||||||
Domain string
|
Domain string
|
||||||
|
|
||||||
|
@ -154,6 +158,9 @@ func MergeConfig(a, b *Config) *Config {
|
||||||
if b.DNSAddr != "" {
|
if b.DNSAddr != "" {
|
||||||
result.DNSAddr = b.DNSAddr
|
result.DNSAddr = b.DNSAddr
|
||||||
}
|
}
|
||||||
|
if b.DNSRecursor != "" {
|
||||||
|
result.DNSRecursor = b.DNSRecursor
|
||||||
|
}
|
||||||
if b.Domain != "" {
|
if b.Domain != "" {
|
||||||
result.Domain = b.Domain
|
result.Domain = b.Domain
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,36 +19,45 @@ const (
|
||||||
// DNSServer is used to wrap an Agent and expose various
|
// DNSServer is used to wrap an Agent and expose various
|
||||||
// service discovery endpoints using a DNS interface.
|
// service discovery endpoints using a DNS interface.
|
||||||
type DNSServer struct {
|
type DNSServer struct {
|
||||||
agent *Agent
|
agent *Agent
|
||||||
dnsHandler *dns.ServeMux
|
dnsHandler *dns.ServeMux
|
||||||
dnsServer *dns.Server
|
dnsServer *dns.Server
|
||||||
domain string
|
dnsServerTCP *dns.Server
|
||||||
logger *log.Logger
|
domain string
|
||||||
|
recursor string
|
||||||
|
logger *log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSServer starts a new DNS server to provide an agent interface
|
// NewDNSServer starts a new DNS server to provide an agent interface
|
||||||
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) {
|
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) {
|
||||||
// Make sure domain is FQDN
|
// Make sure domain is FQDN
|
||||||
domain = dns.Fqdn(domain)
|
domain = dns.Fqdn(domain)
|
||||||
|
|
||||||
// Construct the DNS components
|
// Construct the DNS components
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
// Setup the server
|
// Setup the servers
|
||||||
server := &dns.Server{
|
server := &dns.Server{
|
||||||
Addr: bind,
|
Addr: bind,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
UDPSize: 65535,
|
UDPSize: 65535,
|
||||||
}
|
}
|
||||||
|
serverTCP := &dns.Server{
|
||||||
|
Addr: bind,
|
||||||
|
Net: "tcp",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
// Create the server
|
// Create the server
|
||||||
srv := &DNSServer{
|
srv := &DNSServer{
|
||||||
agent: agent,
|
agent: agent,
|
||||||
dnsHandler: mux,
|
dnsHandler: mux,
|
||||||
dnsServer: server,
|
dnsServer: server,
|
||||||
domain: domain,
|
dnsServerTCP: serverTCP,
|
||||||
logger: log.New(logOutput, "", log.LstdFlags),
|
domain: domain,
|
||||||
|
recursor: recursor,
|
||||||
|
logger: log.New(logOutput, "", log.LstdFlags),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register mux handlers, always handle "consul."
|
// Register mux handlers, always handle "consul."
|
||||||
|
@ -56,15 +65,25 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
||||||
if domain != consulDomain {
|
if domain != consulDomain {
|
||||||
mux.HandleFunc(consulDomain, srv.handleTest)
|
mux.HandleFunc(consulDomain, srv.handleTest)
|
||||||
}
|
}
|
||||||
|
if recursor != "" {
|
||||||
|
mux.HandleFunc(".", srv.handleRecurse)
|
||||||
|
}
|
||||||
|
|
||||||
// Async start the DNS Server, handle a potential error
|
// Async start the DNS Servers, handle a potential error
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
srv.logger.Printf("[ERR] dns: error starting server: %v", err)
|
srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
errChTCP := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := serverTCP.ListenAndServe()
|
||||||
|
srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
|
||||||
|
errChTCP <- err
|
||||||
|
}()
|
||||||
|
|
||||||
// Check the server is running, do a test lookup
|
// Check the server is running, do a test lookup
|
||||||
checkCh := make(chan error, 1)
|
checkCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -93,6 +112,8 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
||||||
select {
|
select {
|
||||||
case e := <-errCh:
|
case e := <-errCh:
|
||||||
return srv, e
|
return srv, e
|
||||||
|
case e := <-errChTCP:
|
||||||
|
return srv, e
|
||||||
case e := <-checkCh:
|
case e := <-checkCh:
|
||||||
return srv, e
|
return srv, e
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
|
@ -119,10 +140,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
||||||
m.SetReply(req)
|
m.SetReply(req)
|
||||||
m.Authoritative = true
|
m.Authoritative = true
|
||||||
d.addSOA(d.domain, m)
|
d.addSOA(d.domain, m)
|
||||||
defer resp.WriteMsg(m)
|
|
||||||
|
|
||||||
// Dispatch the correct handler
|
// Dispatch the correct handler
|
||||||
d.dispatch(req, m)
|
d.dispatch(req, m)
|
||||||
|
|
||||||
|
// Write out the complete response
|
||||||
|
if err := resp.WriteMsg(m); err != nil {
|
||||||
|
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleTest is used to handle DNS queries in the ".consul." domain
|
// handleTest is used to handle DNS queries in the ".consul." domain
|
||||||
|
@ -147,7 +172,9 @@ func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
|
||||||
txt := &dns.TXT{header, []string{"ok"}}
|
txt := &dns.TXT{header, []string{"ok"}}
|
||||||
m.Answer = append(m.Answer, txt)
|
m.Answer = append(m.Answer, txt)
|
||||||
d.addSOA(consulDomain, m)
|
d.addSOA(consulDomain, m)
|
||||||
resp.WriteMsg(m)
|
if err := resp.WriteMsg(m); err != nil {
|
||||||
|
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// addSOA is used to add an SOA record to a message for the given domain
|
// addSOA is used to add an SOA record to a message for the given domain
|
||||||
|
@ -353,3 +380,40 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.ServiceNodes, req
|
||||||
resp.Extra = append(resp.Extra, aRec)
|
resp.Extra = append(resp.Extra, aRec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleRecurse is used to handle recursive DNS queries
|
||||||
|
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
q := req.Question[0]
|
||||||
|
network := "udp"
|
||||||
|
defer func(s time.Time) {
|
||||||
|
d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v)", q, network, time.Now().Sub(s))
|
||||||
|
}(time.Now())
|
||||||
|
|
||||||
|
// Switch to TCP if the client is
|
||||||
|
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
|
||||||
|
network = "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively resolve
|
||||||
|
c := &dns.Client{Net: network}
|
||||||
|
r, rtt, err := c.Exchange(req, d.recursor)
|
||||||
|
|
||||||
|
// On failure, return a SERVFAIL message
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Printf("[ERR] dns: recurse failed: %v", err)
|
||||||
|
m := &dns.Msg{}
|
||||||
|
m.SetReply(req)
|
||||||
|
m.SetRcode(req, dns.RcodeServerFailure)
|
||||||
|
resp.WriteMsg(m)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
|
||||||
|
|
||||||
|
// Seems to be a bug that forcing compression fixes...
|
||||||
|
r.Compress = true
|
||||||
|
|
||||||
|
// Forward the response
|
||||||
|
if err := resp.WriteMsg(r); err != nil {
|
||||||
|
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,8 @@ import (
|
||||||
func makeDNSServer(t *testing.T) (string, *DNSServer) {
|
func makeDNSServer(t *testing.T) (string, *DNSServer) {
|
||||||
conf := nextConfig()
|
conf := nextConfig()
|
||||||
dir, agent := makeAgent(t, conf)
|
dir, agent := makeAgent(t, conf)
|
||||||
server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, conf.DNSAddr)
|
server, err := NewDNSServer(agent, agent.logOutput, conf.Domain,
|
||||||
|
conf.DNSAddr, "8.8.8.8:53")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -173,3 +174,26 @@ func TestDNS_ServiceLookup(t *testing.T) {
|
||||||
t.Fatalf("Bad: %#v", in.Extra[0])
|
t.Fatalf("Bad: %#v", in.Extra[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNS_Recurse(t *testing.T) {
|
||||||
|
dir, srv := makeDNSServer(t)
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
defer srv.agent.Shutdown()
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("apple.com.", dns.TypeANY)
|
||||||
|
|
||||||
|
c := new(dns.Client)
|
||||||
|
c.Net = "tcp"
|
||||||
|
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) == 0 {
|
||||||
|
t.Fatalf("Bad: %#v", in)
|
||||||
|
}
|
||||||
|
if in.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Fatalf("Bad: %#v", in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue