backport of commit 7bd5c6e84eef890cebdb404d9cb2e281919d4529 (#18555)

Co-authored-by: Daniel Bennett <dbennett@hashicorp.com>
This commit is contained in:
hc-github-team-nomad-core 2023-09-21 17:16:14 -05:00 committed by GitHub
parent a2f56797a0
commit a6ecf954b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 131 deletions

View File

@ -4,29 +4,24 @@
package allocrunner package allocrunner
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"path/filepath"
"sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/allocrunner/state" "github.com/hashicorp/nomad/client/allocrunner/state"
"github.com/hashicorp/nomad/client/pluginmanager"
"github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/client/pluginmanager/csimanager"
cstructs "github.com/hashicorp/nomad/client/structs" cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/csi"
"github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test/must" "github.com/shoenig/test/must"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
) )
var _ interfaces.RunnerPrerunHook = (*csiHook)(nil) var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
@ -71,7 +66,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{ expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
}, },
{ {
@ -92,7 +87,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{ expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
}, },
{ {
@ -137,7 +132,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{ expectedCalls: map[string]int{
"claim": 2, "mount": 1, "unmount": 1, "unpublish": 1}, "claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
}, },
{ {
name: "already mounted", name: "already mounted",
@ -163,7 +158,7 @@ func TestCSIHook(t *testing.T) {
expectedMounts: map[string]*csimanager.MountInfo{ expectedMounts: map[string]*csimanager.MountInfo{
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1}, expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1},
}, },
{ {
name: "existing but invalid mounts", name: "existing but invalid mounts",
@ -190,7 +185,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{ expectedCalls: map[string]int{
"hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, "HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
}, },
{ {
@ -212,7 +207,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc}, "vol0": &csimanager.MountInfo{Source: testMountSrc},
}, },
expectedCalls: map[string]int{ expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 2, "unpublish": 2}, "claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2},
}, },
{ {
@ -227,12 +222,11 @@ func TestCSIHook(t *testing.T) {
alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests
callCounts := &callCounter{counts: map[string]int{}} callCounts := testutil.NewCallCounter()
mgr := mockPluginManager{mounter: mockVolumeManager{ vm := &csimanager.MockVolumeManager{
hasMounts: tc.startsWithValidMounts, CallCounter: callCounts,
callCounts: callCounts, }
failsFirstUnmount: pointer.Of(tc.failsFirstUnmount), mgr := &csimanager.MockCSIManager{VM: vm}
}}
rpcer := mockRPCer{ rpcer := mockRPCer{
alloc: alloc, alloc: alloc,
callCounts: callCounts, callCounts: callCounts,
@ -255,6 +249,17 @@ func TestCSIHook(t *testing.T) {
must.NotNil(t, hook) must.NotNil(t, hook)
if tc.startsWithValidMounts {
// TODO: this works, but it requires knowledge of how the mock works. would rather vm.MountVolume()
vm.Mounts = map[string]bool{
tc.expectedMounts["vol0"].Source: true,
}
}
if tc.failsFirstUnmount {
vm.NextUnmountVolumeErr = errors.New("bad first attempt")
}
if tc.expectedClaimErr != nil { if tc.expectedClaimErr != nil {
must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error()) must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
mounts := ar.res.GetCSIMounts() mounts := ar.res.GetCSIMounts()
@ -274,7 +279,7 @@ func TestCSIHook(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
counts := callCounts.get() counts := callCounts.Get()
must.MapEq(t, tc.expectedCalls, counts, must.MapEq(t, tc.expectedCalls, counts,
must.Sprintf("got calls: %v", counts)) must.Sprintf("got calls: %v", counts))
@ -342,14 +347,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
alloc.Job.TaskGroups[0].Volumes = volumeRequests alloc.Job.TaskGroups[0].Volumes = volumeRequests
callCounts := &callCounter{counts: map[string]int{}} mgr := &csimanager.MockCSIManager{
mgr := mockPluginManager{mounter: mockVolumeManager{ VM: &csimanager.MockVolumeManager{},
callCounts: callCounts, }
failsFirstUnmount: pointer.Of(false),
}}
rpcer := mockRPCer{ rpcer := mockRPCer{
alloc: alloc, alloc: alloc,
callCounts: callCounts, callCounts: testutil.NewCallCounter(),
hasExistingClaim: pointer.Of(false), hasExistingClaim: pointer.Of(false),
schedulable: pointer.Of(true), schedulable: pointer.Of(true),
} }
@ -379,26 +382,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
// HELPERS AND MOCKS // HELPERS AND MOCKS
type callCounter struct {
lock sync.Mutex
counts map[string]int
}
func (c *callCounter) inc(name string) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[name]++
}
func (c *callCounter) get() map[string]int {
c.lock.Lock()
defer c.lock.Unlock()
return maps.Clone(c.counts)
}
type mockRPCer struct { type mockRPCer struct {
alloc *structs.Allocation alloc *structs.Allocation
callCounts *callCounter callCounts *testutil.CallCounter
hasExistingClaim *bool hasExistingClaim *bool
schedulable *bool schedulable *bool
} }
@ -407,7 +393,7 @@ type mockRPCer struct {
func (r mockRPCer) RPC(method string, args any, reply any) error { func (r mockRPCer) RPC(method string, args any, reply any) error {
switch method { switch method {
case "CSIVolume.Claim": case "CSIVolume.Claim":
r.callCounts.inc("claim") r.callCounts.Inc("claim")
req := args.(*structs.CSIVolumeClaimRequest) req := args.(*structs.CSIVolumeClaimRequest)
vol := r.testVolume(req.VolumeID) vol := r.testVolume(req.VolumeID)
err := vol.Claim(req.ToClaim(), r.alloc) err := vol.Claim(req.ToClaim(), r.alloc)
@ -427,7 +413,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error {
resp.QueryMeta = structs.QueryMeta{} resp.QueryMeta = structs.QueryMeta{}
case "CSIVolume.Unpublish": case "CSIVolume.Unpublish":
r.callCounts.inc("unpublish") r.callCounts.Inc("unpublish")
resp := reply.(*structs.CSIVolumeUnpublishResponse) resp := reply.(*structs.CSIVolumeUnpublishResponse)
resp.QueryMeta = structs.QueryMeta{} resp.QueryMeta = structs.QueryMeta{}
@ -470,59 +456,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume {
return vol return vol
} }
type mockVolumeManager struct {
hasMounts bool
failsFirstUnmount *bool
callCounts *callCounter
}
func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) {
vm.callCounts.inc("mount")
return &csimanager.MountInfo{
Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()),
}, nil
}
func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error {
vm.callCounts.inc("unmount")
if *vm.failsFirstUnmount {
*vm.failsFirstUnmount = false
return fmt.Errorf("could not unmount")
}
return nil
}
func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) {
vm.callCounts.inc("hasMount")
return mountInfo != nil && vm.hasMounts, nil
}
func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) {
return 0, nil
}
func (vm mockVolumeManager) ExternalID() string {
return "i-example"
}
type mockPluginManager struct {
mounter mockVolumeManager
}
func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error {
return nil
}
func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) {
return mgr.mounter, nil
}
// no-op methods to fulfill the interface
func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil }
func (mgr mockPluginManager) Shutdown() {}
type mockAllocRunner struct { type mockAllocRunner struct {
res *cstructs.AllocHookResources res *cstructs.AllocHookResources
caps *drivers.Capabilities caps *drivers.Capabilities

View File

@ -995,23 +995,21 @@ func TestCSINode_DetachVolume(t *testing.T) {
cases := []struct { cases := []struct {
Name string Name string
ClientSetupFunc func(*fake.Client) ModManager func(m *csimanager.MockCSIManager)
Request *structs.ClientCSINodeDetachVolumeRequest Request *structs.ClientCSINodeDetachVolumeRequest
ExpectedErr error ExpectedErr error
ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse
}{ }{
{ {
Name: "returns plugin not found errors", Name: "success",
Request: &structs.ClientCSINodeDetachVolumeRequest{ Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: "some-garbage", PluginID: "fake-plugin",
VolumeID: "-", VolumeID: "fake-vol",
AllocID: "-", AllocID: "fake-alloc",
NodeID: "-", NodeID: "fake-node",
AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem, AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader, AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader,
ReadOnly: true, ReadOnly: true,
}, },
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"),
}, },
{ {
Name: "validates volumeid is not empty", Name: "validates volumeid is not empty",
@ -1029,43 +1027,51 @@ func TestCSINode_DetachVolume(t *testing.T) {
ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"), ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
}, },
{ {
Name: "returns transitive errors", Name: "returns csi manager errors",
ClientSetupFunc: func(fc *fake.Client) { ModManager: func(m *csimanager.MockCSIManager) {
fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this") m.NextManagerForPluginErr = errors.New("no plugin")
}, },
Request: &structs.ClientCSINodeDetachVolumeRequest{ Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name, PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321", VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234", AllocID: "4321-1234-4321-1234",
}, },
// we don't have a csimanager in this context ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"),
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"), },
{
Name: "returns volume manager errors",
ModManager: func(m *csimanager.MockCSIManager) {
m.VM.NextUnmountVolumeErr = errors.New("error unmounting")
},
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234",
},
ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"),
}, },
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
client, cleanup := TestClient(t, nil) client, cleanup := TestClient(t, nil)
defer cleanup() defer cleanup()
fakeClient := &fake.Client{} mockManager := &csimanager.MockCSIManager{
if tc.ClientSetupFunc != nil { VM: &csimanager.MockVolumeManager{},
tc.ClientSetupFunc(fakeClient)
} }
if tc.ModManager != nil {
dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) { tc.ModManager(mockManager)
return fakeClient, nil
} }
client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc) client.csimanager = mockManager
err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin)
require.Nil(err)
var resp structs.ClientCSINodeDetachVolumeResponse var resp structs.ClientCSINodeDetachVolumeResponse
err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp) err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
require.Equal(tc.ExpectedErr, err) if tc.ExpectedErr != nil {
if tc.ExpectedResponse != nil { must.Error(t, err)
require.Equal(tc.ExpectedResponse, &resp) must.EqError(t, tc.ExpectedErr, err.Error())
} else {
must.NoError(t, err)
} }
}) })
} }

View File

@ -5,10 +5,12 @@ package csimanager
import ( import (
"context" "context"
"path/filepath"
"github.com/hashicorp/nomad/client/pluginmanager" "github.com/hashicorp/nomad/client/pluginmanager"
nstructs "github.com/hashicorp/nomad/nomad/structs" nstructs "github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/csi" "github.com/hashicorp/nomad/plugins/csi"
"github.com/hashicorp/nomad/testutil"
) )
var _ Manager = &MockCSIManager{} var _ Manager = &MockCSIManager{}
@ -29,6 +31,9 @@ func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID s
} }
func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) { func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) {
if m.VM == nil {
m.VM = &MockVolumeManager{}
}
return m.VM, m.NextManagerForPluginErr return m.VM, m.NextManagerForPluginErr
} }
@ -39,20 +44,68 @@ func (m *MockCSIManager) Shutdown() {
var _ VolumeManager = &MockVolumeManager{} var _ VolumeManager = &MockVolumeManager{}
type MockVolumeManager struct { type MockVolumeManager struct {
CallCounter *testutil.CallCounter
Mounts map[string]bool // lazy set
NextMountVolumeErr error
NextUnmountVolumeErr error
NextExpandVolumeErr error NextExpandVolumeErr error
LastExpandVolumeCall *MockExpandVolumeCall LastExpandVolumeCall *MockExpandVolumeCall
} }
func (m *MockVolumeManager) mountName(volID, allocID string, usageOpts *UsageOptions) string {
return filepath.Join("test-alloc-dir", allocID, volID, usageOpts.ToFS())
}
func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) { func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) {
panic("implement me") if m.CallCounter != nil {
m.CallCounter.Inc("MountVolume")
}
if m.NextMountVolumeErr != nil {
err := m.NextMountVolumeErr
m.NextMountVolumeErr = nil // reset it
return nil, err
}
// "mount" it
if m.Mounts == nil {
m.Mounts = make(map[string]bool)
}
source := m.mountName(vol.ID, alloc.ID, usageOpts)
m.Mounts[source] = true
return &MountInfo{
Source: source,
}, nil
} }
func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error { func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error {
panic("implement me") if m.CallCounter != nil {
m.CallCounter.Inc("UnmountVolume")
}
if m.NextUnmountVolumeErr != nil {
err := m.NextUnmountVolumeErr
m.NextUnmountVolumeErr = nil // reset it
return err
}
// "unmount" it
delete(m.Mounts, m.mountName(volID, allocID, usageOpts))
return nil
} }
func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) {
panic("implement me") if m.CallCounter != nil {
m.CallCounter.Inc("HasMount")
}
if m.Mounts == nil {
return false, nil
}
return m.Mounts[mountInfo.Source], nil
} }
func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) { func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) {

42
testutil/mock_calls.go Normal file
View File

@ -0,0 +1,42 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package testutil
import (
"maps"
"sync"
"github.com/mitchellh/go-testing-interface"
)
func NewCallCounter() *CallCounter {
return &CallCounter{
counts: make(map[string]int),
}
}
type CallCounter struct {
lock sync.Mutex
counts map[string]int
}
func (c *CallCounter) Inc(name string) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[name]++
}
func (c *CallCounter) Get() map[string]int {
c.lock.Lock()
defer c.lock.Unlock()
return maps.Clone(c.counts)
}
func (c *CallCounter) AssertCalled(t testing.T, name string) {
t.Helper()
counts := c.Get()
if _, ok := counts[name]; !ok {
t.Errorf("'%s' not called; all counts: %v", counts)
}
}