Merge pull request #833 from mgood/dns-startup-hooks

Simplify DNS server startup check
This commit is contained in:
Armon Dadgar 2015-03-31 17:26:45 -07:00
commit 325e87cdaa
2 changed files with 21 additions and 98 deletions

View File

@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/consul/structs"
@ -14,8 +15,6 @@ import (
) )
const ( const (
testQuery = "_test.consul."
consulDomain = "consul."
maxServiceResponses = 3 // For UDP only maxServiceResponses = 3 // For UDP only
maxRecurseRecords = 3 maxRecurseRecords = 3
) )
@ -51,17 +50,21 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
// Construct the DNS components // Construct the DNS components
mux := dns.NewServeMux() mux := dns.NewServeMux()
var wg sync.WaitGroup
// Setup the servers // Setup the servers
server := &dns.Server{ server := &dns.Server{
Addr: bind, Addr: bind,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
UDPSize: 65535, UDPSize: 65535,
NotifyStartedFunc: wg.Done,
} }
serverTCP := &dns.Server{ serverTCP := &dns.Server{
Addr: bind, Addr: bind,
Net: "tcp", Net: "tcp",
Handler: mux, Handler: mux,
NotifyStartedFunc: wg.Done,
} }
// Create the server // Create the server
@ -79,11 +82,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
// Register mux handler, for reverse lookup // Register mux handler, for reverse lookup
mux.HandleFunc("arpa.", srv.handlePtr) mux.HandleFunc("arpa.", srv.handlePtr)
// Register mux handlers, always handle "consul." // Register mux handlers
mux.HandleFunc(domain, srv.handleQuery) mux.HandleFunc(domain, srv.handleQuery)
if domain != consulDomain {
mux.HandleFunc(consulDomain, srv.handleTest)
}
if len(recursors) > 0 { if len(recursors) > 0 {
validatedRecursors := make([]string, len(recursors)) validatedRecursors := make([]string, len(recursors))
@ -99,6 +99,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
mux.HandleFunc(".", srv.handleRecurse) mux.HandleFunc(".", srv.handleRecurse)
} }
wg.Add(2)
// Async start the DNS Servers, handle a potential error // Async start the DNS Servers, handle a potential error
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
@ -116,28 +118,11 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
} }
}() }()
// Check the server is running, do a test lookup // Wait for NotifyStartedFunc callbacks indicating server has started
checkCh := make(chan error, 1) startCh := make(chan struct{})
go func() { go func() {
// This is jank, but we have no way to edge trigger on wg.Wait()
// the start of our server, so we just wait and hope it is up. close(startCh)
time.Sleep(50 * time.Millisecond)
m := new(dns.Msg)
m.SetQuestion(testQuery, dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, bind)
if err != nil {
checkCh <- fmt.Errorf("dns test query failed: %v", err)
return
}
if len(in.Answer) == 0 {
checkCh <- fmt.Errorf("no response to test message")
return
}
close(checkCh)
}() }()
// Wait for either the check, listen error, or timeout // Wait for either the check, listen error, or timeout
@ -146,8 +131,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
return srv, e return srv, e
case e := <-errChTCP: case e := <-errChTCP:
return srv, e return srv, e
case e := <-checkCh: case <-startCh:
return srv, e return srv, nil
case <-time.After(time.Second): case <-time.After(time.Second):
return srv, fmt.Errorf("timeout setting up DNS server") return srv, fmt.Errorf("timeout setting up DNS server")
} }
@ -234,12 +219,6 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s)) d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now()) }(time.Now())
// Check if this is potentially a test query
if q.Name == testQuery {
d.handleTest(resp, req)
return
}
// Switch to TCP if the client is // Switch to TCP if the client is
network := "udp" network := "udp"
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
@ -266,34 +245,6 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
} }
} }
// handleTest is used to handle DNS queries in the ".consul." domain
func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now())
if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) {
return
}
if q.Name != testQuery {
return
}
// Always respond with TXT "ok"
m := new(dns.Msg)
m.SetReply(req)
m.Authoritative = true
m.RecursionAvailable = true
header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
txt := &dns.TXT{Hdr: header, Txt: []string{"ok"}}
m.Answer = append(m.Answer, txt)
d.addSOA(consulDomain, m)
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}
// addSOA is used to add an SOA record to a message for the given domain // addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(domain string, msg *dns.Msg) { func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
soa := &dns.SOA{ soa := &dns.SOA{

View File

@ -39,34 +39,6 @@ func TestRecursorAddr(t *testing.T) {
} }
} }
func TestDNS_IsAlive(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()
m := new(dns.Msg)
m.SetQuestion("_test.consul.", dns.TypeANY)
c := new(dns.Client)
addr, _ := srv.agent.config.ClientListener("", srv.agent.config.Ports.DNS)
in, _, err := c.Exchange(m, addr.String())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
txt, ok := in.Answer[0].(*dns.TXT)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if txt.Txt[0] != "ok" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}
func TestDNS_NodeLookup(t *testing.T) { func TestDNS_NodeLookup(t *testing.T) {
dir, srv := makeDNSServer(t) dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)