diff --git a/agent/consul/acl_oss.go b/agent/consul/acl_oss.go index be493e5a3..6808ad4f7 100644 --- a/agent/consul/acl_oss.go +++ b/agent/consul/acl_oss.go @@ -12,6 +12,8 @@ import ( // EnterpriseACLResolverDelegate stub type EnterpriseACLResolverDelegate interface{} +func (s *Server) fillReplicationEnterpriseMeta(_ *structs.EnterpriseMeta) {} + func newEnterpriseACLConfig(*log.Logger) *acl.EnterpriseACLConfig { return nil } diff --git a/agent/consul/acl_replication.go b/agent/consul/acl_replication.go index 2e0cb1e99..50e8644b9 100644 --- a/agent/consul/acl_replication.go +++ b/agent/consul/acl_replication.go @@ -112,6 +112,7 @@ func (s *Server) fetchACLRoles(lastRemoteIndex uint64) (*structs.ACLRoleListResp Token: s.tokens.ReplicationToken(), }, } + s.fillReplicationEnterpriseMeta(&req.EnterpriseMeta) var response structs.ACLRoleListResponse if err := s.RPC("ACL.RoleList", &req, &response); err != nil { @@ -149,6 +150,7 @@ func (s *Server) fetchACLPolicies(lastRemoteIndex uint64) (*structs.ACLPolicyLis Token: s.tokens.ReplicationToken(), }, } + s.fillReplicationEnterpriseMeta(&req.EnterpriseMeta) var response structs.ACLPolicyListResponse if err := s.RPC("ACL.PolicyList", &req, &response); err != nil { @@ -342,6 +344,7 @@ func (s *Server) fetchACLTokens(lastRemoteIndex uint64) (*structs.ACLTokenListRe IncludeLocal: false, IncludeGlobal: true, } + s.fillReplicationEnterpriseMeta(&req.EnterpriseMeta) var response structs.ACLTokenListResponse if err := s.RPC("ACL.TokenList", &req, &response); err != nil { diff --git a/agent/consul/acl_replication_types.go b/agent/consul/acl_replication_types.go index a7a703dac..97d0d1316 100644 --- a/agent/consul/acl_replication_types.go +++ b/agent/consul/acl_replication_types.go @@ -34,7 +34,10 @@ func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (i func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) { r.local = nil - idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "", structs.ReplicationEnterpriseMeta()) + var entMeta structs.EnterpriseMeta + srv.fillReplicationEnterpriseMeta(&entMeta) + + idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "", &entMeta) if err != nil { return 0, 0, err } @@ -155,7 +158,10 @@ func (r *aclPolicyReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) ( func (r *aclPolicyReplicator) FetchLocal(srv *Server) (int, uint64, error) { r.local = nil - idx, local, err := srv.fsm.State().ACLPolicyList(nil, structs.ReplicationEnterpriseMeta()) + var entMeta structs.EnterpriseMeta + srv.fillReplicationEnterpriseMeta(&entMeta) + + idx, local, err := srv.fsm.State().ACLPolicyList(nil, &entMeta) if err != nil { return 0, 0, err } @@ -265,7 +271,10 @@ func (r *aclRoleReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (in func (r *aclRoleReplicator) FetchLocal(srv *Server) (int, uint64, error) { r.local = nil - idx, local, err := srv.fsm.State().ACLRoleList(nil, "", nil) + var entMeta structs.EnterpriseMeta + srv.fillReplicationEnterpriseMeta(&entMeta) + + idx, local, err := srv.fsm.State().ACLRoleList(nil, "", &entMeta) if err != nil { return 0, 0, err } diff --git a/agent/consul/replication.go b/agent/consul/replication.go index 1aae6d7ac..bf0eb1eeb 100644 --- a/agent/consul/replication.go +++ b/agent/consul/replication.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/lib" "golang.org/x/time/rate" ) @@ -20,11 +21,15 @@ const ( replicationDefaultRate = 1 ) +type ReplicatorDelegate interface { + Replicate(ctx context.Context, lastRemoteIndex uint64) (index uint64, exit bool, err error) +} + type ReplicatorConfig struct { // Name to be used in various logging Name string - // Function to perform the actual replication - ReplicateFn ReplicatorFunc + // Delegate to perform each round of replication + Delegate ReplicatorDelegate // The number of replication rounds per second that are allowed Rate int // The number of replication rounds that can be done in a burst @@ -37,13 +42,11 @@ type ReplicatorConfig struct { Logger *log.Logger } -type ReplicatorFunc func(ctx context.Context, lastRemoteIndex uint64) (index uint64, exit bool, err error) - type Replicator struct { name string limiter *rate.Limiter waiter *lib.RetryWaiter - replicateFn ReplicatorFunc + delegate ReplicatorDelegate logger *log.Logger lastRemoteIndex uint64 } @@ -52,8 +55,8 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) { if config == nil { return nil, fmt.Errorf("Cannot create the Replicator without a config") } - if config.ReplicateFn == nil { - return nil, fmt.Errorf("Cannot create the Replicator without a ReplicateFn set in the config") + if config.Delegate == nil { + return nil, fmt.Errorf("Cannot create the Replicator without a Delegate set in the config") } if config.Logger == nil { config.Logger = log.New(os.Stderr, "", log.LstdFlags) @@ -71,11 +74,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) { } waiter := lib.NewRetryWaiter(minFailures, 0*time.Second, maxWait, lib.NewJitterRandomStagger(10)) return &Replicator{ - name: config.Name, - limiter: limiter, - waiter: waiter, - replicateFn: config.ReplicateFn, - logger: config.Logger, + name: config.Name, + limiter: limiter, + waiter: waiter, + delegate: config.Delegate, + logger: config.Logger, }, nil } @@ -91,7 +94,7 @@ func (r *Replicator) Run(ctx context.Context) error { } // Perform a single round of replication - index, exit, err := r.replicateFn(ctx, atomic.LoadUint64(&r.lastRemoteIndex)) + index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex)) if exit { // the replication function told us to exit return nil @@ -120,3 +123,143 @@ func (r *Replicator) Run(ctx context.Context) error { func (r *Replicator) Index() uint64 { return atomic.LoadUint64(&r.lastRemoteIndex) } + +type ReplicatorFunc func(ctx context.Context, lastRemoteIndex uint64) (index uint64, exit bool, err error) + +type FunctionReplicator struct { + ReplicateFn ReplicatorFunc +} + +func (r *FunctionReplicator) Replicate(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + return r.ReplicateFn(ctx, lastRemoteIndex) +} + +type IndexReplicatorDiff struct { + NumUpdates int + Updates interface{} + NumDeletions int + Deletions interface{} +} + +type IndexReplicatorDelegate interface { + // SingularNoun is the singular form of the item being replicated. + SingularNoun() string + + // PluralNoun is the plural form of the item being replicated. + PluralNoun() string + + // Name to use when emitting metrics + MetricName() string + + // FetchRemote retrieves items newer than the provided index from the + // remote datacenter (for diffing purposes). + FetchRemote(lastRemoteIndex uint64) (int, interface{}, uint64, error) + + // FetchLocal retrieves items from the current datacenter (for diffing + // purposes). + FetchLocal() (int, interface{}, error) + + DiffRemoteAndLocalState(local interface{}, remote interface{}, lastRemoteIndex uint64) (*IndexReplicatorDiff, error) + + PerformDeletions(ctx context.Context, deletions interface{}) (exit bool, err error) + + PerformUpdates(ctx context.Context, updates interface{}) (exit bool, err error) +} + +type IndexReplicator struct { + Delegate IndexReplicatorDelegate + Logger *log.Logger +} + +func (r *IndexReplicator) Replicate(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + fetchStart := time.Now() + lenRemote, remote, remoteIndex, err := r.Delegate.FetchRemote(lastRemoteIndex) + metrics.MeasureSince([]string{"leader", "replication", r.Delegate.MetricName(), "fetch"}, fetchStart) + + if err != nil { + return 0, false, fmt.Errorf("failed to retrieve %s: %v", r.Delegate.PluralNoun(), err) + } + + r.Logger.Printf("[DEBUG] replication: finished fetching %s: %d", r.Delegate.PluralNoun(), lenRemote) + + // Need to check if we should be stopping. This will be common as the fetching process is a blocking + // RPC which could have been hanging around for a long time and during that time leadership could + // have been lost. + select { + case <-ctx.Done(): + return 0, true, nil + default: + // do nothing + } + + // Measure everything after the remote query, which can block for long + // periods of time. This metric is a good measure of how expensive the + // replication process is. + defer metrics.MeasureSince([]string{"leader", "replication", r.Delegate.MetricName(), "apply"}, time.Now()) + + lenLocal, local, err := r.Delegate.FetchLocal() + if err != nil { + return 0, false, fmt.Errorf("failed to retrieve local %s: %v", r.Delegate.PluralNoun(), err) + } + + // If the remote index ever goes backwards, it's a good indication that + // the remote side was rebuilt and we should do a full sync since we + // can't make any assumptions about what's going on. + // + // Resetting lastRemoteIndex to 0 will work because we never consider local + // raft indices. Instead we compare the raft modify index in the response object + // with the lastRemoteIndex (only when we already have a config entry of the same kind/name) + // to determine if an update is needed. Resetting lastRemoteIndex to 0 then has the affect + // of making us think all the local state is out of date and any matching entries should + // still be updated. + // + // The lastRemoteIndex is not used when the entry exists either only in the local state or + // only in the remote state. In those situations we need to either delete it or create it. + if remoteIndex < lastRemoteIndex { + r.Logger.Printf("[WARN] replication: %[1]s replication remote index moved backwards (%d to %d), forcing a full %[1]s sync", r.Delegate.SingularNoun(), lastRemoteIndex, remoteIndex) + lastRemoteIndex = 0 + } + + r.Logger.Printf("[DEBUG] replication: %s replication - local: %d, remote: %d", r.Delegate.SingularNoun(), lenLocal, lenRemote) + + // Calculate the changes required to bring the state into sync and then + // apply them. + diff, err := r.Delegate.DiffRemoteAndLocalState(local, remote, lastRemoteIndex) + if err != nil { + return 0, false, fmt.Errorf("failed to diff %s local and remote states: %v", r.Delegate.SingularNoun(), err) + } + + r.Logger.Printf("[DEBUG] replication: %s replication - deletions: %d, updates: %d", r.Delegate.SingularNoun(), diff.NumDeletions, diff.NumUpdates) + + if diff.NumDeletions > 0 { + r.Logger.Printf("[DEBUG] replication: %s replication - performing %d deletions", r.Delegate.SingularNoun(), diff.NumDeletions) + + exit, err := r.Delegate.PerformDeletions(ctx, diff.Deletions) + if exit { + return 0, true, nil + } + + if err != nil { + return 0, false, fmt.Errorf("failed to apply local %s deletions: %v", r.Delegate.SingularNoun(), err) + } + r.Logger.Printf("[DEBUG] replication: %s replication - finished deletions", r.Delegate.SingularNoun()) + } + + if diff.NumUpdates > 0 { + r.Logger.Printf("[DEBUG] replication: %s replication - performing %d updates", r.Delegate.SingularNoun(), diff.NumUpdates) + + exit, err := r.Delegate.PerformUpdates(ctx, diff.Updates) + if exit { + return 0, true, nil + } + + if err != nil { + return 0, false, fmt.Errorf("failed to apply local %s updates: %v", r.Delegate.SingularNoun(), err) + } + r.Logger.Printf("[DEBUG] replication: %s replication - finished updates", r.Delegate.SingularNoun()) + } + + // Return the index we got back from the remote side, since we've synced + // up with the remote state as of that index. + return remoteIndex, false, nil +} diff --git a/agent/consul/replication_test.go b/agent/consul/replication_test.go index 145a7635c..cbb88cc16 100644 --- a/agent/consul/replication_test.go +++ b/agent/consul/replication_test.go @@ -2,9 +2,11 @@ package consul import ( "context" + "fmt" "testing" "github.com/hashicorp/consul/sdk/testutil" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -13,8 +15,10 @@ func TestReplicationRestart(t *testing.T) { config := ReplicatorConfig{ Name: "mock", - ReplicateFn: func(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { - return 1, false, nil + Delegate: &FunctionReplicator{ + ReplicateFn: func(ctx context.Context, lastRemoteIndex uint64) (uint64, bool, error) { + return 1, false, nil + }, }, Rate: 1, @@ -30,3 +34,239 @@ func TestReplicationRestart(t *testing.T) { // Previously this would have segfaulted mgr.Stop("mock") } + +type indexReplicatorTestDelegate struct { + mock.Mock +} + +func (d *indexReplicatorTestDelegate) SingularNoun() string { + return "test" +} + +func (d *indexReplicatorTestDelegate) PluralNoun() string { + return "tests" +} + +func (d *indexReplicatorTestDelegate) MetricName() string { + return "test" +} + +func (d *indexReplicatorTestDelegate) FetchRemote(lastRemoteIndex uint64) (int, interface{}, uint64, error) { + ret := d.Called(lastRemoteIndex) + return ret.Int(0), ret.Get(1), ret.Get(2).(uint64), ret.Error(3) +} + +func (d *indexReplicatorTestDelegate) FetchLocal() (int, interface{}, error) { + ret := d.Called() + return ret.Int(0), ret.Get(1), ret.Error(2) +} + +func (d *indexReplicatorTestDelegate) DiffRemoteAndLocalState(local interface{}, remote interface{}, lastRemoteIndex uint64) (*IndexReplicatorDiff, error) { + ret := d.Called(local, remote, lastRemoteIndex) + return ret.Get(0).(*IndexReplicatorDiff), ret.Error(1) +} + +func (d *indexReplicatorTestDelegate) PerformDeletions(ctx context.Context, deletions interface{}) (exit bool, err error) { + // ignore the context for the call + ret := d.Called(deletions) + return ret.Bool(0), ret.Error(1) +} + +func (d *indexReplicatorTestDelegate) PerformUpdates(ctx context.Context, updates interface{}) (exit bool, err error) { + // ignore the context for the call + ret := d.Called(updates) + return ret.Bool(0), ret.Error(1) +} + +func TestIndexReplicator(t *testing.T) { + t.Parallel() + + t.Run("Remote Fetch Error", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(0)).Return(0, nil, uint64(0), fmt.Errorf("induced error")) + + idx, done, err := replicator.Replicate(context.Background(), 0) + + require.Equal(t, uint64(0), idx) + require.False(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to retrieve tests: induced error") + delegate.AssertExpectations(t) + }) + + t.Run("Local Fetch Error", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(1), nil) + delegate.On("FetchLocal").Return(0, nil, fmt.Errorf("induced error")) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.False(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to retrieve local tests: induced error") + delegate.AssertExpectations(t) + }) + + t.Run("Diff Error", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(1), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + // this also is verifying that when the remote index goes backwards then we reset the index to 0 + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(0)).Return(&IndexReplicatorDiff{}, fmt.Errorf("induced error")) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.False(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to diff test local and remote states: induced error") + delegate.AssertExpectations(t) + }) + + t.Run("No Change", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(4), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(3)).Return(&IndexReplicatorDiff{}, nil) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(4), idx) + require.False(t, done) + require.NoError(t, err) + delegate.AssertExpectations(t) + }) + + t.Run("Deletion Error", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(4), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(3)).Return(&IndexReplicatorDiff{NumDeletions: 1}, nil) + delegate.On("PerformDeletions", nil).Return(false, fmt.Errorf("induced error")) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.False(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to apply local test deletions: induced error") + delegate.AssertExpectations(t) + }) + + t.Run("Deletion Exit", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(4), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(3)).Return(&IndexReplicatorDiff{NumDeletions: 1}, nil) + delegate.On("PerformDeletions", nil).Return(true, nil) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.True(t, done) + require.NoError(t, err) + delegate.AssertExpectations(t) + }) + + t.Run("Update Error", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(4), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(3)).Return(&IndexReplicatorDiff{NumUpdates: 1}, nil) + delegate.On("PerformUpdates", nil).Return(false, fmt.Errorf("induced error")) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.False(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to apply local test updates: induced error") + delegate.AssertExpectations(t) + }) + + t.Run("Update Exit", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(1, nil, uint64(4), nil) + delegate.On("FetchLocal").Return(1, nil, nil) + delegate.On("DiffRemoteAndLocalState", nil, nil, uint64(3)).Return(&IndexReplicatorDiff{NumUpdates: 1}, nil) + delegate.On("PerformUpdates", nil).Return(true, nil) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(0), idx) + require.True(t, done) + require.NoError(t, err) + delegate.AssertExpectations(t) + }) + + t.Run("All Good", func(t *testing.T) { + delegate := &indexReplicatorTestDelegate{} + + replicator := IndexReplicator{ + Delegate: delegate, + Logger: testutil.TestLogger(t), + } + + delegate.On("FetchRemote", uint64(3)).Return(3, "bcd", uint64(4), nil) + delegate.On("FetchLocal").Return(1, "a", nil) + delegate.On("DiffRemoteAndLocalState", "a", "bcd", uint64(3)).Return(&IndexReplicatorDiff{NumDeletions: 1, Deletions: "a", NumUpdates: 3, Updates: "bcd"}, nil) + delegate.On("PerformDeletions", "a").Return(false, nil) + delegate.On("PerformUpdates", "bcd").Return(false, nil) + + idx, done, err := replicator.Replicate(context.Background(), 3) + + require.Equal(t, uint64(4), idx) + require.False(t, done) + require.NoError(t, err) + delegate.AssertExpectations(t) + }) +} diff --git a/agent/consul/server.go b/agent/consul/server.go index 8bcc8ed0b..1b6f1087c 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -382,11 +382,11 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store, tl s.rpcLimiter.Store(rate.NewLimiter(config.RPCRate, config.RPCMaxBurst)) configReplicatorConfig := ReplicatorConfig{ - Name: "Config Entry", - ReplicateFn: s.replicateConfig, - Rate: s.config.ConfigReplicationRate, - Burst: s.config.ConfigReplicationBurst, - Logger: logger, + Name: "Config Entry", + Delegate: &FunctionReplicator{ReplicateFn: s.replicateConfig}, + Rate: s.config.ConfigReplicationRate, + Burst: s.config.ConfigReplicationBurst, + Logger: logger, } s.configReplicator, err = NewReplicator(&configReplicatorConfig) if err != nil {