diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index a6e516111..1613e753a 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -5,9 +5,10 @@ import ( "fmt" "time" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/consul/agent/structs" pbacl "github.com/hashicorp/consul/proto/pbacl" - memdb "github.com/hashicorp/go-memdb" ) type TokenPoliciesIndex struct { diff --git a/agent/subscribe/subscribe.go b/agent/subscribe/subscribe.go index 7908a410b..0b1d0a6a9 100644 --- a/agent/subscribe/subscribe.go +++ b/agent/subscribe/subscribe.go @@ -88,7 +88,6 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub for { events, err := sub.Next(ctx) switch { - // TODO: test case case errors.Is(err, stream.ErrSubscriptionClosed): h.Logger.Trace("subscription reset by server", "stream_id", streamID) return status.Error(codes.Aborted, err.Error()) diff --git a/agent/subscribe/subscribe_test.go b/agent/subscribe/subscribe_test.go index bacd94253..a005f6eee 100644 --- a/agent/subscribe/subscribe_test.go +++ b/agent/subscribe/subscribe_test.go @@ -10,9 +10,12 @@ import ( "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" gogrpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/state" @@ -688,15 +691,11 @@ node "node1" { chEvents := make(chan eventOrError, 0) go recvEvents(chEvents, streamHandle) - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - snapshotEvents = append(snapshotEvents, getEvent(t, chEvents)) - } + event := getEvent(t, chEvents) + require.Equal(t, "foo", event.GetServiceHealth().CheckServiceNode.Service.Service) + require.Equal(t, "node1", event.GetServiceHealth().CheckServiceNode.Node.Node) - require.Len(t, snapshotEvents, 2) - require.Equal(t, "foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) - require.Equal(t, "node1", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) - require.True(t, snapshotEvents[1].GetEndOfSnapshot()) + require.True(t, getEvent(t, chEvents).GetEndOfSnapshot()) // Update the service with a new port to trigger a new event. req := &structs.RegisterRequest{ @@ -719,7 +718,7 @@ node "node1" { } require.NoError(t, backend.store.EnsureRegistration(ids.Next("reg4"), req)) - event := getEvent(t, chEvents) + event = getEvent(t, chEvents) service := event.GetServiceHealth().CheckServiceNode.Service require.Equal(t, "foo", service.Service) require.Equal(t, int32(1234), service.Port) @@ -794,549 +793,79 @@ node "node1" { } } -/* -func TestStreaming_Subscribe_ACLUpdate(t *testing.T) { - t.Parallel() +func TestServer_Subscribe_IntegrationWithBackend_ACLUpdate(t *testing.T) { + backend, err := newTestBackend() + require.NoError(t, err) + srv := &Server{Backend: backend, Logger: hclog.New(nil)} + addr := newTestServer(t, srv) - require := require.New(t) - dir, _, server, codec := testACLFilterServerWithConfigFn(t, func(c *Config) { - c.ACLDatacenter = "dc1" - c.ACLsEnabled = true - c.ACLMasterToken = "root" - c.ACLDefaultPolicy = "deny" - c.ACLEnforceVersion8 = true - c.GRPCEnabled = true + rules := ` +service "foo" { + policy = "write" +} +node "node1" { + policy = "write" +} +` + authorizer, err := acl.NewAuthorizerFromRules( + "1", 0, rules, acl.SyntaxCurrent, + &acl.Config{WildcardName: structs.WildcardSpecifier}, + nil) + require.NoError(t, err) + authorizer = acl.NewChainedAuthorizer([]acl.Authorizer{authorizer, acl.DenyAll()}) + require.Equal(t, acl.Deny, authorizer.NodeRead("denied", nil)) + + // TODO: is there any easy way to do this with the acl package? + token := "this-token-is-good" + backend.authorizer = func(tok string) acl.Authorizer { + if tok == token { + return authorizer + } + return acl.DenyAll() + } + + ids := newCounter() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + + conn, err := gogrpc.DialContext(ctx, addr.String(), gogrpc.WithInsecure()) + require.NoError(t, err) + t.Cleanup(logError(t, conn.Close)) + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) + + streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "foo", + Token: token, }) - defer os.RemoveAll(dir) - defer server.Shutdown() - defer codec.Close() + require.NoError(t, err) - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() + chEvents := make(chan eventOrError, 0) + go recvEvents(chEvents, streamHandle) - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1", testrpc.WithToken("root")) + require.True(t, getEvent(t, chEvents).GetEndOfSnapshot()) - // Create a new token/policy that only has access to one node. - var token structs.ACLToken + tokenID, err := uuid.GenerateUUID() + require.NoError(t, err) - policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", fmt.Sprintf(` - service "foo" { - policy = "write" - } - node "%s" { - policy = "write" - } - `, server.config.NodeName)) - require.NoError(err) - - arg := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{ - Description: "Service/node token", - Policies: []structs.ACLTokenPolicyLink{ - structs.ACLTokenPolicyLink{ - ID: policy.ID, - }, - }, - Local: false, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, + aclToken := &structs.ACLToken{ + AccessorID: tokenID, + SecretID: token, + Rules: "", } + require.NoError(t, backend.store.ACLTokenSet(ids.Next("update"), aclToken, false)) - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) - auth, err := server.ResolveToken(token.SecretID) - require.NoError(err) - require.Equal(auth.NodeRead("denied", nil), acl.Deny) - - // Set up the gRPC client. - conn, err := client.GRPCConn() - require.NoError(err) - streamClient := pbsubscribe.NewConsulClient(conn) - - // Start a Subscribe call to our streaming endpoint for the service we have access to. - { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{ - Topic: pbsubscribe.Topic_ServiceHealth, - Key: "foo", - Token: token.SecretID, - }) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - // Read events off the pbsubscribe. - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(5 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - require.Len(snapshotEvents, 2) - require.Equal("foo", snapshotEvents[0].GetServiceHealth().CheckServiceNode.Service.Service) - require.Equal(server.config.NodeName, snapshotEvents[0].GetServiceHealth().CheckServiceNode.Node.Node) - require.True(snapshotEvents[1].GetEndOfSnapshot()) - - // Update a different token and make sure we don't see an event. - arg2 := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: structs.ACLToken{ - Description: "Ignored token", - Policies: []structs.ACLTokenPolicyLink{ - structs.ACLTokenPolicyLink{ - ID: policy.ID, - }, - }, - Local: false, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - var ignoredToken structs.ACLToken - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg2, &ignoredToken)) - - select { - case event := <-eventCh: - t.Fatalf("should not have received event: %v", event) - case <-time.After(500 * time.Millisecond): - } - - // Update our token to trigger a refresh event. - token.Policies = []structs.ACLTokenPolicyLink{} - arg := structs.ACLTokenSetRequest{ - Datacenter: "dc1", - ACLToken: token, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - require.NoError(msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &arg, &token)) - - select { - case event := <-eventCh: - require.True(event.GetResetStream()) - // 500 ms was not enough in CI apparently... - case <-time.After(2 * time.Second): - t.Fatalf("did not receive reload event") - } + select { + case item := <-chEvents: + require.Error(t, item.err, "got event: %v", item.event) + s, _ := status.FromError(item.err) + require.Equal(t, codes.Aborted, s.Code()) + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for aborted error") } } -func TestStreaming_TLSEnabled(t *testing.T) { - t.Parallel() - - require := require.New(t) - dir1, conf1 := testServerConfig(t) - conf1.VerifyIncoming = true - conf1.VerifyOutgoing = true - conf1.GRPCEnabled = true - configureTLS(conf1) - server, err := newServer(conf1) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir1) - defer server.Shutdown() - - dir2, conf2 := testClientConfig(t) - conf2.VerifyOutgoing = true - conf2.GRPCEnabled = true - configureTLS(conf2) - client, err := NewClient(conf2) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a dummy node with our service on it. - { - req := &structs.RegisterRequest{ - Node: "node1", - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: "redis1", - Service: "redis", - Address: "3.4.5.6", - Port: 8080, - }, - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - } - - // Start a Subscribe call to our streaming endpoint from the client. - { - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - - // Make sure the snapshot events come back with no issues. - require.Len(snapshotEvents, 2) - } - - // Start a Subscribe call to our streaming endpoint from the server's loopback client. - { - conn, err := server.GRPCConn() - require.NoError(err) - - retryFailedConn(t, conn) - - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - streamHandle, err := streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - - // Start a goroutine to read updates off the pbsubscribe. - eventCh := make(chan *pbsubscribe.Event, 0) - go recvEvents(t, eventCh, streamHandle) - - var snapshotEvents []*pbsubscribe.Event - for i := 0; i < 2; i++ { - select { - case event := <-eventCh: - snapshotEvents = append(snapshotEvents, event) - case <-time.After(3 * time.Second): - t.Fatalf("did not receive events past %d", len(snapshotEvents)) - } - } - - // Make sure the snapshot events come back with no issues. - require.Len(snapshotEvents, 2) - } -} - -func TestStreaming_TLSReload(t *testing.T) { - t.Parallel() - - // Set up a server with initially bad certificates. - require := require.New(t) - dir1, conf1 := testServerConfig(t) - conf1.VerifyIncoming = true - conf1.VerifyOutgoing = true - conf1.CAFile = "../../test/ca/root.cer" - conf1.CertFile = "../../test/key/ssl-cert-snakeoil.pem" - conf1.KeyFile = "../../test/key/ssl-cert-snakeoil.key" - conf1.GRPCEnabled = true - - server, err := newServer(conf1) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir1) - defer server.Shutdown() - - // Set up a client with valid certs and verify_outgoing = true - dir2, conf2 := testClientConfig(t) - conf2.VerifyOutgoing = true - conf2.GRPCEnabled = true - configureTLS(conf2) - client, err := NewClient(conf2) - if err != nil { - t.Fatalf("err: %v", err) - } - defer os.RemoveAll(dir2) - defer client.Shutdown() - - testrpc.WaitForLeader(t, server.RPC, "dc1") - - // Subscribe calls should fail initially - joinLAN(t, client, server) - conn, err := client.GRPCConn() - require.NoError(err) - { - streamClient := pbsubscribe.NewConsulClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.Error(err, "tls: bad certificate") - } - - // Reload the server with valid certs - newConf := server.config.ToTLSUtilConfig() - newConf.CertFile = "../../test/key/ourdomain.cer" - newConf.KeyFile = "../../test/key/ourdomain.key" - server.tlsConfigurator.Update(newConf) - - // Try the subscribe call again - { - retryFailedConn(t, conn) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - streamClient := pbsubscribe.NewConsulClient(conn) - _, err = streamClient.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - require.NoError(err) - } -} - -// retryFailedConn forces the ClientConn to reset its backoff timer and retry the connection, -// to simulate the client eventually retrying after the initial failure. This is used both to simulate -// retrying after an expected failure as well as to avoid flakiness when running many tests in parallel. -func retryFailedConn(t *testing.T, conn *grpc.ClientConn) { - state := conn.GetState() - if state.String() != "TRANSIENT_FAILURE" { - return - } - - // If the connection has failed, retry and wait for a state change. - conn.ResetConnectBackoff() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - require.True(t, conn.WaitForStateChange(ctx, state)) -} - -func TestStreaming_DeliversAllMessages(t *testing.T) { - // This is a fuzz/probabilistic test to try to provoke streaming into dropping - // messages. There is a bug in the initial implementation that should make - // this fail. While we can't be certain a pass means it's correct, it is - // useful for finding bugs in our concurrency design. - - // The issue is that when updates are coming in fast such that updates occur - // in between us making the snapshot and beginning the stream updates, we - // shouldn't miss anything. - - // To test this, we will run a background goroutine that will write updates as - // fast as possible while we then try to stream the results and ensure that we - // see every change. We'll make the updates monotonically increasing so we can - // easily tell if we missed one. - - require := require.New(t) - dir1, server := testServerWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.Bootstrap = true - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir1) - defer server.Shutdown() - codec := rpcClient(t, server) - defer codec.Close() - - dir2, client := testClientWithConfig(t, func(c *Config) { - c.Datacenter = "dc1" - c.NodeName = uniqueNodeName(t.Name()) - c.GRPCEnabled = true - }) - defer os.RemoveAll(dir2) - defer client.Shutdown() - - // Try to join - testrpc.WaitForLeader(t, server.RPC, "dc1") - joinLAN(t, client, server) - testrpc.WaitForTestAgent(t, client.RPC, "dc1") - - // Register a whole bunch of service instances so that the initial snapshot on - // subscribe is big enough to take a bit of time to load giving more - // opportunity for missed updates if there is a bug. - for i := 0; i < 1000; i++ { - req := &structs.RegisterRequest{ - Node: fmt.Sprintf("node-redis-%03d", i), - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: fmt.Sprintf("redis-%03d", i), - Service: "redis", - Port: 11211, - }, - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - } - - // Start background writer - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - go func() { - // Update the registration with a monotonically increasing port as fast as - // we can. - req := &structs.RegisterRequest{ - Node: "node1", - Address: "3.4.5.6", - Datacenter: "dc1", - Service: &structs.NodeService{ - ID: "redis-canary", - Service: "redis", - Port: 0, - }, - } - for { - if ctx.Err() != nil { - return - } - var out struct{} - require.NoError(server.RPC("Catalog.Register", &req, &out)) - req.Service.Port++ - if req.Service.Port > 100 { - return - } - time.Sleep(1 * time.Millisecond) - } - }() - - // Now start a whole bunch of streamers in parallel to maximise chance of - // catching a race. - conn, err := client.GRPCConn() - require.NoError(err) - - streamClient := pbsubscribe.NewConsulClient(conn) - - n := 5 - var wg sync.WaitGroup - var updateCount uint64 - // Buffered error chan so that workers can exit and terminate wg without - // blocking on send. We collect errors this way since t isn't thread safe. - errCh := make(chan error, n) - for i := 0; i < n; i++ { - wg.Add(1) - go verifyMonotonicStreamUpdates(ctx, t, streamClient, &wg, i, &updateCount, errCh) - } - - // Wait until all subscribers have verified the first bunch of updates all got - // delivered. - wg.Wait() - - close(errCh) - - // Require that none of them errored. Since we closed the chan above this loop - // should terminate immediately if no errors were buffered. - for err := range errCh { - require.NoError(err) - } - - // Sanity check that at least some non-snapshot messages were delivered. We - // can't know exactly how many because it's timing dependent based on when - // each subscribers snapshot occurs. - require.True(atomic.LoadUint64(&updateCount) > 0, - "at least some of the subscribers should have received non-snapshot updates") -} - -type testLogger interface { - Logf(format string, args ...interface{}) -} - -func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client pbsubscribe.StateChangeSubscriptionClient, wg *sync.WaitGroup, i int, updateCount *uint64, errCh chan<- error) { - defer wg.Done() - streamHandle, err := client.Subscribe(ctx, &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"}) - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - logger.Logf("subscriber %05d: context cancelled before loop") - return - } - errCh <- err - return - } - - snapshotDone := false - expectPort := 0 - for { - event, err := streamHandle.Recv() - if err == io.EOF { - break - } - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - break - } - errCh <- err - return - } - - // Ignore snapshot message - if event.GetEndOfSnapshot() || event.GetResumeStream() { - snapshotDone = true - logger.Logf("subscriber %05d: snapshot done, expect next port to be %d", i, expectPort) - } else if snapshotDone { - // Verify we get all updates in order - svc, err := svcOrErr(event) - if err != nil { - errCh <- err - return - } - if expectPort != svc.Port { - errCh <- fmt.Errorf("subscriber %05d: missed %d update(s)!", i, svc.Port-expectPort) - return - } - atomic.AddUint64(updateCount, 1) - logger.Logf("subscriber %05d: got event with correct port=%d", i, expectPort) - expectPort++ - } else { - // This is a snapshot update. Check if it's an update for the canary - // instance that got applied before our snapshot was sent (likely) - svc, err := svcOrErr(event) - if err != nil { - errCh <- err - return - } - if svc.ID == "redis-canary" { - // Update the expected port we see in the next update to be one more - // than the port in the snapshot. - expectPort = svc.Port + 1 - logger.Logf("subscriber %05d: saw canary in snapshot with port %d", i, svc.Port) - } - } - if expectPort > 100 { - return - } - } -} - -func svcOrErr(event *pbsubscribe.Event) (*pbservice.NodeService, error) { - health := event.GetServiceHealth() - if health == nil { - return nil, fmt.Errorf("not a health event: %#v", event) - } - csn := health.CheckServiceNode - if csn == nil { - return nil, fmt.Errorf("nil CSN: %#v", event) - } - if csn.Service == nil { - return nil, fmt.Errorf("nil service: %#v", event) - } - return csn.Service, nil -} -*/ - func logError(t *testing.T, f func() error) func() { return func() { if err := f(); err != nil {