diff --git a/changelog/22567.txt b/changelog/22567.txt new file mode 100644 index 000000000..d9e557013 --- /dev/null +++ b/changelog/22567.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core: Use a worker pool for the rollback manager. Add new metrics for the rollback manager to track the queued tasks. +``` \ No newline at end of file diff --git a/go.mod b/go.mod index 849d64a85..f183b25f5 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/fatih/color v1.15.0 github.com/fatih/structs v1.1.0 github.com/favadi/protoc-go-inject-tag v1.4.0 + github.com/gammazero/workerpool v1.1.3 github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 github.com/go-errors/errors v1.4.2 github.com/go-jose/go-jose/v3 v3.0.0 @@ -334,7 +335,6 @@ require ( github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gammazero/deque v0.2.1 // indirect - github.com/gammazero/workerpool v1.1.3 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/go-ldap/ldif v0.0.0-20200320164324-fd88d9b715b3 // indirect github.com/go-logr/logr v1.2.4 // indirect diff --git a/vault/core.go b/vault/core.go index 50d7641a9..2767d6bf4 100644 --- a/vault/core.go +++ b/vault/core.go @@ -692,7 +692,8 @@ type Core struct { // heartbeating with the active node. Default to the current SDK version. effectiveSDKVersion string - rollbackPeriod time.Duration + numRollbackWorkers int + rollbackPeriod time.Duration experiments []string @@ -879,6 +880,8 @@ type CoreConfig struct { AdministrativeNamespacePath string UserLockoutLogInterval time.Duration + + NumRollbackWorkers int } // SubloggerHook implements the SubloggerAdder interface. This implementation @@ -971,6 +974,9 @@ func CreateCore(conf *CoreConfig) (*Core, error) { conf.NumExpirationWorkers = numExpirationWorkersDefault } + if conf.NumRollbackWorkers == 0 { + conf.NumRollbackWorkers = RollbackDefaultNumWorkers + } // Use imported logging deadlock if requested var stateLock locking.RWMutex if strings.Contains(conf.DetectDeadlocks, "statelock") { @@ -1055,6 +1061,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { experiments: conf.Experiments, pendingRemovalMountsAllowed: conf.PendingRemovalMountsAllowed, expirationRevokeRetryBase: conf.ExpirationRevokeRetryBase, + numRollbackWorkers: conf.NumRollbackWorkers, impreciseLeaseRoleTracking: conf.ImpreciseLeaseRoleTracking, } diff --git a/vault/rollback.go b/vault/rollback.go index 922a8a709..aa35b814a 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -6,16 +6,25 @@ package vault import ( "context" "errors" + "fmt" + "os" + "strconv" "strings" "sync" "time" metrics "github.com/armon/go-metrics" + "github.com/gammazero/workerpool" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) +const ( + RollbackDefaultNumWorkers = 256 + RollbackWorkersEnvVar = "VAULT_ROLLBACK_WORKERS" +) + // RollbackManager is responsible for performing rollbacks of partial // secrets within logical backends. // @@ -50,8 +59,10 @@ type RollbackManager struct { stopTicker chan struct{} tickerIsStopped bool quitContext context.Context - - core *Core + runner *workerpool.WorkerPool + core *Core + // This channel is used for testing + rollbacksDoneCh chan struct{} } // rollbackState is used to track the state of a single rollback attempt @@ -60,6 +71,9 @@ type rollbackState struct { sync.WaitGroup cancelLockGrabCtx context.Context cancelLockGrabCtxCancel context.CancelFunc + // scheduled is the time that this job was created and submitted to the + // rollbackRunner + scheduled time.Time } // NewRollbackManager is used to create a new rollback manager @@ -76,9 +90,26 @@ func NewRollbackManager(ctx context.Context, logger log.Logger, backendsFunc fun quitContext: ctx, core: core, } + numWorkers := r.numRollbackWorkers() + r.logger.Info(fmt.Sprintf("Starting the rollback manager with %d workers", numWorkers)) + r.runner = workerpool.New(numWorkers) return r } +func (m *RollbackManager) numRollbackWorkers() int { + numWorkers := m.core.numRollbackWorkers + envOverride := os.Getenv(RollbackWorkersEnvVar) + if envOverride != "" { + envVarWorkers, err := strconv.Atoi(envOverride) + if err != nil || envVarWorkers < 1 { + m.logger.Warn(fmt.Sprintf("%s must be a positive integer, but was %s", RollbackWorkersEnvVar, envOverride)) + } else { + numWorkers = envVarWorkers + } + } + return numWorkers +} + // Start starts the rollback manager func (m *RollbackManager) Start() { go m.run() @@ -94,7 +125,7 @@ func (m *RollbackManager) Stop() { close(m.shutdownCh) <-m.doneCh } - m.inflightAll.Wait() + m.runner.StopWait() } // StopTicker stops the automatic Rollback manager's ticker, causing us @@ -164,6 +195,8 @@ func (m *RollbackManager) triggerRollbacks() { func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { m.inflightLock.Lock() defer m.inflightLock.Unlock() + defer metrics.SetGauge([]string{"rollback", "queued"}, float32(m.runner.WaitingQueueSize())) + defer metrics.SetGauge([]string{"rollback", "inflight"}, float32(len(m.inflight))) rsInflight, ok := m.inflight[fullPath] if ok { return rsInflight @@ -179,22 +212,44 @@ func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath st m.inflight[fullPath] = rs rs.Add(1) m.inflightAll.Add(1) - go m.attemptRollback(ctx, fullPath, rs, grabStatelock) + rs.scheduled = time.Now() + select { + case <-m.doneCh: + // if we've already shut down, then don't submit the task to avoid a panic + // we should still call finishRollback for the rollback state in order to remove + // it from the map and decrement the waitgroup. + + // we already have the inflight lock, so we can't grab it here + m.finishRollback(rs, errors.New("rollback manager is stopped"), fullPath, false) + default: + m.runner.Submit(func() { + m.attemptRollback(ctx, fullPath, rs, grabStatelock) + select { + case m.rollbacksDoneCh <- struct{}{}: + default: + } + }) + + } return rs } +func (m *RollbackManager) finishRollback(rs *rollbackState, err error, fullPath string, grabInflightLock bool) { + rs.lastError = err + rs.Done() + m.inflightAll.Done() + if grabInflightLock { + m.inflightLock.Lock() + defer m.inflightLock.Unlock() + } + delete(m.inflight, fullPath) +} + // attemptRollback invokes a RollbackOperation for the given path func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) { + metrics.MeasureSince([]string{"rollback", "waiting"}, rs.scheduled) defer metrics.MeasureSince([]string{"rollback", "attempt", strings.ReplaceAll(fullPath, "/", "-")}, time.Now()) - - defer func() { - rs.lastError = err - rs.Done() - m.inflightAll.Done() - m.inflightLock.Lock() - delete(m.inflight, fullPath) - m.inflightLock.Unlock() - }() + defer m.finishRollback(rs, err, fullPath, true) ns, err := namespace.FromContext(ctx) if err != nil { diff --git a/vault/rollback_test.go b/vault/rollback_test.go index 8eb457c12..f67ef7c94 100644 --- a/vault/rollback_test.go +++ b/vault/rollback_test.go @@ -5,6 +5,7 @@ package vault import ( "context" + "fmt" "sync" "testing" "time" @@ -13,6 +14,8 @@ import ( "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" ) // mockRollback returns a mock rollback manager @@ -77,6 +80,247 @@ func TestRollbackManager(t *testing.T) { } } +// TestRollbackManager_ManyWorkers adds 10 backends that require a rollback +// operation, with 20 workers. The test verifies that the 10 +// work items will run in parallel +func TestRollbackManager_ManyWorkers(t *testing.T) { + core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10}) + view := NewBarrierView(core.barrier, "logical/") + + ran := make(chan string) + release := make(chan struct{}) + core, _, _ = testCoreUnsealed(t, core) + + // create 10 backends + // when a rollback happens, each backend will try to write to an unbuffered + // channel, then wait to be released + for i := 0; i < 10; i++ { + b := &NoopBackend{} + b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { + if request.Operation == logical.RollbackOperation { + ran <- request.Path + <-release + } + return nil, nil + } + b.Root = []string{fmt.Sprintf("foo/%d", i)} + meUUID, err := uuid.GenerateUUID() + require.NoError(t, err) + mountEntry := &MountEntry{ + Table: mountTableType, + UUID: meUUID, + Accessor: fmt.Sprintf("accessor-%d", i), + NamespaceID: namespace.RootNamespaceID, + namespace: namespace.RootNamespace, + Path: fmt.Sprintf("logical/foo/%d", i), + } + func() { + core.mountsLock.Lock() + defer core.mountsLock.Unlock() + newTable := core.mounts.shallowClone() + newTable.Entries = append(newTable.Entries, mountEntry) + core.mounts = newTable + err = core.router.Mount(b, "logical", mountEntry, view) + require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) + }() + } + + timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + got := make(map[string]bool) + hasMore := true + for hasMore { + // we're not bounding the number of workers, so we would expect to see + // all 10 writes to the channel from each of the backends. Once that + // happens, close the release channel so that the functions can exit + select { + case <-timeout.Done(): + require.Fail(t, "test timed out") + case i := <-ran: + got[i] = true + if len(got) == 10 { + close(release) + hasMore = false + } + } + } + done := make(chan struct{}) + + // start a goroutine to consume the remaining items from the queued work + go func() { + for { + select { + case <-ran: + case <-done: + return + } + } + }() + // stop the rollback worker, which will wait for all inflight rollbacks to + // complete + core.rollback.Stop() + close(done) +} + +// TestRollbackManager_WorkerPool adds 10 backends that require a rollback +// operation, with 5 workers. The test verifies that the 5 work items can occur +// concurrently, and that the remainder of the work is queued and run when +// workers are available +func TestRollbackManager_WorkerPool(t *testing.T) { + core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10}) + view := NewBarrierView(core.barrier, "logical/") + + ran := make(chan string) + release := make(chan struct{}) + core, _, _ = testCoreUnsealed(t, core) + + // create 10 backends + // when a rollback happens, each backend will try to write to an unbuffered + // channel, then wait to be released + for i := 0; i < 10; i++ { + b := &NoopBackend{} + b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { + if request.Operation == logical.RollbackOperation { + ran <- request.Path + <-release + } + return nil, nil + } + b.Root = []string{fmt.Sprintf("foo/%d", i)} + meUUID, err := uuid.GenerateUUID() + require.NoError(t, err) + mountEntry := &MountEntry{ + Table: mountTableType, + UUID: meUUID, + Accessor: fmt.Sprintf("accessor-%d", i), + NamespaceID: namespace.RootNamespaceID, + namespace: namespace.RootNamespace, + Path: fmt.Sprintf("logical/foo/%d", i), + } + func() { + core.mountsLock.Lock() + defer core.mountsLock.Unlock() + newTable := core.mounts.shallowClone() + newTable.Entries = append(newTable.Entries, mountEntry) + core.mounts = newTable + err = core.router.Mount(b, "logical", mountEntry, view) + require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) + }() + } + + timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + got := make(map[string]bool) + hasMore := true + for hasMore { + // we're using 5 workers, so we would expect to see 5 writes to the + // channel. Once that happens, close the release channel so that the + // functions can exit and new rollback operations can run + select { + case <-timeout.Done(): + require.Fail(t, "test timed out") + case i := <-ran: + got[i] = true + numGot := len(got) + if numGot == 5 { + close(release) + hasMore = false + } + } + } + done := make(chan struct{}) + defer close(done) + + // start a goroutine to consume the remaining items from the queued work + gotAllPaths := make(chan struct{}) + go func() { + channelClosed := false + for { + select { + case i := <-ran: + got[i] = true + + // keep this goroutine running even after there are 10 paths. + // More rollback operations might get queued before Stop() is + // called, and we don't want them to block on writing the to the + // ran channel + if len(got) == 10 && !channelClosed { + close(gotAllPaths) + channelClosed = true + } + case <-timeout.Done(): + require.Fail(t, "test timed out") + case <-done: + return + } + } + }() + + // wait until all 10 backends have each ran at least once + <-gotAllPaths + // stop the rollback worker, which will wait for any inflight rollbacks to + // complete + core.rollback.Stop() +} + +// TestRollbackManager_numRollbackWorkers verifies that the number of rollback +// workers is parsed from the configuration, but can be overridden by an +// environment variable. This test cannot be run in parallel because of the +// environment variable +func TestRollbackManager_numRollbackWorkers(t *testing.T) { + testCases := []struct { + name string + configWorkers int + setEnvVar bool + envVar string + wantWorkers int + }{ + { + name: "default in config", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + }, + { + name: "invalid envvar", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "invalid", + }, + { + name: "envvar overrides config", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: 20, + setEnvVar: true, + envVar: "20", + }, + { + name: "envvar negative", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "-1", + }, + { + name: "envvar zero", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "0", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.setEnvVar { + t.Setenv(RollbackWorkersEnvVar, tc.envVar) + } + core := &Core{numRollbackWorkers: tc.configWorkers} + r := &RollbackManager{logger: logger.Named("test"), core: core} + require.Equal(t, tc.wantWorkers, r.numRollbackWorkers()) + }) + } +} + func TestRollbackManager_Join(t *testing.T) { m, backend := mockRollback(t) if len(backend.Paths) > 0 { diff --git a/vault/testing.go b/vault/testing.go index 38144ca28..649268a1b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -243,6 +243,12 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core { for k, v := range opts.AuditBackends { conf.AuditBackends[k] = v } + if opts.RollbackPeriod != time.Duration(0) { + conf.RollbackPeriod = opts.RollbackPeriod + } + if opts.NumRollbackWorkers != 0 { + conf.NumRollbackWorkers = opts.NumRollbackWorkers + } conf.ActivityLogConfig = opts.ActivityLogConfig testApplyEntBaseConfig(conf, opts) @@ -299,6 +305,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo CredentialBackends: credentialBackends, DisableMlock: true, Logger: logger, + NumRollbackWorkers: 10, BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), } diff --git a/website/content/docs/internals/telemetry/metrics/all.mdx b/website/content/docs/internals/telemetry/metrics/all.mdx index 67e1a8ec8..8099dea76 100644 --- a/website/content/docs/internals/telemetry/metrics/all.mdx +++ b/website/content/docs/internals/telemetry/metrics/all.mdx @@ -616,6 +616,12 @@ alphabetic order by name. @include 'telemetry-metrics/vault/rollback/attempt/mountpoint.mdx' +@include 'telemetry-metrics/vault/rollback/inflight.mdx' + +@include 'telemetry-metrics/vault/rollback/queued.mdx' + +@include 'telemetry-metrics/vault/rollback/waiting.mdx' + @include 'telemetry-metrics/vault/route/create/mountpoint.mdx' @include 'telemetry-metrics/vault/route/delete/mountpoint.mdx' @@ -722,4 +728,4 @@ alphabetic order by name. @include 'telemetry-metrics/vault/zookeeper/list.mdx' -@include 'telemetry-metrics/vault/zookeeper/put.mdx' \ No newline at end of file +@include 'telemetry-metrics/vault/zookeeper/put.mdx' diff --git a/website/content/docs/internals/telemetry/metrics/core-system.mdx b/website/content/docs/internals/telemetry/metrics/core-system.mdx index eea0e91f9..1039c6f19 100644 --- a/website/content/docs/internals/telemetry/metrics/core-system.mdx +++ b/website/content/docs/internals/telemetry/metrics/core-system.mdx @@ -112,6 +112,12 @@ Vault instance. @include 'telemetry-metrics/vault/rollback/attempt/mountpoint.mdx' +@include 'telemetry-metrics/vault/rollback/inflight.mdx' + +@include 'telemetry-metrics/vault/rollback/queued.mdx' + +@include 'telemetry-metrics/vault/rollback/waiting.mdx' + ## Route metrics @include 'telemetry-metrics/route-intro.mdx' @@ -146,4 +152,4 @@ Vault instance. @include 'telemetry-metrics/vault/runtime/total_gc_pause_ns.mdx' -@include 'telemetry-metrics/vault/runtime/total_gc_runs.mdx' \ No newline at end of file +@include 'telemetry-metrics/vault/runtime/total_gc_runs.mdx' diff --git a/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx b/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx new file mode 100644 index 000000000..832cb3088 --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx @@ -0,0 +1,5 @@ +### vault.rollback.inflight ((#vault-rollback-inflight)) + +Metric type | Value | Description +----------- | ------ | ----------- +gauge | number | Number of rollback operations inflight diff --git a/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx b/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx new file mode 100644 index 000000000..e8a7d099f --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx @@ -0,0 +1,5 @@ +### vault.rollback.queued ((#vault-rollback-queued)) + +Metric type | Value | Description +----------- | ------ | ----------- +guage | number | The number of rollback operations waiting to be started diff --git a/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx b/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx new file mode 100644 index 000000000..2fb0e2eab --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx @@ -0,0 +1,5 @@ +### vault.rollback.waiting ((#vault-rollback-waiting)) + +Metric type | Value | Description +----------- | ----- | ----------- +summary | ms | Time between queueing a rollback operation and the operation starting