diff --git a/.changelog/13446.txt b/.changelog/13446.txt new file mode 100644 index 000000000..7c10253f1 --- /dev/null +++ b/.changelog/13446.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where CSI hook validation would fail if all tasks didn't support CSI. +`` diff --git a/client/allocrunner/csi_hook.go b/client/allocrunner/csi_hook.go index 2a2fe2963..0a09b4695 100644 --- a/client/allocrunner/csi_hook.go +++ b/client/allocrunner/csi_hook.go @@ -184,24 +184,29 @@ type volumeAndRequest struct { func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) { result := make(map[string]*volumeAndRequest) tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup) + supportsVolumes := false + + for _, task := range tg.Tasks { + caps, err := c.taskCapabilityGetter.GetTaskDriverCapabilities(task.Name) + if err != nil { + return nil, fmt.Errorf("could not validate task driver capabilities: %v", err) + } + + if caps.MountConfigs == drivers.MountConfigSupportNone { + continue + } + + supportsVolumes = true + break + } + + if !supportsVolumes { + return nil, fmt.Errorf("no task supports CSI") + } // Initially, populate the result map with all of the requests for alias, volumeRequest := range tg.Volumes { - if volumeRequest.Type == structs.VolumeTypeCSI { - - for _, task := range tg.Tasks { - caps, err := c.taskCapabilityGetter.GetTaskDriverCapabilities(task.Name) - if err != nil { - return nil, fmt.Errorf("could not validate task driver capabilities: %v", err) - } - - if caps.MountConfigs == drivers.MountConfigSupportNone { - return nil, fmt.Errorf( - "task driver %q for %q does not support CSI", task.Driver, task.Name) - } - } - result[alias] = &volumeAndRequest{request: volumeRequest} } } diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go index 1d3b04ed3..21d3fc91d 100644 --- a/client/allocrunner/csi_hook_test.go +++ b/client/allocrunner/csi_hook_test.go @@ -232,6 +232,99 @@ func TestCSIHook(t *testing.T) { } +// TestCSIHook_claimVolumesFromAlloc_Validation tests that the validation of task +// capabilities in claimVolumesFromAlloc ensures at least one task supports CSI. +func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { + ci.Parallel(t) + + alloc := mock.Alloc() + logger := testlog.HCLogger(t) + volumeRequests := map[string]*structs.VolumeRequest{ + "vol0": { + Name: "vol0", + Type: structs.VolumeTypeCSI, + Source: "testvolume0", + ReadOnly: true, + AccessMode: structs.CSIVolumeAccessModeSingleNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + MountOptions: &structs.CSIMountOptions{}, + PerAlloc: false, + }, + } + + type testCase struct { + name string + caps *drivers.Capabilities + capFunc func() (*drivers.Capabilities, error) + expectedClaimErr error + } + + testcases := []testCase{ + { + name: "invalid - driver does not support CSI", + caps: &drivers.Capabilities{ + MountConfigs: drivers.MountConfigSupportNone, + }, + capFunc: nil, + expectedClaimErr: errors.New("claim volumes: no task supports CSI"), + }, + + { + name: "invalid - driver error", + caps: &drivers.Capabilities{}, + capFunc: func() (*drivers.Capabilities, error) { + return nil, errors.New("error thrown by driver") + }, + expectedClaimErr: errors.New("claim volumes: could not validate task driver capabilities: error thrown by driver"), + }, + + { + name: "valid - driver supports CSI", + caps: &drivers.Capabilities{ + MountConfigs: drivers.MountConfigSupportAll, + }, + capFunc: nil, + expectedClaimErr: nil, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + alloc.Job.TaskGroups[0].Volumes = volumeRequests + + callCounts := map[string]int{} + mgr := mockPluginManager{mounter: mockVolumeMounter{callCounts: callCounts}} + + rpcer := mockRPCer{ + alloc: alloc, + callCounts: callCounts, + hasExistingClaim: helper.BoolToPtr(false), + schedulable: helper.BoolToPtr(true), + } + + ar := mockAllocRunner{ + res: &cstructs.AllocHookResources{}, + caps: tc.caps, + capFunc: tc.capFunc, + } + + hook := newCSIHook(alloc, logger, mgr, rpcer, ar, ar, "secret") + require.NotNil(t, hook) + + if tc.expectedClaimErr != nil { + require.EqualError(t, hook.Prerun(), tc.expectedClaimErr.Error()) + mounts := ar.GetAllocHookResources().GetCSIMounts() + require.Nil(t, mounts) + } else { + require.NoError(t, hook.Prerun()) + mounts := ar.GetAllocHookResources().GetCSIMounts() + require.NotNil(t, mounts) + require.NoError(t, hook.Postrun()) + } + }) + } +} + // HELPERS AND MOCKS type mockRPCer struct { @@ -333,8 +426,9 @@ func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { retur func (mgr mockPluginManager) Shutdown() {} type mockAllocRunner struct { - res *cstructs.AllocHookResources - caps *drivers.Capabilities + res *cstructs.AllocHookResources + caps *drivers.Capabilities + capFunc func() (*drivers.Capabilities, error) } func (ar mockAllocRunner) GetAllocHookResources() *cstructs.AllocHookResources { @@ -346,5 +440,8 @@ func (ar mockAllocRunner) SetAllocHookResources(res *cstructs.AllocHookResources } func (ar mockAllocRunner) GetTaskDriverCapabilities(taskName string) (*drivers.Capabilities, error) { + if ar.capFunc != nil { + return ar.capFunc() + } return ar.caps, nil }