backport of commit 7bd5c6e84eef890cebdb404d9cb2e281919d4529 (#18555)
Co-authored-by: Daniel Bennett <dbennett@hashicorp.com>
This commit is contained in:
parent
a2f56797a0
commit
a6ecf954b0
|
@ -4,29 +4,24 @@
|
||||||
package allocrunner
|
package allocrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/nomad/ci"
|
"github.com/hashicorp/nomad/ci"
|
||||||
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
||||||
"github.com/hashicorp/nomad/client/allocrunner/state"
|
"github.com/hashicorp/nomad/client/allocrunner/state"
|
||||||
"github.com/hashicorp/nomad/client/pluginmanager"
|
|
||||||
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
|
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
|
||||||
cstructs "github.com/hashicorp/nomad/client/structs"
|
cstructs "github.com/hashicorp/nomad/client/structs"
|
||||||
"github.com/hashicorp/nomad/helper/pointer"
|
"github.com/hashicorp/nomad/helper/pointer"
|
||||||
"github.com/hashicorp/nomad/helper/testlog"
|
"github.com/hashicorp/nomad/helper/testlog"
|
||||||
"github.com/hashicorp/nomad/nomad/mock"
|
"github.com/hashicorp/nomad/nomad/mock"
|
||||||
"github.com/hashicorp/nomad/nomad/structs"
|
"github.com/hashicorp/nomad/nomad/structs"
|
||||||
"github.com/hashicorp/nomad/plugins/csi"
|
|
||||||
"github.com/hashicorp/nomad/plugins/drivers"
|
"github.com/hashicorp/nomad/plugins/drivers"
|
||||||
|
"github.com/hashicorp/nomad/testutil"
|
||||||
"github.com/shoenig/test/must"
|
"github.com/shoenig/test/must"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
|
var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
|
||||||
|
@ -71,7 +66,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{
|
expectedCalls: map[string]int{
|
||||||
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
|
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -92,7 +87,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{
|
expectedCalls: map[string]int{
|
||||||
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
|
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -137,7 +132,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{
|
expectedCalls: map[string]int{
|
||||||
"claim": 2, "mount": 1, "unmount": 1, "unpublish": 1},
|
"claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "already mounted",
|
name: "already mounted",
|
||||||
|
@ -163,7 +158,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
expectedMounts: map[string]*csimanager.MountInfo{
|
expectedMounts: map[string]*csimanager.MountInfo{
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1},
|
expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "existing but invalid mounts",
|
name: "existing but invalid mounts",
|
||||||
|
@ -190,7 +185,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{
|
expectedCalls: map[string]int{
|
||||||
"hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
|
"HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -212,7 +207,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
"vol0": &csimanager.MountInfo{Source: testMountSrc},
|
||||||
},
|
},
|
||||||
expectedCalls: map[string]int{
|
expectedCalls: map[string]int{
|
||||||
"claim": 1, "mount": 1, "unmount": 2, "unpublish": 2},
|
"claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -227,12 +222,11 @@ func TestCSIHook(t *testing.T) {
|
||||||
|
|
||||||
alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests
|
alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests
|
||||||
|
|
||||||
callCounts := &callCounter{counts: map[string]int{}}
|
callCounts := testutil.NewCallCounter()
|
||||||
mgr := mockPluginManager{mounter: mockVolumeManager{
|
vm := &csimanager.MockVolumeManager{
|
||||||
hasMounts: tc.startsWithValidMounts,
|
CallCounter: callCounts,
|
||||||
callCounts: callCounts,
|
}
|
||||||
failsFirstUnmount: pointer.Of(tc.failsFirstUnmount),
|
mgr := &csimanager.MockCSIManager{VM: vm}
|
||||||
}}
|
|
||||||
rpcer := mockRPCer{
|
rpcer := mockRPCer{
|
||||||
alloc: alloc,
|
alloc: alloc,
|
||||||
callCounts: callCounts,
|
callCounts: callCounts,
|
||||||
|
@ -255,6 +249,17 @@ func TestCSIHook(t *testing.T) {
|
||||||
|
|
||||||
must.NotNil(t, hook)
|
must.NotNil(t, hook)
|
||||||
|
|
||||||
|
if tc.startsWithValidMounts {
|
||||||
|
// TODO: this works, but it requires knowledge of how the mock works. would rather vm.MountVolume()
|
||||||
|
vm.Mounts = map[string]bool{
|
||||||
|
tc.expectedMounts["vol0"].Source: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.failsFirstUnmount {
|
||||||
|
vm.NextUnmountVolumeErr = errors.New("bad first attempt")
|
||||||
|
}
|
||||||
|
|
||||||
if tc.expectedClaimErr != nil {
|
if tc.expectedClaimErr != nil {
|
||||||
must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
|
must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
|
||||||
mounts := ar.res.GetCSIMounts()
|
mounts := ar.res.GetCSIMounts()
|
||||||
|
@ -274,7 +279,7 @@ func TestCSIHook(t *testing.T) {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
counts := callCounts.get()
|
counts := callCounts.Get()
|
||||||
must.MapEq(t, tc.expectedCalls, counts,
|
must.MapEq(t, tc.expectedCalls, counts,
|
||||||
must.Sprintf("got calls: %v", counts))
|
must.Sprintf("got calls: %v", counts))
|
||||||
|
|
||||||
|
@ -342,14 +347,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
alloc.Job.TaskGroups[0].Volumes = volumeRequests
|
alloc.Job.TaskGroups[0].Volumes = volumeRequests
|
||||||
|
|
||||||
callCounts := &callCounter{counts: map[string]int{}}
|
mgr := &csimanager.MockCSIManager{
|
||||||
mgr := mockPluginManager{mounter: mockVolumeManager{
|
VM: &csimanager.MockVolumeManager{},
|
||||||
callCounts: callCounts,
|
}
|
||||||
failsFirstUnmount: pointer.Of(false),
|
|
||||||
}}
|
|
||||||
rpcer := mockRPCer{
|
rpcer := mockRPCer{
|
||||||
alloc: alloc,
|
alloc: alloc,
|
||||||
callCounts: callCounts,
|
callCounts: testutil.NewCallCounter(),
|
||||||
hasExistingClaim: pointer.Of(false),
|
hasExistingClaim: pointer.Of(false),
|
||||||
schedulable: pointer.Of(true),
|
schedulable: pointer.Of(true),
|
||||||
}
|
}
|
||||||
|
@ -379,26 +382,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
|
||||||
|
|
||||||
// HELPERS AND MOCKS
|
// HELPERS AND MOCKS
|
||||||
|
|
||||||
type callCounter struct {
|
|
||||||
lock sync.Mutex
|
|
||||||
counts map[string]int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callCounter) inc(name string) {
|
|
||||||
c.lock.Lock()
|
|
||||||
defer c.lock.Unlock()
|
|
||||||
c.counts[name]++
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *callCounter) get() map[string]int {
|
|
||||||
c.lock.Lock()
|
|
||||||
defer c.lock.Unlock()
|
|
||||||
return maps.Clone(c.counts)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockRPCer struct {
|
type mockRPCer struct {
|
||||||
alloc *structs.Allocation
|
alloc *structs.Allocation
|
||||||
callCounts *callCounter
|
callCounts *testutil.CallCounter
|
||||||
hasExistingClaim *bool
|
hasExistingClaim *bool
|
||||||
schedulable *bool
|
schedulable *bool
|
||||||
}
|
}
|
||||||
|
@ -407,7 +393,7 @@ type mockRPCer struct {
|
||||||
func (r mockRPCer) RPC(method string, args any, reply any) error {
|
func (r mockRPCer) RPC(method string, args any, reply any) error {
|
||||||
switch method {
|
switch method {
|
||||||
case "CSIVolume.Claim":
|
case "CSIVolume.Claim":
|
||||||
r.callCounts.inc("claim")
|
r.callCounts.Inc("claim")
|
||||||
req := args.(*structs.CSIVolumeClaimRequest)
|
req := args.(*structs.CSIVolumeClaimRequest)
|
||||||
vol := r.testVolume(req.VolumeID)
|
vol := r.testVolume(req.VolumeID)
|
||||||
err := vol.Claim(req.ToClaim(), r.alloc)
|
err := vol.Claim(req.ToClaim(), r.alloc)
|
||||||
|
@ -427,7 +413,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error {
|
||||||
resp.QueryMeta = structs.QueryMeta{}
|
resp.QueryMeta = structs.QueryMeta{}
|
||||||
|
|
||||||
case "CSIVolume.Unpublish":
|
case "CSIVolume.Unpublish":
|
||||||
r.callCounts.inc("unpublish")
|
r.callCounts.Inc("unpublish")
|
||||||
resp := reply.(*structs.CSIVolumeUnpublishResponse)
|
resp := reply.(*structs.CSIVolumeUnpublishResponse)
|
||||||
resp.QueryMeta = structs.QueryMeta{}
|
resp.QueryMeta = structs.QueryMeta{}
|
||||||
|
|
||||||
|
@ -470,59 +456,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume {
|
||||||
return vol
|
return vol
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockVolumeManager struct {
|
|
||||||
hasMounts bool
|
|
||||||
failsFirstUnmount *bool
|
|
||||||
callCounts *callCounter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) {
|
|
||||||
vm.callCounts.inc("mount")
|
|
||||||
return &csimanager.MountInfo{
|
|
||||||
Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error {
|
|
||||||
vm.callCounts.inc("unmount")
|
|
||||||
|
|
||||||
if *vm.failsFirstUnmount {
|
|
||||||
*vm.failsFirstUnmount = false
|
|
||||||
return fmt.Errorf("could not unmount")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) {
|
|
||||||
vm.callCounts.inc("hasMount")
|
|
||||||
return mountInfo != nil && vm.hasMounts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vm mockVolumeManager) ExternalID() string {
|
|
||||||
return "i-example"
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockPluginManager struct {
|
|
||||||
mounter mockVolumeManager
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) {
|
|
||||||
return mgr.mounter, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// no-op methods to fulfill the interface
|
|
||||||
func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil }
|
|
||||||
func (mgr mockPluginManager) Shutdown() {}
|
|
||||||
|
|
||||||
type mockAllocRunner struct {
|
type mockAllocRunner struct {
|
||||||
res *cstructs.AllocHookResources
|
res *cstructs.AllocHookResources
|
||||||
caps *drivers.Capabilities
|
caps *drivers.Capabilities
|
||||||
|
|
|
@ -995,23 +995,21 @@ func TestCSINode_DetachVolume(t *testing.T) {
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
Name string
|
Name string
|
||||||
ClientSetupFunc func(*fake.Client)
|
ModManager func(m *csimanager.MockCSIManager)
|
||||||
Request *structs.ClientCSINodeDetachVolumeRequest
|
Request *structs.ClientCSINodeDetachVolumeRequest
|
||||||
ExpectedErr error
|
ExpectedErr error
|
||||||
ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
Name: "returns plugin not found errors",
|
Name: "success",
|
||||||
Request: &structs.ClientCSINodeDetachVolumeRequest{
|
Request: &structs.ClientCSINodeDetachVolumeRequest{
|
||||||
PluginID: "some-garbage",
|
PluginID: "fake-plugin",
|
||||||
VolumeID: "-",
|
VolumeID: "fake-vol",
|
||||||
AllocID: "-",
|
AllocID: "fake-alloc",
|
||||||
NodeID: "-",
|
NodeID: "fake-node",
|
||||||
AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
|
AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
|
||||||
AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader,
|
AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader,
|
||||||
ReadOnly: true,
|
ReadOnly: true,
|
||||||
},
|
},
|
||||||
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"),
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "validates volumeid is not empty",
|
Name: "validates volumeid is not empty",
|
||||||
|
@ -1029,43 +1027,51 @@ func TestCSINode_DetachVolume(t *testing.T) {
|
||||||
ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
|
ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "returns transitive errors",
|
Name: "returns csi manager errors",
|
||||||
ClientSetupFunc: func(fc *fake.Client) {
|
ModManager: func(m *csimanager.MockCSIManager) {
|
||||||
fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this")
|
m.NextManagerForPluginErr = errors.New("no plugin")
|
||||||
},
|
},
|
||||||
Request: &structs.ClientCSINodeDetachVolumeRequest{
|
Request: &structs.ClientCSINodeDetachVolumeRequest{
|
||||||
PluginID: fakeNodePlugin.Name,
|
PluginID: fakeNodePlugin.Name,
|
||||||
VolumeID: "1234-4321-1234-4321",
|
VolumeID: "1234-4321-1234-4321",
|
||||||
AllocID: "4321-1234-4321-1234",
|
AllocID: "4321-1234-4321-1234",
|
||||||
},
|
},
|
||||||
// we don't have a csimanager in this context
|
ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"),
|
||||||
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"),
|
},
|
||||||
|
{
|
||||||
|
Name: "returns volume manager errors",
|
||||||
|
ModManager: func(m *csimanager.MockCSIManager) {
|
||||||
|
m.VM.NextUnmountVolumeErr = errors.New("error unmounting")
|
||||||
|
},
|
||||||
|
Request: &structs.ClientCSINodeDetachVolumeRequest{
|
||||||
|
PluginID: fakeNodePlugin.Name,
|
||||||
|
VolumeID: "1234-4321-1234-4321",
|
||||||
|
AllocID: "4321-1234-4321-1234",
|
||||||
|
},
|
||||||
|
ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
require := require.New(t)
|
|
||||||
client, cleanup := TestClient(t, nil)
|
client, cleanup := TestClient(t, nil)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
fakeClient := &fake.Client{}
|
mockManager := &csimanager.MockCSIManager{
|
||||||
if tc.ClientSetupFunc != nil {
|
VM: &csimanager.MockVolumeManager{},
|
||||||
tc.ClientSetupFunc(fakeClient)
|
|
||||||
}
|
}
|
||||||
|
if tc.ModManager != nil {
|
||||||
dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) {
|
tc.ModManager(mockManager)
|
||||||
return fakeClient, nil
|
|
||||||
}
|
}
|
||||||
client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc)
|
client.csimanager = mockManager
|
||||||
err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin)
|
|
||||||
require.Nil(err)
|
|
||||||
|
|
||||||
var resp structs.ClientCSINodeDetachVolumeResponse
|
var resp structs.ClientCSINodeDetachVolumeResponse
|
||||||
err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
|
err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
|
||||||
require.Equal(tc.ExpectedErr, err)
|
if tc.ExpectedErr != nil {
|
||||||
if tc.ExpectedResponse != nil {
|
must.Error(t, err)
|
||||||
require.Equal(tc.ExpectedResponse, &resp)
|
must.EqError(t, tc.ExpectedErr, err.Error())
|
||||||
|
} else {
|
||||||
|
must.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,12 @@ package csimanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/hashicorp/nomad/client/pluginmanager"
|
"github.com/hashicorp/nomad/client/pluginmanager"
|
||||||
nstructs "github.com/hashicorp/nomad/nomad/structs"
|
nstructs "github.com/hashicorp/nomad/nomad/structs"
|
||||||
"github.com/hashicorp/nomad/plugins/csi"
|
"github.com/hashicorp/nomad/plugins/csi"
|
||||||
|
"github.com/hashicorp/nomad/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Manager = &MockCSIManager{}
|
var _ Manager = &MockCSIManager{}
|
||||||
|
@ -29,6 +31,9 @@ func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) {
|
func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) {
|
||||||
|
if m.VM == nil {
|
||||||
|
m.VM = &MockVolumeManager{}
|
||||||
|
}
|
||||||
return m.VM, m.NextManagerForPluginErr
|
return m.VM, m.NextManagerForPluginErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,20 +44,68 @@ func (m *MockCSIManager) Shutdown() {
|
||||||
var _ VolumeManager = &MockVolumeManager{}
|
var _ VolumeManager = &MockVolumeManager{}
|
||||||
|
|
||||||
type MockVolumeManager struct {
|
type MockVolumeManager struct {
|
||||||
|
CallCounter *testutil.CallCounter
|
||||||
|
|
||||||
|
Mounts map[string]bool // lazy set
|
||||||
|
|
||||||
|
NextMountVolumeErr error
|
||||||
|
NextUnmountVolumeErr error
|
||||||
|
|
||||||
NextExpandVolumeErr error
|
NextExpandVolumeErr error
|
||||||
LastExpandVolumeCall *MockExpandVolumeCall
|
LastExpandVolumeCall *MockExpandVolumeCall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockVolumeManager) mountName(volID, allocID string, usageOpts *UsageOptions) string {
|
||||||
|
return filepath.Join("test-alloc-dir", allocID, volID, usageOpts.ToFS())
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) {
|
func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) {
|
||||||
panic("implement me")
|
if m.CallCounter != nil {
|
||||||
|
m.CallCounter.Inc("MountVolume")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.NextMountVolumeErr != nil {
|
||||||
|
err := m.NextMountVolumeErr
|
||||||
|
m.NextMountVolumeErr = nil // reset it
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// "mount" it
|
||||||
|
if m.Mounts == nil {
|
||||||
|
m.Mounts = make(map[string]bool)
|
||||||
|
}
|
||||||
|
source := m.mountName(vol.ID, alloc.ID, usageOpts)
|
||||||
|
m.Mounts[source] = true
|
||||||
|
|
||||||
|
return &MountInfo{
|
||||||
|
Source: source,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error {
|
func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error {
|
||||||
panic("implement me")
|
if m.CallCounter != nil {
|
||||||
|
m.CallCounter.Inc("UnmountVolume")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.NextUnmountVolumeErr != nil {
|
||||||
|
err := m.NextUnmountVolumeErr
|
||||||
|
m.NextUnmountVolumeErr = nil // reset it
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// "unmount" it
|
||||||
|
delete(m.Mounts, m.mountName(volID, allocID, usageOpts))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) {
|
func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) {
|
||||||
panic("implement me")
|
if m.CallCounter != nil {
|
||||||
|
m.CallCounter.Inc("HasMount")
|
||||||
|
}
|
||||||
|
if m.Mounts == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return m.Mounts[mountInfo.Source], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) {
|
func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) {
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: BUSL-1.1
|
||||||
|
|
||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"maps"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mitchellh/go-testing-interface"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewCallCounter() *CallCounter {
|
||||||
|
return &CallCounter{
|
||||||
|
counts: make(map[string]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallCounter struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
counts map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CallCounter) Inc(name string) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
c.counts[name]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CallCounter) Get() map[string]int {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
return maps.Clone(c.counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CallCounter) AssertCalled(t testing.T, name string) {
|
||||||
|
t.Helper()
|
||||||
|
counts := c.Get()
|
||||||
|
if _, ok := counts[name]; !ok {
|
||||||
|
t.Errorf("'%s' not called; all counts: %v", counts)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue