diff --git a/agent/consul/health_endpoint.go b/agent/consul/health_endpoint.go index db59356c8..70cc2e37d 100644 --- a/agent/consul/health_endpoint.go +++ b/agent/consul/health_endpoint.go @@ -111,18 +111,30 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc return fmt.Errorf("Must provide service name") } + // Determine the function we'll call + var f func(memdb.WatchSet, *state.Store) (uint64, structs.CheckServiceNodes, error) + switch { + case args.Connect: + f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.CheckServiceNodes, error) { + return s.CheckConnectServiceNodes(ws, args.ServiceName) + } + + case args.TagFilter: + f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.CheckServiceNodes, error) { + return s.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag) + } + + default: + f = func(ws memdb.WatchSet, s *state.Store) (uint64, structs.CheckServiceNodes, error) { + return s.CheckServiceNodes(ws, args.ServiceName) + } + } + err := h.srv.blockingQuery( &args.QueryOptions, &reply.QueryMeta, func(ws memdb.WatchSet, state *state.Store) error { - var index uint64 - var nodes structs.CheckServiceNodes - var err error - if args.TagFilter { - index, nodes, err = state.CheckServiceTagNodes(ws, args.ServiceName, args.ServiceTag) - } else { - index, nodes, err = state.CheckServiceNodes(ws, args.ServiceName) - } + index, nodes, err := f(ws, state) if err != nil { return err } @@ -139,14 +151,20 @@ func (h *Health) ServiceNodes(args *structs.ServiceSpecificRequest, reply *struc // Provide some metrics if err == nil { - metrics.IncrCounterWithLabels([]string{"health", "service", "query"}, 1, + // For metrics, we separate Connect-based lookups from non-Connect + key := "service" + if args.Connect { + key = "connect" + } + + metrics.IncrCounterWithLabels([]string{"health", key, "query"}, 1, []metrics.Label{{Name: "service", Value: args.ServiceName}}) if args.ServiceTag != "" { - metrics.IncrCounterWithLabels([]string{"health", "service", "query-tag"}, 1, + metrics.IncrCounterWithLabels([]string{"health", key, "query-tag"}, 1, []metrics.Label{{Name: "service", Value: args.ServiceName}, {Name: "tag", Value: args.ServiceTag}}) } if len(reply.Nodes) == 0 { - metrics.IncrCounterWithLabels([]string{"health", "service", "not-found"}, 1, + metrics.IncrCounterWithLabels([]string{"health", key, "not-found"}, 1, []metrics.Label{{Name: "service", Value: args.ServiceName}}) } } diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 3eb733bbe..2ce2da36b 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -1525,14 +1525,36 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t // CheckServiceNodes is used to query all nodes and checks for a given service. func (s *Store) CheckServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) { + return s.checkServiceNodes(ws, serviceName, false) +} + +// CheckConnectServiceNodes is used to query all nodes and checks for Connect +// compatible endpoints for a given service. +func (s *Store) CheckConnectServiceNodes(ws memdb.WatchSet, serviceName string) (uint64, structs.CheckServiceNodes, error) { + return s.checkServiceNodes(ws, serviceName, true) +} + +func (s *Store) checkServiceNodes(ws memdb.WatchSet, serviceName string, connect bool) (uint64, structs.CheckServiceNodes, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. idx := maxIndexForService(tx, serviceName, true) + // Function for lookup + var f func() (memdb.ResultIterator, error) + if !connect { + f = func() (memdb.ResultIterator, error) { + return tx.Get("services", "service", serviceName) + } + } else { + f = func() (memdb.ResultIterator, error) { + return tx.Get("services", "proxy_destination", serviceName) + } + } + // Query the state store for the service. - iter, err := tx.Get("services", "service", serviceName) + iter, err := f() if err != nil { return 0, nil, fmt.Errorf("failed service lookup: %s", err) } diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index 1f20fb9b8..9d771ca48 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -2529,6 +2529,48 @@ func TestStateStore_CheckServiceNodes(t *testing.T) { } } +func TestStateStore_CheckConnectServiceNodes(t *testing.T) { + assert := assert.New(t) + s := testStateStore(t) + + // Listing with no results returns an empty list. + ws := memdb.NewWatchSet() + idx, nodes, err := s.CheckConnectServiceNodes(ws, "db") + assert.Nil(err) + assert.Equal(idx, uint64(0)) + assert.Len(nodes, 0) + + // Create some nodes and services. + assert.Nil(s.EnsureNode(10, &structs.Node{Node: "foo", Address: "127.0.0.1"})) + assert.Nil(s.EnsureNode(11, &structs.Node{Node: "bar", Address: "127.0.0.2"})) + assert.Nil(s.EnsureService(12, "foo", &structs.NodeService{ID: "db", Service: "db", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(13, "bar", &structs.NodeService{ID: "api", Service: "api", Tags: nil, Address: "", Port: 5000})) + assert.Nil(s.EnsureService(14, "foo", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000})) + assert.Nil(s.EnsureService(15, "bar", &structs.NodeService{Kind: structs.ServiceKindConnectProxy, ID: "proxy", Service: "proxy", ProxyDestination: "db", Port: 8000})) + assert.Nil(s.EnsureService(16, "bar", &structs.NodeService{ID: "db2", Service: "db", Tags: []string{"slave"}, Address: "", Port: 8001})) + assert.True(watchFired(ws)) + + // Register node checks + testRegisterCheck(t, s, 17, "foo", "", "check1", api.HealthPassing) + testRegisterCheck(t, s, 18, "bar", "", "check2", api.HealthPassing) + + // Register checks against the services. + testRegisterCheck(t, s, 19, "foo", "db", "check3", api.HealthPassing) + testRegisterCheck(t, s, 20, "bar", "proxy", "check4", api.HealthPassing) + + // Read everything back. + ws = memdb.NewWatchSet() + idx, nodes, err = s.CheckConnectServiceNodes(ws, "db") + assert.Nil(err) + assert.Equal(idx, uint64(idx)) + assert.Len(nodes, 2) + + for _, n := range nodes { + assert.Equal(structs.ServiceKindConnectProxy, n.Service.Kind) + assert.Equal("db", n.Service.ProxyDestination) + } +} + func BenchmarkCheckServiceNodes(b *testing.B) { s, err := NewStateStore(nil) if err != nil { diff --git a/agent/dns_test.go b/agent/dns_test.go index 5d1082888..a501a9c9f 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1053,6 +1053,7 @@ func TestDNS_ConnectServiceLookup(t *testing.T) { { args := structs.TestRegisterRequestProxy(t) args.Service.ProxyDestination = "db" + args.Service.Address = "" args.Service.Port = 12345 var out struct{} assert.Nil(a.RPC("Catalog.Register", args, &out)) @@ -1073,14 +1074,14 @@ func TestDNS_ConnectServiceLookup(t *testing.T) { srvRec, ok := in.Answer[0].(*dns.SRV) assert.True(ok) - assert.Equal(12345, srvRec.Port) + assert.Equal(uint16(12345), srvRec.Port) assert.Equal("foo.node.dc1.consul.", srvRec.Target) - assert.Equal(0, srvRec.Hdr.Ttl) + assert.Equal(uint32(0), srvRec.Hdr.Ttl) - cnameRec, ok := in.Extra[0].(*dns.CNAME) + cnameRec, ok := in.Extra[0].(*dns.A) assert.True(ok) assert.Equal("foo.node.dc1.consul.", cnameRec.Hdr.Name) - assert.Equal(0, srvRec.Hdr.Ttl) + assert.Equal(uint32(0), srvRec.Hdr.Ttl) } }