Adding DNS based node lookup

This commit is contained in:
Armon Dadgar 2014-01-02 17:58:58 -08:00
parent ca43075041
commit c0d53a9d62
5 changed files with 224 additions and 14 deletions

View File

@ -144,8 +144,8 @@ func (s *HTTPServer) CatalogNodeServices(resp http.ResponseWriter, req *http.Req
}
// Make the RPC request
var out structs.NodeServices
if err := s.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
out := new(structs.NodeServices)
if err := s.agent.RPC("Catalog.NodeServices", &args, out); err != nil {
return nil, err
}
return out, nil

View File

@ -232,8 +232,8 @@ func TestCatalogNodeServices(t *testing.T) {
t.Fatalf("err: %v", err)
}
services := obj.(structs.NodeServices)
if len(services) != 1 {
services := obj.(*structs.NodeServices)
if len(services.Services) != 1 {
t.Fatalf("bad: %v", obj)
}
}

View File

@ -2,23 +2,35 @@ package agent
import (
"fmt"
"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
"io"
"log"
"net"
"strings"
"time"
)
const (
testQuery = "_test.consul."
consulDomain = "consul."
)
// DNSServer is used to wrap an Agent and expose various
// service discovery endpoints using a DNS interface.
type DNSServer struct {
agent *Agent
dnsHandler *dns.ServeMux
dnsServer *dns.Server
domain string
logger *log.Logger
}
// NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) {
// Make sure domain is FQDN
domain = dns.Fqdn(domain)
// Construct the DNS components
mux := dns.NewServeMux()
@ -35,11 +47,15 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
agent: agent,
dnsHandler: mux,
dnsServer: server,
domain: domain,
logger: log.New(logOutput, "", log.LstdFlags),
}
// Register mux handlers
mux.HandleFunc("consul.", srv.handleConsul)
// Register mux handlers, always handle "consul."
mux.HandleFunc(domain, srv.handleQuery)
if domain != consulDomain {
mux.HandleFunc(consulDomain, srv.handleTest)
}
// Async start the DNS Server, handle a potential error
errCh := make(chan error, 1)
@ -57,7 +73,7 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
time.Sleep(50 * time.Millisecond)
m := new(dns.Msg)
m.SetQuestion("_test.consul.", dns.TypeANY)
m.SetQuestion(testQuery, dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, bind)
@ -85,12 +101,41 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
return srv, nil
}
// handleConsul is used to handle DNS queries in the ".consul." domain
func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) {
// handleQUery is used to handle DNS queries in the configured domain
func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
d.logger.Printf("[DEBUG] dns: request for %v", q)
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now())
if q.Qtype != dns.TypeANY && q.Qtype != dns.TypeTXT {
// Check if this is potentially a test query
if q.Name == testQuery {
d.handleTest(resp, req)
return
}
// Setup the message response
m := new(dns.Msg)
m.SetReply(req)
m.Authoritative = true
d.addSOA(d.domain, m)
defer resp.WriteMsg(m)
// Dispatch the correct handler
d.dispatch(req, m)
}
// handleTest is used to handle DNS queries in the ".consul." domain
func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now())
if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) {
return
}
if q.Name != testQuery {
return
}
@ -101,7 +146,7 @@ func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) {
header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
txt := &dns.TXT{header, []string{"ok"}}
m.Answer = append(m.Answer, txt)
d.addSOA("consul.", m)
d.addSOA(consulDomain, m)
resp.WriteMsg(m)
}
@ -124,3 +169,103 @@ func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
}
msg.Ns = append(msg.Ns, soa)
}
// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(req, resp *dns.Msg) {
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
// Get the QName without the domain suffix
qName := dns.Fqdn(req.Question[0].Name)
qName = strings.TrimSuffix(qName, d.domain)
// Split into the label parts
labels := dns.SplitDomainName(qName)
// The last label is either "node", "service" or a datacenter name
PARSE:
if len(labels) == 0 {
goto INVALID
}
switch labels[len(labels)-1] {
case "service":
// Handle lookup with and without tag
switch len(labels) {
case 2:
d.serviceLookup(datacenter, labels[0], "", req, resp)
case 3:
d.serviceLookup(datacenter, labels[1], labels[0], req, resp)
default:
goto INVALID
}
case "node":
if len(labels) != 2 {
goto INVALID
}
d.nodeLookup(datacenter, labels[0], req, resp)
default:
// Store the DC, and re-parse
datacenter = labels[len(labels)-1]
labels = labels[:len(labels)-1]
goto PARSE
}
return
INVALID:
d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
resp.SetRcode(req, dns.RcodeNameError)
}
// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) {
// Only handle ANY and A type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA {
return
}
// Make an RPC request
args := structs.NodeServicesRequest{
Datacenter: datacenter,
Node: node,
}
var out structs.NodeServices
if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// If we have no address, return not found!
if out.Address == "" {
resp.SetRcode(req, dns.RcodeNameError)
return
}
// Parse the IP
ip := net.ParseIP(out.Address)
if ip == nil {
d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", out.Address, node)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// Format A record
aRec := &dns.A{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: ip,
}
// Add the response
resp.Answer = append(resp.Answer, aRec)
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dns.Msg) {
}

View File

@ -1,9 +1,11 @@
package agent
import (
"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
"os"
"testing"
"time"
)
func makeDNSServer(t *testing.T) (string, *DNSServer) {
@ -42,3 +44,66 @@ func TestDNS_IsAlive(t *testing.T) {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}
func TestDNS_NodeLookup(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()
// Wait for leader
time.Sleep(100 * time.Millisecond)
// Register node
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "foo",
Address: "127.0.0.1",
}
var out struct{}
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
m := new(dns.Msg)
m.SetQuestion("foo.node.consul.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
aRec, ok := in.Answer[0].(*dns.A)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if aRec.A.String() != "127.0.0.1" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
// Re-do the query, but specify the DC
m = new(dns.Msg)
m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY)
c = new(dns.Client)
in, _, err = c.Exchange(m, srv.agent.config.DNSAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}
aRec, ok = in.Answer[0].(*dns.A)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if aRec.A.String() != "127.0.0.1" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}

View File

@ -72,14 +72,14 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
// Invoke the handler
start := time.Now()
defer func() {
s.logger.Printf("[DEBUG] HTTP Request %v (%v)", req.URL, time.Now().Sub(start))
s.logger.Printf("[DEBUG] http: Request %v (%v)", req.URL, time.Now().Sub(start))
}()
obj, err := handler(resp, req)
// Check for an error
HAS_ERR:
if err != nil {
s.logger.Printf("[ERR] Request %v, error: %v", req.URL, err)
s.logger.Printf("[ERR] http: Request %v, error: %v", req.URL, err)
resp.WriteHeader(500)
resp.Write([]byte(err.Error()))
return