Adding DNS based node lookup
This commit is contained in:
parent
ca43075041
commit
c0d53a9d62
|
@ -144,8 +144,8 @@ func (s *HTTPServer) CatalogNodeServices(resp http.ResponseWriter, req *http.Req
|
|||
}
|
||||
|
||||
// Make the RPC request
|
||||
var out structs.NodeServices
|
||||
if err := s.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
|
||||
out := new(structs.NodeServices)
|
||||
if err := s.agent.RPC("Catalog.NodeServices", &args, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
|
|
|
@ -232,8 +232,8 @@ func TestCatalogNodeServices(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
services := obj.(structs.NodeServices)
|
||||
if len(services) != 1 {
|
||||
services := obj.(*structs.NodeServices)
|
||||
if len(services.Services) != 1 {
|
||||
t.Fatalf("bad: %v", obj)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,23 +2,35 @@ package agent
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/miekg/dns"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
testQuery = "_test.consul."
|
||||
consulDomain = "consul."
|
||||
)
|
||||
|
||||
// DNSServer is used to wrap an Agent and expose various
|
||||
// service discovery endpoints using a DNS interface.
|
||||
type DNSServer struct {
|
||||
agent *Agent
|
||||
dnsHandler *dns.ServeMux
|
||||
dnsServer *dns.Server
|
||||
domain string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewDNSServer starts a new DNS server to provide an agent interface
|
||||
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) {
|
||||
// Make sure domain is FQDN
|
||||
domain = dns.Fqdn(domain)
|
||||
|
||||
// Construct the DNS components
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
|
@ -35,11 +47,15 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
|||
agent: agent,
|
||||
dnsHandler: mux,
|
||||
dnsServer: server,
|
||||
domain: domain,
|
||||
logger: log.New(logOutput, "", log.LstdFlags),
|
||||
}
|
||||
|
||||
// Register mux handlers
|
||||
mux.HandleFunc("consul.", srv.handleConsul)
|
||||
// Register mux handlers, always handle "consul."
|
||||
mux.HandleFunc(domain, srv.handleQuery)
|
||||
if domain != consulDomain {
|
||||
mux.HandleFunc(consulDomain, srv.handleTest)
|
||||
}
|
||||
|
||||
// Async start the DNS Server, handle a potential error
|
||||
errCh := make(chan error, 1)
|
||||
|
@ -57,7 +73,7 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("_test.consul.", dns.TypeANY)
|
||||
m.SetQuestion(testQuery, dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, bind)
|
||||
|
@ -85,12 +101,41 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
|
|||
return srv, nil
|
||||
}
|
||||
|
||||
// handleConsul is used to handle DNS queries in the ".consul." domain
|
||||
func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) {
|
||||
// handleQUery is used to handle DNS queries in the configured domain
|
||||
func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
d.logger.Printf("[DEBUG] dns: request for %v", q)
|
||||
defer func(s time.Time) {
|
||||
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
|
||||
}(time.Now())
|
||||
|
||||
if q.Qtype != dns.TypeANY && q.Qtype != dns.TypeTXT {
|
||||
// Check if this is potentially a test query
|
||||
if q.Name == testQuery {
|
||||
d.handleTest(resp, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Setup the message response
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
m.Authoritative = true
|
||||
d.addSOA(d.domain, m)
|
||||
defer resp.WriteMsg(m)
|
||||
|
||||
// Dispatch the correct handler
|
||||
d.dispatch(req, m)
|
||||
}
|
||||
|
||||
// handleTest is used to handle DNS queries in the ".consul." domain
|
||||
func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
|
||||
q := req.Question[0]
|
||||
defer func(s time.Time) {
|
||||
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
|
||||
}(time.Now())
|
||||
|
||||
if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) {
|
||||
return
|
||||
}
|
||||
if q.Name != testQuery {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -101,7 +146,7 @@ func (d *DNSServer) handleConsul(resp dns.ResponseWriter, req *dns.Msg) {
|
|||
header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
|
||||
txt := &dns.TXT{header, []string{"ok"}}
|
||||
m.Answer = append(m.Answer, txt)
|
||||
d.addSOA("consul.", m)
|
||||
d.addSOA(consulDomain, m)
|
||||
resp.WriteMsg(m)
|
||||
}
|
||||
|
||||
|
@ -124,3 +169,103 @@ func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
|
|||
}
|
||||
msg.Ns = append(msg.Ns, soa)
|
||||
}
|
||||
|
||||
// dispatch is used to parse a request and invoke the correct handler
|
||||
func (d *DNSServer) dispatch(req, resp *dns.Msg) {
|
||||
// By default the query is in the default datacenter
|
||||
datacenter := d.agent.config.Datacenter
|
||||
|
||||
// Get the QName without the domain suffix
|
||||
qName := dns.Fqdn(req.Question[0].Name)
|
||||
qName = strings.TrimSuffix(qName, d.domain)
|
||||
|
||||
// Split into the label parts
|
||||
labels := dns.SplitDomainName(qName)
|
||||
|
||||
// The last label is either "node", "service" or a datacenter name
|
||||
PARSE:
|
||||
if len(labels) == 0 {
|
||||
goto INVALID
|
||||
}
|
||||
switch labels[len(labels)-1] {
|
||||
case "service":
|
||||
// Handle lookup with and without tag
|
||||
switch len(labels) {
|
||||
case 2:
|
||||
d.serviceLookup(datacenter, labels[0], "", req, resp)
|
||||
case 3:
|
||||
d.serviceLookup(datacenter, labels[1], labels[0], req, resp)
|
||||
default:
|
||||
goto INVALID
|
||||
}
|
||||
|
||||
case "node":
|
||||
if len(labels) != 2 {
|
||||
goto INVALID
|
||||
}
|
||||
d.nodeLookup(datacenter, labels[0], req, resp)
|
||||
|
||||
default:
|
||||
// Store the DC, and re-parse
|
||||
datacenter = labels[len(labels)-1]
|
||||
labels = labels[:len(labels)-1]
|
||||
goto PARSE
|
||||
}
|
||||
return
|
||||
INVALID:
|
||||
d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
|
||||
resp.SetRcode(req, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
// nodeLookup is used to handle a node query
|
||||
func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) {
|
||||
// Only handle ANY and A type requests
|
||||
qType := req.Question[0].Qtype
|
||||
if qType != dns.TypeANY && qType != dns.TypeA {
|
||||
return
|
||||
}
|
||||
|
||||
// Make an RPC request
|
||||
args := structs.NodeServicesRequest{
|
||||
Datacenter: datacenter,
|
||||
Node: node,
|
||||
}
|
||||
var out structs.NodeServices
|
||||
if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
|
||||
d.logger.Printf("[ERR] dns: rpc error: %v", err)
|
||||
resp.SetRcode(req, dns.RcodeServerFailure)
|
||||
return
|
||||
}
|
||||
|
||||
// If we have no address, return not found!
|
||||
if out.Address == "" {
|
||||
resp.SetRcode(req, dns.RcodeNameError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the IP
|
||||
ip := net.ParseIP(out.Address)
|
||||
if ip == nil {
|
||||
d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", out.Address, node)
|
||||
resp.SetRcode(req, dns.RcodeServerFailure)
|
||||
return
|
||||
}
|
||||
|
||||
// Format A record
|
||||
aRec := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: ip,
|
||||
}
|
||||
|
||||
// Add the response
|
||||
resp.Answer = append(resp.Answer, aRec)
|
||||
}
|
||||
|
||||
// serviceLookup is used to handle a service query
|
||||
func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dns.Msg) {
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/consul/consul/structs"
|
||||
"github.com/miekg/dns"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func makeDNSServer(t *testing.T) (string, *DNSServer) {
|
||||
|
@ -42,3 +44,66 @@ func TestDNS_IsAlive(t *testing.T) {
|
|||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_NodeLookup(t *testing.T) {
|
||||
dir, srv := makeDNSServer(t)
|
||||
defer os.RemoveAll(dir)
|
||||
defer srv.agent.Shutdown()
|
||||
|
||||
// Wait for leader
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Register node
|
||||
args := &structs.RegisterRequest{
|
||||
Datacenter: "dc1",
|
||||
Node: "foo",
|
||||
Address: "127.0.0.1",
|
||||
}
|
||||
var out struct{}
|
||||
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("foo.node.consul.", dns.TypeANY)
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 1 {
|
||||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
|
||||
aRec, ok := in.Answer[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
if aRec.A.String() != "127.0.0.1" {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
|
||||
// Re-do the query, but specify the DC
|
||||
m = new(dns.Msg)
|
||||
m.SetQuestion("foo.node.dc1.consul.", dns.TypeANY)
|
||||
|
||||
c = new(dns.Client)
|
||||
in, _, err = c.Exchange(m, srv.agent.config.DNSAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if len(in.Answer) != 1 {
|
||||
t.Fatalf("Bad: %#v", in)
|
||||
}
|
||||
|
||||
aRec, ok = in.Answer[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
if aRec.A.String() != "127.0.0.1" {
|
||||
t.Fatalf("Bad: %#v", in.Answer[0])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,14 +72,14 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque
|
|||
// Invoke the handler
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
s.logger.Printf("[DEBUG] HTTP Request %v (%v)", req.URL, time.Now().Sub(start))
|
||||
s.logger.Printf("[DEBUG] http: Request %v (%v)", req.URL, time.Now().Sub(start))
|
||||
}()
|
||||
obj, err := handler(resp, req)
|
||||
|
||||
// Check for an error
|
||||
HAS_ERR:
|
||||
if err != nil {
|
||||
s.logger.Printf("[ERR] Request %v, error: %v", req.URL, err)
|
||||
s.logger.Printf("[ERR] http: Request %v, error: %v", req.URL, err)
|
||||
resp.WriteHeader(500)
|
||||
resp.Write([]byte(err.Error()))
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue