open-nomad/nomad/client_csi_endpoint.go
Tim Gross ce3eef8037
metrics: Add rate metrics to Client CSI endpoints (#15905)
Also tightens up authentication for these endpoints by enforcing the server
certificate name is valid. We protect these endpoints currently by mTLS and
can't use an auth token because these endpoints are (uniquely) called by the
leader and followers for a given node won't have the leader's ephemeral ACL
token. Add a certificate name check that requests come from a server and not a
client, because no client should ever send these RPCs directly.
2023-01-26 16:40:58 -05:00

301 lines
9.7 KiB
Go

package nomad
import (
"fmt"
"math/rand"
"strings"
"time"
metrics "github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
memdb "github.com/hashicorp/go-memdb"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/nomad/structs"
)
// ClientCSI is used to forward RPC requests to the targed Nomad client's
// CSIController endpoint.
type ClientCSI struct {
srv *Server
ctx *RPCContext
logger log.Logger
}
func NewClientCSIEndpoint(srv *Server, ctx *RPCContext) *ClientCSI {
return &ClientCSI{srv: srv, ctx: ctx, logger: srv.logger.Named("client_csi")}
}
func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerAttachVolume",
"ClientCSI.ControllerAttachVolume",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller attach volume: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerValidateVolume",
"ClientCSI.ControllerValidateVolume",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller validate volume: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerDetachVolume",
"ClientCSI.ControllerDetachVolume",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller detach volume: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerCreateVolume(args *cstructs.ClientCSIControllerCreateVolumeRequest, reply *cstructs.ClientCSIControllerCreateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "create_volume"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerCreateVolume",
"ClientCSI.ControllerCreateVolume",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller create volume: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerDeleteVolume(args *cstructs.ClientCSIControllerDeleteVolumeRequest, reply *cstructs.ClientCSIControllerDeleteVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "delete_volume"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerDeleteVolume",
"ClientCSI.ControllerDeleteVolume",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller delete volume: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerListVolumes(args *cstructs.ClientCSIControllerListVolumesRequest, reply *cstructs.ClientCSIControllerListVolumesResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "list_volumes"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerListVolumes",
"ClientCSI.ControllerListVolumes",
structs.RateMetricList,
args, reply)
if err != nil {
return fmt.Errorf("controller list volumes: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerCreateSnapshot(args *cstructs.ClientCSIControllerCreateSnapshotRequest, reply *cstructs.ClientCSIControllerCreateSnapshotResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "create_snapshot"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerCreateSnapshot",
"ClientCSI.ControllerCreateSnapshot",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller create snapshot: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerDeleteSnapshot(args *cstructs.ClientCSIControllerDeleteSnapshotRequest, reply *cstructs.ClientCSIControllerDeleteSnapshotResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "delete_snapshot"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerDeleteSnapshot",
"ClientCSI.ControllerDeleteSnapshot",
structs.RateMetricWrite,
args, reply)
if err != nil {
return fmt.Errorf("controller delete snapshot: %v", err)
}
return nil
}
func (a *ClientCSI) ControllerListSnapshots(args *cstructs.ClientCSIControllerListSnapshotsRequest, reply *cstructs.ClientCSIControllerListSnapshotsResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "list_snapshots"}, time.Now())
err := a.sendCSIControllerRPC(args.PluginID,
"CSI.ControllerListSnapshots",
"ClientCSI.ControllerListSnapshots",
structs.RateMetricList,
args, reply)
if err != nil {
return fmt.Errorf("controller list snapshots: %v", err)
}
return nil
}
func (a *ClientCSI) sendCSIControllerRPC(pluginID, method, fwdMethod, op string, args cstructs.CSIControllerRequest, reply interface{}) error {
// client requests aren't RequestWithIdentity, so we use a placeholder here
// to populate the identity data for metrics
identityReq := &structs.GenericRequest{}
authErr := a.srv.Authenticate(a.ctx, identityReq)
a.srv.MeasureRPCRate("client_csi", op, identityReq)
// only servers can send these client RPCs
err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelServer)
if authErr != nil || err != nil {
return structs.ErrPermissionDenied
}
clientIDs, err := a.clientIDsForController(pluginID)
if err != nil {
return err
}
for _, clientID := range clientIDs {
args.SetControllerNodeID(clientID)
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, fwdMethod, args, reply)
}
err = NodeRpc(state.Session, method, args, reply)
if err == nil {
return nil
}
if a.isRetryable(err) {
a.logger.Debug("failed to reach controller on client",
"nodeID", clientID, "error", err)
continue
}
return err
}
return err
}
// we can retry the same RPC on a different controller in the cases where the
// client has stopped and been GC'd, or where the controller has stopped but
// we don't have the fingerprint update yet
func (a *ClientCSI) isRetryable(err error) bool {
// TODO: msgpack-rpc mangles the error so we lose the wrapping,
// but if that can be fixed upstream we should use that here instead
return strings.Contains(err.Error(), "CSI client error (retryable)") ||
strings.Contains(err.Error(), "Unknown node")
}
func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_node", "detach_volume"}, time.Now())
// client requests aren't RequestWithIdentity, so we use a placeholder here
// to populate the identity data for metrics
identityReq := &structs.GenericRequest{}
authErr := a.srv.Authenticate(a.ctx, identityReq)
a.srv.MeasureRPCRate("client_csi", structs.RateMetricWrite, identityReq)
// only servers can send these client RPCs
err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelServer)
if authErr != nil || err != nil {
return structs.ErrPermissionDenied
}
// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}
_, err = getNodeForRpc(snap, args.NodeID)
if err != nil {
return err
}
// Get the connection to the client
state, ok := a.srv.getNodeConn(args.NodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.NodeID, "ClientCSI.NodeDetachVolume", args, reply)
}
// Make the RPC
err = NodeRpc(state.Session, "CSI.NodeDetachVolume", args, reply)
if err != nil {
return fmt.Errorf("node detach volume: %v", err)
}
return nil
}
// clientIDsForController returns a shuffled list of client IDs where the
// controller plugin is expected to be running.
func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
snap, err := a.srv.State().Snapshot()
if err != nil {
return nil, err
}
if pluginID == "" {
return nil, fmt.Errorf("missing plugin ID")
}
ws := memdb.NewWatchSet()
// note: plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same
// region/DC for the volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
}
if plugin == nil {
return nil, fmt.Errorf("plugin missing: %s", pluginID)
}
// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := []string{}
for clientID, controller := range plugin.Controllers {
if !controller.IsController() {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
continue
}
node, err := getNodeForRpc(snap, clientID)
if err == nil && node != nil && node.Ready() {
clientIDs = append(clientIDs, clientID)
}
}
if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
}
rand.Shuffle(len(clientIDs), func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
return clientIDs, nil
}