367 lines
9.6 KiB
Go
367 lines
9.6 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package vault
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// mockRollback returns a mock rollback manager
|
|
func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
|
|
backend := new(NoopBackend)
|
|
mounts := new(MountTable)
|
|
router := NewRouter()
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
_, barrier, _ := mockBarrier(t)
|
|
view := NewBarrierView(barrier, "logical/")
|
|
|
|
mounts.Entries = []*MountEntry{
|
|
{
|
|
Path: "foo",
|
|
NamespaceID: namespace.RootNamespaceID,
|
|
namespace: namespace.RootNamespace,
|
|
},
|
|
}
|
|
meUUID, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID, Accessor: "noopaccessor", NamespaceID: namespace.RootNamespaceID, namespace: namespace.RootNamespace}, view); err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
mountsFunc := func() []*MountEntry {
|
|
return mounts.Entries
|
|
}
|
|
|
|
logger := logging.NewVaultLogger(log.Trace)
|
|
|
|
rb := NewRollbackManager(context.Background(), logger, mountsFunc, router, core)
|
|
rb.period = 10 * time.Millisecond
|
|
return rb, backend
|
|
}
|
|
|
|
func TestRollbackManager(t *testing.T) {
|
|
m, backend := mockRollback(t)
|
|
if len(backend.Paths) > 0 {
|
|
t.Fatalf("bad: %#v", backend)
|
|
}
|
|
|
|
m.Start()
|
|
time.Sleep(50 * time.Millisecond)
|
|
m.Stop()
|
|
|
|
count := len(backend.Paths)
|
|
if count == 0 {
|
|
t.Fatalf("bad: %#v", backend)
|
|
}
|
|
if backend.Paths[0] != "" {
|
|
t.Fatalf("bad: %#v", backend)
|
|
}
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if count != len(backend.Paths) {
|
|
t.Fatalf("should stop requests: %#v", backend)
|
|
}
|
|
}
|
|
|
|
// TestRollbackManager_ManyWorkers adds 10 backends that require a rollback
|
|
// operation, with 20 workers. The test verifies that the 10
|
|
// work items will run in parallel
|
|
func TestRollbackManager_ManyWorkers(t *testing.T) {
|
|
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10})
|
|
view := NewBarrierView(core.barrier, "logical/")
|
|
|
|
ran := make(chan string)
|
|
release := make(chan struct{})
|
|
core, _, _ = testCoreUnsealed(t, core)
|
|
|
|
// create 10 backends
|
|
// when a rollback happens, each backend will try to write to an unbuffered
|
|
// channel, then wait to be released
|
|
for i := 0; i < 10; i++ {
|
|
b := &NoopBackend{}
|
|
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
|
|
if request.Operation == logical.RollbackOperation {
|
|
ran <- request.Path
|
|
<-release
|
|
}
|
|
return nil, nil
|
|
}
|
|
b.Root = []string{fmt.Sprintf("foo/%d", i)}
|
|
meUUID, err := uuid.GenerateUUID()
|
|
require.NoError(t, err)
|
|
mountEntry := &MountEntry{
|
|
Table: mountTableType,
|
|
UUID: meUUID,
|
|
Accessor: fmt.Sprintf("accessor-%d", i),
|
|
NamespaceID: namespace.RootNamespaceID,
|
|
namespace: namespace.RootNamespace,
|
|
Path: fmt.Sprintf("logical/foo/%d", i),
|
|
}
|
|
func() {
|
|
core.mountsLock.Lock()
|
|
defer core.mountsLock.Unlock()
|
|
newTable := core.mounts.shallowClone()
|
|
newTable.Entries = append(newTable.Entries, mountEntry)
|
|
core.mounts = newTable
|
|
err = core.router.Mount(b, "logical", mountEntry, view)
|
|
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
|
|
}()
|
|
}
|
|
|
|
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
|
defer cancel()
|
|
got := make(map[string]bool)
|
|
hasMore := true
|
|
for hasMore {
|
|
// we're not bounding the number of workers, so we would expect to see
|
|
// all 10 writes to the channel from each of the backends. Once that
|
|
// happens, close the release channel so that the functions can exit
|
|
select {
|
|
case <-timeout.Done():
|
|
require.Fail(t, "test timed out")
|
|
case i := <-ran:
|
|
got[i] = true
|
|
if len(got) == 10 {
|
|
close(release)
|
|
hasMore = false
|
|
}
|
|
}
|
|
}
|
|
done := make(chan struct{})
|
|
|
|
// start a goroutine to consume the remaining items from the queued work
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ran:
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
// stop the rollback worker, which will wait for all inflight rollbacks to
|
|
// complete
|
|
core.rollback.Stop()
|
|
close(done)
|
|
}
|
|
|
|
// TestRollbackManager_WorkerPool adds 10 backends that require a rollback
|
|
// operation, with 5 workers. The test verifies that the 5 work items can occur
|
|
// concurrently, and that the remainder of the work is queued and run when
|
|
// workers are available
|
|
func TestRollbackManager_WorkerPool(t *testing.T) {
|
|
core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10})
|
|
view := NewBarrierView(core.barrier, "logical/")
|
|
|
|
ran := make(chan string)
|
|
release := make(chan struct{})
|
|
core, _, _ = testCoreUnsealed(t, core)
|
|
|
|
// create 10 backends
|
|
// when a rollback happens, each backend will try to write to an unbuffered
|
|
// channel, then wait to be released
|
|
for i := 0; i < 10; i++ {
|
|
b := &NoopBackend{}
|
|
b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) {
|
|
if request.Operation == logical.RollbackOperation {
|
|
ran <- request.Path
|
|
<-release
|
|
}
|
|
return nil, nil
|
|
}
|
|
b.Root = []string{fmt.Sprintf("foo/%d", i)}
|
|
meUUID, err := uuid.GenerateUUID()
|
|
require.NoError(t, err)
|
|
mountEntry := &MountEntry{
|
|
Table: mountTableType,
|
|
UUID: meUUID,
|
|
Accessor: fmt.Sprintf("accessor-%d", i),
|
|
NamespaceID: namespace.RootNamespaceID,
|
|
namespace: namespace.RootNamespace,
|
|
Path: fmt.Sprintf("logical/foo/%d", i),
|
|
}
|
|
func() {
|
|
core.mountsLock.Lock()
|
|
defer core.mountsLock.Unlock()
|
|
newTable := core.mounts.shallowClone()
|
|
newTable.Entries = append(newTable.Entries, mountEntry)
|
|
core.mounts = newTable
|
|
err = core.router.Mount(b, "logical", mountEntry, view)
|
|
require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local))
|
|
}()
|
|
}
|
|
|
|
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
|
defer cancel()
|
|
got := make(map[string]bool)
|
|
hasMore := true
|
|
for hasMore {
|
|
// we're using 5 workers, so we would expect to see 5 writes to the
|
|
// channel. Once that happens, close the release channel so that the
|
|
// functions can exit and new rollback operations can run
|
|
select {
|
|
case <-timeout.Done():
|
|
require.Fail(t, "test timed out")
|
|
case i := <-ran:
|
|
got[i] = true
|
|
numGot := len(got)
|
|
if numGot == 5 {
|
|
close(release)
|
|
hasMore = false
|
|
}
|
|
}
|
|
}
|
|
done := make(chan struct{})
|
|
defer close(done)
|
|
|
|
// start a goroutine to consume the remaining items from the queued work
|
|
gotAllPaths := make(chan struct{})
|
|
go func() {
|
|
channelClosed := false
|
|
for {
|
|
select {
|
|
case i := <-ran:
|
|
got[i] = true
|
|
|
|
// keep this goroutine running even after there are 10 paths.
|
|
// More rollback operations might get queued before Stop() is
|
|
// called, and we don't want them to block on writing the to the
|
|
// ran channel
|
|
if len(got) == 10 && !channelClosed {
|
|
close(gotAllPaths)
|
|
channelClosed = true
|
|
}
|
|
case <-timeout.Done():
|
|
require.Fail(t, "test timed out")
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// wait until all 10 backends have each ran at least once
|
|
<-gotAllPaths
|
|
// stop the rollback worker, which will wait for any inflight rollbacks to
|
|
// complete
|
|
core.rollback.Stop()
|
|
}
|
|
|
|
// TestRollbackManager_numRollbackWorkers verifies that the number of rollback
|
|
// workers is parsed from the configuration, but can be overridden by an
|
|
// environment variable. This test cannot be run in parallel because of the
|
|
// environment variable
|
|
func TestRollbackManager_numRollbackWorkers(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
configWorkers int
|
|
setEnvVar bool
|
|
envVar string
|
|
wantWorkers int
|
|
}{
|
|
{
|
|
name: "default in config",
|
|
configWorkers: RollbackDefaultNumWorkers,
|
|
wantWorkers: RollbackDefaultNumWorkers,
|
|
},
|
|
{
|
|
name: "invalid envvar",
|
|
configWorkers: RollbackDefaultNumWorkers,
|
|
wantWorkers: RollbackDefaultNumWorkers,
|
|
setEnvVar: true,
|
|
envVar: "invalid",
|
|
},
|
|
{
|
|
name: "envvar overrides config",
|
|
configWorkers: RollbackDefaultNumWorkers,
|
|
wantWorkers: 20,
|
|
setEnvVar: true,
|
|
envVar: "20",
|
|
},
|
|
{
|
|
name: "envvar negative",
|
|
configWorkers: RollbackDefaultNumWorkers,
|
|
wantWorkers: RollbackDefaultNumWorkers,
|
|
setEnvVar: true,
|
|
envVar: "-1",
|
|
},
|
|
{
|
|
name: "envvar zero",
|
|
configWorkers: RollbackDefaultNumWorkers,
|
|
wantWorkers: RollbackDefaultNumWorkers,
|
|
setEnvVar: true,
|
|
envVar: "0",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.setEnvVar {
|
|
t.Setenv(RollbackWorkersEnvVar, tc.envVar)
|
|
}
|
|
core := &Core{numRollbackWorkers: tc.configWorkers}
|
|
r := &RollbackManager{logger: logger.Named("test"), core: core}
|
|
require.Equal(t, tc.wantWorkers, r.numRollbackWorkers())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRollbackManager_Join(t *testing.T) {
|
|
m, backend := mockRollback(t)
|
|
if len(backend.Paths) > 0 {
|
|
t.Fatalf("bad: %#v", backend)
|
|
}
|
|
|
|
m.Start()
|
|
defer m.Stop()
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(3)
|
|
|
|
errCh := make(chan error, 3)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := m.Rollback(namespace.RootContext(nil), "foo")
|
|
if err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
err := m.Rollback(namespace.RootContext(nil), "foo")
|
|
if err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
err := m.Rollback(namespace.RootContext(nil), "foo")
|
|
if err != nil {
|
|
errCh <- err
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
close(errCh)
|
|
err := <-errCh
|
|
if err != nil {
|
|
t.Fatalf("Error on rollback:%v", err)
|
|
}
|
|
}
|