open-vault/vault/rollback_test.go

367 lines
9.6 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"context"
"fmt"
"sync"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// mockRollback returns a mock rollback manager
func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
backend := new(NoopBackend)
mounts := new(MountTable)
router := NewRouter()
core, _, _ := TestCoreUnsealed(t)
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
mounts.Entries = []*MountEntry{
{
Path: "foo",
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
},
}
meUUID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID, Accessor: "noopaccessor", NamespaceID: namespace.RootNamespaceID, namespace: namespace.RootNamespace}, view); err != nil {
t.Fatalf("err: %s", err)
}
mountsFunc := func() []*MountEntry {
return mounts.Entries
}
logger := logging.NewVaultLogger(log.Trace)
rb := NewRollbackManager(context.Background(), logger, mountsFunc, router, core)
rb.period = 10 * time.Millisecond
return rb, backend
}
func TestRollbackManager(t *testing.T) {
m, backend := mockRollback(t)
if len(backend.Paths) > 0 {
t.Fatalf("bad: %#v", backend)
}
m.Start()
time.Sleep(50 * time.Millisecond)
m.Stop()
count := len(backend.Paths)
if count == 0 {
t.Fatalf("bad: %#v", backend)
}
if backend.Paths[0] != "" {
t.Fatalf("bad: %#v", backend)
}
time.Sleep(50 * time.Millisecond)
if count != len(backend.Paths) {
t.Fatalf("should stop requests: %#v", backend)
}
}
// TestRollbackManager_ManyWorkers adds 10 backends that require a rollback
// operation, with 20 workers. The test verifies that the 10
// work items will run in parallel
func TestRollbackManager_ManyWorkers(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")
ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)
// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're not bounding the number of workers, so we would expect to see
// all 10 writes to the channel from each of the backends. Once that
// happens, close the release channel so that the functions can exit
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
if len(got) == 10 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})
// start a goroutine to consume the remaining items from the queued work
go func() {
for {
select {
case <-ran:
case <-done:
return
}
}
}()
// stop the rollback worker, which will wait for all inflight rollbacks to
// complete
core.rollback.Stop()
close(done)
}
// TestRollbackManager_WorkerPool adds 10 backends that require a rollback
// operation, with 5 workers. The test verifies that the 5 work items can occur
// concurrently, and that the remainder of the work is queued and run when
// workers are available
func TestRollbackManager_WorkerPool(t *testing.T) {
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10})
view := NewBarrierView(core.barrier, "logical/")
ran := make(chan string)
release := make(chan struct{})
core, _, _ = testCoreUnsealed(t, core)
// create 10 backends
// when a rollback happens, each backend will try to write to an unbuffered
// channel, then wait to be released
for i := 0; i < 10; i++ {
b := &NoopBackend{}
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
if request.Operation == logical.RollbackOperation {
ran <- request.Path
<-release
}
return nil, nil
}
b.Root = []string{fmt.Sprintf("foo/%d", i)}
meUUID, err := uuid.GenerateUUID()
require.NoError(t, err)
mountEntry := &MountEntry{
Table: mountTableType,
UUID: meUUID,
Accessor: fmt.Sprintf("accessor-%d", i),
NamespaceID: namespace.RootNamespaceID,
namespace: namespace.RootNamespace,
Path: fmt.Sprintf("logical/foo/%d", i),
}
func() {
core.mountsLock.Lock()
defer core.mountsLock.Unlock()
newTable := core.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, mountEntry)
core.mounts = newTable
err = core.router.Mount(b, "logical", mountEntry, view)
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
}()
}
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
got := make(map[string]bool)
hasMore := true
for hasMore {
// we're using 5 workers, so we would expect to see 5 writes to the
// channel. Once that happens, close the release channel so that the
// functions can exit and new rollback operations can run
select {
case <-timeout.Done():
require.Fail(t, "test timed out")
case i := <-ran:
got[i] = true
numGot := len(got)
if numGot == 5 {
close(release)
hasMore = false
}
}
}
done := make(chan struct{})
defer close(done)
// start a goroutine to consume the remaining items from the queued work
gotAllPaths := make(chan struct{})
go func() {
channelClosed := false
for {
select {
case i := <-ran:
got[i] = true
// keep this goroutine running even after there are 10 paths.
// More rollback operations might get queued before Stop() is
// called, and we don't want them to block on writing the to the
// ran channel
if len(got) == 10 && !channelClosed {
close(gotAllPaths)
channelClosed = true
}
case <-timeout.Done():
require.Fail(t, "test timed out")
case <-done:
return
}
}
}()
// wait until all 10 backends have each ran at least once
<-gotAllPaths
// stop the rollback worker, which will wait for any inflight rollbacks to
// complete
core.rollback.Stop()
}
// TestRollbackManager_numRollbackWorkers verifies that the number of rollback
// workers is parsed from the configuration, but can be overridden by an
// environment variable. This test cannot be run in parallel because of the
// environment variable
func TestRollbackManager_numRollbackWorkers(t *testing.T) {
testCases := []struct {
name string
configWorkers int
setEnvVar bool
envVar string
wantWorkers int
}{
{
name: "default in config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
},
{
name: "invalid envvar",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "invalid",
},
{
name: "envvar overrides config",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: 20,
setEnvVar: true,
envVar: "20",
},
{
name: "envvar negative",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "-1",
},
{
name: "envvar zero",
configWorkers: RollbackDefaultNumWorkers,
wantWorkers: RollbackDefaultNumWorkers,
setEnvVar: true,
envVar: "0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.setEnvVar {
t.Setenv(RollbackWorkersEnvVar, tc.envVar)
}
core := &Core{numRollbackWorkers: tc.configWorkers}
r := &RollbackManager{logger: logger.Named("test"), core: core}
require.Equal(t, tc.wantWorkers, r.numRollbackWorkers())
})
}
}
func TestRollbackManager_Join(t *testing.T) {
m, backend := mockRollback(t)
if len(backend.Paths) > 0 {
t.Fatalf("bad: %#v", backend)
}
m.Start()
defer m.Stop()
wg := &sync.WaitGroup{}
wg.Add(3)
errCh := make(chan error, 3)
go func() {
defer wg.Done()
err := m.Rollback(namespace.RootContext(nil), "foo")
if err != nil {
errCh <- err
}
}()
go func() {
defer wg.Done()
err := m.Rollback(namespace.RootContext(nil), "foo")
if err != nil {
errCh <- err
}
}()
go func() {
defer wg.Done()
err := m.Rollback(namespace.RootContext(nil), "foo")
if err != nil {
errCh <- err
}
}()
wg.Wait()
close(errCh)
err := <-errCh
if err != nil {
t.Fatalf("Error on rollback:%v", err)
}
}