Backport of CSI: improve controller RPC reliability into release/1.6.x (#18015)

This pull request was automerged via backport-assistant
This commit is contained in:
hc-github-team-nomad-core 2023-07-20 13:52:27 -05:00 committed by GitHub
parent 180ea2df9c
commit e891026755
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 186 additions and 31 deletions

11
.changelog/17996.txt Normal file
View File

@ -0,0 +1,11 @@
```release-note:bug
csi: Fixed a bug in sending concurrent requests to CSI controller plugins by serializing them per plugin
```
```release-note:bug
csi: Fixed a bug where CSI controller requests could be sent to unhealthy plugins
```
```release-note:bug
csi: Fixed a bug where CSI controller requests could not be sent to controllers on nodes ineligible for scheduling
```

View File

@ -4,8 +4,9 @@
package nomad package nomad
import ( import (
"errors"
"fmt" "fmt"
"math/rand" "sort"
"strings" "strings"
"time" "time"
@ -262,9 +263,9 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
// note: plugin IDs are not scoped to region/DC but volumes are. // note: plugin IDs are not scoped to region but volumes are. so any Nomad
// so any node we get for a controller is already in the same // client we get for a controller is already in the same region for the
// region/DC for the volume. // volume.
plugin, err := snap.CSIPluginByID(ws, pluginID) plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err) return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
@ -273,31 +274,55 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
return nil, fmt.Errorf("plugin missing: %s", pluginID) return nil, fmt.Errorf("plugin missing: %s", pluginID)
} }
// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := []string{} clientIDs := []string{}
for clientID, controller := range plugin.Controllers { if len(plugin.Controllers) == 0 {
if !controller.IsController() { return nil, fmt.Errorf("failed to find instances of controller plugin %q", pluginID)
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
continue
}
node, err := getNodeForRpc(snap, clientID)
if err == nil && node != nil && node.Ready() {
clientIDs = append(clientIDs, clientID)
}
}
if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
} }
rand.Shuffle(len(clientIDs), func(i, j int) { var merr error
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i] for clientID, controller := range plugin.Controllers {
}) if !controller.IsController() {
// we don't have separate types for CSIInfo depending on whether
// it's a controller or node. this error should never make it to
// production
merr = errors.Join(merr, fmt.Errorf(
"plugin instance %q is not a controller but was registered as one - this is always a bug", controller.AllocID))
continue
}
if !controller.Healthy {
merr = errors.Join(merr, fmt.Errorf(
"plugin instance %q is not healthy", controller.AllocID))
continue
}
node, err := getNodeForRpc(snap, clientID)
if err != nil || node == nil {
merr = errors.Join(merr, fmt.Errorf(
"cannot find node %q for plugin instance %q", clientID, controller.AllocID))
continue
}
if node.Status != structs.NodeStatusReady {
merr = errors.Join(merr, fmt.Errorf(
"node %q for plugin instance %q is not ready", clientID, controller.AllocID))
continue
}
clientIDs = append(clientIDs, clientID)
}
if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q: %v",
pluginID, merr)
}
// Many plugins don't handle concurrent requests as described in the spec,
// and have undocumented expectations of using k8s-specific sidecars to
// leader elect. Sort the client IDs so that we prefer sending requests to
// the same controller to hack around this.
clientIDs = sort.StringSlice(clientIDs)
return clientIDs, nil return clientIDs, nil
} }

View File

