diff --git a/agent/consul/fsm/snapshot_oss_test.go b/agent/consul/fsm/snapshot_oss_test.go index 10e1cd061..652706865 100644 --- a/agent/consul/fsm/snapshot_oss_test.go +++ b/agent/consul/fsm/snapshot_oss_test.go @@ -464,6 +464,14 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { require.NoError(t, err) require.Equal(t, vip, "240.0.0.2") + _, serviceNames, err := fsm.state.ServiceNamesOfKind(nil, structs.ServiceKindTypical) + require.NoError(t, err) + + expect := []string{"backend", "db", "frontend", "web"} + for i, sn := range serviceNames { + require.Equal(t, expect[i], sn.Service.Name) + } + // Snapshot snap, err := fsm.Snapshot() require.NoError(t, err) @@ -690,10 +698,10 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { require.Len(t, roots, 2) // Verify provider state is restored. - _, state, err := fsm2.state.CAProviderState("asdf") + _, provider, err := fsm2.state.CAProviderState("asdf") require.NoError(t, err) - require.Equal(t, "foo", state.PrivateKey) - require.Equal(t, "bar", state.RootCert) + require.Equal(t, "foo", provider.PrivateKey) + require.Equal(t, "bar", provider.RootCert) // Verify CA configuration is restored. _, caConf, err := fsm2.state.CAConfig(nil) @@ -751,6 +759,14 @@ func TestFSM_SnapshotRestore_OSS(t *testing.T) { require.NoError(t, err) require.Equal(t, meshConfig, meshConfigEntry) + _, restoredServiceNames, err := fsm2.state.ServiceNamesOfKind(nil, structs.ServiceKindTypical) + require.NoError(t, err) + + expect = []string{"backend", "db", "frontend", "web"} + for i, sn := range restoredServiceNames { + require.Equal(t, expect[i], sn.Service.Name) + } + // Snapshot snap, err = fsm2.Snapshot() require.NoError(t, err) diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index d0974c30e..896cbc1ee 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -741,6 +741,9 @@ func ensureServiceTxn(tx WriteTxn, idx uint64, node string, preserveIndexes bool if err = checkGatewayWildcardsAndUpdate(tx, idx, svc); err != nil { return fmt.Errorf("failed updating gateway mapping: %s", err) } + if err := upsertKindServiceName(tx, idx, svc.Kind, svc.CompoundServiceName()); err != nil { + return fmt.Errorf("failed to persist service name: %v", err) + } // Update upstream/downstream mappings if it's a connect service if svc.Kind == structs.ServiceKindConnectProxy || svc.Connect.Native { @@ -1691,6 +1694,9 @@ func (s *Store) deleteServiceTxn(tx WriteTxn, idx uint64, nodeName, serviceID st if err := freeServiceVirtualIP(tx, svc.ServiceName, entMeta); err != nil { return fmt.Errorf("failed to clean up virtual IP for %q: %v", name.String(), err) } + if err := cleanupKindServiceName(tx, idx, svc.CompoundServiceName(), svc.ServiceKind); err != nil { + return fmt.Errorf("failed to persist service name: %v", err) + } } } else { return fmt.Errorf("Could not find any service %s: %s", svc.ServiceName, err) @@ -2526,6 +2532,26 @@ func (s *Store) VirtualIPForService(sn structs.ServiceName) (string, error) { return result.String(), nil } +func (s *Store) KindServiceNamesOfKind(ws memdb.WatchSet, kind structs.ServiceKind) (uint64, []*KindServiceName, error) { + tx := s.db.Txn(false) + defer tx.Abort() + + var names []*KindServiceName + iter, err := tx.Get(tableKindServiceNames, indexKindOnly, kind) + if err != nil { + return 0, nil, err + } + ws.Add(iter.WatchCh()) + + idx := kindServiceNamesMaxIndex(tx, ws, kind) + for name := iter.Next(); name != nil; name = iter.Next() { + ksn := name.(*KindServiceName) + names = append(names, ksn) + } + + return idx, names, nil +} + // parseCheckServiceNodes is used to parse through a given set of services, // and query for an associated node and a set of checks. This is the inner // method used to return a rich set of results from a more simple query. @@ -3862,3 +3888,44 @@ func truncateGatewayServiceTopologyMappings(tx WriteTxn, idx uint64, gateway str return nil } + +func upsertKindServiceName(tx WriteTxn, idx uint64, kind structs.ServiceKind, name structs.ServiceName) error { + q := KindServiceNameQuery{Name: name.Name, Kind: kind, EnterpriseMeta: name.EnterpriseMeta} + existing, err := tx.First(tableKindServiceNames, indexID, q) + if err != nil { + return err + } + + // Service name is already known. Nothing to do. + if existing != nil { + return nil + } + + ksn := KindServiceName{ + Kind: kind, + Service: name, + RaftIndex: structs.RaftIndex{ + CreateIndex: idx, + ModifyIndex: idx, + }, + } + if err := tx.Insert(tableKindServiceNames, &ksn); err != nil { + return fmt.Errorf("failed inserting %s/%s into %s: %s", kind, name.String(), tableKindServiceNames, err) + } + if err := indexUpdateMaxTxn(tx, idx, kindServiceNameIndexName(kind)); err != nil { + return fmt.Errorf("failed updating %s index: %v", tableKindServiceNames, err) + } + return nil +} + +func cleanupKindServiceName(tx WriteTxn, idx uint64, name structs.ServiceName, kind structs.ServiceKind) error { + q := KindServiceNameQuery{Name: name.Name, Kind: kind, EnterpriseMeta: name.EnterpriseMeta} + if _, err := tx.DeleteAll(tableKindServiceNames, indexID, q); err != nil { + return fmt.Errorf("failed to delete %s from %s: %s", name, tableKindServiceNames, err) + } + + if err := indexUpdateMaxTxn(tx, idx, kindServiceNameIndexName(kind)); err != nil { + return fmt.Errorf("failed updating %s index: %v", tableKindServiceNames, err) + } + return nil +} diff --git a/agent/consul/state/catalog_oss.go b/agent/consul/state/catalog_oss.go index e71d13ae3..f2902ca71 100644 --- a/agent/consul/state/catalog_oss.go +++ b/agent/consul/state/catalog_oss.go @@ -5,6 +5,7 @@ package state import ( "fmt" + "strings" memdb "github.com/hashicorp/go-memdb" @@ -18,13 +19,7 @@ func serviceIndexName(name string, _ *structs.EnterpriseMeta) string { } func serviceKindIndexName(kind structs.ServiceKind, _ *structs.EnterpriseMeta) string { - switch kind { - case structs.ServiceKindTypical: - // needs a special case here - return "service_kind.typical" - default: - return "service_kind." + string(kind) - } + return "service_kind." + kind.Normalized() } func catalogUpdateNodesIndexes(tx WriteTxn, idx uint64, entMeta *structs.EnterpriseMeta) error { @@ -192,3 +187,22 @@ func validateRegisterRequestTxn(_ ReadTxn, _ *structs.RegisterRequest, _ bool) ( func (s *Store) ValidateRegisterRequest(_ *structs.RegisterRequest) (*structs.EnterpriseMeta, error) { return nil, nil } + +func indexFromKindServiceName(arg interface{}) ([]byte, error) { + var b indexBuilder + + switch n := arg.(type) { + case KindServiceNameQuery: + b.String(strings.ToLower(string(n.Kind))) + b.String(strings.ToLower(n.Name)) + return b.Bytes(), nil + + case *KindServiceName: + b.String(strings.ToLower(string(n.Kind))) + b.String(strings.ToLower(n.Service.Name)) + return b.Bytes(), nil + + default: + return nil, fmt.Errorf("type must be KindServiceNameQuery or *KindServiceName: %T", arg) + } +} diff --git a/agent/consul/state/catalog_oss_test.go b/agent/consul/state/catalog_oss_test.go index 04162072b..5811416b1 100644 --- a/agent/consul/state/catalog_oss_test.go +++ b/agent/consul/state/catalog_oss_test.go @@ -412,3 +412,40 @@ func testIndexerTableServiceVirtualIPs() map[string]indexerTestCase { }, } } + +func testIndexerTableKindServiceNames() map[string]indexerTestCase { + obj := &KindServiceName{ + Service: structs.ServiceName{ + Name: "web-sidecar-proxy", + }, + Kind: structs.ServiceKindConnectProxy, + } + + return map[string]indexerTestCase{ + indexID: { + read: indexValue{ + source: &KindServiceName{ + Service: structs.ServiceName{ + Name: "web-sidecar-proxy", + }, + Kind: structs.ServiceKindConnectProxy, + }, + expected: []byte("connect-proxy\x00web-sidecar-proxy\x00"), + }, + write: indexValue{ + source: obj, + expected: []byte("connect-proxy\x00web-sidecar-proxy\x00"), + }, + }, + indexKind: { + read: indexValue{ + source: structs.ServiceKindConnectProxy, + expected: []byte("connect-proxy\x00"), + }, + write: indexValue{ + source: obj, + expected: []byte("connect-proxy\x00"), + }, + }, + } +} diff --git a/agent/consul/state/catalog_schema.go b/agent/consul/state/catalog_schema.go index b67bf5049..c03f649be 100644 --- a/agent/consul/state/catalog_schema.go +++ b/agent/consul/state/catalog_schema.go @@ -19,6 +19,7 @@ const ( tableMeshTopology = "mesh-topology" tableServiceVirtualIPs = "service-virtual-ips" tableFreeVirtualIPs = "free-virtual-ips" + tableKindServiceNames = "kind-service-names" indexID = "id" indexService = "service" @@ -661,3 +662,80 @@ func freeVirtualIPTableSchema() *memdb.TableSchema { }, } } + +type KindServiceName struct { + Kind structs.ServiceKind + Service structs.ServiceName + + structs.RaftIndex +} + +func kindServiceNameTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: tableKindServiceNames, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + AllowMissing: false, + Unique: true, + Indexer: indexerSingle{ + readIndex: indexFromKindServiceName, + writeIndex: indexFromKindServiceName, + }, + }, + indexKindOnly: { + Name: indexKindOnly, + AllowMissing: false, + Unique: false, + Indexer: indexerSingle{ + readIndex: indexFromKindServiceNameKindOnly, + writeIndex: indexFromKindServiceNameKindOnly, + }, + }, + }, + } +} + +// KindServiceNameQuery is used to lookup service names by kind or enterprise meta. +type KindServiceNameQuery struct { + Kind structs.ServiceKind + Name string + structs.EnterpriseMeta +} + +// NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer +// receiver for this method. Remove once that is fixed. +func (q KindServiceNameQuery) NamespaceOrDefault() string { + return q.EnterpriseMeta.NamespaceOrDefault() +} + +// PartitionOrDefault exists because structs.EnterpriseMeta uses a pointer +// receiver for this method. Remove once that is fixed. +func (q KindServiceNameQuery) PartitionOrDefault() string { + return q.EnterpriseMeta.PartitionOrDefault() +} + +func indexFromKindServiceNameKindOnly(raw interface{}) ([]byte, error) { + switch x := raw.(type) { + case *KindServiceName: + var b indexBuilder + b.String(strings.ToLower(string(x.Kind))) + return b.Bytes(), nil + + case structs.ServiceKind: + var b indexBuilder + b.String(strings.ToLower(string(x))) + return b.Bytes(), nil + + default: + return nil, fmt.Errorf("type must be *KindServiceName or structs.ServiceKind: %T", raw) + } +} + +func kindServiceNamesMaxIndex(tx ReadTxn, ws memdb.WatchSet, kind structs.ServiceKind) uint64 { + return maxIndexWatchTxn(tx, ws, kindServiceNameIndexName(kind)) +} + +func kindServiceNameIndexName(kind structs.ServiceKind) string { + return "kind_service_names." + kind.Normalized() +} diff --git a/agent/consul/state/catalog_test.go b/agent/consul/state/catalog_test.go index c95989b80..c4d7a775a 100644 --- a/agent/consul/state/catalog_test.go +++ b/agent/consul/state/catalog_test.go @@ -7656,6 +7656,143 @@ func TestProtocolForIngressGateway(t *testing.T) { } } +func TestStateStore_EnsureService_ServiceNames(t *testing.T) { + s := testStateStore(t) + + // Create the service registration. + entMeta := structs.DefaultEnterpriseMetaInDefaultPartition() + + services := []structs.NodeService{ + { + Kind: structs.ServiceKindIngressGateway, + ID: "ingress-gateway", + Service: "ingress-gateway", + Address: "2.2.2.2", + Port: 2222, + EnterpriseMeta: *entMeta, + }, + { + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway", + Service: "mesh-gateway", + Address: "4.4.4.4", + Port: 4444, + EnterpriseMeta: *entMeta, + }, + { + Kind: structs.ServiceKindConnectProxy, + ID: "connect-proxy", + Service: "connect-proxy", + Address: "1.1.1.1", + Port: 1111, + Proxy: structs.ConnectProxyConfig{DestinationServiceName: "foo"}, + EnterpriseMeta: *entMeta, + }, + { + Kind: structs.ServiceKindTerminatingGateway, + ID: "terminating-gateway", + Service: "terminating-gateway", + Address: "3.3.3.3", + Port: 3333, + EnterpriseMeta: *entMeta, + }, + { + Kind: structs.ServiceKindTypical, + ID: "web", + Service: "web", + Address: "5.5.5.5", + Port: 5555, + EnterpriseMeta: *entMeta, + }, + } + + var idx uint64 + testRegisterNode(t, s, idx, "node1") + + for _, svc := range services { + idx++ + require.NoError(t, s.EnsureService(idx, "node1", &svc)) + + // Ensure the service name was stored for all of them under the appropriate kind + gotIdx, gotNames, err := s.KindServiceNamesOfKind(nil, svc.Kind) + require.NoError(t, err) + require.Equal(t, idx, gotIdx) + require.Len(t, gotNames, 1) + require.Equal(t, svc.CompoundServiceName(), gotNames[0].Service) + require.Equal(t, svc.Kind, gotNames[0].Kind) + } + + // Register another ingress gateway and there should be two names under the kind index + newIngress := structs.NodeService{ + Kind: structs.ServiceKindIngressGateway, + ID: "new-ingress-gateway", + Service: "new-ingress-gateway", + Address: "6.6.6.6", + Port: 6666, + EnterpriseMeta: *entMeta, + } + idx++ + require.NoError(t, s.EnsureService(idx, "node1", &newIngress)) + + gotIdx, got, err := s.KindServiceNamesOfKind(nil, structs.ServiceKindIngressGateway) + require.NoError(t, err) + require.Equal(t, idx, gotIdx) + + expect := []*KindServiceName{ + { + Kind: structs.ServiceKindIngressGateway, + Service: structs.NewServiceName("ingress-gateway", nil), + RaftIndex: structs.RaftIndex{ + CreateIndex: 1, + ModifyIndex: 1, + }, + }, + { + Kind: structs.ServiceKindIngressGateway, + Service: structs.NewServiceName("new-ingress-gateway", nil), + RaftIndex: structs.RaftIndex{ + CreateIndex: idx, + ModifyIndex: idx, + }, + }, + } + require.Equal(t, expect, got) + + // Deregister an ingress gateway and the index should not slide back + idx++ + require.NoError(t, s.DeleteService(idx, "node1", "new-ingress-gateway", entMeta)) + + gotIdx, got, err = s.ServiceNamesOfKind(nil, structs.ServiceKindIngressGateway) + require.NoError(t, err) + require.Equal(t, idx, gotIdx) + require.Equal(t, expect[:1], got) + + // Registering another instance of a known service should not bump the kind index + newMGW := structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway-1", + Service: "mesh-gateway", + Address: "7.7.7.7", + Port: 7777, + EnterpriseMeta: *entMeta, + } + idx++ + require.NoError(t, s.EnsureService(idx, "node1", &newMGW)) + + gotIdx, _, err = s.KindServiceNamesOfKind(nil, structs.ServiceKindMeshGateway) + require.NoError(t, err) + require.Equal(t, uint64(2), gotIdx) + + // Deregister the single typical service and the service name should also be dropped + idx++ + require.NoError(t, s.DeleteService(idx, "node1", "web", entMeta)) + + gotIdx, got, err = s.KindServiceNamesOfKind(nil, structs.ServiceKindTypical) + require.NoError(t, err) + require.Equal(t, idx, gotIdx) + require.Empty(t, got) +} + func runStep(t *testing.T, name string, fn func(t *testing.T)) { t.Helper() if !t.Run(name, fn) { diff --git a/agent/consul/state/schema.go b/agent/consul/state/schema.go index 4005469fd..75a2ffa74 100644 --- a/agent/consul/state/schema.go +++ b/agent/consul/state/schema.go @@ -40,6 +40,7 @@ func newDBSchema() *memdb.DBSchema { tombstonesTableSchema, usageTableSchema, freeVirtualIPTableSchema, + kindServiceNameTableSchema, ) withEnterpriseSchema(db) return db diff --git a/agent/consul/state/schema_test.go b/agent/consul/state/schema_test.go index b83491587..7ef17c8fd 100644 --- a/agent/consul/state/schema_test.go +++ b/agent/consul/state/schema_test.go @@ -50,6 +50,7 @@ func TestNewDBSchema_Indexers(t *testing.T) { tableMeshTopology: testIndexerTableMeshTopology, tableGatewayServices: testIndexerTableGatewayServices, tableServiceVirtualIPs: testIndexerTableServiceVirtualIPs, + tableKindServiceNames: testIndexerTableKindServiceNames, // KV tableKVs: testIndexerTableKVs, tableTombstones: testIndexerTableTombstones, diff --git a/agent/structs/structs.go b/agent/structs/structs.go index 6edbc5545..e79cf6ef9 100644 --- a/agent/structs/structs.go +++ b/agent/structs/structs.go @@ -72,6 +72,7 @@ const ( SystemMetadataRequestType = 31 ServiceVirtualIPRequestType = 32 FreeVirtualIPRequestType = 33 + KindServiceNamesType = 34 ) // if a new request type is added above it must be @@ -114,6 +115,7 @@ var requestTypeStrings = map[MessageType]string{ SystemMetadataRequestType: "SystemMetadata", ServiceVirtualIPRequestType: "ServiceVirtualIP", FreeVirtualIPRequestType: "FreeVirtualIP", + KindServiceNamesType: "KindServiceName", } const ( @@ -1029,6 +1031,13 @@ type ServiceNodes []*ServiceNode // ServiceKind is the kind of service being registered. type ServiceKind string +func (k ServiceKind) Normalized() string { + if k == ServiceKindTypical { + return "typical" + } + return string(k) +} + const ( // ServiceKindTypical is a typical, classic Consul service. This is // represented by the absence of a value. This was chosen for ease of