diff --git a/.changelog/17840.txt b/.changelog/17840.txt new file mode 100644 index 000000000..fde0cdbd8 --- /dev/null +++ b/.changelog/17840.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where CSI volumes would fail to restore during client restarts +``` diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index d81cedd68..84c1466dd 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -1460,3 +1460,20 @@ func (ar *allocRunner) GetUpdatePriority(a *structs.Allocation) cstructs.AllocUp return cstructs.AllocUpdatePriorityNone } + +func (ar *allocRunner) SetCSIVolumes(vols map[string]*state.CSIVolumeStub) error { + return ar.stateDB.PutAllocVolumes(ar.id, &state.AllocVolumes{ + CSIVolumes: vols, + }) +} + +func (ar *allocRunner) GetCSIVolumes() (map[string]*state.CSIVolumeStub, error) { + allocVols, err := ar.stateDB.GetAllocVolumes(ar.id) + if err != nil { + return nil, err + } + if allocVols == nil { + return nil, nil + } + return allocVols.CSIVolumes, nil +} diff --git a/client/allocrunner/csi_hook.go b/client/allocrunner/csi_hook.go index 30e63ccc2..0c41f9815 100644 --- a/client/allocrunner/csi_hook.go +++ b/client/allocrunner/csi_hook.go @@ -12,6 +12,7 @@ import ( hclog "github.com/hashicorp/go-hclog" multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/nomad/client/allocrunner/state" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/dynamicplugins" "github.com/hashicorp/nomad/client/pluginmanager/csimanager" @@ -31,43 +32,47 @@ type csiHook struct { csimanager csimanager.Manager // interfaces implemented by the allocRunner - rpcClient config.RPCer - taskCapabilityGetter taskCapabilityGetter - hookResources *cstructs.AllocHookResources + rpcClient config.RPCer + allocRunnerShim allocRunnerShim + hookResources *cstructs.AllocHookResources nodeSecret string - volumeRequests map[string]*volumeAndRequest minBackoffInterval time.Duration maxBackoffInterval time.Duration maxBackoffDuration time.Duration + volumeResultsLock sync.Mutex + volumeResults map[string]*volumePublishResult // alias -> volumePublishResult + shutdownCtx context.Context shutdownCancelFn context.CancelFunc } // implemented by allocrunner -type taskCapabilityGetter interface { +type allocRunnerShim interface { GetTaskDriverCapabilities(string) (*drivers.Capabilities, error) + SetCSIVolumes(vols map[string]*state.CSIVolumeStub) error + GetCSIVolumes() (map[string]*state.CSIVolumeStub, error) } -func newCSIHook(alloc *structs.Allocation, logger hclog.Logger, csi csimanager.Manager, rpcClient config.RPCer, taskCapabilityGetter taskCapabilityGetter, hookResources *cstructs.AllocHookResources, nodeSecret string) *csiHook { +func newCSIHook(alloc *structs.Allocation, logger hclog.Logger, csi csimanager.Manager, rpcClient config.RPCer, arShim allocRunnerShim, hookResources *cstructs.AllocHookResources, nodeSecret string) *csiHook { shutdownCtx, shutdownCancelFn := context.WithCancel(context.Background()) return &csiHook{ - alloc: alloc, - logger: logger.Named("csi_hook"), - csimanager: csi, - rpcClient: rpcClient, - taskCapabilityGetter: taskCapabilityGetter, - hookResources: hookResources, - nodeSecret: nodeSecret, - volumeRequests: map[string]*volumeAndRequest{}, - minBackoffInterval: time.Second, - maxBackoffInterval: time.Minute, - maxBackoffDuration: time.Hour * 24, - shutdownCtx: shutdownCtx, - shutdownCancelFn: shutdownCancelFn, + alloc: alloc, + logger: logger.Named("csi_hook"), + csimanager: csi, + rpcClient: rpcClient, + allocRunnerShim: arShim, + hookResources: hookResources, + nodeSecret: nodeSecret, + volumeResults: map[string]*volumePublishResult{}, + minBackoffInterval: time.Second, + maxBackoffInterval: time.Minute, + maxBackoffDuration: time.Hour * 24, + shutdownCtx: shutdownCtx, + shutdownCancelFn: shutdownCancelFn, } } @@ -80,47 +85,61 @@ func (c *csiHook) Prerun() error { return nil } - volumes, err := c.claimVolumesFromAlloc() - if err != nil { - return fmt.Errorf("claim volumes: %v", err) + tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup) + if err := c.validateTasksSupportCSI(tg); err != nil { + return err } - c.volumeRequests = volumes - mounts := make(map[string]*csimanager.MountInfo, len(volumes)) - for alias, pair := range volumes { + // Because operations on CSI volumes are expensive and can error, we do each + // step for all volumes before proceeding to the next step so we have to + // unwind less work. In practice, most allocations with volumes will only + // have one or a few at most. We lock the results so that if an update/stop + // comes in while we're running we can assert we'll safely tear down + // everything that's been done so far. - // make sure the plugin is ready or becomes so quickly. - plugin := pair.volume.PluginID - pType := dynamicplugins.PluginTypeCSINode - if err := c.csimanager.WaitForPlugin(c.shutdownCtx, pType, plugin); err != nil { - return err + c.volumeResultsLock.Lock() + defer c.volumeResultsLock.Unlock() + + // Initially, populate the result map with all of the requests + for alias, volumeRequest := range tg.Volumes { + if volumeRequest.Type == structs.VolumeTypeCSI { + c.volumeResults[alias] = &volumePublishResult{ + request: volumeRequest, + stub: &state.CSIVolumeStub{ + VolumeID: volumeRequest.VolumeID(c.alloc.Name)}, + } } - c.logger.Debug("found CSI plugin", "type", pType, "name", plugin) + } - mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, plugin) - if err != nil { - return err - } + err := c.restoreMounts(c.volumeResults) + if err != nil { + return fmt.Errorf("restoring mounts: %w", err) + } - usageOpts := &csimanager.UsageOptions{ - ReadOnly: pair.request.ReadOnly, - AttachmentMode: pair.request.AttachmentMode, - AccessMode: pair.request.AccessMode, - MountOptions: pair.request.MountOptions, - } + err = c.claimVolumes(c.volumeResults) + if err != nil { + return fmt.Errorf("claiming volumes: %w", err) + } - mountInfo, err := mounter.MountVolume( - c.shutdownCtx, pair.volume, c.alloc, usageOpts, pair.publishContext) - if err != nil { - return err - } - - mounts[alias] = mountInfo + err = c.mountVolumes(c.volumeResults) + if err != nil { + return fmt.Errorf("mounting volumes: %w", err) } // make the mounts available to the taskrunner's volume_hook + mounts := helper.ConvertMap(c.volumeResults, + func(result *volumePublishResult) *csimanager.MountInfo { + return result.stub.MountInfo + }) c.hookResources.SetCSIMounts(mounts) + // persist the published mount info so we can restore on client restarts + stubs := helper.ConvertMap(c.volumeResults, + func(result *volumePublishResult) *state.CSIVolumeStub { + return result.stub + }) + c.allocRunnerShim.SetCSIVolumes(stubs) + return nil } @@ -133,39 +152,42 @@ func (c *csiHook) Postrun() error { return nil } - var wg sync.WaitGroup - errs := make(chan error, len(c.volumeRequests)) + c.volumeResultsLock.Lock() + defer c.volumeResultsLock.Unlock() - for _, pair := range c.volumeRequests { + var wg sync.WaitGroup + errs := make(chan error, len(c.volumeResults)) + + for _, result := range c.volumeResults { wg.Add(1) // CSI RPCs can potentially take a long time. Split the work // into goroutines so that operators could potentially reuse // one of a set of volumes - go func(pair *volumeAndRequest) { + go func(result *volumePublishResult) { defer wg.Done() - err := c.unmountImpl(pair) + err := c.unmountImpl(result) if err != nil { // we can recover an unmount failure if the operator // brings the plugin back up, so retry every few minutes // but eventually give up. Don't block shutdown so that // we don't block shutting down the client in -dev mode - go func(pair *volumeAndRequest) { - err := c.unmountWithRetry(pair) + go func(result *volumePublishResult) { + err := c.unmountWithRetry(result) if err != nil { c.logger.Error("volume could not be unmounted") } - err = c.unpublish(pair) + err = c.unpublish(result) if err != nil { c.logger.Error("volume could not be unpublished") } - }(pair) + }(result) } // we can't recover from this RPC error client-side; the // volume claim GC job will have to clean up for us once // the allocation is marked terminal - errs <- c.unpublish(pair) - }(pair) + errs <- c.unpublish(result) + }(result) } wg.Wait() @@ -179,67 +201,109 @@ func (c *csiHook) Postrun() error { return mErr.ErrorOrNil() } -type volumeAndRequest struct { - volume *structs.CSIVolume - request *structs.VolumeRequest - - // When volumeAndRequest was returned from a volume claim, this field will be - // populated for plugins that require it. - publishContext map[string]string +type volumePublishResult struct { + request *structs.VolumeRequest // the request from the jobspec + volume *structs.CSIVolume // the volume we get back from the server + publishContext map[string]string // populated after claim if provided by plugin + stub *state.CSIVolumeStub // populated from volume, plugin, or stub } -// claimVolumesFromAlloc is used by the pre-run hook to fetch all of the volume -// metadata and claim it for use by this alloc/node at the same time. -func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) { - result := make(map[string]*volumeAndRequest) - tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup) - supportsVolumes := false +// validateTasksSupportCSI verifies that at least one task in the group uses a +// task driver that supports CSI. This prevents us from publishing CSI volumes +// only to find out once we get to the taskrunner/volume_hook that no task can +// mount them. +func (c *csiHook) validateTasksSupportCSI(tg *structs.TaskGroup) error { for _, task := range tg.Tasks { - caps, err := c.taskCapabilityGetter.GetTaskDriverCapabilities(task.Name) + caps, err := c.allocRunnerShim.GetTaskDriverCapabilities(task.Name) if err != nil { - return nil, fmt.Errorf("could not validate task driver capabilities: %v", err) + return fmt.Errorf("could not validate task driver capabilities: %v", err) } if caps.MountConfigs == drivers.MountConfigSupportNone { continue } - supportsVolumes = true - break + return nil } - if !supportsVolumes { - return nil, fmt.Errorf("no task supports CSI") - } + return fmt.Errorf("no task supports CSI") +} - // Initially, populate the result map with all of the requests - for alias, volumeRequest := range tg.Volumes { - if volumeRequest.Type == structs.VolumeTypeCSI { - result[alias] = &volumeAndRequest{request: volumeRequest} +// restoreMounts tries to restore the mount info from the local client state and +// then verifies it with the plugin. If the volume is already mounted, we don't +// want to re-run the claim and mount workflow again. This lets us tolerate +// restarting clients even on disconnected nodes. +func (c *csiHook) restoreMounts(results map[string]*volumePublishResult) error { + stubs, err := c.allocRunnerShim.GetCSIVolumes() + if err != nil { + return err + } + if stubs == nil { + return nil // no previous volumes + } + for _, result := range results { + stub := stubs[result.request.Name] + if stub == nil { + continue + } + + result.stub = stub + + if result.stub.MountInfo != nil && result.stub.PluginID != "" { + + // make sure the plugin is ready or becomes so quickly. + plugin := result.stub.PluginID + pType := dynamicplugins.PluginTypeCSINode + if err := c.csimanager.WaitForPlugin(c.shutdownCtx, pType, plugin); err != nil { + return err + } + c.logger.Debug("found CSI plugin", "type", pType, "name", plugin) + + mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, plugin) + if err != nil { + return err + } + + isMounted, err := mounter.HasMount(c.shutdownCtx, result.stub.MountInfo) + if err != nil { + return err + } + if !isMounted { + // the mount is gone, so clear this from our result state so it + // we can try to remount it with the plugin ID we have + result.stub.MountInfo = nil + } } } - // Iterate over the result map and upsert the volume field as each volume gets - // claimed by the server. - for alias, pair := range result { + return nil +} + +// claimVolumes sends a claim to the server for each volume to mark it in use +// and kick off the controller publish workflow (optionally) +func (c *csiHook) claimVolumes(results map[string]*volumePublishResult) error { + + for _, result := range results { + if result.stub.MountInfo != nil { + continue // already mounted + } + + request := result.request + claimType := structs.CSIVolumeClaimWrite - if pair.request.ReadOnly { + if request.ReadOnly { claimType = structs.CSIVolumeClaimRead } - source := pair.request.Source - if pair.request.PerAlloc { - source = source + structs.AllocSuffix(c.alloc.Name) - } - req := &structs.CSIVolumeClaimRequest{ - VolumeID: source, + VolumeID: result.stub.VolumeID, AllocationID: c.alloc.ID, NodeID: c.alloc.NodeID, + ExternalNodeID: result.stub.ExternalNodeID, Claim: claimType, - AccessMode: pair.request.AccessMode, - AttachmentMode: pair.request.AttachmentMode, + AccessMode: request.AccessMode, + AttachmentMode: request.AttachmentMode, WriteRequest: structs.WriteRequest{ Region: c.alloc.Job.Region, Namespace: c.alloc.Job.Namespace, @@ -249,18 +313,64 @@ func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) resp, err := c.claimWithRetry(req) if err != nil { - return nil, fmt.Errorf("could not claim volume %s: %w", req.VolumeID, err) + return fmt.Errorf("could not claim volume %s: %w", req.VolumeID, err) } if resp.Volume == nil { - return nil, fmt.Errorf("Unexpected nil volume returned for ID: %v", pair.request.Source) + return fmt.Errorf("Unexpected nil volume returned for ID: %v", request.Source) } - result[alias].request = pair.request - result[alias].volume = resp.Volume - result[alias].publishContext = resp.PublishContext + result.volume = resp.Volume + + // populate data we'll write later to disk + result.stub.VolumeID = resp.Volume.ID + result.stub.VolumeExternalID = resp.Volume.RemoteID() + result.stub.PluginID = resp.Volume.PluginID + result.publishContext = resp.PublishContext } - return result, nil + return nil +} + +func (c *csiHook) mountVolumes(results map[string]*volumePublishResult) error { + + for _, result := range results { + if result.stub.MountInfo != nil { + continue // already mounted + } + if result.volume == nil { + return fmt.Errorf("volume not available from claim for mounting volume request %q", + result.request.Name) // should be unreachable + } + + // make sure the plugin is ready or becomes so quickly. + plugin := result.volume.PluginID + pType := dynamicplugins.PluginTypeCSINode + if err := c.csimanager.WaitForPlugin(c.shutdownCtx, pType, plugin); err != nil { + return err + } + c.logger.Debug("found CSI plugin", "type", pType, "name", plugin) + + mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, plugin) + if err != nil { + return err + } + + usageOpts := &csimanager.UsageOptions{ + ReadOnly: result.request.ReadOnly, + AttachmentMode: result.request.AttachmentMode, + AccessMode: result.request.AccessMode, + MountOptions: result.request.MountOptions, + } + + mountInfo, err := mounter.MountVolume( + c.shutdownCtx, result.volume, c.alloc, usageOpts, result.publishContext) + if err != nil { + return err + } + result.stub.MountInfo = mountInfo + } + + return nil } // claimWithRetry tries to claim the volume on the server, retrying @@ -337,15 +447,15 @@ func (c *csiHook) shouldRun() bool { return false } -func (c *csiHook) unpublish(pair *volumeAndRequest) error { +func (c *csiHook) unpublish(result *volumePublishResult) error { mode := structs.CSIVolumeClaimRead - if !pair.request.ReadOnly { + if !result.request.ReadOnly { mode = structs.CSIVolumeClaimWrite } - source := pair.request.Source - if pair.request.PerAlloc { + source := result.request.Source + if result.request.PerAlloc { // NOTE: PerAlloc can't be set if we have canaries source = source + structs.AllocSuffix(c.alloc.Name) } @@ -372,7 +482,7 @@ func (c *csiHook) unpublish(pair *volumeAndRequest) error { // unmountWithRetry tries to unmount/unstage the volume, retrying with // exponential backoff capped to a maximum interval -func (c *csiHook) unmountWithRetry(pair *volumeAndRequest) error { +func (c *csiHook) unmountWithRetry(result *volumePublishResult) error { ctx, cancel := context.WithTimeout(c.shutdownCtx, c.maxBackoffDuration) defer cancel() @@ -387,7 +497,7 @@ func (c *csiHook) unmountWithRetry(pair *volumeAndRequest) error { case <-t.C: } - err = c.unmountImpl(pair) + err = c.unmountImpl(result) if err == nil { break } @@ -407,22 +517,22 @@ func (c *csiHook) unmountWithRetry(pair *volumeAndRequest) error { // unmountImpl implements the call to the CSI plugin manager to // unmount the volume. Each retry will write an "Unmount volume" // NodeEvent -func (c *csiHook) unmountImpl(pair *volumeAndRequest) error { +func (c *csiHook) unmountImpl(result *volumePublishResult) error { - mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, pair.volume.PluginID) + mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, result.stub.PluginID) if err != nil { return err } usageOpts := &csimanager.UsageOptions{ - ReadOnly: pair.request.ReadOnly, - AttachmentMode: pair.request.AttachmentMode, - AccessMode: pair.request.AccessMode, - MountOptions: pair.request.MountOptions, + ReadOnly: result.request.ReadOnly, + AttachmentMode: result.request.AttachmentMode, + AccessMode: result.request.AccessMode, + MountOptions: result.request.MountOptions, } return mounter.UnmountVolume(c.shutdownCtx, - pair.volume.ID, pair.volume.RemoteID(), c.alloc.ID, usageOpts) + result.stub.VolumeID, result.stub.VolumeExternalID, c.alloc.ID, usageOpts) } // Shutdown will get called when the client is gracefully diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go index 32e0c0d06..717b33731 100644 --- a/client/allocrunner/csi_hook_test.go +++ b/client/allocrunner/csi_hook_test.go @@ -8,11 +8,13 @@ import ( "errors" "fmt" "path/filepath" + "sync" "testing" "time" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/allocrunner/state" "github.com/hashicorp/nomad/client/pluginmanager" "github.com/hashicorp/nomad/client/pluginmanager/csimanager" cstructs "github.com/hashicorp/nomad/client/structs" @@ -21,7 +23,9 @@ import ( "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) var _ interfaces.RunnerPrerunHook = (*csiHook)(nil) @@ -31,19 +35,21 @@ func TestCSIHook(t *testing.T) { ci.Parallel(t) alloc := mock.Alloc() + testMountSrc := fmt.Sprintf( + "test-alloc-dir/%s/testvolume0/ro-file-system-single-node-reader-only", alloc.ID) logger := testlog.HCLogger(t) testcases := []struct { - name string - volumeRequests map[string]*structs.VolumeRequest - startsUnschedulable bool - startsWithClaims bool - expectedClaimErr error - expectedMounts map[string]*csimanager.MountInfo - expectedMountCalls int - expectedUnmountCalls int - expectedClaimCalls int - expectedUnpublishCalls int + name string + volumeRequests map[string]*structs.VolumeRequest + startsUnschedulable bool + startsWithClaims bool + startsWithStubs map[string]*state.CSIVolumeStub + startsWithValidMounts bool + failsFirstUnmount bool + expectedClaimErr error + expectedMounts map[string]*csimanager.MountInfo + expectedCalls map[string]int }{ { @@ -61,13 +67,10 @@ func TestCSIHook(t *testing.T) { }, }, expectedMounts: map[string]*csimanager.MountInfo{ - "vol0": &csimanager.MountInfo{Source: fmt.Sprintf( - "test-alloc-dir/%s/testvolume0/ro-file-system-single-node-reader-only", alloc.ID)}, + "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, - expectedMountCalls: 1, - expectedUnmountCalls: 1, - expectedClaimCalls: 1, - expectedUnpublishCalls: 1, + expectedCalls: map[string]int{ + "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, }, { @@ -85,13 +88,10 @@ func TestCSIHook(t *testing.T) { }, }, expectedMounts: map[string]*csimanager.MountInfo{ - "vol0": &csimanager.MountInfo{Source: fmt.Sprintf( - "test-alloc-dir/%s/testvolume0/ro-file-system-single-node-reader-only", alloc.ID)}, + "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, - expectedMountCalls: 1, - expectedUnmountCalls: 1, - expectedClaimCalls: 1, - expectedUnpublishCalls: 1, + expectedCalls: map[string]int{ + "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, }, { @@ -110,15 +110,11 @@ func TestCSIHook(t *testing.T) { }, startsUnschedulable: true, expectedMounts: map[string]*csimanager.MountInfo{ - "vol0": &csimanager.MountInfo{Source: fmt.Sprintf( - "test-alloc-dir/%s/testvolume0/ro-file-system-single-node-reader-only", alloc.ID)}, + "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, - expectedMountCalls: 0, - expectedUnmountCalls: 0, - expectedClaimCalls: 1, - expectedUnpublishCalls: 0, + expectedCalls: map[string]int{"claim": 1}, expectedClaimErr: errors.New( - "claim volumes: could not claim volume testvolume0: volume is currently unschedulable"), + "claiming volumes: could not claim volume testvolume0: volume is currently unschedulable"), }, { @@ -137,62 +133,105 @@ func TestCSIHook(t *testing.T) { }, startsWithClaims: true, expectedMounts: map[string]*csimanager.MountInfo{ - "vol0": &csimanager.MountInfo{Source: fmt.Sprintf( - "test-alloc-dir/%s/testvolume0/ro-file-system-single-node-reader-only", alloc.ID)}, + "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, - expectedMountCalls: 1, - expectedUnmountCalls: 1, - expectedClaimCalls: 2, - expectedUnpublishCalls: 1, + expectedCalls: map[string]int{ + "claim": 2, "mount": 1, "unmount": 1, "unpublish": 1}, + }, + { + name: "already mounted", + volumeRequests: map[string]*structs.VolumeRequest{ + "vol0": { + Name: "vol0", + Type: structs.VolumeTypeCSI, + Source: "testvolume0", + ReadOnly: true, + AccessMode: structs.CSIVolumeAccessModeSingleNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + MountOptions: &structs.CSIMountOptions{}, + PerAlloc: false, + }, + }, + startsWithStubs: map[string]*state.CSIVolumeStub{"vol0": { + VolumeID: "vol0", + PluginID: "vol0-plugin", + ExternalNodeID: "i-example", + MountInfo: &csimanager.MountInfo{Source: testMountSrc}, + }}, + startsWithValidMounts: true, + expectedMounts: map[string]*csimanager.MountInfo{ + "vol0": &csimanager.MountInfo{Source: testMountSrc}, + }, + expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1}, + }, + { + name: "existing but invalid mounts", + volumeRequests: map[string]*structs.VolumeRequest{ + "vol0": { + Name: "vol0", + Type: structs.VolumeTypeCSI, + Source: "testvolume0", + ReadOnly: true, + AccessMode: structs.CSIVolumeAccessModeSingleNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + MountOptions: &structs.CSIMountOptions{}, + PerAlloc: false, + }, + }, + startsWithStubs: map[string]*state.CSIVolumeStub{"vol0": { + VolumeID: "testvolume0", + PluginID: "vol0-plugin", + ExternalNodeID: "i-example", + MountInfo: &csimanager.MountInfo{Source: testMountSrc}, + }}, + startsWithValidMounts: false, + expectedMounts: map[string]*csimanager.MountInfo{ + "vol0": &csimanager.MountInfo{Source: testMountSrc}, + }, + expectedCalls: map[string]int{ + "hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, }, - // TODO: this won't actually work on the client. - // https://github.com/hashicorp/nomad/issues/11798 - // - // { - // name: "one source volume mounted read-only twice", - // volumeRequests: map[string]*structs.VolumeRequest{ - // "vol0": { - // Name: "vol0", - // Type: structs.VolumeTypeCSI, - // Source: "testvolume0", - // ReadOnly: true, - // AccessMode: structs.CSIVolumeAccessModeMultiNodeReader, - // AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, - // MountOptions: &structs.CSIMountOptions{}, - // PerAlloc: false, - // }, - // "vol1": { - // Name: "vol1", - // Type: structs.VolumeTypeCSI, - // Source: "testvolume0", - // ReadOnly: false, - // AccessMode: structs.CSIVolumeAccessModeMultiNodeReader, - // AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, - // MountOptions: &structs.CSIMountOptions{}, - // PerAlloc: false, - // }, - // }, - // expectedMounts: map[string]*csimanager.MountInfo{ - // "vol0": &csimanager.MountInfo{Source: fmt.Sprintf( - // "test-alloc-dir/%s/testvolume0/ro-file-system-multi-node-reader-only", alloc.ID)}, - // "vol1": &csimanager.MountInfo{Source: fmt.Sprintf( - // "test-alloc-dir/%s/testvolume0/ro-file-system-multi-node-reader-only", alloc.ID)}, - // }, - // expectedMountCalls: 1, - // expectedUnmountCalls: 1, - // expectedClaimCalls: 1, - // expectedUnpublishCalls: 1, - // }, + { + name: "retry on failed unmount", + volumeRequests: map[string]*structs.VolumeRequest{ + "vol0": { + Name: "vol0", + Type: structs.VolumeTypeCSI, + Source: "testvolume0", + ReadOnly: true, + AccessMode: structs.CSIVolumeAccessModeSingleNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + MountOptions: &structs.CSIMountOptions{}, + PerAlloc: false, + }, + }, + failsFirstUnmount: true, + expectedMounts: map[string]*csimanager.MountInfo{ + "vol0": &csimanager.MountInfo{Source: testMountSrc}, + }, + expectedCalls: map[string]int{ + "claim": 1, "mount": 1, "unmount": 2, "unpublish": 2}, + }, + + { + name: "should not run", + volumeRequests: map[string]*structs.VolumeRequest{}, + }, } for i := range testcases { tc := testcases[i] t.Run(tc.name, func(t *testing.T) { + alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests - callCounts := map[string]int{} - mgr := mockPluginManager{mounter: mockVolumeMounter{callCounts: callCounts}} + callCounts := &callCounter{counts: map[string]int{}} + mgr := mockPluginManager{mounter: mockVolumeMounter{ + hasMounts: tc.startsWithValidMounts, + callCounts: callCounts, + failsFirstUnmount: pointer.Of(tc.failsFirstUnmount), + }} rpcer := mockRPCer{ alloc: alloc, callCounts: callCounts, @@ -205,39 +244,47 @@ func TestCSIHook(t *testing.T) { FSIsolation: drivers.FSIsolationChroot, MountConfigs: drivers.MountConfigSupportAll, }, + stubs: tc.startsWithStubs, } + hook := newCSIHook(alloc, logger, mgr, rpcer, ar, ar.res, "secret") hook.minBackoffInterval = 1 * time.Millisecond hook.maxBackoffInterval = 10 * time.Millisecond hook.maxBackoffDuration = 500 * time.Millisecond - require.NotNil(t, hook) + must.NotNil(t, hook) if tc.expectedClaimErr != nil { - require.EqualError(t, hook.Prerun(), tc.expectedClaimErr.Error()) + must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error()) mounts := ar.res.GetCSIMounts() - require.Nil(t, mounts) + must.Nil(t, mounts) } else { - require.NoError(t, hook.Prerun()) + must.NoError(t, hook.Prerun()) mounts := ar.res.GetCSIMounts() - require.NotNil(t, mounts) - require.Equal(t, tc.expectedMounts, mounts) - require.NoError(t, hook.Postrun()) + must.MapEq(t, tc.expectedMounts, mounts, + must.Sprintf("got mounts: %v", mounts)) + must.NoError(t, hook.Postrun()) } - require.Equal(t, tc.expectedMountCalls, callCounts["mount"]) - require.Equal(t, tc.expectedUnmountCalls, callCounts["unmount"]) - require.Equal(t, tc.expectedClaimCalls, callCounts["claim"]) - require.Equal(t, tc.expectedUnpublishCalls, callCounts["unpublish"]) + if tc.failsFirstUnmount { + // retrying the unmount doesn't block Postrun, so give it time + // to run once more before checking the call counts to ensure + // this doesn't flake between 1 and 2 unmount/unpublish calls + time.Sleep(100 * time.Millisecond) + } + + counts := callCounts.get() + must.MapEq(t, tc.expectedCalls, counts, + must.Sprintf("got calls: %v", counts)) }) } } -// TestCSIHook_claimVolumesFromAlloc_Validation tests that the validation of task -// capabilities in claimVolumesFromAlloc ensures at least one task supports CSI. -func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { +// TestCSIHook_Prerun_Validation tests that the validation of task capabilities +// in Prerun ensures at least one task supports CSI. +func TestCSIHook_Prerun_Validation(t *testing.T) { ci.Parallel(t) alloc := mock.Alloc() @@ -256,10 +303,10 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { } type testCase struct { - name string - caps *drivers.Capabilities - capFunc func() (*drivers.Capabilities, error) - expectedClaimErr error + name string + caps *drivers.Capabilities + capFunc func() (*drivers.Capabilities, error) + expectedErr string } testcases := []testCase{ @@ -268,8 +315,8 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { caps: &drivers.Capabilities{ MountConfigs: drivers.MountConfigSupportNone, }, - capFunc: nil, - expectedClaimErr: errors.New("claim volumes: no task supports CSI"), + capFunc: nil, + expectedErr: "no task supports CSI", }, { @@ -278,7 +325,7 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { capFunc: func() (*drivers.Capabilities, error) { return nil, errors.New("error thrown by driver") }, - expectedClaimErr: errors.New("claim volumes: could not validate task driver capabilities: error thrown by driver"), + expectedErr: "could not validate task driver capabilities: error thrown by driver", }, { @@ -286,8 +333,7 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { caps: &drivers.Capabilities{ MountConfigs: drivers.MountConfigSupportAll, }, - capFunc: nil, - expectedClaimErr: nil, + capFunc: nil, }, } @@ -295,9 +341,11 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { alloc.Job.TaskGroups[0].Volumes = volumeRequests - callCounts := map[string]int{} - mgr := mockPluginManager{mounter: mockVolumeMounter{callCounts: callCounts}} - + callCounts := &callCounter{counts: map[string]int{}} + mgr := mockPluginManager{mounter: mockVolumeMounter{ + callCounts: callCounts, + failsFirstUnmount: pointer.Of(false), + }} rpcer := mockRPCer{ alloc: alloc, callCounts: callCounts, @@ -314,8 +362,8 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { hook := newCSIHook(alloc, logger, mgr, rpcer, ar, ar.res, "secret") require.NotNil(t, hook) - if tc.expectedClaimErr != nil { - require.EqualError(t, hook.Prerun(), tc.expectedClaimErr.Error()) + if tc.expectedErr != "" { + require.EqualError(t, hook.Prerun(), tc.expectedErr) mounts := ar.res.GetCSIMounts() require.Nil(t, mounts) } else { @@ -330,21 +378,44 @@ func TestCSIHook_claimVolumesFromAlloc_Validation(t *testing.T) { // HELPERS AND MOCKS +type callCounter struct { + lock sync.Mutex + counts map[string]int +} + +func (c *callCounter) inc(name string) { + c.lock.Lock() + defer c.lock.Unlock() + c.counts[name]++ +} + +func (c *callCounter) get() map[string]int { + c.lock.Lock() + defer c.lock.Unlock() + return maps.Clone(c.counts) +} + type mockRPCer struct { alloc *structs.Allocation - callCounts map[string]int + callCounts *callCounter hasExistingClaim *bool schedulable *bool } // RPC mocks the server RPCs, acting as though any request succeeds -func (r mockRPCer) RPC(method string, args interface{}, reply interface{}) error { +func (r mockRPCer) RPC(method string, args any, reply any) error { switch method { case "CSIVolume.Claim": - r.callCounts["claim"]++ + r.callCounts.inc("claim") req := args.(*structs.CSIVolumeClaimRequest) vol := r.testVolume(req.VolumeID) err := vol.Claim(req.ToClaim(), r.alloc) + + // after the first claim attempt is made, reset the volume's claims as + // though it's been released from another node + *r.hasExistingClaim = false + *r.schedulable = true + if err != nil { return err } @@ -353,22 +424,24 @@ func (r mockRPCer) RPC(method string, args interface{}, reply interface{}) error resp.PublishContext = map[string]string{} resp.Volume = vol resp.QueryMeta = structs.QueryMeta{} + case "CSIVolume.Unpublish": - r.callCounts["unpublish"]++ + r.callCounts.inc("unpublish") resp := reply.(*structs.CSIVolumeUnpublishResponse) resp.QueryMeta = structs.QueryMeta{} + default: return fmt.Errorf("unexpected method") } return nil } -// testVolume is a helper that optionally starts as unschedulable / -// claimed until after the first claim RPC is made, so that we can -// test retryable vs non-retryable failures +// testVolume is a helper that optionally starts as unschedulable / claimed, so +// that we can test retryable vs non-retryable failures func (r mockRPCer) testVolume(id string) *structs.CSIVolume { vol := structs.NewCSIVolume(id, 0) vol.Schedulable = *r.schedulable + vol.PluginID = "plugin-" + id vol.RequestedCapabilities = []*structs.CSIVolumeCapability{ { AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, @@ -393,29 +466,42 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume { } } - if r.callCounts["claim"] >= 0 { - *r.hasExistingClaim = false - *r.schedulable = true - } - return vol } type mockVolumeMounter struct { - callCounts map[string]int + hasMounts bool + failsFirstUnmount *bool + callCounts *callCounter } func (vm mockVolumeMounter) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) { - vm.callCounts["mount"]++ + vm.callCounts.inc("mount") return &csimanager.MountInfo{ Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()), }, nil } + func (vm mockVolumeMounter) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error { - vm.callCounts["unmount"]++ + vm.callCounts.inc("unmount") + + if *vm.failsFirstUnmount { + *vm.failsFirstUnmount = false + return fmt.Errorf("could not unmount") + } + return nil } +func (vm mockVolumeMounter) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) { + vm.callCounts.inc("hasMount") + return mountInfo != nil && vm.hasMounts, nil +} + +func (vm mockVolumeMounter) ExternalID() string { + return "i-example" +} + type mockPluginManager struct { mounter mockVolumeMounter } @@ -436,6 +522,9 @@ type mockAllocRunner struct { res *cstructs.AllocHookResources caps *drivers.Capabilities capFunc func() (*drivers.Capabilities, error) + + stubs map[string]*state.CSIVolumeStub + stubFunc func() (map[string]*state.CSIVolumeStub, error) } func (ar mockAllocRunner) GetTaskDriverCapabilities(taskName string) (*drivers.Capabilities, error) { @@ -444,3 +533,15 @@ func (ar mockAllocRunner) GetTaskDriverCapabilities(taskName string) (*drivers.C } return ar.caps, nil } + +func (ar mockAllocRunner) SetCSIVolumes(stubs map[string]*state.CSIVolumeStub) error { + ar.stubs = stubs + return nil +} + +func (ar mockAllocRunner) GetCSIVolumes() (map[string]*state.CSIVolumeStub, error) { + if ar.stubFunc != nil { + return ar.stubFunc() + } + return ar.stubs, nil +} diff --git a/client/allocrunner/state/state.go b/client/allocrunner/state/state.go index c3f245eed..42d2cb556 100644 --- a/client/allocrunner/state/state.go +++ b/client/allocrunner/state/state.go @@ -6,6 +6,7 @@ package state import ( "time" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/nomad/structs" ) @@ -76,3 +77,17 @@ func (s *State) ClientTerminalStatus() bool { return false } } + +type AllocVolumes struct { + CSIVolumes map[string]*CSIVolumeStub // volume request name -> CSIVolumeStub +} + +// CSIVolumeStub is a stripped-down version of the CSIVolume with just the +// relevant data that we need to persist about the volume. +type CSIVolumeStub struct { + VolumeID string + VolumeExternalID string + PluginID string + ExternalNodeID string + MountInfo *csimanager.MountInfo +} diff --git a/client/pluginmanager/csimanager/instance.go b/client/pluginmanager/csimanager/instance.go index 41ca29151..cfb0b15bd 100644 --- a/client/pluginmanager/csimanager/instance.go +++ b/client/pluginmanager/csimanager/instance.go @@ -94,7 +94,13 @@ func (i *instanceManager) setupVolumeManager() { case <-i.shutdownCtx.Done(): return case <-i.fp.hadFirstSuccessfulFingerprintCh: - i.volumeManager = newVolumeManager(i.logger, i.eventer, i.client, i.mountPoint, i.containerMountPoint, i.fp.requiresStaging) + + var externalID string + if i.fp.basicInfo != nil && i.fp.basicInfo.NodeInfo != nil { + externalID = i.fp.basicInfo.NodeInfo.ID + } + + i.volumeManager = newVolumeManager(i.logger, i.eventer, i.client, i.mountPoint, i.containerMountPoint, i.fp.requiresStaging, externalID) i.logger.Debug("volume manager setup complete") close(i.volumeManagerSetupCh) return diff --git a/client/pluginmanager/csimanager/interface.go b/client/pluginmanager/csimanager/interface.go index 0e986858a..73f632e3a 100644 --- a/client/pluginmanager/csimanager/interface.go +++ b/client/pluginmanager/csimanager/interface.go @@ -56,6 +56,8 @@ func (u *UsageOptions) ToFS() string { type VolumeMounter interface { MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error + HasMount(ctx context.Context, mountInfo *MountInfo) (bool, error) + ExternalID() string } type Manager interface { diff --git a/client/pluginmanager/csimanager/volume.go b/client/pluginmanager/csimanager/volume.go index 9cb96461c..9ab76adea 100644 --- a/client/pluginmanager/csimanager/volume.go +++ b/client/pluginmanager/csimanager/volume.go @@ -53,9 +53,14 @@ type volumeManager struct { // requiresStaging shows whether the plugin requires that the volume manager // calls NodeStageVolume and NodeUnstageVolume RPCs during setup and teardown requiresStaging bool + + // externalNodeID is the identity of a given nomad client as observed by the + // storage provider (ex. a hostname, VM instance ID, etc.) + externalNodeID string } -func newVolumeManager(logger hclog.Logger, eventer TriggerNodeEvent, plugin csi.CSIPlugin, rootDir, containerRootDir string, requiresStaging bool) *volumeManager { +func newVolumeManager(logger hclog.Logger, eventer TriggerNodeEvent, plugin csi.CSIPlugin, rootDir, containerRootDir string, requiresStaging bool, externalID string) *volumeManager { + return &volumeManager{ logger: logger.Named("volume_manager"), eventer: eventer, @@ -64,6 +69,7 @@ func newVolumeManager(logger hclog.Logger, eventer TriggerNodeEvent, plugin csi. containerMountPoint: containerRootDir, requiresStaging: requiresStaging, usageTracker: newVolumeUsageTracker(), + externalNodeID: externalID, } } @@ -376,3 +382,13 @@ func (v *volumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allo return err } + +func (v *volumeManager) ExternalID() string { + return v.externalNodeID +} + +func (v *volumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { + m := mount.New() + isNotMount, err := m.IsNotAMountPoint(mountInfo.Source) + return !isNotMount, err +} diff --git a/client/pluginmanager/csimanager/volume_test.go b/client/pluginmanager/csimanager/volume_test.go index 28240d304..fa43375a1 100644 --- a/client/pluginmanager/csimanager/volume_test.go +++ b/client/pluginmanager/csimanager/volume_test.go @@ -92,7 +92,8 @@ func TestVolumeManager_ensureStagingDir(t *testing.T) { csiFake := &csifake.Client{} eventer := func(e *structs.NodeEvent) {} - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") expectedStagingPath := manager.stagingDirForVolume(tmpPath, tc.Volume.ID, tc.UsageOptions) if tc.CreateDirAheadOfTime { @@ -193,7 +194,8 @@ func TestVolumeManager_stageVolume(t *testing.T) { csiFake.NextNodeStageVolumeErr = tc.PluginErr eventer := func(e *structs.NodeEvent) {} - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") ctx := context.Background() err := manager.stageVolume(ctx, tc.Volume, tc.UsageOptions, nil) @@ -251,7 +253,8 @@ func TestVolumeManager_unstageVolume(t *testing.T) { csiFake.NextNodeUnstageVolumeErr = tc.PluginErr eventer := func(e *structs.NodeEvent) {} - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") ctx := context.Background() err := manager.unstageVolume(ctx, @@ -374,7 +377,8 @@ func TestVolumeManager_publishVolume(t *testing.T) { csiFake.NextNodePublishVolumeErr = tc.PluginErr eventer := func(e *structs.NodeEvent) {} - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") ctx := context.Background() _, err := manager.publishVolume(ctx, tc.Volume, tc.Allocation, tc.UsageOptions, nil) @@ -441,7 +445,8 @@ func TestVolumeManager_unpublishVolume(t *testing.T) { csiFake.NextNodeUnpublishVolumeErr = tc.PluginErr eventer := func(e *structs.NodeEvent) {} - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") ctx := context.Background() err := manager.unpublishVolume(ctx, @@ -473,7 +478,8 @@ func TestVolumeManager_MountVolumeEvents(t *testing.T) { events = append(events, e) } - manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, tmpPath, tmpPath, true) + manager := newVolumeManager(testlog.HCLogger(t), eventer, csiFake, + tmpPath, tmpPath, true, "i-example") ctx := context.Background() vol := &structs.CSIVolume{ ID: "vol", diff --git a/client/state/db_bolt.go b/client/state/db_bolt.go index f54aa4184..1b819c73c 100644 --- a/client/state/db_bolt.go +++ b/client/state/db_bolt.go @@ -35,6 +35,7 @@ allocations/ |--> deploy_status -> deployStatusEntry{*structs.AllocDeploymentStatus} |--> network_status -> networkStatusEntry{*structs.AllocNetworkStatus} |--> acknowledged_state -> acknowledgedStateEntry{*arstate.State} + |--> alloc_volumes -> allocVolumeStatesEntry{arstate.AllocVolumes} |--> task-/ |--> local_state -> *trstate.LocalState # Local-only state |--> task_state -> *structs.TaskState # Syncs to servers @@ -92,6 +93,8 @@ var ( // acknowledgedStateKey is the key *arstate.State is stored under acknowledgedStateKey = []byte("acknowledged_state") + allocVolumeKey = []byte("alloc_volume") + // checkResultsBucket is the bucket name in which check query results are stored checkResultsBucket = []byte("check_results") @@ -455,6 +458,59 @@ func (s *BoltStateDB) GetAcknowledgedState(allocID string) (*arstate.State, erro return entry.State, nil } +type allocVolumeStatesEntry struct { + State *arstate.AllocVolumes +} + +// PutAllocVolumes stores stubs of an allocation's dynamic volume mounts so they +// can be restored. +func (s *BoltStateDB) PutAllocVolumes(allocID string, state *arstate.AllocVolumes, opts ...WriteOption) error { + return s.updateWithOptions(opts, func(tx *boltdd.Tx) error { + allocBkt, err := getAllocationBucket(tx, allocID) + if err != nil { + return err + } + + entry := allocVolumeStatesEntry{ + State: state, + } + return allocBkt.Put(allocVolumeKey, &entry) + }) +} + +// GetAllocVolumes retrieves stubs of an allocation's dynamic volume mounts so +// they can be restored. +func (s *BoltStateDB) GetAllocVolumes(allocID string) (*arstate.AllocVolumes, error) { + var entry allocVolumeStatesEntry + + err := s.db.View(func(tx *boltdd.Tx) error { + allAllocsBkt := tx.Bucket(allocationsBucketName) + if allAllocsBkt == nil { + // No state, return + return nil + } + + allocBkt := allAllocsBkt.Bucket([]byte(allocID)) + if allocBkt == nil { + // No state for alloc, return + return nil + } + + return allocBkt.Get(allocVolumeKey, &entry) + }) + + // It's valid for this field to be nil/missing + if boltdd.IsErrNotFound(err) { + return nil, nil + } + + if err != nil { + return nil, err + } + + return entry.State, nil +} + // GetTaskRunnerState returns the LocalState and TaskState for a // TaskRunner. LocalState or TaskState will be nil if they do not exist. // diff --git a/client/state/db_error.go b/client/state/db_error.go index bf4e1ce52..58340a716 100644 --- a/client/state/db_error.go +++ b/client/state/db_error.go @@ -62,6 +62,14 @@ func (m *ErrDB) GetAcknowledgedState(allocID string) (*arstate.State, error) { return nil, fmt.Errorf("Error!") } +func (m *ErrDB) PutAllocVolumes(allocID string, state *arstate.AllocVolumes, opts ...WriteOption) error { + return fmt.Errorf("Error!") +} + +func (m *ErrDB) GetAllocVolumes(allocID string) (*arstate.AllocVolumes, error) { + return nil, fmt.Errorf("Error!") +} + func (m *ErrDB) GetTaskRunnerState(allocID string, taskName string) (*state.LocalState, *structs.TaskState, error) { return nil, nil, fmt.Errorf("Error!") } diff --git a/client/state/db_mem.go b/client/state/db_mem.go index bad6bb476..d05ff17ad 100644 --- a/client/state/db_mem.go +++ b/client/state/db_mem.go @@ -33,6 +33,9 @@ type MemDB struct { // alloc_id -> value acknowledgedState map[string]*arstate.State + // alloc_id -> value + allocVolumeStates map[string]*arstate.AllocVolumes + // alloc_id -> task_name -> value localTaskState map[string]map[string]*state.LocalState taskState map[string]map[string]*structs.TaskState @@ -139,6 +142,19 @@ func (m *MemDB) GetAcknowledgedState(allocID string) (*arstate.State, error) { return m.acknowledgedState[allocID], nil } +func (m *MemDB) PutAllocVolumes(allocID string, state *arstate.AllocVolumes, opts ...WriteOption) error { + m.mu.Lock() + m.allocVolumeStates[allocID] = state + defer m.mu.Unlock() + return nil +} + +func (m *MemDB) GetAllocVolumes(allocID string) (*arstate.AllocVolumes, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.allocVolumeStates[allocID], nil +} + func (m *MemDB) GetTaskRunnerState(allocID string, taskName string) (*state.LocalState, *structs.TaskState, error) { m.mu.RLock() defer m.mu.RUnlock() diff --git a/client/state/db_noop.go b/client/state/db_noop.go index 4cee92603..62cfdb834 100644 --- a/client/state/db_noop.go +++ b/client/state/db_noop.go @@ -55,6 +55,12 @@ func (n NoopDB) PutAcknowledgedState(allocID string, state *arstate.State, opts func (n NoopDB) GetAcknowledgedState(allocID string) (*arstate.State, error) { return nil, nil } +func (n NoopDB) PutAllocVolumes(allocID string, state *arstate.AllocVolumes, opts ...WriteOption) error { + return nil +} + +func (n NoopDB) GetAllocVolumes(allocID string) (*arstate.AllocVolumes, error) { return nil, nil } + func (n NoopDB) GetTaskRunnerState(allocID string, taskName string) (*state.LocalState, *structs.TaskState, error) { return nil, nil, nil } diff --git a/client/state/interface.go b/client/state/interface.go index 82427a77c..51691ada3 100644 --- a/client/state/interface.go +++ b/client/state/interface.go @@ -54,6 +54,14 @@ type StateDB interface { // state. It may be nil even if there's no error GetAcknowledgedState(string) (*arstate.State, error) + // PutAllocVolumes stores stubs of an allocation's dynamic volume mounts so + // they can be restored. + PutAllocVolumes(allocID string, state *arstate.AllocVolumes, opts ...WriteOption) error + + // GetAllocVolumes retrieves stubs of an allocation's dynamic volume mounts + // so they can be restored. + GetAllocVolumes(allocID string) (*arstate.AllocVolumes, error) + // GetTaskRunnerState returns the LocalState and TaskState for a // TaskRunner. Either state may be nil if it is not found, but if an // error is encountered only the error will be non-nil. diff --git a/helper/funcs.go b/helper/funcs.go index 2e6d8529f..2ec0ceb92 100644 --- a/helper/funcs.go +++ b/helper/funcs.go @@ -412,6 +412,17 @@ func ConvertSlice[A, B any](original []A, conversion func(a A) B) []B { return result } +// ConvertMap takes the input map and generates a new one using the supplied +// conversion function to convert the values. This is useful when converting one +// map to another using the same keys. +func ConvertMap[K comparable, A, B any](original map[K]A, conversion func(a A) B) map[K]B { + result := make(map[K]B, len(original)) + for k, a := range original { + result[k] = conversion(a) + } + return result +} + // IsMethodHTTP returns whether s is a known HTTP method, ignoring case. func IsMethodHTTP(s string) bool { switch strings.ToUpper(s) { diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index a7a48dd9b..0614b87b6 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -525,19 +525,20 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID) } - // get the the storage provider's ID for the client node (not - // Nomad's ID for the node) - targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID] - if !ok { - return fmt.Errorf("failed to find storage provider info for client %q, node plugin %q is not running or has not fingerprinted on this client", targetNode.ID, plug.ID) + // if the RPC is sent by a client node, it may not know the claim's + // external node ID. + if req.ExternalNodeID == "" { + externalNodeID, err := v.lookupExternalNodeID(vol, req.ToClaim()) + if err != nil { + return fmt.Errorf("missing external node ID: %v", err) + } + req.ExternalNodeID = externalNodeID } - externalNodeID := targetCSIInfo.NodeInfo.ID - req.ExternalNodeID = externalNodeID // update with the target info method := "ClientCSI.ControllerAttachVolume" cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{ VolumeID: vol.RemoteID(), - ClientCSINodeID: externalNodeID, + ClientCSINodeID: req.ExternalNodeID, AttachmentMode: req.AttachmentMode, AccessMode: req.AccessMode, MountOptions: csiVolumeMountOptions(vol.MountOptions), @@ -846,7 +847,7 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str } } - // if the RPC is sent by a client node, it doesn't know the claim's + // if the RPC is sent by a client node, it may not know the claim's // external node ID. if claim.ExternalNodeID == "" { externalNodeID, err := v.lookupExternalNodeID(vol, claim) diff --git a/nomad/structs/volumes.go b/nomad/structs/volumes.go index da110b6a5..781e23f9f 100644 --- a/nomad/structs/volumes.go +++ b/nomad/structs/volumes.go @@ -222,6 +222,14 @@ func (v *VolumeRequest) Copy() *VolumeRequest { return nv } +func (v *VolumeRequest) VolumeID(tgName string) string { + source := v.Source + if v.PerAlloc { + source = source + AllocSuffix(tgName) + } + return source +} + func CopyMapVolumeRequest(s map[string]*VolumeRequest) map[string]*VolumeRequest { if s == nil { return nil