open-nomad/nomad/csi_endpoint.go
Tim Gross b03b78b212 csi: server-to-controller publish/unpublish RPCs (#7124)
Nomad servers need to make requests to CSI controller plugins running
on a client for publish/unpublish. The RPC needs to look up the client
node based on the plugin, load balancing across controllers, and then
perform the required client RPC to that node (via server forwarding if
neccessary).
2020-03-23 13:58:30 -04:00

562 lines
16 KiB
Go

package nomad
import (
"fmt"
"math/rand"
"time"
metrics "github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
memdb "github.com/hashicorp/go-memdb"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/nomad/acl"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
)
// CSIVolume wraps the structs.CSIVolume with request data and server context
type CSIVolume struct {
srv *Server
logger log.Logger
}
// QueryACLObj looks up the ACL token in the request and returns the acl.ACL object
// - fallback to node secret ids
func (srv *Server) QueryACLObj(args *structs.QueryOptions, allowNodeAccess bool) (*acl.ACL, error) {
// Lookup the token
aclObj, err := srv.ResolveToken(args.AuthToken)
if err != nil {
// If ResolveToken had an unexpected error return that
if !structs.IsErrTokenNotFound(err) {
return nil, err
}
// If we don't allow access to this endpoint from Nodes, then return token
// not found.
if !allowNodeAccess {
return nil, structs.ErrTokenNotFound
}
ws := memdb.NewWatchSet()
// Attempt to lookup AuthToken as a Node.SecretID since nodes may call
// call this endpoint and don't have an ACL token.
node, stateErr := srv.fsm.State().NodeBySecretID(ws, args.AuthToken)
if stateErr != nil {
// Return the original ResolveToken error with this err
var merr multierror.Error
merr.Errors = append(merr.Errors, err, stateErr)
return nil, merr.ErrorOrNil()
}
// We did not find a Node for this ID, so return Token Not Found.
if node == nil {
return nil, structs.ErrTokenNotFound
}
}
// Return either the users aclObj, or nil if ACLs are disabled.
return aclObj, nil
}
// WriteACLObj calls QueryACLObj for a WriteRequest
func (srv *Server) WriteACLObj(args *structs.WriteRequest, allowNodeAccess bool) (*acl.ACL, error) {
opts := &structs.QueryOptions{
Region: args.RequestRegion(),
Namespace: args.RequestNamespace(),
AuthToken: args.AuthToken,
}
return srv.QueryACLObj(opts, allowNodeAccess)
}
const (
csiVolumeTable = "csi_volumes"
csiPluginTable = "csi_plugins"
)
// replySetIndex sets the reply with the last index that modified the table
func (srv *Server) replySetIndex(table string, reply *structs.QueryMeta) error {
s := srv.fsm.State()
index, err := s.Index(table)
if err != nil {
return err
}
reply.Index = index
// Set the query response
srv.setQueryMeta(reply)
return nil
}
// List replies with CSIVolumes, filtered by ACL access
func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIVolumeListResponse) error {
if done, err := v.srv.forward("CSIVolume.List", args, args, reply); done {
return err
}
allowCSIAccess := acl.NamespaceValidator(acl.NamespaceCapabilityCSIAccess)
aclObj, err := v.srv.QueryACLObj(&args.QueryOptions, false)
if err != nil {
return err
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "volume", "list"}, metricsStart)
ns := args.RequestNamespace()
opts := blockingOptions{
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
// Query all volumes
var err error
var iter memdb.ResultIterator
if args.PluginID != "" {
iter, err = state.CSIVolumesByPluginID(ws, args.PluginID)
} else {
iter, err = state.CSIVolumes(ws)
}
if err != nil {
return err
}
// Collect results, filter by ACL access
var vs []*structs.CSIVolListStub
cache := map[string]bool{}
for {
raw := iter.Next()
if raw == nil {
break
}
vol := raw.(*structs.CSIVolume)
vol, err := state.CSIVolumeDenormalizePlugins(ws, vol)
if err != nil {
return err
}
// Filter on the request namespace to avoid ACL checks by volume
if ns != "" && vol.Namespace != args.RequestNamespace() {
continue
}
// Cache ACL checks QUESTION: are they expensive?
allowed, ok := cache[vol.Namespace]
if !ok {
allowed = allowCSIAccess(aclObj, vol.Namespace)
cache[vol.Namespace] = allowed
}
if allowed {
vs = append(vs, vol.Stub())
}
}
reply.Volumes = vs
return v.srv.replySetIndex(csiVolumeTable, &reply.QueryMeta)
}}
return v.srv.blockingRPC(&opts)
}
// Get fetches detailed information about a specific volume
func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVolumeGetResponse) error {
if done, err := v.srv.forward("CSIVolume.Get", args, args, reply); done {
return err
}
allowCSIAccess := acl.NamespaceValidator(acl.NamespaceCapabilityCSIAccess)
aclObj, err := v.srv.QueryACLObj(&args.QueryOptions, true)
if err != nil {
return err
}
if !allowCSIAccess(aclObj, args.RequestNamespace()) {
return structs.ErrPermissionDenied
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "volume", "get"}, metricsStart)
opts := blockingOptions{
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
vol, err := state.CSIVolumeByID(ws, args.ID)
if err != nil {
return err
}
if vol != nil {
vol, err = state.CSIVolumeDenormalize(ws, vol)
}
if err != nil {
return err
}
reply.Volume = vol
return v.srv.replySetIndex(csiVolumeTable, &reply.QueryMeta)
}}
return v.srv.blockingRPC(&opts)
}
// Register registers a new volume
func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *structs.CSIVolumeRegisterResponse) error {
if done, err := v.srv.forward("CSIVolume.Register", args, args, reply); done {
return err
}
allowCSIVolumeManagement := acl.NamespaceValidator(acl.NamespaceCapabilityCSICreateVolume)
aclObj, err := v.srv.WriteACLObj(&args.WriteRequest, false)
if err != nil {
return err
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "volume", "register"}, metricsStart)
if !allowCSIVolumeManagement(aclObj, args.RequestNamespace()) {
return structs.ErrPermissionDenied
}
// This is the only namespace we ACL checked, force all the volumes to use it
for _, vol := range args.Volumes {
vol.Namespace = args.RequestNamespace()
if err = vol.Validate(); err != nil {
return err
}
}
resp, index, err := v.srv.raftApply(structs.CSIVolumeRegisterRequestType, args)
if err != nil {
v.logger.Error("csi raft apply failed", "error", err, "method", "register")
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
reply.Index = index
v.srv.setQueryMeta(&reply.QueryMeta)
return nil
}
// Deregister removes a set of volumes
func (v *CSIVolume) Deregister(args *structs.CSIVolumeDeregisterRequest, reply *structs.CSIVolumeDeregisterResponse) error {
if done, err := v.srv.forward("CSIVolume.Deregister", args, args, reply); done {
return err
}
allowCSIVolumeManagement := acl.NamespaceValidator(acl.NamespaceCapabilityCSICreateVolume)
aclObj, err := v.srv.WriteACLObj(&args.WriteRequest, false)
if err != nil {
return err
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "volume", "deregister"}, metricsStart)
ns := args.RequestNamespace()
if !allowCSIVolumeManagement(aclObj, ns) {
return structs.ErrPermissionDenied
}
resp, index, err := v.srv.raftApply(structs.CSIVolumeDeregisterRequestType, args)
if err != nil {
v.logger.Error("csi raft apply failed", "error", err, "method", "deregister")
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
reply.Index = index
v.srv.setQueryMeta(&reply.QueryMeta)
return nil
}
// Claim claims a volume
func (v *CSIVolume) Claim(args *structs.CSIVolumeClaimRequest, reply *structs.CSIVolumeClaimResponse) error {
if done, err := v.srv.forward("CSIVolume.Claim", args, args, reply); done {
return err
}
allowCSIAccess := acl.NamespaceValidator(acl.NamespaceCapabilityCSIAccess)
aclObj, err := v.srv.WriteACLObj(&args.WriteRequest, true)
if err != nil {
return err
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "volume", "claim"}, metricsStart)
if !allowCSIAccess(aclObj, args.RequestNamespace()) {
return structs.ErrPermissionDenied
}
// adds a PublishContext from the controller (if any) to the reply
err = v.srv.controllerPublishVolume(args, reply)
if err != nil {
return err
}
resp, index, err := v.srv.raftApply(structs.CSIVolumeClaimRequestType, args)
if err != nil {
v.logger.Error("csi raft apply failed", "error", err, "method", "claim")
return err
}
if respErr, ok := resp.(error); ok {
return respErr
}
reply.Index = index
v.srv.setQueryMeta(&reply.QueryMeta)
return nil
}
// CSIPlugin wraps the structs.CSIPlugin with request data and server context
type CSIPlugin struct {
srv *Server
logger log.Logger
}
// List replies with CSIPlugins, filtered by ACL access
func (v *CSIPlugin) List(args *structs.CSIPluginListRequest, reply *structs.CSIPluginListResponse) error {
if done, err := v.srv.forward("CSIPlugin.List", args, args, reply); done {
return err
}
allowCSIAccess := acl.NamespaceValidator(acl.NamespaceCapabilityCSIAccess)
aclObj, err := v.srv.QueryACLObj(&args.QueryOptions, false)
if err != nil {
return err
}
if !allowCSIAccess(aclObj, args.RequestNamespace()) {
return structs.ErrPermissionDenied
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "plugin", "list"}, metricsStart)
opts := blockingOptions{
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
// Query all plugins
iter, err := state.CSIPlugins(ws)
if err != nil {
return err
}
// Collect results
var ps []*structs.CSIPluginListStub
for {
raw := iter.Next()
if raw == nil {
break
}
plug := raw.(*structs.CSIPlugin)
// FIXME we should filter the ACL access for the plugin's
// namespace, but plugins don't currently have namespaces
ps = append(ps, plug.Stub())
}
reply.Plugins = ps
return v.srv.replySetIndex(csiPluginTable, &reply.QueryMeta)
}}
return v.srv.blockingRPC(&opts)
}
// Get fetches detailed information about a specific plugin
func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPluginGetResponse) error {
if done, err := v.srv.forward("CSIPlugin.Get", args, args, reply); done {
return err
}
allowCSIAccess := acl.NamespaceValidator(acl.NamespaceCapabilityCSIAccess)
aclObj, err := v.srv.QueryACLObj(&args.QueryOptions, false)
if err != nil {
return err
}
if !allowCSIAccess(aclObj, args.RequestNamespace()) {
return structs.ErrPermissionDenied
}
metricsStart := time.Now()
defer metrics.MeasureSince([]string{"nomad", "plugin", "get"}, metricsStart)
opts := blockingOptions{
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
plug, err := state.CSIPluginByID(ws, args.ID)
if err != nil {
return err
}
if plug != nil {
plug, err = state.CSIPluginDenormalize(ws, plug)
}
if err != nil {
return err
}
// FIXME we should re-check the ACL access for the plugin's
// namespace, but plugins don't currently have namespaces
reply.Plugin = plug
return v.srv.replySetIndex(csiPluginTable, &reply.QueryMeta)
}}
return v.srv.blockingRPC(&opts)
}
// controllerPublishVolume sends publish request to the CSI controller
// plugin associated with a volume, if any.
func (srv *Server) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, resp *structs.CSIVolumeClaimResponse) error {
plug, vol, err := srv.volAndPluginLookup(req.VolumeID)
if plug == nil || vol == nil || err != nil {
return err // possibly nil if no controller required
}
method := "ClientCSI.AttachVolume"
cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{
PluginName: plug.ID,
VolumeID: req.VolumeID,
NodeID: req.Allocation.NodeID,
AttachmentMode: vol.AttachmentMode,
AccessMode: vol.AccessMode,
ReadOnly: req.Claim == structs.CSIVolumeClaimRead,
// TODO(tgross): we don't have a way of setting these yet.
// ref https://github.com/hashicorp/nomad/issues/7007
// MountOptions: vol.MountOptions,
}
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}
// CSI controller plugins can block for arbitrarily long times,
// but we need to make sure it completes before we can safely
// mark the volume as claimed and return to the client so it
// can do a `NodePublish`.
err = srv.csiControllerRPC(plug, method, cReq, cResp)
if err != nil {
return err
}
resp.PublishContext = cResp.PublishContext
return nil
}
// controllerUnpublishVolume sends an unpublish request to the CSI
// controller plugin associated with a volume, if any.
func (srv *Server) controllerUnpublishVolume(req *structs.CSIVolumeClaimRequest, nodeID string) error {
plug, vol, err := srv.volAndPluginLookup(req.VolumeID)
if plug == nil || vol == nil || err != nil {
return err // possibly nil if no controller required
}
method := "ClientCSI.DetachVolume"
cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{
PluginName: plug.ID,
VolumeID: req.VolumeID,
NodeID: nodeID,
}
err = srv.csiControllerRPC(plug, method, cReq,
&cstructs.ClientCSIControllerDetachVolumeResponse{})
if err != nil {
return err
}
return nil
}
func (srv *Server) volAndPluginLookup(volID string) (*structs.CSIPlugin, *structs.CSIVolume, error) {
state := srv.fsm.State()
ws := memdb.NewWatchSet()
vol, err := state.CSIVolumeByID(ws, volID)
if err != nil {
return nil, nil, err
}
if vol == nil {
return nil, nil, fmt.Errorf("volume not found: %s", volID)
}
if !vol.ControllerRequired {
return nil, nil, nil
}
// note: we do this same lookup in CSIVolumeByID but then throw
// away the pointer to the plugin rather than attaching it to
// the volume so we have to do it again here.
plug, err := state.CSIPluginByID(ws, vol.PluginID)
if err != nil {
return nil, nil, err
}
if plug == nil {
return nil, nil, fmt.Errorf("plugin not found: %s", vol.PluginID)
}
return plug, vol, nil
}
func (srv *Server) csiControllerRPC(plugin *structs.CSIPlugin, method string, args, reply interface{}) error {
// plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same region/DC
// for the volume.
nodeID, err := srv.nodeForControllerPlugin(plugin)
if err != nil || nodeID == "" {
return err
}
err = findNodeConnAndForward(srv, nodeID, method, args, reply)
if err != nil {
return err
}
if replyErr, ok := reply.(error); ok {
return replyErr
}
return nil
}
// nodeForControllerPlugin returns the node ID for a random controller
// to load-balance long-blocking RPCs across client nodes.
func (srv *Server) nodeForControllerPlugin(plugin *structs.CSIPlugin) (string, error) {
count := len(plugin.Controllers)
if count == 0 {
return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID)
}
snap, err := srv.fsm.State().Snapshot()
if err != nil {
return "", err
}
// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := make([]string, count)
for clientID := range plugin.Controllers {
clientIDs = append(clientIDs, clientID)
}
rand.Shuffle(count, func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
for _, clientID := range clientIDs {
controller := plugin.Controllers[clientID]
if !controller.IsController() {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
err = fmt.Errorf("plugin is not a controller")
continue
}
_, err = getNodeForRpc(snap, clientID)
if err != nil {
continue
}
return clientID, nil
}
return "", err
}