backport of commit 4e3b91d91f379b6368e778849c044fadfa7e67e5 (#23691)

* backport of commit 4e3b91d91f379b6368e778849c044fadfa7e67e5

* workerpool implementation

* rollback tests

* website documentation

* add changelog

* fix failing test

* backport of commit de043d673692e91bdb82f0decb5dfa316dcbc48a

* fix flaky rollback test

* better fix

* switch to defer

* add comment

---------

Co-authored-by: miagilepner <mia.epner@hashicorp.com>
This commit is contained in:
hc-github-team-secure-vault-core 2023-10-17 08:33:54 -04:00 committed by GitHub
parent f3e2841fcd
commit ea40c49f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 360 additions and 17 deletions

3
changelog/22567.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
core: Use a worker pool for the rollback manager. Add new metrics for the rollback manager to track the queued tasks.
```

2
go.mod
View File

@ -56,6 +56,7 @@ require (
github.com/fatih/color v1.15.0
github.com/fatih/structs v1.1.0
github.com/favadi/protoc-go-inject-tag v1.4.0
github.com/gammazero/workerpool v1.1.3
github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32
github.com/go-errors/errors v1.4.2
github.com/go-jose/go-jose/v3 v3.0.0
@ -334,7 +335,6 @@ require (
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gammazero/deque v0.2.1 // indirect
github.com/gammazero/workerpool v1.1.3 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
github.com/go-ldap/ldif v0.0.0-20200320164324-fd88d9b715b3 // indirect
github.com/go-logr/logr v1.2.4 // indirect

View File

@ -692,7 +692,8 @@ type Core struct {
// heartbeating with the active node. Default to the current SDK version.
effectiveSDKVersion string
rollbackPeriod time.Duration
numRollbackWorkers int
rollbackPeriod time.Duration
experiments []string
@ -879,6 +880,8 @@ type CoreConfig struct {
AdministrativeNamespacePath string
UserLockoutLogInterval time.Duration
NumRollbackWorkers int
}
// SubloggerHook implements the SubloggerAdder interface. This implementation
@ -971,6 +974,9 @@ func CreateCore(conf *CoreConfig) (*Core, error) {
conf.NumExpirationWorkers = numExpirationWorkersDefault
}
if conf.NumRollbackWorkers == 0 {
conf.NumRollbackWorkers = RollbackDefaultNumWorkers
}
// Use imported logging deadlock if requested
var stateLock locking.RWMutex
if strings.Contains(conf.DetectDeadlocks, "statelock") {
@ -1055,6 +1061,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) {
experiments: conf.Experiments,
pendingRemovalMountsAllowed: conf.PendingRemovalMountsAllowed,
expirationRevokeRetryBase: conf.ExpirationRevokeRetryBase,
numRollbackWorkers: conf.NumRollbackWorkers,
impreciseLeaseRoleTracking: conf.ImpreciseLeaseRoleTracking,
}

View File

@ -6,16 +6,25 @@ package vault
import (
"context"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
metrics "github.com/armon/go-metrics"
"github.com/gammazero/workerpool"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)
const (
RollbackDefaultNumWorkers = 256
RollbackWorkersEnvVar = "VAULT_ROLLBACK_WORKERS"
)
// RollbackManager is responsible for performing rollbacks of partial
// secrets within logical backends.
//
@ -50,8 +59,10 @@ type RollbackManager struct {
stopTicker chan struct{}
tickerIsStopped bool
quitContext context.Context
core *Core
runner *workerpool.WorkerPool
core *Core
// This channel is used for testing
rollbacksDoneCh chan struct{}
}
// rollbackState is used to track the state of a single rollback attempt
@ -60,6 +71,9 @@ type rollbackState struct {
sync.WaitGroup
cancelLockGrabCtx context.Context
cancelLockGrabCtxCancel context.CancelFunc
// scheduled is the time that this job was created and submitted to the
// rollbackRunner
scheduled time.Time
}
// NewRollbackManager is used to create a new rollback manager
@ -76,9 +90,26 @@ func NewRollbackManager(ctx context.Context, logger log.Logger, backendsFunc fun
quitContext: ctx,
core: core,
}
numWorkers := r.numRollbackWorkers()
r.logger.Info(fmt.Sprintf("Starting the rollback manager with %d workers", numWorkers))
r.runner = workerpool.New(numWorkers)
return r
}
func (m *RollbackManager) numRollbackWorkers() int {
numWorkers := m.core.numRollbackWorkers
envOverride := os.Getenv(RollbackWorkersEnvVar)
if envOverride != "" {
envVarWorkers, err := strconv.Atoi(envOverride)
if err != nil || envVarWorkers < 1 {
m.logger.Warn(fmt.Sprintf("%s must be a positive integer, but was %s", RollbackWorkersEnvVar, envOverride))
} else {
numWorkers = envVarWorkers
}
}
return numWorkers
}
// Start starts the rollback manager
func (m *RollbackManager) Start() {
go m.run()
@ -94,7 +125,7 @@ func (m *RollbackManager) Stop() {
close(m.shutdownCh)
<-m.doneCh
}
m.inflightAll.Wait()
m.runner.StopWait()
}
// StopTicker stops the automatic Rollback manager's ticker, causing us
@ -164,6 +195,8 @@ func (m *RollbackManager) triggerRollbacks() {
func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState {
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
defer metrics.SetGauge([]string{"rollback", "queued"}, float32(m.runner.WaitingQueueSize()))
defer metrics.SetGauge([]string{"rollback", "inflight"}, float32(len(m.inflight)))
rsInflight, ok := m.inflight[fullPath]
if ok {
return rsInflight
@ -179,22 +212,44 @@ func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath st
m.inflight[fullPath] = rs
rs.Add(1)
m.inflightAll.Add(1)
go m.attemptRollback(ctx, fullPath, rs, grabStatelock)
rs.scheduled = time.Now()
select {
case <-m.doneCh:
// if we've already shut down, then don't submit the task to avoid a panic
// we should still call finishRollback for the rollback state in order to remove
// it from the map and decrement the waitgroup.
// we already have the inflight lock, so we can't grab it here
m.finishRollback(rs, errors.New("rollback manager is stopped"), fullPath, false)
default:
m.runner.Submit(func() {
m.attemptRollback(ctx, fullPath, rs, grabStatelock)
select {
case m.rollbacksDoneCh <- struct{}{}:
default:
}
})
}
return rs
}
func (m *RollbackManager) finishRollback(rs *rollbackState, err error, fullPath string, grabInflightLock bool) {
rs.lastError = err
rs.Done()
m.inflightAll.Done()
if grabInflightLock {
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
}
delete(m.inflight, fullPath)
}
// attemptRollback invokes a RollbackOperation for the given path
func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) {
metrics.MeasureSince([]string{"rollback", "waiting"}, rs.scheduled)
defer metrics.MeasureSince([]string{"rollback", "attempt", strings.ReplaceAll(fullPath, "/", "-")}, time.Now())
defer func() {
rs.lastError = err
rs.Done()
m.inflightAll.Done()
m.inflightLock.Lock()
delete(m.inflight, fullPath)
m.inflightLock.Unlock()
}()
defer m.finishRollback(rs, err, fullPath, true)
ns, err := namespace.FromContext(ctx)
if err != nil {

View File

@ -5,6 +5,7 @@ package vault
import (
"context"
"fmt"
"sync"
"testing"
"time"
@ -13,6 +14,8 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// mockRollback returns a mock rollback manager
@ -77,6 +80,247 @@ func TestRollbackManager(t *testing.T) {
}
}
// TestRollbackManager_ManyWorkers adds 10 backends that require a rollback
// operation, with 20 workers. The test verifies that the 10
// work items will run in parallel
func TestRollbackManager_ManyWorkers(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")
ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)
// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're not bounding the number of workers, so we would expect to see
// all 10 writes to the channel from each of the backends. Once that
// happens, close the release channel so that the functions can exit
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
if len(got) == 10 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})
// start a goroutine to consume the remaining items from the queued work
go func() {
for {
select {
case <-ran:
case <-done:
return
}
}
}()
// stop the rollback worker, which will wait for all inflight rollbacks to
// complete
core.rollback.Stop()
close(done)
}
// TestRollbackManager_WorkerPool adds 10 backends that require a rollback
// operation, with 5 workers. The test verifies that the 5 work items can occur
// concurrently, and that the remainder of the work is queued and run when
// workers are available
func TestRollbackManager_WorkerPool(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")
ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)
// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're using 5 workers, so we would expect to see 5 writes to the
// channel. Once that happens, close the release channel so that the
// functions can exit and new rollback operations can run
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
numGot := len(got)
if numGot == 5 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})
defer close(done)
// start a goroutine to consume the remaining items from the queued work
gotAllPaths := make(chan struct{})
go func() {
channelClosed := false
for {
select {
case i := <-ran:
got[i] = true
// keep this goroutine running even after there are 10 paths.
// More rollback operations might get queued before Stop() is
// called, and we don't want them to block on writing the to the
// ran channel
if len(got) == 10 && !channelClosed {
close(gotAllPaths)
channelClosed = true
}
case <-timeout.Done():
require.Fail(t, "test timed out")
case <-done:
return
}
}
}()
// wait until all 10 backends have each ran at least once
<-gotAllPaths
// stop the rollback worker, which will wait for any inflight rollbacks to
// complete
core.rollback.Stop()
}
// TestRollbackManager_numRollbackWorkers verifies that the number of rollback
// workers is parsed from the configuration, but can be overridden by an
// environment variable. This test cannot be run in parallel because of the
// environment variable
func TestRollbackManager_numRollbackWorkers(t *testing.T) {
testCases := []struct {
name string
configWorkers int
setEnvVar bool
envVar string
wantWorkers int
}{
{
name: "default in config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
},
{
name: "invalid envvar",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "invalid",
},
{
name: "envvar overrides config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: 20,
setEnvVar: true,
envVar: "20",
},
{
name: "envvar negative",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "-1",
},
{
name: "envvar zero",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.setEnvVar {
t.Setenv(RollbackWorkersEnvVar, tc.envVar)
}
core := &Core{numRollbackWorkers: tc.configWorkers}
r := &RollbackManager{logger: logger.Named("test"), core: core}
require.Equal(t, tc.wantWorkers, r.numRollbackWorkers())
})
}
}
func TestRollbackManager_Join(t *testing.T) {
m, backend := mockRollback(t)
if len(backend.Paths) > 0 {

View File

@ -243,6 +243,12 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core {
for k, v := range opts.AuditBackends {
conf.AuditBackends[k] = v
}
if opts.RollbackPeriod != time.Duration(0) {
conf.RollbackPeriod = opts.RollbackPeriod
}
if opts.NumRollbackWorkers != 0 {
conf.NumRollbackWorkers = opts.NumRollbackWorkers
}
conf.ActivityLogConfig = opts.ActivityLogConfig
testApplyEntBaseConfig(conf, opts)
@ -299,6 +305,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo
CredentialBackends: credentialBackends,
DisableMlock: true,
Logger: logger,
NumRollbackWorkers: 10,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
}

View File

@ -616,6 +616,12 @@ alphabetic order by name.
@include 'telemetry-metrics/vault/rollback/attempt/mountpoint.mdx'
@include 'telemetry-metrics/vault/rollback/inflight.mdx'
@include 'telemetry-metrics/vault/rollback/queued.mdx'
@include 'telemetry-metrics/vault/rollback/waiting.mdx'
@include 'telemetry-metrics/vault/route/create/mountpoint.mdx'
@include 'telemetry-metrics/vault/route/delete/mountpoint.mdx'
@ -722,4 +728,4 @@ alphabetic order by name.
@include 'telemetry-metrics/vault/zookeeper/list.mdx'
@include 'telemetry-metrics/vault/zookeeper/put.mdx'
@include 'telemetry-metrics/vault/zookeeper/put.mdx'

View File

@ -112,6 +112,12 @@ Vault instance.
@include 'telemetry-metrics/vault/rollback/attempt/mountpoint.mdx'
@include 'telemetry-metrics/vault/rollback/inflight.mdx'
@include 'telemetry-metrics/vault/rollback/queued.mdx'
@include 'telemetry-metrics/vault/rollback/waiting.mdx'
## Route metrics
@include 'telemetry-metrics/route-intro.mdx'
@ -146,4 +152,4 @@ Vault instance.
@include 'telemetry-metrics/vault/runtime/total_gc_pause_ns.mdx'
@include 'telemetry-metrics/vault/runtime/total_gc_runs.mdx'
@include 'telemetry-metrics/vault/runtime/total_gc_runs.mdx'

View File

@ -0,0 +1,5 @@
### vault.rollback.inflight ((#vault-rollback-inflight))
Metric type | Value | Description
----------- | ------ | -----------
gauge | number | Number of rollback operations inflight

View File

@ -0,0 +1,5 @@
### vault.rollback.queued ((#vault-rollback-queued))
Metric type | Value | Description
----------- | ------ | -----------
guage | number | The number of rollback operations waiting to be started

View File

@ -0,0 +1,5 @@
### vault.rollback.waiting ((#vault-rollback-waiting))
Metric type | Value | Description
----------- | ----- | -----------
summary | ms | Time between queueing a rollback operation and the operation starting