diff --git a/nomad/csi_volume_endpoint.go b/nomad/csi_volume_endpoint.go new file mode 100644 index 000000000..846cee249 --- /dev/null +++ b/nomad/csi_volume_endpoint.go @@ -0,0 +1,255 @@ +package nomad + +import ( + "fmt" + "time" + + metrics "github.com/armon/go-metrics" + log "github.com/hashicorp/go-hclog" + memdb "github.com/hashicorp/go-memdb" + multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +// CSIVolume wraps the structs.CSIVolume with request data and server context +type CSIVolume struct { + srv *Server + logger log.Logger +} + +// QueryACLObj looks up the ACL token in the request and returns the acl.ACL object +// - fallback to node secret ids +func (srv *Server) QueryACLObj(args *structs.QueryOptions) (*acl.ACL, error) { + if args.AuthToken == "" { + return nil, fmt.Errorf("authorization required") + } + + // Lookup the token + aclObj, err := srv.ResolveToken(args.AuthToken) + if err != nil { + // If ResolveToken had an unexpected error return that + return nil, err + } + + if aclObj == nil { + ws := memdb.NewWatchSet() + node, stateErr := srv.fsm.State().NodeBySecretID(ws, args.AuthToken) + if stateErr != nil { + // Return the original ResolveToken error with this err + var merr multierror.Error + merr.Errors = append(merr.Errors, err, stateErr) + return nil, merr.ErrorOrNil() + } + + if node == nil { + return nil, structs.ErrTokenNotFound + } + } + + return aclObj, nil +} + +// WriteACLObj calls QueryACLObj for a WriteRequest +func (srv *Server) WriteACLObj(args *structs.WriteRequest) (*acl.ACL, error) { + opts := &structs.QueryOptions{ + Region: args.RequestRegion(), + Namespace: args.RequestNamespace(), + AuthToken: args.AuthToken, + } + return srv.QueryACLObj(opts) +} + +// replyCSIVolumeIndex sets the reply with the last index that modified the table csi_volumes +func (srv *Server) replySetCSIVolumeIndex(state *state.StateStore, reply *structs.QueryMeta) error { + // Use the last index that affected the table + index, err := state.Index("csi_volumes") + if err != nil { + return err + } + reply.Index = index + + // Set the query response + srv.setQueryMeta(reply) + return nil +} + +// List replies with CSIVolumes, filtered by ACL access +func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIVolumeListResponse) error { + if done, err := v.srv.forward("CSIVolume.List", args, args, reply); done { + return err + } + + aclObj, err := v.srv.QueryACLObj(&args.QueryOptions) + if err != nil { + return err + } + + metricsStart := time.Now() + defer metrics.MeasureSince([]string{"nomad", "volume", "list"}, metricsStart) + + ns := args.RequestNamespace() + opts := blockingOptions{ + queryOpts: &args.QueryOptions, + queryMeta: &reply.QueryMeta, + run: func(ws memdb.WatchSet, state *state.StateStore) error { + // Query all volumes + var err error + var iter memdb.ResultIterator + + if ns == "" && args.Driver == "" { + iter, err = state.CSIVolumes(ws) + } else if args.Driver != "" { + iter, err = state.CSIVolumesByNSDriver(ws, args.Namespace, args.Driver) + } else { + iter, err = state.CSIVolumesByNS(ws, args.Namespace) + } + + if err != nil { + return err + } + + // Collect results, filter by ACL access + var vs []*structs.CSIVolListStub + cache := map[string]bool{} + + for { + raw := iter.Next() + if raw == nil { + break + } + vol := raw.(*structs.CSIVolume) + + // Filter on the request namespace to avoid ACL checks by volume + if ns != "" && vol.Namespace != args.RequestNamespace() { + continue + } + + // Cache ACL checks QUESTION: are they expensive? + allowed, ok := cache[vol.Namespace] + if !ok { + allowed = aclObj.AllowNsOp(vol.Namespace, acl.NamespaceCapabilityCSIAccess) + cache[vol.Namespace] = allowed + } + + if allowed { + vs = append(vs, vol.Stub()) + } + } + reply.Volumes = vs + return v.srv.replySetCSIVolumeIndex(state, &reply.QueryMeta) + }} + return v.srv.blockingRPC(&opts) +} + +// Get fetches detailed information about a specific volume +func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVolumeGetResponse) error { + if done, err := v.srv.forward("CSIVolume.Get", args, args, reply); done { + return err + } + + aclObj, err := v.srv.QueryACLObj(&args.QueryOptions) + if err != nil { + return err + } + + if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityCSIAccess) { + return structs.ErrPermissionDenied + } + + metricsStart := time.Now() + defer metrics.MeasureSince([]string{"nomad", "volume", "get"}, metricsStart) + + ns := args.RequestNamespace() + opts := blockingOptions{ + queryOpts: &args.QueryOptions, + queryMeta: &reply.QueryMeta, + run: func(ws memdb.WatchSet, state *state.StateStore) error { + vol, err := state.CSIVolumeByID(ws, ns, args.ID) + if err != nil { + return err + } + + if vol == nil { + return structs.ErrMissingCSIVolumeID + } + + reply.Volume = vol + return v.srv.replySetCSIVolumeIndex(state, &reply.QueryMeta) + }} + return v.srv.blockingRPC(&opts) +} + +// Register registers a new volume +func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *structs.CSIVolumeRegisterResponse) error { + if done, err := v.srv.forward("CSIVolume.Register", args, args, reply); done { + return err + } + + aclObj, err := v.srv.WriteACLObj(&args.WriteRequest) + if err != nil { + return err + } + + metricsStart := time.Now() + defer metrics.MeasureSince([]string{"nomad", "volume", "register"}, metricsStart) + + if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityCSICreateVolume) { + return structs.ErrPermissionDenied + } + + // This is the only namespace we ACL checked, force all the volumes to use it + for _, v := range args.Volumes { + v.Namespace = args.RequestNamespace() + if err = v.Validate(); err != nil { + return err + } + } + + state := v.srv.State() + index, err := state.LatestIndex() + if err != nil { + return err + } + + err = state.CSIVolumeRegister(index, args.Volumes) + if err != nil { + return err + } + + return v.srv.replySetCSIVolumeIndex(state, &reply.QueryMeta) +} + +// Deregister removes a set of volumes +func (v *CSIVolume) Deregister(args *structs.CSIVolumeDeregisterRequest, reply *structs.CSIVolumeDeregisterResponse) error { + if done, err := v.srv.forward("CSIVolume.Deregister", args, args, reply); done { + return err + } + + aclObj, err := v.srv.WriteACLObj(&args.WriteRequest) + if err != nil { + return err + } + + metricsStart := time.Now() + defer metrics.MeasureSince([]string{"nomad", "volume", "deregister"}, metricsStart) + + ns := args.RequestNamespace() + if !aclObj.AllowNsOp(ns, acl.NamespaceCapabilityCSICreateVolume) { + return structs.ErrPermissionDenied + } + + state := v.srv.State() + index, err := state.LatestIndex() + if err != nil { + return err + } + + err = state.CSIVolumeDeregister(index, ns, args.VolumeIDs) + if err != nil { + return err + } + + return v.srv.replySetCSIVolumeIndex(state, &reply.QueryMeta) +} diff --git a/nomad/csi_volume_endpoint_test.go b/nomad/csi_volume_endpoint_test.go new file mode 100644 index 000000000..35c3c371b --- /dev/null +++ b/nomad/csi_volume_endpoint_test.go @@ -0,0 +1,231 @@ +package nomad + +import ( + "testing" + + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +func TestCSIVolumeEndpoint_Get(t *testing.T) { + t.Parallel() + srv, shutdown := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer shutdown() + testutil.WaitForLeader(t, srv.RPC) + + ns := structs.DefaultNamespace + + state := srv.fsm.State() + state.BootstrapACLTokens(1, 0, mock.ACLManagementToken()) + srv.config.ACLEnabled = true + policy := mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSIAccess}) + validToken := mock.CreatePolicyAndToken(t, state, 1001, "csi-access", policy) + + codec := rpcClient(t, srv) + + // Create the volume + vols := []*structs.CSIVolume{{ + ID: "DEADBEEF-70AD-4672-9178-802BCA500C87", + Namespace: ns, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + Driver: "minnie", + }} + err := state.CSIVolumeRegister(0, vols) + require.NoError(t, err) + + // Create the register request + req := &structs.CSIVolumeGetRequest{ + ID: "DEADBEEF-70AD-4672-9178-802BCA500C87", + QueryOptions: structs.QueryOptions{ + Region: "global", + Namespace: ns, + AuthToken: validToken.SecretID, + }, + } + + var resp structs.CSIVolumeGetResponse + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Get", req, &resp) + require.NoError(t, err) + require.NotEqual(t, 0, resp.Index) + require.Equal(t, vols[0].ID, resp.Volume.ID) +} + +func TestCSIVolumeEndpoint_Register(t *testing.T) { + t.Parallel() + srv, shutdown := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer shutdown() + testutil.WaitForLeader(t, srv.RPC) + + ns := structs.DefaultNamespace + + state := srv.fsm.State() + state.BootstrapACLTokens(1, 0, mock.ACLManagementToken()) + srv.config.ACLEnabled = true + policy := mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSICreateVolume}) + validToken := mock.CreatePolicyAndToken(t, state, 1001, acl.NamespaceCapabilityCSICreateVolume, policy) + + codec := rpcClient(t, srv) + + // Create the volume + vols := []*structs.CSIVolume{{ + ID: "DEADBEEF-70AD-4672-9178-802BCA500C87", + Namespace: "notTheNamespace", + Driver: "minnie", + AccessMode: structs.CSIVolumeAccessModeMultiNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + Topologies: []*structs.CSITopology{{ + Segments: map[string]string{"foo": "bar"}, + }}, + }} + + // Create the register request + req1 := &structs.CSIVolumeRegisterRequest{ + Volumes: vols, + WriteRequest: structs.WriteRequest{ + Region: "global", + Namespace: ns, + AuthToken: validToken.SecretID, + }, + } + resp1 := &structs.CSIVolumeRegisterResponse{} + err := msgpackrpc.CallWithCodec(codec, "CSIVolume.Register", req1, resp1) + require.NoError(t, err) + require.NotEqual(t, 0, resp1.Index) + + // Get the volume back out + policy = mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSIAccess}) + getToken := mock.CreatePolicyAndToken(t, state, 1001, "csi-access", policy) + + req2 := &structs.CSIVolumeGetRequest{ + ID: "DEADBEEF-70AD-4672-9178-802BCA500C87", + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: getToken.SecretID, + }, + } + resp2 := &structs.CSIVolumeGetResponse{} + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Get", req2, resp2) + require.NoError(t, err) + require.NotEqual(t, 0, resp2.Index) + require.Equal(t, vols[0].ID, resp2.Volume.ID) + + // Registration does not update + req1.Volumes[0].Driver = "adam" + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Register", req1, resp1) + require.Error(t, err, "exists") + + // Deregistration works + req3 := &structs.CSIVolumeDeregisterRequest{ + VolumeIDs: []string{"DEADBEEF-70AD-4672-9178-802BCA500C87"}, + WriteRequest: structs.WriteRequest{ + Region: "global", + Namespace: ns, + AuthToken: validToken.SecretID, + }, + } + resp3 := &structs.CSIVolumeDeregisterResponse{} + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Deregister", req3, resp3) + require.NoError(t, err) + + // Volume is missing + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Get", req2, resp2) + require.Error(t, err, "missing") +} + +func TestCSIVolumeEndpoint_List(t *testing.T) { + t.Parallel() + srv, shutdown := TestServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer shutdown() + testutil.WaitForLeader(t, srv.RPC) + + ns := structs.DefaultNamespace + ms := "altNamespace" + + state := srv.fsm.State() + state.BootstrapACLTokens(1, 0, mock.ACLManagementToken()) + srv.config.ACLEnabled = true + + policy := mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSIAccess}) + nsTok := mock.CreatePolicyAndToken(t, state, 1001, "csi-access", policy) + codec := rpcClient(t, srv) + + // Create the volume + vols := []*structs.CSIVolume{{ + ID: "DEADBEEF-70AD-4672-9178-802BCA500C87", + Namespace: ns, + AccessMode: structs.CSIVolumeAccessModeMultiNodeReader, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + Driver: "minnie", + }, { + ID: "BAADF00D-70AD-4672-9178-802BCA500C87", + Namespace: ns, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + Driver: "adam", + }, { + ID: "BEADCEED-70AD-4672-9178-802BCA500C87", + Namespace: ms, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + Driver: "paddy", + }} + err := state.CSIVolumeRegister(0, vols) + require.NoError(t, err) + + var resp structs.CSIVolumeListResponse + + // Query all, ACL only allows ns + req := &structs.CSIVolumeListRequest{ + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: nsTok.SecretID, + }, + } + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.List", req, &resp) + require.NoError(t, err) + require.NotEqual(t, 0, resp.Index) + require.Equal(t, 2, len(resp.Volumes)) + ids := map[string]bool{vols[0].ID: true, vols[1].ID: true} + for _, v := range resp.Volumes { + delete(ids, v.ID) + } + require.Equal(t, 0, len(ids)) + + // Query by Driver + req = &structs.CSIVolumeListRequest{ + Driver: "adam", + QueryOptions: structs.QueryOptions{ + Region: "global", + Namespace: ns, + AuthToken: nsTok.SecretID, + }, + } + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.List", req, &resp) + require.NoError(t, err) + require.Equal(t, 1, len(resp.Volumes)) + require.Equal(t, vols[1].ID, resp.Volumes[0].ID) + + // Query by Driver, ACL filters all results + req = &structs.CSIVolumeListRequest{ + Driver: "paddy", + QueryOptions: structs.QueryOptions{ + Region: "global", + Namespace: ms, + AuthToken: nsTok.SecretID, + }, + } + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.List", req, &resp) + require.NoError(t, err) + require.Equal(t, 0, len(resp.Volumes)) +}