@ -4,6 +4,7 @@
package nomad package nomad
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -549,7 +550,9 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest,
cReq.PluginID = plug.ID cReq.PluginID = plug.ID
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{} cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}
err = v.srv.RPC(method, cReq, cResp) err = v.serializedControllerRPC(plug.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil { if err != nil {
if strings.Contains(err.Error(), "FailedPrecondition") { if strings.Contains(err.Error(), "FailedPrecondition") {
return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err) return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err)
@ -586,6 +589,57 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu
return plug, vol, nil return plug, vol, nil
} }
// serializedControllerRPC ensures we're only sending a single controller RPC to
// a given plugin if the RPC can cause conflicting state changes.
//
// The CSI specification says that we SHOULD send no more than one in-flight
// request per *volume* at a time, with an allowance for losing state
// (ex. leadership transitions) which the plugins SHOULD handle gracefully.
//
// In practice many CSI plugins rely on k8s-specific sidecars for serializing
// storage provider API calls globally (ex. concurrently attaching EBS volumes
// to an EC2 instance results in a race for device names). So we have to be much
// more conservative about concurrency in Nomad than the spec allows.
func (v *CSIVolume) serializedControllerRPC(pluginID string, fn func() error) error {
for {
v.srv.volumeControllerLock.Lock()
future := v.srv.volumeControllerFutures[pluginID]
if future == nil {
future, futureDone := context.WithCancel(v.srv.shutdownCtx)
v.srv.volumeControllerFutures[pluginID] = future
v.srv.volumeControllerLock.Unlock()
err := fn()
// close the future while holding the lock and not in a defer so
// that we can ensure we've cleared it from the map before allowing
// anyone else to take the lock and write a new one
v.srv.volumeControllerLock.Lock()
futureDone()
delete(v.srv.volumeControllerFutures, pluginID)
v.srv.volumeControllerLock.Unlock()
return err
} else {
v.srv.volumeControllerLock.Unlock()
select {
case <-future.Done():
continue
case <-v.srv.shutdownCh:
// The csi_hook publish workflow on the client will retry if it
// gets this error. On unpublish, we don't want to block client
// shutdown so we give up on error. The new leader's
// volumewatcher will iterate all the claims at startup to
// detect this and mop up any claims in the NodeDetached state
// (volume GC will run periodically as well)
return structs.ErrNoLeader
}
}
}
}
// allowCSIMount is called on Job register to check mount permission // allowCSIMount is called on Job register to check mount permission
func allowCSIMount(aclObj *acl.ACL, namespace string) bool { func allowCSIMount(aclObj *acl.ACL, namespace string) bool {
return aclObj.AllowPluginRead() && return aclObj.AllowPluginRead() &&
@ -863,8 +917,11 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
Secrets: vol.Secrets, Secrets: vol.Secrets,
} }
req.PluginID = vol.PluginID req.PluginID = vol.PluginID
err = v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
err = v.serializedControllerRPC(vol.PluginID, func() error {
return v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
&cstructs.ClientCSIControllerDetachVolumeResponse{}) &cstructs.ClientCSIControllerDetachVolumeResponse{})
})
if err != nil { if err != nil {
return fmt.Errorf("could not detach from controller: %v", err) return fmt.Errorf("could not detach from controller: %v", err)
} }
@ -1139,7 +1196,9 @@ func (v *CSIVolume) deleteVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug
cReq.PluginID = plugin.ID cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{} cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{}
return v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp) return v.srv.RPC(method, cReq, cResp)
})
} }
func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error { func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error {
@ -1286,7 +1345,9 @@ func (v *CSIVolume) CreateSnapshot(args *structs.CSISnapshotCreateRequest, reply
} }
cReq.PluginID = pluginID cReq.PluginID = pluginID
cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{} cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp) err = v.serializedControllerRPC(pluginID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil { if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err)) multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err))
continue continue
@ -1360,7 +1421,9 @@ func (v *CSIVolume) DeleteSnapshot(args *structs.CSISnapshotDeleteRequest, reply
cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID} cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID}
cReq.PluginID = plugin.ID cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{} cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp) err = v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil { if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err)) multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err))
} }

View File

@ -6,6 +6,7 @@ package nomad
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -21,6 +22,7 @@ import (
cconfig "github.com/hashicorp/nomad/client/config" cconfig "github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs" cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/lib/lang"
"github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
@ -1971,3 +1973,49 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
require.Nil(t, vol) require.Nil(t, vol)
require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2)) require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2))
} }
func TestCSI_SerializedControllerRPC(t *testing.T) {
ci.Parallel(t)
srv, shutdown := TestServer(t, func(c *Config) { c.NumSchedulers = 0 })
defer shutdown()
testutil.WaitForLeader(t, srv.RPC)
var wg sync.WaitGroup
wg.Add(3)
timeCh := make(chan lang.Pair[string, time.Duration])
testFn := func(pluginID string, dur time.Duration) {
defer wg.Done()
c := NewCSIVolumeEndpoint(srv, nil)
now := time.Now()
err := c.serializedControllerRPC(pluginID, func() error {
time.Sleep(dur)
return nil
})
elapsed := time.Since(now)
timeCh <- lang.Pair[string, time.Duration]{pluginID, elapsed}
must.NoError(t, err)
}
go testFn("plugin1", 50*time.Millisecond)
go testFn("plugin2", 50*time.Millisecond)
go testFn("plugin1", 50*time.Millisecond)
totals := map[string]time.Duration{}
for i := 0; i < 3; i++ {
pair := <-timeCh
totals[pair.First] += pair.Second
}
wg.Wait()
// plugin1 RPCs should block each other
must.GreaterEq(t, 150*time.Millisecond, totals["plugin1"])
must.Less(t, 200*time.Millisecond, totals["plugin1"])
// plugin1 RPCs should not block plugin2 RPCs
must.GreaterEq(t, 50*time.Millisecond, totals["plugin2"])
must.Less(t, 100*time.Millisecond, totals["plugin2"])
}

View File

@ -218,6 +218,13 @@ type Server struct {
// volumeWatcher is used to release volume claims // volumeWatcher is used to release volume claims
volumeWatcher *volumewatcher.Watcher volumeWatcher *volumewatcher.Watcher
// volumeControllerFutures is a map of plugin IDs to pending controller RPCs. If
// no RPC is pending for a given plugin, this may be nil.
volumeControllerFutures map[string]context.Context
// volumeControllerLock synchronizes access controllerFutures map
volumeControllerLock sync.Mutex
// keyringReplicator is used to replicate root encryption keys from the // keyringReplicator is used to replicate root encryption keys from the
// leader // leader
keyringReplicator *KeyringReplicator keyringReplicator *KeyringReplicator
@ -445,6 +452,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr
s.logger.Error("failed to create volume watcher", "error", err) s.logger.Error("failed to create volume watcher", "error", err)
return nil, fmt.Errorf("failed to create volume watcher: %v", err) return nil, fmt.Errorf("failed to create volume watcher: %v", err)
} }
s.volumeControllerFutures = map[string]context.Context{}
// Start the eval broker notification system so any subscribers can get // Start the eval broker notification system so any subscribers can get
// updates when the processes SetEnabled is triggered. // updates when the processes SetEnabled is triggered.