diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go index 5f4812133..f90660fe8 100644 --- a/client/allocrunner/csi_hook_test.go +++ b/client/allocrunner/csi_hook_test.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" "github.com/hashicorp/nomad/plugins/drivers" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" @@ -498,6 +499,10 @@ func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.Mo return mountInfo != nil && vm.hasMounts, nil } +func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) { + return 0, nil +} + func (vm mockVolumeManager) ExternalID() string { return "i-example" } diff --git a/client/csi_endpoint.go b/client/csi_endpoint.go index 3d0f448e8..f8b95a67e 100644 --- a/client/csi_endpoint.go +++ b/client/csi_endpoint.go @@ -537,6 +537,45 @@ func (c *CSI) NodeDetachVolume(req *structs.ClientCSINodeDetachVolumeRequest, re return nil } +// NodeExpandVolume instructs the node plugin to complete a volume expansion +// for a particular claim held by an allocation. +func (c *CSI) NodeExpandVolume(req *structs.ClientCSINodeExpandVolumeRequest, resp *structs.ClientCSINodeExpandVolumeResponse) error { + defer metrics.MeasureSince([]string{"client", "csi_node", "expand_volume"}, time.Now()) + + if err := req.Validate(); err != nil { + return err + } + usageOpts := &csimanager.UsageOptions{ + // Claim will not be nil here, per req.Validate() above. + ReadOnly: req.Claim.Mode == nstructs.CSIVolumeClaimRead, + AttachmentMode: req.Claim.AttachmentMode, + AccessMode: req.Claim.AccessMode, + } + + ctx, cancel := c.requestContext() // note: this has a 2-minute timeout + defer cancel() + + err := c.c.csimanager.WaitForPlugin(ctx, dynamicplugins.PluginTypeCSINode, req.PluginID) + if err != nil { + return err + } + + manager, err := c.c.csimanager.ManagerForPlugin(ctx, req.PluginID) + if err != nil { + return err + } + + newCapacity, err := manager.ExpandVolume(ctx, + req.VolumeID, req.ExternalID, req.Claim.AllocationID, usageOpts, req.Capacity) + + if err != nil && !errors.Is(err, nstructs.ErrCSIClientRPCIgnorable) { + return err + } + resp.CapacityBytes = newCapacity + + return nil +} + func (c *CSI) findControllerPlugin(name string) (csi.CSIPlugin, error) { return c.findPlugin(dynamicplugins.PluginTypeCSIController, name) } diff --git a/client/csi_endpoint_test.go b/client/csi_endpoint_test.go index 924a0a3f9..a511b128b 100644 --- a/client/csi_endpoint_test.go +++ b/client/csi_endpoint_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/client/structs" nstructs "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" @@ -1069,3 +1070,101 @@ func TestCSINode_DetachVolume(t *testing.T) { }) } } + +func TestCSINode_ExpandVolume(t *testing.T) { + ci.Parallel(t) + + client, cleanup := TestClient(t, nil) + t.Cleanup(func() { test.NoError(t, cleanup()) }) + + cases := []struct { + Name string + ModRequest func(r *structs.ClientCSINodeExpandVolumeRequest) + ModManager func(m *csimanager.MockCSIManager) + ExpectErr error + }{ + { + Name: "success", + }, + { + Name: "invalid request", + ModRequest: func(r *structs.ClientCSINodeExpandVolumeRequest) { + r.Claim = nil + }, + ExpectErr: errors.New("Claim is required"), + }, + { + Name: "error waiting for plugin", + ModManager: func(m *csimanager.MockCSIManager) { + m.NextWaitForPluginErr = errors.New("sad plugin") + }, + ExpectErr: errors.New("sad plugin"), + }, + { + Name: "error from manager expand", + ModManager: func(m *csimanager.MockCSIManager) { + m.VM.NextExpandVolumeErr = errors.New("no expand, so sad") + }, + ExpectErr: errors.New("no expand, so sad"), + }, + { + Name: "ignorable error from manager expand", + ModManager: func(m *csimanager.MockCSIManager) { + m.VM.NextExpandVolumeErr = fmt.Errorf("%w: not found", nstructs.ErrCSIClientRPCIgnorable) + }, + ExpectErr: nil, // explicitly expecting no error + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + mockManager := &csimanager.MockCSIManager{ + VM: &csimanager.MockVolumeManager{}, + } + if tc.ModManager != nil { + tc.ModManager(mockManager) + } + client.csimanager = mockManager + + req := &structs.ClientCSINodeExpandVolumeRequest{ + PluginID: "fake-plug", + VolumeID: "fake-vol", + ExternalID: "fake-external", + Capacity: &csi.CapacityRange{ + RequiredBytes: 5, + }, + Claim: &nstructs.CSIVolumeClaim{ + // minimal claim to pass validation + AllocationID: "fake-alloc", + }, + } + if tc.ModRequest != nil { + tc.ModRequest(req) + } + + var resp structs.ClientCSINodeExpandVolumeResponse + err := client.ClientRPC("CSI.NodeExpandVolume", req, &resp) + + if tc.ExpectErr != nil { + test.EqError(t, tc.ExpectErr, err.Error()) + return + } + test.NoError(t, err) + + expect := csimanager.MockExpandVolumeCall{ + VolID: req.VolumeID, + RemoteID: req.ExternalID, + AllocID: req.Claim.AllocationID, + Capacity: req.Capacity, + UsageOpts: &csimanager.UsageOptions{ + ReadOnly: true, + }, + } + test.Eq(t, req.Capacity.RequiredBytes, resp.CapacityBytes) + test.NotNil(t, mockManager.VM.LastExpandVolumeCall) + test.Eq(t, &expect, mockManager.VM.LastExpandVolumeCall) + + }) + } +} diff --git a/client/pluginmanager/csimanager/interface.go b/client/pluginmanager/csimanager/interface.go index 87415cc51..5dc1a81eb 100644 --- a/client/pluginmanager/csimanager/interface.go +++ b/client/pluginmanager/csimanager/interface.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/nomad/client/pluginmanager" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" ) type MountInfo struct { @@ -57,6 +58,7 @@ type VolumeManager interface { MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error HasMount(ctx context.Context, mountInfo *MountInfo) (bool, error) + ExpandVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) ExternalID() string } diff --git a/client/pluginmanager/csimanager/testing.go b/client/pluginmanager/csimanager/testing.go new file mode 100644 index 000000000..f27f74265 --- /dev/null +++ b/client/pluginmanager/csimanager/testing.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package csimanager + +import ( + "context" + + "github.com/hashicorp/nomad/client/pluginmanager" + nstructs "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" +) + +var _ Manager = &MockCSIManager{} + +type MockCSIManager struct { + VM *MockVolumeManager + + NextWaitForPluginErr error + NextManagerForPluginErr error +} + +func (m *MockCSIManager) PluginManager() pluginmanager.PluginManager { + panic("implement me") +} + +func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID string) error { + return m.NextWaitForPluginErr +} + +func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) { + return m.VM, m.NextManagerForPluginErr +} + +func (m *MockCSIManager) Shutdown() { + panic("implement me") +} + +var _ VolumeManager = &MockVolumeManager{} + +type MockVolumeManager struct { + NextExpandVolumeErr error + LastExpandVolumeCall *MockExpandVolumeCall +} + +func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) { + panic("implement me") +} + +func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error { + panic("implement me") +} + +func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { + panic("implement me") +} + +func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) { + m.LastExpandVolumeCall = &MockExpandVolumeCall{ + volID, remoteID, allocID, usageOpts, capacity, + } + return capacity.RequiredBytes, m.NextExpandVolumeErr +} + +type MockExpandVolumeCall struct { + VolID, RemoteID, AllocID string + UsageOpts *UsageOptions + Capacity *csi.CapacityRange +} + +func (m *MockVolumeManager) ExternalID() string { + return "mock-volume-manager" +} diff --git a/client/pluginmanager/csimanager/volume.go b/client/pluginmanager/csimanager/volume.go index 80f839f43..6d0c4ac6e 100644 --- a/client/pluginmanager/csimanager/volume.go +++ b/client/pluginmanager/csimanager/volume.go @@ -383,6 +383,36 @@ func (v *volumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allo return err } +// ExpandVolume sends a NodeExpandVolume request to the node plugin +func (v *volumeManager) ExpandVolume(ctx context.Context, volID, remoteID, allocID string, usage *UsageOptions, capacity *csi.CapacityRange) (newCapacity int64, err error) { + capability, err := csi.VolumeCapabilityFromStructs(usage.AttachmentMode, usage.AccessMode, usage.MountOptions) + if err != nil { + // nil may be acceptable, so let the node plugin decide. + v.logger.Warn("ExpandVolume: unable to detect volume capability", + "volume_id", volID, "alloc_id", allocID, "error", err) + } + + req := &csi.NodeExpandVolumeRequest{ + ExternalVolumeID: remoteID, + CapacityRange: capacity, + Capability: capability, + TargetPath: v.targetForVolume(v.containerMountPoint, volID, allocID, usage), + StagingPath: v.stagingDirForVolume(v.containerMountPoint, volID, usage), + } + resp, err := v.plugin.NodeExpandVolume(ctx, req, + grpc_retry.WithPerRetryTimeout(DefaultMountActionTimeout), + grpc_retry.WithMax(3), + grpc_retry.WithBackoff(grpc_retry.BackoffExponential(100*time.Millisecond)), + ) + if err != nil { + return 0, err + } + if resp == nil { + return 0, errors.New("nil response from plugin.NodeExpandVolume") + } + return resp.CapacityBytes, nil +} + func (v *volumeManager) ExternalID() string { return v.externalNodeID } diff --git a/client/structs/csi.go b/client/structs/csi.go index ae3e26ba0..6caa18aec 100644 --- a/client/structs/csi.go +++ b/client/structs/csi.go @@ -4,6 +4,8 @@ package structs import ( + "errors" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" ) @@ -452,3 +454,44 @@ type ClientCSINodeDetachVolumeRequest struct { } type ClientCSINodeDetachVolumeResponse struct{} + +// ClientCSINodeExpandVolumeRequest is the RPC made from the server to +// a Nomad client to tell a CSI node plugin on that client to perform +// NodeExpandVolume. +type ClientCSINodeExpandVolumeRequest struct { + PluginID string // ID of the plugin that manages the volume (required) + VolumeID string // ID of the volume to be expanded (required) + ExternalID string // External ID of the volume to be expanded (required) + + // Capacity range (required) to be sent to the node plugin + Capacity *csi.CapacityRange + + // Claim currently held for the allocation (required) + // used to determine capabilities and the mount point on the client + Claim *structs.CSIVolumeClaim +} + +func (req *ClientCSINodeExpandVolumeRequest) Validate() error { + var err error + // These should not occur during normal operations; they're here + // mainly to catch potential programmer error. + if req.PluginID == "" { + err = errors.Join(err, errors.New("PluginID is required")) + } + if req.VolumeID == "" { + err = errors.Join(err, errors.New("VolumeID is required")) + } + if req.ExternalID == "" { + err = errors.Join(err, errors.New("ExternalID is required")) + } + if req.Claim == nil { + err = errors.Join(err, errors.New("Claim is required")) + } else if req.Claim.AllocationID == "" { + err = errors.Join(err, errors.New("Claim.AllocationID is required")) + } + return err +} + +type ClientCSINodeExpandVolumeResponse struct { + CapacityBytes int64 +} diff --git a/client/structs/csi_test.go b/client/structs/csi_test.go new file mode 100644 index 000000000..117a230b5 --- /dev/null +++ b/client/structs/csi_test.go @@ -0,0 +1,45 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package structs + +import ( + "testing" + + "github.com/shoenig/test/must" + + "github.com/hashicorp/nomad/nomad/structs" +) + +func TestClientCSINodeExpandVolumeRequest_Validate(t *testing.T) { + req := &ClientCSINodeExpandVolumeRequest{ + PluginID: "plug-id", + VolumeID: "vol-id", + ExternalID: "ext-id", + Claim: &structs.CSIVolumeClaim{ + AllocationID: "alloc-id", + }, + } + err := req.Validate() + must.NoError(t, err) + + req.PluginID = "" + err = req.Validate() + must.ErrorContains(t, err, "PluginID is required") + + req.VolumeID = "" + err = req.Validate() + must.ErrorContains(t, err, "VolumeID is required") + + req.ExternalID = "" + err = req.Validate() + must.ErrorContains(t, err, "ExternalID is required") + + req.Claim.AllocationID = "" + err = req.Validate() + must.ErrorContains(t, err, "Claim.AllocationID is required") + + req.Claim = nil + err = req.Validate() + must.ErrorContains(t, err, "Claim is required") +} diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index 0d1164265..a8e321b8b 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -224,12 +224,34 @@ func (a *ClientCSI) isRetryable(err error) bool { func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_node", "detach_volume"}, time.Now()) + return a.sendCSINodeRPC( + args.NodeID, + "CSI.NodeDetachVolume", + "ClientCSI.NodeDetachVolume", + structs.RateMetricWrite, + args, + reply, + ) +} +func (a *ClientCSI) NodeExpandVolume(args *cstructs.ClientCSINodeExpandVolumeRequest, reply *cstructs.ClientCSINodeExpandVolumeResponse) error { + defer metrics.MeasureSince([]string{"nomad", "client_csi_node", "expand_volume"}, time.Now()) + return a.sendCSINodeRPC( + args.Claim.NodeID, + "CSI.NodeExpandVolume", + "ClientCSI.NodeExpandVolume", + structs.RateMetricWrite, + args, + reply, + ) +} + +func (a *ClientCSI) sendCSINodeRPC(nodeID, method, fwdMethod, op string, args any, reply any) error { // client requests aren't RequestWithIdentity, so we use a placeholder here // to populate the identity data for metrics identityReq := &structs.GenericRequest{} authErr := a.srv.Authenticate(a.ctx, identityReq) - a.srv.MeasureRPCRate("client_csi", structs.RateMetricWrite, identityReq) + a.srv.MeasureRPCRate("client_csi", op, identityReq) // only servers can send these client RPCs err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelServer) @@ -243,24 +265,22 @@ func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeReq return err } - _, err = getNodeForRpc(snap, args.NodeID) + _, err = getNodeForRpc(snap, nodeID) if err != nil { return err } // Get the connection to the client - state, ok := a.srv.getNodeConn(args.NodeID) + state, ok := a.srv.getNodeConn(nodeID) if !ok { - return findNodeConnAndForward(a.srv, args.NodeID, "ClientCSI.NodeDetachVolume", args, reply) + return findNodeConnAndForward(a.srv, nodeID, fwdMethod, args, reply) } // Make the RPC - err = NodeRpc(state.Session, "CSI.NodeDetachVolume", args, reply) - if err != nil { - return fmt.Errorf("node detach volume: %v", err) + if err := NodeRpc(state.Session, method, args, reply); err != nil { + return fmt.Errorf("%s error: %w", method, err) } return nil - } // clientIDsForController returns a shuffled list of client IDs where the diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index 6100078c1..c139121b2 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -45,6 +45,8 @@ type MockClientCSI struct { NextControllerExpandVolumeError error NextControllerExpandVolumeResponse *cstructs.ClientCSIControllerExpandVolumeResponse NextNodeDetachError error + NextNodeExpandError error + LastNodeExpandRequest *cstructs.ClientCSINodeExpandVolumeRequest } func newMockClientCSI() *MockClientCSI { @@ -108,6 +110,11 @@ func (c *MockClientCSI) NodeDetachVolume(req *cstructs.ClientCSINodeDetachVolume return c.NextNodeDetachError } +func (c *MockClientCSI) NodeExpandVolume(req *cstructs.ClientCSINodeExpandVolumeRequest, resp *cstructs.ClientCSINodeExpandVolumeResponse) error { + c.LastNodeExpandRequest = req + return c.NextNodeExpandError +} + func TestClientCSIController_AttachVolume_Local(t *testing.T) { ci.Parallel(t) require := require.New(t) diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 51db1dcb9..2f50ac8b2 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -1272,12 +1272,57 @@ func (v *CSIVolume) expandVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug logger.Info("controller done expanding volume") if cResp.NodeExpansionRequired { - v.logger.Warn("TODO: also do node volume expansion if needed") // TODO + return v.nodeExpandVolume(vol, plugin, capacity) } return nil } +// nodeExpandVolume sends NodeExpandVolume requests to the appropriate client +// for each allocation that has a claim on the volume. The client will then +// send a gRPC call to the CSI node plugin colocated with the allocation. +func (v *CSIVolume) nodeExpandVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, capacity *csi.CapacityRange) error { + var mErr multierror.Error + logger := v.logger.Named("nodeExpandVolume"). + With("volume", vol.ID, "plugin", plugin.ID) + + expand := func(claim *structs.CSIVolumeClaim) { + if claim == nil { + return + } + + logger.Debug("starting volume expansion on node", + "node_id", claim.NodeID, "alloc_id", claim.AllocationID) + + resp := &cstructs.ClientCSINodeExpandVolumeResponse{} + req := &cstructs.ClientCSINodeExpandVolumeRequest{ + PluginID: plugin.ID, + VolumeID: vol.ID, + ExternalID: vol.ExternalID, + Capacity: capacity, + Claim: claim, + } + if err := v.srv.RPC("ClientCSI.NodeExpandVolume", req, resp); err != nil { + mErr.Errors = append(mErr.Errors, err) + } + + if resp.CapacityBytes != vol.Capacity { + // not necessarily an error, but maybe notable + logger.Warn("unexpected capacity from NodeExpandVolume", + "expected", vol.Capacity, "resp", resp.CapacityBytes) + } + } + + for _, claim := range vol.ReadClaims { + expand(claim) + } + for _, claim := range vol.WriteClaims { + expand(claim) + } + + return mErr.ErrorOrNil() +} + func (v *CSIVolume) Delete(args *structs.CSIVolumeDeleteRequest, reply *structs.CSIVolumeDeleteResponse) error { authErr := v.srv.Authenticate(v.ctx, args) diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 6e20b63e1..9703725ee 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -4,6 +4,7 @@ package nomad import ( + "errors" "fmt" "strings" "sync" @@ -1889,7 +1890,8 @@ func TestCSIVolume_expandVolume(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { fake.NextControllerExpandVolumeResponse = &cstructs.ClientCSIControllerExpandVolumeResponse{ - CapacityBytes: tc.ControllerResp, + CapacityBytes: tc.ControllerResp, + // this also exercises some node expand code, incidentally NodeExpansionRequired: true, } @@ -1914,6 +1916,81 @@ func TestCSIVolume_expandVolume(t *testing.T) { }) } + // a nodeExpandVolume error should fail expandVolume too + t.Run("node error", func(t *testing.T) { + expect := "sad node expand" + fake.NextNodeExpandError = errors.New(expect) + fake.NextControllerExpandVolumeResponse = &cstructs.ClientCSIControllerExpandVolumeResponse{ + CapacityBytes: 2000, + NodeExpansionRequired: true, + } + err = endpoint.expandVolume(vol, plug, &csi.CapacityRange{ + RequiredBytes: 2000, + }) + test.ErrorContains(t, err, expect) + }) + +} + +func TestCSIVolume_nodeExpandVolume(t *testing.T) { + ci.Parallel(t) + + srv, cleanupSrv := TestServer(t, nil) + t.Cleanup(cleanupSrv) + testutil.WaitForLeader(t, srv.RPC) + t.Log("server started 👍") + + c, fake, _, fakeVolID := testClientWithCSI(t, srv) + fakeClaim := fakeCSIClaim(c.NodeID()) + + endpoint := NewCSIVolumeEndpoint(srv, nil) + plug, vol, err := endpoint.volAndPluginLookup(structs.DefaultNamespace, fakeVolID) + must.NoError(t, err) + + // there's not a lot of logic here -- validation has been done prior, + // in (controller) expandVolume and what preceeds it. + cases := []struct { + Name string + Error error + }{ + { + Name: "ok", + }, + { + Name: "not ok", + Error: errors.New("test node expand fail"), + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + fake.NextNodeExpandError = tc.Error + capacity := &csi.CapacityRange{ + RequiredBytes: 10, + LimitBytes: 10, + } + + err = endpoint.nodeExpandVolume(vol, plug, capacity) + + if tc.Error == nil { + test.NoError(t, err) + } else { + must.Error(t, err) + must.ErrorContains(t, err, + fmt.Sprintf("CSI.NodeExpandVolume error: %s", tc.Error)) + } + + req := fake.LastNodeExpandRequest + must.NotNil(t, req, must.Sprint("request should have happened")) + test.Eq(t, fakeVolID, req.VolumeID) + test.Eq(t, capacity, req.Capacity) + test.Eq(t, "fake-csi-plugin", req.PluginID) + test.Eq(t, "fake-csi-external-id", req.ExternalID) + test.Eq(t, fakeClaim, req.Claim) + + }) + } } func TestCSIPluginEndpoint_RegisterViaFingerprint(t *testing.T) { @@ -2266,8 +2343,8 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie t.Helper() m = newMockClientCSI() - plugID = "fake-plugin" - volID = "fake-volume" + plugID = "fake-csi-plugin" + volID = "fake-csi-volume" c, cleanup := client.TestClientWithRPCs(t, func(c *cconfig.Config) { @@ -2316,15 +2393,19 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie // Register a minimum-viable fake volume req := &structs.CSIVolumeRegisterRequest{ Volumes: []*structs.CSIVolume{{ - PluginID: plugID, - ID: volID, - Namespace: structs.DefaultNamespace, + PluginID: plugID, + ID: volID, + ExternalID: "fake-csi-external-id", + Namespace: structs.DefaultNamespace, RequestedCapabilities: []*structs.CSIVolumeCapability{ { - AccessMode: structs.CSIVolumeAccessModeMultiNodeMultiWriter, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, }, }, + WriteClaims: map[string]*structs.CSIVolumeClaim{ + "fake-csi-claim": fakeCSIClaim(c.NodeID()), + }, }}, WriteRequest: structs.WriteRequest{Region: srv.Region()}, } @@ -2333,3 +2414,15 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie return c, m, plugID, volID } + +func fakeCSIClaim(nodeID string) *structs.CSIVolumeClaim { + return &structs.CSIVolumeClaim{ + NodeID: nodeID, + AllocationID: "fake-csi-alloc", + ExternalNodeID: "fake-csi-external-node", + Mode: structs.CSIVolumeClaimWrite, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + State: structs.CSIVolumeClaimStateTaken, + } +} diff --git a/plugins/csi/client.go b/plugins/csi/client.go index 0f101f4c5..736289e9c 100644 --- a/plugins/csi/client.go +++ b/plugins/csi/client.go @@ -926,5 +926,37 @@ func (c *client) NodeUnpublishVolume(ctx context.Context, volumeID, targetPath s } func (c *client) NodeExpandVolume(ctx context.Context, req *NodeExpandVolumeRequest, opts ...grpc.CallOption) (*NodeExpandVolumeResponse, error) { - return nil, nil + if err := req.Validate(); err != nil { + return nil, err + } + if err := c.ensureConnected(ctx); err != nil { + return nil, err + } + + exReq := req.ToCSIRepresentation() + resp, err := c.nodeClient.NodeExpandVolume(ctx, exReq, opts...) + if err != nil { + code := status.Code(err) + switch code { + case codes.InvalidArgument: + return nil, fmt.Errorf( + "requested capabilities not compatible with volume %q: %v", + req.ExternalVolumeID, err) + case codes.NotFound: + return nil, fmt.Errorf("%w: volume %q could not be found: %v", + structs.ErrCSIClientRPCIgnorable, req.ExternalVolumeID, err) + case codes.FailedPrecondition: + return nil, fmt.Errorf("volume %q cannot be expanded while in use: %v", req.ExternalVolumeID, err) + case codes.OutOfRange: + return nil, fmt.Errorf( + "unsupported capacity_range for volume %q: %v", req.ExternalVolumeID, err) + case codes.Internal: + return nil, fmt.Errorf( + "node plugin returned an internal error, check the plugin allocation logs for more information: %v", err) + default: + return nil, fmt.Errorf("node plugin returned an error: %v", err) + } + } + + return &NodeExpandVolumeResponse{resp.GetCapacityBytes()}, nil } diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go index 6947a953e..45bfb4783 100644 --- a/plugins/csi/client_test.go +++ b/plugins/csi/client_test.go @@ -1518,3 +1518,163 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) { }) } } + +func TestClient_RPC_NodeExpandVolume(t *testing.T) { + // minimum valid request + minRequest := &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + } + + cases := []struct { + Name string + Request *NodeExpandVolumeRequest + ExpectCall *csipbv1.NodeExpandVolumeRequest + ResponseErr error + ExpectedErr error + }{ + { + Name: "success min", + Request: minRequest, + ExpectCall: &csipbv1.NodeExpandVolumeRequest{ + VolumeId: "test-vol", + VolumePath: "/test-path", + }, + }, + { + Name: "success full", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + StagingPath: "/test-staging-path", + CapacityRange: &CapacityRange{ + RequiredBytes: 5, + LimitBytes: 10, + }, + Capability: &VolumeCapability{ + AccessType: VolumeAccessTypeMount, + AccessMode: VolumeAccessModeMultiNodeSingleWriter, + MountVolume: &structs.CSIMountOptions{ + FSType: "test-fstype", + MountFlags: []string{"test-flags"}, + }, + }, + }, + ExpectCall: &csipbv1.NodeExpandVolumeRequest{ + VolumeId: "test-vol", + VolumePath: "/test-path", + StagingTargetPath: "/test-staging-path", + CapacityRange: &csipbv1.CapacityRange{ + RequiredBytes: 5, + LimitBytes: 10, + }, + VolumeCapability: &csipbv1.VolumeCapability{ + AccessType: &csipbv1.VolumeCapability_Mount{ + Mount: &csipbv1.VolumeCapability_MountVolume{ + FsType: "test-fstype", + MountFlags: []string{"test-flags"}, + VolumeMountGroup: "", + }}, + AccessMode: &csipbv1.VolumeCapability_AccessMode{ + Mode: csipbv1.VolumeCapability_AccessMode_MULTI_NODE_SINGLE_WRITER}, + }, + }, + }, + + { + Name: "validate missing volume id", + Request: &NodeExpandVolumeRequest{ + TargetPath: "/test-path", + }, + ExpectedErr: errors.New("ExternalVolumeID is required"), + }, + { + Name: "validate missing target path", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-volume", + }, + ExpectedErr: errors.New("TargetPath is required"), + }, + { + Name: "validate min greater than max", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + CapacityRange: &CapacityRange{ + RequiredBytes: 4, + LimitBytes: 2, + }, + }, + ExpectedErr: errors.New("LimitBytes cannot be less than RequiredBytes"), + }, + + { + Name: "grpc error default case", + Request: minRequest, + ResponseErr: status.Errorf(codes.DataLoss, "misc unspecified error"), + ExpectedErr: errors.New("node plugin returned an error: rpc error: code = DataLoss desc = misc unspecified error"), + }, + { + Name: "grpc error invalid argument", + Request: minRequest, + ResponseErr: status.Errorf(codes.InvalidArgument, "sad args"), + ExpectedErr: errors.New("requested capabilities not compatible with volume \"test-vol\": rpc error: code = InvalidArgument desc = sad args"), + }, + { + Name: "grpc error NotFound", + Request: minRequest, + ResponseErr: status.Errorf(codes.NotFound, "does not exist"), + ExpectedErr: errors.New("CSI client error (ignorable): volume \"test-vol\" could not be found: rpc error: code = NotFound desc = does not exist"), + }, + { + Name: "grpc error FailedPrecondition", + Request: minRequest, + ResponseErr: status.Errorf(codes.FailedPrecondition, "unsupported"), + ExpectedErr: errors.New("volume \"test-vol\" cannot be expanded while in use: rpc error: code = FailedPrecondition desc = unsupported"), + }, + { + Name: "grpc error OutOfRange", + Request: minRequest, + ResponseErr: status.Errorf(codes.OutOfRange, "too small"), + ExpectedErr: errors.New("unsupported capacity_range for volume \"test-vol\": rpc error: code = OutOfRange desc = too small"), + }, + { + Name: "grpc error Internal", + Request: minRequest, + ResponseErr: status.Errorf(codes.Internal, "some grpc error"), + ExpectedErr: errors.New("node plugin returned an internal error, check the plugin allocation logs for more information: rpc error: code = Internal desc = some grpc error"), + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + _, _, nc, client := newTestClient(t) + + nc.NextErr = tc.ResponseErr + // the fake client should take ~no time, but set a timeout just in case + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + resp, err := client.NodeExpandVolume(ctx, tc.Request) + if tc.ExpectedErr != nil { + must.EqError(t, err, tc.ExpectedErr.Error()) + return + } + must.NoError(t, err) + must.NotNil(t, resp) + must.Eq(t, tc.ExpectCall, nc.LastExpandVolumeRequest) + + }) + } + + t.Run("connection error", func(t *testing.T) { + c := &client{} // induce c.ensureConnected() error + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + resp, err := c.NodeExpandVolume(ctx, &NodeExpandVolumeRequest{ + ExternalVolumeID: "valid-id", + TargetPath: "/some-path", + }) + must.Nil(t, resp) + must.EqError(t, err, "address is empty") + }) +} diff --git a/plugins/csi/plugin.go b/plugins/csi/plugin.go index df3c9c4fd..e4f5155df 100644 --- a/plugins/csi/plugin.go +++ b/plugins/csi/plugin.go @@ -1020,6 +1020,19 @@ type CapacityRange struct { LimitBytes int64 } +func (c *CapacityRange) Validate() error { + if c == nil { + return nil + } + if c.RequiredBytes == 0 && c.LimitBytes == 0 { + return errors.New("either RequiredBytes or LimitBytes must be set") + } + if c.LimitBytes > 0 && c.LimitBytes < c.RequiredBytes { + return errors.New("LimitBytes cannot be less than RequiredBytes") + } + return nil +} + func (c *CapacityRange) ToCSIRepresentation() *csipbv1.CapacityRange { if c == nil { return nil @@ -1032,11 +1045,24 @@ func (c *CapacityRange) ToCSIRepresentation() *csipbv1.CapacityRange { type NodeExpandVolumeRequest struct { ExternalVolumeID string - RequiredBytes int64 - LimitBytes int64 + CapacityRange *CapacityRange + Capability *VolumeCapability TargetPath string StagingPath string - Capability *VolumeCapability +} + +func (r *NodeExpandVolumeRequest) Validate() error { + var err error + if r.ExternalVolumeID == "" { + err = errors.Join(err, errors.New("ExternalVolumeID is required")) + } + if r.TargetPath == "" { + err = errors.Join(err, errors.New("TargetPath is required")) + } + if e := r.CapacityRange.Validate(); e != nil { + err = errors.Join(err, e) + } + return err } func (r *NodeExpandVolumeRequest) ToCSIRepresentation() *csipbv1.NodeExpandVolumeRequest { @@ -1044,13 +1070,10 @@ func (r *NodeExpandVolumeRequest) ToCSIRepresentation() *csipbv1.NodeExpandVolum return nil } return &csipbv1.NodeExpandVolumeRequest{ - VolumeId: r.ExternalVolumeID, - VolumePath: r.TargetPath, - CapacityRange: &csipbv1.CapacityRange{ - RequiredBytes: r.RequiredBytes, - LimitBytes: r.LimitBytes, - }, + VolumeId: r.ExternalVolumeID, + VolumePath: r.TargetPath, StagingTargetPath: r.StagingPath, + CapacityRange: r.CapacityRange.ToCSIRepresentation(), VolumeCapability: r.Capability.ToCSIRepresentation(), } } diff --git a/plugins/csi/testing/client.go b/plugins/csi/testing/client.go index 7595d79f2..64e9cf8b9 100644 --- a/plugins/csi/testing/client.go +++ b/plugins/csi/testing/client.go @@ -150,6 +150,7 @@ type NodeClient struct { NextPublishVolumeResponse *csipbv1.NodePublishVolumeResponse NextUnpublishVolumeResponse *csipbv1.NodeUnpublishVolumeResponse NextExpandVolumeResponse *csipbv1.NodeExpandVolumeResponse + LastExpandVolumeRequest *csipbv1.NodeExpandVolumeRequest } // NewNodeClient returns a new stub NodeClient @@ -193,5 +194,6 @@ func (c *NodeClient) NodeUnpublishVolume(ctx context.Context, in *csipbv1.NodeUn } func (c *NodeClient) NodeExpandVolume(ctx context.Context, in *csipbv1.NodeExpandVolumeRequest, opts ...grpc.CallOption) (*csipbv1.NodeExpandVolumeResponse, error) { + c.LastExpandVolumeRequest = in return c.NextExpandVolumeResponse, c.NextErr }