diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 4306783ef..f6aa3c112 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -773,7 +773,6 @@ func (c *CoreScheduler) csiVolumeClaimGC(eval *structs.Evaluation) error { "index", oldThreshold, "csi_volume_claim_gc_threshold", c.srv.config.CSIVolumeClaimGCThreshold) -NEXT_VOLUME: for i := iter.Next(); i != nil; i = iter.Next() { vol := i.(*structs.CSIVolume) @@ -782,38 +781,12 @@ NEXT_VOLUME: continue } - // TODO(tgross): consider moving the TerminalStatus check into - // the denormalize volume logic so that we can just check the - // volume for past claims - // we only call the claim release RPC if the volume has claims // that no longer have valid allocations. otherwise we'd send // out a lot of do-nothing RPCs. - for id := range vol.ReadClaims { - alloc, err := c.snap.AllocByID(ws, id) - if err != nil { - return err - } - if alloc == nil || alloc.TerminalStatus() { - err = gcClaims(vol.Namespace, vol.ID) - if err != nil { - return err - } - goto NEXT_VOLUME - } - } - for id := range vol.WriteClaims { - alloc, err := c.snap.AllocByID(ws, id) - if err != nil { - return err - } - if alloc == nil || alloc.TerminalStatus() { - err = gcClaims(vol.Namespace, vol.ID) - if err != nil { - return err - } - goto NEXT_VOLUME - } + vol, err := c.snap.CSIVolumeDenormalize(ws, vol) + if err != nil { + return err } if len(vol.PastClaims) > 0 { err = gcClaims(vol.Namespace, vol.ID) diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index fea730dfc..ac63b8fa5 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -615,39 +615,25 @@ func (v *CSIVolume) nodeUnpublishVolume(vol *structs.CSIVolume, claim *structs.C return v.checkpointClaim(vol, claim) } - // The RPC sent from the 'nomad node detach' command won't have an + // The RPC sent from the 'nomad node detach' command or GC won't have an // allocation ID set so we try to unpublish every terminal or invalid - // alloc on the node - allocIDs := []string{} + // alloc on the node, all of which will be in PastClaims after denormalizing state := v.srv.fsm.State() vol, err := state.CSIVolumeDenormalize(memdb.NewWatchSet(), vol) if err != nil { return err } - for allocID, alloc := range vol.ReadAllocs { - if alloc == nil { - rclaim, ok := vol.ReadClaims[allocID] - if ok && rclaim.NodeID == claim.NodeID { - allocIDs = append(allocIDs, allocID) - } - } else if alloc.NodeID == claim.NodeID && alloc.TerminalStatus() { - allocIDs = append(allocIDs, allocID) - } - } - for allocID, alloc := range vol.WriteAllocs { - if alloc == nil { - wclaim, ok := vol.WriteClaims[allocID] - if ok && wclaim.NodeID == claim.NodeID { - allocIDs = append(allocIDs, allocID) - } - } else if alloc.NodeID == claim.NodeID && alloc.TerminalStatus() { - allocIDs = append(allocIDs, allocID) + + claimsToUnpublish := []*structs.CSIVolumeClaim{} + for _, pastClaim := range vol.PastClaims { + if claim.NodeID == pastClaim.NodeID { + claimsToUnpublish = append(claimsToUnpublish, pastClaim) } } + var merr multierror.Error - for _, allocID := range allocIDs { - claim.AllocationID = allocID - err := v.nodeUnpublishVolumeImpl(vol, claim) + for _, pastClaim := range claimsToUnpublish { + err := v.nodeUnpublishVolumeImpl(vol, pastClaim) if err != nil { merr.Errors = append(merr.Errors, err) } @@ -668,8 +654,8 @@ func (v *CSIVolume) nodeUnpublishVolumeImpl(vol *structs.CSIVolume, claim *struc ExternalID: vol.RemoteID(), AllocID: claim.AllocationID, NodeID: claim.NodeID, - AttachmentMode: vol.AttachmentMode, - AccessMode: vol.AccessMode, + AttachmentMode: claim.AttachmentMode, + AccessMode: claim.AccessMode, ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, } err := v.srv.RPC("ClientCSI.NodeDetachVolume", diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 78f98bcb1..9b451f101 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2531,7 +2531,7 @@ func (s *StateStore) csiVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *st } currentAllocs[id] = a - if a == nil && pastClaim == nil { + if (a == nil || a.TerminalStatus()) && pastClaim == nil { // the alloc is garbage collected but nothing has written a PastClaim, // so create one now pastClaim = &structs.CSIVolumeClaim{ diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index 28fc94f35..fe69bca41 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -177,17 +177,10 @@ func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool { return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0 } +// volumeReapImpl unpublished all the volume's PastClaims. PastClaims +// will be populated from nil or terminal allocs when we call +// CSIVolumeDenormalize(), so this assumes we've done so in the caller func (vw *volumeWatcher) volumeReapImpl(vol *structs.CSIVolume) error { - - // PastClaims written by a volume GC core job will have no allocation, - // so we need to find out which allocs are eligible for cleanup. - for _, claim := range vol.PastClaims { - if claim.AllocationID == "" { - vol = vw.collectPastClaims(vol) - break // only need to collect once - } - } - var result *multierror.Error for _, claim := range vol.PastClaims { err := vw.unpublish(vol, claim) @@ -195,9 +188,7 @@ func (vw *volumeWatcher) volumeReapImpl(vol *structs.CSIVolume) error { result = multierror.Append(result, err) } } - return result.ErrorOrNil() - } func (vw *volumeWatcher) collectPastClaims(vol *structs.CSIVolume) *structs.CSIVolume { diff --git a/nomad/volumewatcher/volume_watcher_test.go b/nomad/volumewatcher/volume_watcher_test.go index 4e8a556a4..4bb4ddae4 100644 --- a/nomad/volumewatcher/volume_watcher_test.go +++ b/nomad/volumewatcher/volume_watcher_test.go @@ -37,6 +37,7 @@ func TestVolumeWatch_Reap(t *testing.T) { logger: testlog.HCLogger(t), } + vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) err := w.volumeReapImpl(vol) require.NoError(err) @@ -48,6 +49,7 @@ func TestVolumeWatch_Reap(t *testing.T) { State: structs.CSIVolumeClaimStateNodeDetached, }, } + vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) require.NoError(err) require.Len(vol.PastClaims, 1) @@ -59,6 +61,7 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimGC, }, } + vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) require.NoError(err) require.Len(vol.PastClaims, 2) // alloc claim + GC claim @@ -71,6 +74,7 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimRead, }, } + vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) require.NoError(err) require.Len(vol.PastClaims, 2) // alloc claim + GC claim diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go index 2271c2f20..c66411631 100644 --- a/nomad/volumewatcher/volumes_watcher_test.go +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -67,7 +67,7 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { plugin := mock.CSIPlugin() node := testNode(plugin, srv.State()) alloc := mock.Alloc() - alloc.ClientStatus = structs.AllocClientStatusComplete + alloc.ClientStatus = structs.AllocClientStatusRunning vol := testVolume(plugin, alloc, node.ID) index++