open-nomad/plugins/drivers/server.go
Tim Gross eb06c25d5f
deps: remove deprecated net/context (#13932)
The `golang.org/x/net/context` package was merged into the stdlib as of go
1.7. Update the imports to use the identical stdlib version. Clean up import
blocks for the impacted files to remove unnecessary package aliasing.
2022-07-28 14:46:56 -04:00

420 lines
10 KiB
Go

package drivers
import (
"context"
"fmt"
"io"
"math"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-plugin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
)
type driverPluginServer struct {
broker *plugin.GRPCBroker
impl DriverPlugin
}
func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
spec, err := b.impl.TaskConfigSchema()
if err != nil {
return nil, err
}
resp := &proto.TaskConfigSchemaResponse{
Spec: spec,
}
return resp, nil
}
func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
caps, err := b.impl.Capabilities()
if err != nil {
return nil, err
}
resp := &proto.CapabilitiesResponse{
Capabilities: &proto.DriverCapabilities{
SendSignals: caps.SendSignals,
Exec: caps.Exec,
MustCreateNetwork: caps.MustInitiateNetwork,
NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{},
RemoteTasks: caps.RemoteTasks,
},
}
switch caps.FSIsolation {
case FSIsolationNone:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
case FSIsolationChroot:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
case FSIsolationImage:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
default:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
}
for _, mode := range caps.NetIsolationModes {
resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode))
}
return resp, nil
}
func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
ctx := srv.Context()
ch, err := b.impl.Fingerprint(ctx)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case f, ok := <-ch:
if !ok {
return nil
}
resp := &proto.FingerprintResponse{
Attributes: dstructs.ConvertStructAttributeMap(f.Attributes),
Health: healthStateToProto(f.Health),
HealthDescription: f.HealthDescription,
}
if err := srv.Send(resp); err != nil {
return err
}
}
}
}
func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
if err != nil {
return nil, err
}
return &proto.RecoverTaskResponse{}, nil
}
func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return nil, st.Err()
}
return nil, err
}
var pbNet *proto.NetworkOverride
if net != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: net.IP,
AutoAdvertise: net.AutoAdvertise,
}
for k, v := range net.PortMap {
if v > math.MaxInt32 {
return nil, fmt.Errorf("port map out of bounds")
}
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.StartTaskResponse{
Handle: taskHandleToProto(handle),
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
ch, err := b.impl.WaitTask(ctx, req.TaskId)
if err != nil {
return nil, err
}
var ok bool
var result *ExitResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case result, ok = <-ch:
if !ok {
return &proto.WaitTaskResponse{
Err: "channel closed",
}, nil
}
}
var errStr string
if result.Err != nil {
errStr = result.Err.Error()
}
resp := &proto.WaitTaskResponse{
Err: errStr,
Result: &proto.ExitResult{
ExitCode: int32(result.ExitCode),
Signal: int32(result.Signal),
OomKilled: result.OOMKilled,
},
}
return resp, nil
}
func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
if err != nil {
return nil, err
}
return &proto.StopTaskResponse{}, nil
}
func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
err := b.impl.DestroyTask(req.TaskId, req.Force)
if err != nil {
return nil, err
}
return &proto.DestroyTaskResponse{}, nil
}
func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
status, err := b.impl.InspectTask(req.TaskId)
if err != nil {
return nil, err
}
protoStatus, err := taskStatusToProto(status)
if err != nil {
return nil, err
}
var pbNet *proto.NetworkOverride
if status.NetworkOverride != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: status.NetworkOverride.IP,
AutoAdvertise: status.NetworkOverride.AutoAdvertise,
}
for k, v := range status.NetworkOverride.PortMap {
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.InspectTaskResponse{
Task: protoStatus,
Driver: &proto.TaskDriverStatus{
Attributes: status.DriverAttributes,
},
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
interval, err := ptypes.Duration(req.CollectionInterval)
if err != nil {
return fmt.Errorf("failed to parse collection interval: %v", err)
}
ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return st.Err()
}
return err
}
for stats := range ch {
pb, err := TaskStatsToProto(stats)
if err != nil {
return fmt.Errorf("failed to encode task stats: %v", err)
}
if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
if err != nil {
return nil, err
}
resp := &proto.ExecTaskResponse{
Stdout: result.Stdout,
Stderr: result.Stderr,
Result: exitResultToProto(result.ExitResult),
}
return resp, nil
}
func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
msg, err := server.Recv()
if err != nil {
return fmt.Errorf("failed to receive initial message: %v", err)
}
if msg.Setup == nil {
return fmt.Errorf("first message should always be setup")
}
if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
return impl.ExecTaskStreamingRaw(server.Context(),
msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
server)
}
d, ok := b.impl.(ExecTaskStreamingDriver)
if !ok {
return fmt.Errorf("driver does not support exec")
}
execOpts, errCh := StreamToExecOptions(server.Context(),
msg.Setup.Command, msg.Setup.Tty,
server)
result, err := d.ExecTaskStreaming(server.Context(),
msg.Setup.TaskId, execOpts)
execOpts.Stdout.Close()
execOpts.Stderr.Close()
if err != nil {
return err
}
// wait for copy to be done
select {
case err = <-errCh:
case <-server.Context().Done():
err = fmt.Errorf("exec timed out: %v", server.Context().Err())
}
if err != nil {
return err
}
server.Send(&ExecTaskStreamingResponseMsg{
Exited: true,
Result: exitResultToProto(result),
})
return err
}
func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
err := b.impl.SignalTask(req.TaskId, req.Signal)
if err != nil {
return nil, err
}
resp := &proto.SignalTaskResponse{}
return resp, nil
}
func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
ch, err := b.impl.TaskEvents(srv.Context())
if err != nil {
return err
}
for {
event := <-ch
if event == nil {
break
}
pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
if err != nil {
return err
}
pbEvent := &proto.DriverTaskEvent{
TaskId: event.TaskID,
AllocId: event.AllocID,
TaskName: event.TaskName,
Timestamp: pbTimestamp,
Message: event.Message,
Annotations: event.Annotations,
}
if err = srv.Send(pbEvent); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("CreateNetwork RPC not supported by driver")
}
spec, created, err := nm.CreateNetwork(req.GetAllocId(), networkCreateRequestFromProto(req))
if err != nil {
return nil, err
}
return &proto.CreateNetworkResponse{
IsolationSpec: NetworkIsolationSpecToProto(spec),
Created: created,
}, nil
}
func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver")
}
err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec))
if err != nil {
return nil, err
}
return &proto.DestroyNetworkResponse{}, nil
}