diff --git a/agent/consul/state/catalog.go b/agent/consul/state/catalog.go index 5e3ad9c9e..08f0b6dcb 100644 --- a/agent/consul/state/catalog.go +++ b/agent/consul/state/catalog.go @@ -691,6 +691,9 @@ func (s *Store) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeName string) error if err := tx.Insert("index", &IndexEntry{serviceIndexName(svc.ServiceName), idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + if err := tx.Insert("index", &IndexEntry{serviceKindIndexName(svc.ServiceKind), idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } } // Do the delete in a separate loop so we don't trash the iterator. @@ -1421,6 +1424,10 @@ func (s *Store) deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeName, serviceID } svc := service.(*structs.ServiceNode) + if err := tx.Insert("index", &IndexEntry{serviceKindIndexName(svc.ServiceKind), idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + if remainingService, err := tx.First("services", "service", svc.ServiceName); err == nil { if remainingService != nil { // We have at least one remaining service, update the index @@ -1869,6 +1876,16 @@ func (s *Store) deleteCheckTxn(tx *memdb.Txn, idx uint64, node string, checkID t if err = tx.Insert("index", &IndexEntry{serviceIndexName(existing.ServiceName), idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + svcRaw, err := tx.First("services", "id", existing.Node, existing.ServiceID) + if err != nil { + return fmt.Errorf("failed retrieving service from state store: %v", err) + } + + svc := svcRaw.(*structs.ServiceNode) + if err := tx.Insert("index", &IndexEntry{serviceKindIndexName(svc.ServiceKind), idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } } else { err = s.updateAllServiceIndexesOfNode(tx, idx, existing.Node) if err != nil { diff --git a/agent/proxycfg/state.go b/agent/proxycfg/state.go index 27e488096..7d42fb44d 100644 --- a/agent/proxycfg/state.go +++ b/agent/proxycfg/state.go @@ -15,6 +15,11 @@ import ( "github.com/mitchellh/copystructure" ) +type CacheNotifier interface { + Notify(ctx context.Context, t string, r cache.Request, + correlationID string, ch chan<- cache.UpdateEvent) error +} + const ( coalesceTimeout = 200 * time.Millisecond rootsWatchID = "roots" @@ -35,7 +40,7 @@ type state struct { // logger, source and cache are required to be set before calling Watch. logger *log.Logger source *structs.QuerySource - cache *cache.Cache + cache CacheNotifier // ctx and cancel store the context created during initWatches call ctx context.Context @@ -328,13 +333,7 @@ func (s *state) initWatchesMeshGateway() error { return err } -func (s *state) run() { - // Close the channel we return from Watch when we stop so consumers can stop - // watching and clean up their goroutines. It's important we do this here and - // not in Close since this routine sends on this chan and so might panic if it - // gets closed from another goroutine. - defer close(s.snapCh) - +func (s *state) initialConfigSnapshot() ConfigSnapshot { snap := ConfigSnapshot{ Kind: s.kind, Service: s.service, @@ -361,6 +360,18 @@ func (s *state) run() { // fully rebuild it every time we get updates } + return snap +} + +func (s *state) run() { + // Close the channel we return from Watch when we stop so consumers can stop + // watching and clean up their goroutines. It's important we do this here and + // not in Close since this routine sends on this chan and so might panic if it + // gets closed from another goroutine. + defer close(s.snapCh) + + snap := s.initialConfigSnapshot() + // This turns out to be really fiddly/painful by just using time.Timer.C // directly in the code below since you can't detect when a timer is stopped // vs waiting in order to know to reset it. So just use a chan to send @@ -627,11 +638,15 @@ func (s *state) resetWatchesFromChain( meshGateway := structs.MeshGatewayModeDefault if target.Datacenter != s.source.Datacenter { meshGateway = meshGatewayModes[target] + + if meshGateway == structs.MeshGatewayModeDefault { + meshGateway = s.proxyCfg.MeshGateway.Mode + } } // if the default mode if meshGateway == structs.MeshGatewayModeDefault { - meshGateway = s.proxyCfg.MeshGateway.Mode + meshGateway = structs.MeshGatewayModeNone } filterExp := subset.Filter diff --git a/agent/proxycfg/state_test.go b/agent/proxycfg/state_test.go index cfdccfd6d..fcae4b557 100644 --- a/agent/proxycfg/state_test.go +++ b/agent/proxycfg/state_test.go @@ -1,11 +1,16 @@ package proxycfg import ( + "context" + "fmt" + "sync" "testing" - "github.com/stretchr/testify/require" - + "github.com/hashicorp/consul/agent/cache" + cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/sdk/testutil" + "github.com/stretchr/testify/require" ) func TestStateChanged(t *testing.T) { @@ -112,3 +117,388 @@ func TestStateChanged(t *testing.T) { }) } } + +type testCacheNotifierRequest struct { + cacheType string + request cache.Request + ch chan<- cache.UpdateEvent +} + +type testCacheNotifier struct { + lock sync.RWMutex + notifiers map[string]testCacheNotifierRequest +} + +func newTestCacheNotifier() *testCacheNotifier { + return &testCacheNotifier{ + notifiers: make(map[string]testCacheNotifierRequest), + } +} + +func (cn *testCacheNotifier) Notify(ctx context.Context, t string, r cache.Request, correlationId string, ch chan<- cache.UpdateEvent) error { + cn.lock.Lock() + cn.notifiers[correlationId] = testCacheNotifierRequest{t, r, ch} + cn.lock.Unlock() + return nil +} + +func (cn *testCacheNotifier) getNotifierRequest(t testing.TB, correlationId string) testCacheNotifierRequest { + cn.lock.RLock() + req, ok := cn.notifiers[correlationId] + cn.lock.RUnlock() + require.True(t, ok) + return req +} + +func (cn *testCacheNotifier) getChanForCorrelationId(t testing.TB, correlationId string) chan<- cache.UpdateEvent { + req := cn.getNotifierRequest(t, correlationId) + require.NotNil(t, req.ch) + return req.ch +} + +func (cn *testCacheNotifier) sendNotification(t testing.TB, correlationId string, event cache.UpdateEvent) { + cn.getChanForCorrelationId(t, correlationId) <- event +} + +func (cn *testCacheNotifier) verifyWatch(t testing.TB, correlationId string) (string, cache.Request) { + // t.Logf("Watches: %+v", cn.notifiers) + req := cn.getNotifierRequest(t, correlationId) + require.NotNil(t, req.ch) + return req.cacheType, req.request +} + +type verifyWatchRequest func(t testing.TB, cacheType string, request cache.Request) + +func genVerifyDCSpecificWatch(expectedCacheType string, expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, expectedCacheType, cacheType) + + reqReal, ok := request.(*structs.DCSpecificRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + } +} + +func genVerifyRootsWatch(expectedDatacenter string) verifyWatchRequest { + return genVerifyDCSpecificWatch(cachetype.ConnectCARootName, expectedDatacenter) +} + +func genVerifyListServicesWatch(expectedDatacenter string) verifyWatchRequest { + return genVerifyDCSpecificWatch(cachetype.CatalogListServicesName, expectedDatacenter) +} + +func verifyDatacentersWatch(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.CatalogDatacentersName, cacheType) + + _, ok := request.(*structs.DatacentersRequest) + require.True(t, ok) +} + +func genVerifyLeafWatch(expectedService string, expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.ConnectCALeafName, cacheType) + + reqReal, ok := request.(*cachetype.ConnectCALeafRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.Equal(t, expectedService, reqReal.Service) + } +} + +func genVerifyIntentionWatch(expectedService string, expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.IntentionMatchName, cacheType) + + reqReal, ok := request.(*structs.IntentionQueryRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.NotNil(t, reqReal.Match) + require.Equal(t, structs.IntentionMatchDestination, reqReal.Match.Type) + require.Len(t, reqReal.Match.Entries, 1) + require.Equal(t, structs.IntentionDefaultNamespace, reqReal.Match.Entries[0].Namespace) + require.Equal(t, expectedService, reqReal.Match.Entries[0].Name) + } +} + +func genVerifyPreparedQueryWatch(expectedName string, expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.PreparedQueryName, cacheType) + + reqReal, ok := request.(*structs.PreparedQueryExecuteRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.Equal(t, expectedName, reqReal.QueryIDOrName) + require.Equal(t, true, reqReal.Connect) + } +} + +func genVerifyDiscoveryChainWatch(expectedName string, expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.CompiledDiscoveryChainName, cacheType) + + reqReal, ok := request.(*structs.DiscoveryChainRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.Equal(t, expectedName, reqReal.Name) + } +} + +func genVerifyGatewayWatch(expectedDatacenter string) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, cachetype.InternalServiceDumpName, cacheType) + + reqReal, ok := request.(*structs.ServiceDumpRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.True(t, reqReal.UseServiceKind) + require.Equal(t, structs.ServiceKindMeshGateway, reqReal.ServiceKind) + } +} + +func genVerifyServiceSpecificRequest(expectedCacheType, expectedService, expectedFilter, expectedDatacenter string, connect bool) verifyWatchRequest { + return func(t testing.TB, cacheType string, request cache.Request) { + require.Equal(t, expectedCacheType, cacheType) + + reqReal, ok := request.(*structs.ServiceSpecificRequest) + require.True(t, ok) + require.Equal(t, expectedDatacenter, reqReal.Datacenter) + require.Equal(t, expectedService, reqReal.ServiceName) + require.Equal(t, expectedFilter, reqReal.QueryOptions.Filter) + require.Equal(t, connect, reqReal.Connect) + } +} + +func genVerifyServiceWatch(expectedService, expectedFilter, expectedDatacenter string, connect bool) verifyWatchRequest { + return genVerifyServiceSpecificRequest(cachetype.HealthServicesName, expectedService, expectedFilter, expectedDatacenter, connect) +} + +// This test is meant to exercise the various parts of the cache watching done by the state as +// well as its management of the ConfigSnapshot +// +// This test is expressly not calling Watch which in turn would execute the run function in a go +// routine. This allows the test to be fully synchronous and deterministic while still being able +// to validate the logic of most of the watching and state updating. +// +// The general strategy here is to +// +// 1. Initialize a state with a call to newState + setting some of the extra stuff like the CacheNotifier +// We will not be using the CacheNotifier to send notifications but calling handleUpdate ourselves +// 2. Iterate through a list of verification stages performing validation and updates for each. +// a. Ensure that the required watches are in place and validate they are correct +// b. Process a bunch of UpdateEvents by calling handleUpdate +// c. Validate that the ConfigSnapshot has been updated appropriately +func TestState_WatchesAndUpdates(t *testing.T) { + t.Parallel() + + type verificationStage struct { + requiredWatches map[string]verifyWatchRequest + events []cache.UpdateEvent + verifySnapshot func(t testing.TB, snap *ConfigSnapshot) + } + + type testCase struct { + // the state to operate on. the logger, source, cache, + // ctx and cancel fields will be filled in by the test + ns structs.NodeService + sourceDC string + stages []verificationStage + } + + cases := map[string]testCase{ + "initial-gateway": testCase{ + ns: structs.NodeService{ + Kind: structs.ServiceKindMeshGateway, + ID: "mesh-gateway", + Service: "mesh-gateway", + Address: "10.0.1.1", + Port: 443, + }, + sourceDC: "dc1", + stages: []verificationStage{ + verificationStage{ + requiredWatches: map[string]verifyWatchRequest{ + rootsWatchID: genVerifyRootsWatch("dc1"), + serviceListWatchID: genVerifyListServicesWatch("dc1"), + datacentersWatchID: verifyDatacentersWatch, + }, + }, + }, + }, + "connect-proxy": testCase{ + ns: structs.NodeService{ + Kind: structs.ServiceKindConnectProxy, + ID: "web-sidecar-proxy", + Service: "web-sidecar-proxy", + Address: "10.0.1.1", + Port: 443, + Proxy: structs.ConnectProxyConfig{ + DestinationServiceName: "web", + Upstreams: structs.Upstreams{ + structs.Upstream{ + DestinationType: structs.UpstreamDestTypePreparedQuery, + DestinationName: "query", + LocalBindPort: 10001, + }, + structs.Upstream{ + DestinationType: structs.UpstreamDestTypeService, + DestinationName: "api", + LocalBindPort: 10002, + }, + structs.Upstream{ + DestinationType: structs.UpstreamDestTypeService, + DestinationName: "api-failover-remote", + Datacenter: "dc2", + LocalBindPort: 10003, + MeshGateway: structs.MeshGatewayConfig{ + Mode: structs.MeshGatewayModeRemote, + }, + }, + structs.Upstream{ + DestinationType: structs.UpstreamDestTypeService, + DestinationName: "api-failover-local", + Datacenter: "dc2", + LocalBindPort: 10004, + MeshGateway: structs.MeshGatewayConfig{ + Mode: structs.MeshGatewayModeLocal, + }, + }, + structs.Upstream{ + DestinationType: structs.UpstreamDestTypeService, + DestinationName: "api-failover-direct", + Datacenter: "dc2", + LocalBindPort: 10005, + MeshGateway: structs.MeshGatewayConfig{ + Mode: structs.MeshGatewayModeNone, + }, + }, + structs.Upstream{ + DestinationType: structs.UpstreamDestTypeService, + DestinationName: "api-dc2", + LocalBindPort: 10006, + MeshGateway: structs.MeshGatewayConfig{ + Mode: structs.MeshGatewayModeLocal, + }, + }, + }, + MeshGateway: structs.MeshGatewayConfig{ + Mode: structs.MeshGatewayModeLocal, + }, + }, + }, + sourceDC: "dc1", + stages: []verificationStage{ + verificationStage{ + requiredWatches: map[string]verifyWatchRequest{ + rootsWatchID: genVerifyRootsWatch("dc1"), + leafWatchID: genVerifyLeafWatch("web", "dc1"), + intentionsWatchID: genVerifyIntentionWatch("web", "dc1"), + "upstream:prepared_query:query": genVerifyPreparedQueryWatch("query", "dc1"), + "discovery-chain:api": genVerifyDiscoveryChainWatch("api", "dc1"), + "upstream:" + serviceIDPrefix + "api-failover-remote?dc=dc2": genVerifyGatewayWatch("dc2"), + "upstream:" + serviceIDPrefix + "api-failover-local?dc=dc2": genVerifyGatewayWatch("dc1"), + "upstream:" + serviceIDPrefix + "api-failover-direct?dc=dc2": genVerifyServiceWatch("api-failover-direct", "", "dc2", true), + "discovery-chain:api-dc2": genVerifyDiscoveryChainWatch("api-dc2", "dc1"), + }, + events: []cache.UpdateEvent{ + cache.UpdateEvent{ + CorrelationID: "discovery-chain:api", + Result: &structs.DiscoveryChainResponse{ + Chain: TestCompileConfigEntries(t, "api", "default", "dc1"), + }, + Err: nil, + }, + cache.UpdateEvent{ + CorrelationID: "discovery-chain:api", + Result: &structs.DiscoveryChainResponse{ + Chain: TestCompileConfigEntries(t, "api-dc2", "default", "dc1", + &structs.ServiceResolverConfigEntry{ + Kind: structs.ServiceResolver, + Name: "api-dc2", + Redirect: &structs.ServiceResolverRedirect{ + Service: "api", + Datacenter: "dc2", + }, + }, + ), + }, + Err: nil, + }, + }, + }, + verificationStage{ + requiredWatches: map[string]verifyWatchRequest{ + "upstream-target:api,,,dc1:api": genVerifyServiceWatch("api", "", "dc1", true), + "upstream-target:api,,,dc2:api": genVerifyGatewayWatch("dc1"), + }, + }, + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + state, err := newState(&tc.ns, "") + + // verify building the initial state worked + require.NoError(t, err) + require.NotNil(t, state) + + // setup the test logger to use the t.Log + state.logger = testutil.TestLogger(t) + + // setup a new testing cache notifier + cn := newTestCacheNotifier() + state.cache = cn + + // setup the local datacenter information + state.source = &structs.QuerySource{ + Datacenter: tc.sourceDC, + } + + // setup the ctx as initWatches expects this to be there + state.ctx, state.cancel = context.WithCancel(context.Background()) + + // ensure the initial watch setup did not error + require.NoError(t, state.initWatches()) + + // get the initial configuration snapshot + snap := state.initialConfigSnapshot() + + //-------------------------------------------------------------------- + // + // All the nested subtests here are to make failures easier to + // correlate back with the test table + // + //-------------------------------------------------------------------- + + for idx, stage := range tc.stages { + require.True(t, t.Run(fmt.Sprintf("stage-%d", idx), func(t *testing.T) { + for correlationId, verifier := range stage.requiredWatches { + require.True(t, t.Run(correlationId, func(t *testing.T) { + // verify that the watch was initiated + cacheType, request := cn.verifyWatch(t, correlationId) + + // run the verifier if any + if verifier != nil { + verifier(t, cacheType, request) + } + })) + } + + // the state is not currently executing the run method in a goroutine + // therefore we just tell it about the updates + for eveIdx, event := range stage.events { + require.True(t, t.Run(fmt.Sprintf("update-%d", eveIdx), func(t *testing.T) { + require.NoError(t, state.handleUpdate(event, &snap)) + })) + } + + // verify the snapshot + if stage.verifySnapshot != nil { + stage.verifySnapshot(t, &snap) + } + })) + } + }) + } +} diff --git a/agent/structs/connect_proxy_config.go b/agent/structs/connect_proxy_config.go index f83eb5774..d05d6752a 100644 --- a/agent/structs/connect_proxy_config.go +++ b/agent/structs/connect_proxy_config.go @@ -209,6 +209,7 @@ func (u *Upstream) ToAPI() api.Upstream { LocalBindAddress: u.LocalBindAddress, LocalBindPort: u.LocalBindPort, Config: u.Config, + MeshGateway: u.MeshGateway.ToAPI(), } } diff --git a/agent/structs/connect_proxy_config_test.go b/agent/structs/connect_proxy_config_test.go index ba587c126..7c3f9ef98 100644 --- a/agent/structs/connect_proxy_config_test.go +++ b/agent/structs/connect_proxy_config_test.go @@ -24,12 +24,18 @@ func TestConnectProxyConfig_ToAPI(t *testing.T) { Config: map[string]interface{}{ "foo": "bar", }, + MeshGateway: MeshGatewayConfig{ + Mode: MeshGatewayModeLocal, + }, Upstreams: Upstreams{ { DestinationType: UpstreamDestTypeService, DestinationName: "foo", Datacenter: "dc1", LocalBindPort: 1234, + MeshGateway: MeshGatewayConfig{ + Mode: MeshGatewayModeLocal, + }, }, { DestinationType: UpstreamDestTypePreparedQuery, @@ -48,12 +54,18 @@ func TestConnectProxyConfig_ToAPI(t *testing.T) { Config: map[string]interface{}{ "foo": "bar", }, + MeshGateway: api.MeshGatewayConfig{ + Mode: api.MeshGatewayModeLocal, + }, Upstreams: []api.Upstream{ { DestinationType: UpstreamDestTypeService, DestinationName: "foo", Datacenter: "dc1", LocalBindPort: 1234, + MeshGateway: api.MeshGatewayConfig{ + Mode: api.MeshGatewayModeLocal, + }, }, { DestinationType: UpstreamDestTypePreparedQuery, diff --git a/sdk/testutil/testlog.go b/sdk/testutil/testlog.go index 6daee3593..3c284f434 100644 --- a/sdk/testutil/testlog.go +++ b/sdk/testutil/testlog.go @@ -16,7 +16,7 @@ func init() { } func TestLogger(t testing.TB) *log.Logger { - return log.New(&testWriter{t}, "test: ", log.LstdFlags) + return log.New(&testWriter{t}, t.Name()+": ", log.LstdFlags) } func TestLoggerWithName(t testing.TB, name string) *log.Logger {