open-nomad/plugins/drivers/server.go

423 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package drivers
import (
"context"
"fmt"
"io"
"math"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-plugin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
)
type driverPluginServer struct {
broker *plugin.GRPCBroker
impl DriverPlugin
}
func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
spec, err := b.impl.TaskConfigSchema()
if err != nil {
return nil, err
}
resp := &proto.TaskConfigSchemaResponse{
Spec: spec,
}
return resp, nil
}
func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
caps, err := b.impl.Capabilities()
if err != nil {
return nil, err
}
resp := &proto.CapabilitiesResponse{
Capabilities: &proto.DriverCapabilities{
SendSignals: caps.SendSignals,
Exec: caps.Exec,
MustCreateNetwork: caps.MustInitiateNetwork,
NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{},
RemoteTasks: caps.RemoteTasks,
},
}
switch caps.FSIsolation {
case FSIsolationNone:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
case FSIsolationChroot:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
case FSIsolationImage:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
default:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
}
for _, mode := range caps.NetIsolationModes {
resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode))
}
return resp, nil
}
func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
ctx := srv.Context()
ch, err := b.impl.Fingerprint(ctx)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case f, ok := <-ch:
if !ok {
return nil
}
resp := &proto.FingerprintResponse{
Attributes: dstructs.ConvertStructAttributeMap(f.Attributes),
Health: healthStateToProto(f.Health),
HealthDescription: f.HealthDescription,
}
if err := srv.Send(resp); err != nil {
return err
}
}
}
}
func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
if err != nil {
return nil, err
}
return &proto.RecoverTaskResponse{}, nil
}
func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return nil, st.Err()
}
return nil, err
}
var pbNet *proto.NetworkOverride
if net != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: net.IP,
AutoAdvertise: net.AutoAdvertise,
}
for k, v := range net.PortMap {
if v > math.MaxInt32 {
return nil, fmt.Errorf("port map out of bounds")
}
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.StartTaskResponse{
Handle: taskHandleToProto(handle),
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
ch, err := b.impl.WaitTask(ctx, req.TaskId)
if err != nil {
return nil, err
}
var ok bool
var result *ExitResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case result, ok = <-ch:
if !ok {
return &proto.WaitTaskResponse{
Err: "channel closed",
}, nil
}
}
var errStr string
if result.Err != nil {
errStr = result.Err.Error()
}
resp := &proto.WaitTaskResponse{
Err: errStr,
Result: &proto.ExitResult{
ExitCode: int32(result.ExitCode),
Signal: int32(result.Signal),
OomKilled: result.OOMKilled,
},
}
return resp, nil
}
func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
if err != nil {
return nil, err
}
return &proto.StopTaskResponse{}, nil
}
func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
err := b.impl.DestroyTask(req.TaskId, req.Force)
if err != nil {
return nil, err
}
return &proto.DestroyTaskResponse{}, nil
}
func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
status, err := b.impl.InspectTask(req.TaskId)
if err != nil {
return nil, err
}
protoStatus, err := taskStatusToProto(status)
if err != nil {
return nil, err
}
var pbNet *proto.NetworkOverride
if status.NetworkOverride != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: status.NetworkOverride.IP,
AutoAdvertise: status.NetworkOverride.AutoAdvertise,
}
for k, v := range status.NetworkOverride.PortMap {
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.InspectTaskResponse{
Task: protoStatus,
Driver: &proto.TaskDriverStatus{
Attributes: status.DriverAttributes,
},
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
interval, err := ptypes.Duration(req.CollectionInterval)
if err != nil {
return fmt.Errorf("failed to parse collection interval: %v", err)
}
ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return st.Err()
}
return err
}
for stats := range ch {
pb, err := TaskStatsToProto(stats)
if err != nil {
return fmt.Errorf("failed to encode task stats: %v", err)
}
if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
if err != nil {
return nil, err
}
resp := &proto.ExecTaskResponse{
Stdout: result.Stdout,
Stderr: result.Stderr,
Result: exitResultToProto(result.ExitResult),
}
return resp, nil
}
func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
msg, err := server.Recv()
if err != nil {
return fmt.Errorf("failed to receive initial message: %v", err)
}
if msg.Setup == nil {
return fmt.Errorf("first message should always be setup")
}
if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
return impl.ExecTaskStreamingRaw(server.Context(),
msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
server)
}
d, ok := b.impl.(ExecTaskStreamingDriver)
if !ok {
return fmt.Errorf("driver does not support exec")
}
execOpts, errCh := StreamToExecOptions(server.Context(),
msg.Setup.Command, msg.Setup.Tty,
server)
result, err := d.ExecTaskStreaming(server.Context(),
msg.Setup.TaskId, execOpts)
execOpts.Stdout.Close()
execOpts.Stderr.Close()
if err != nil {
return err
}
// wait for copy to be done
select {
case err = <-errCh:
case <-server.Context().Done():
err = fmt.Errorf("exec timed out: %v", server.Context().Err())
}
if err != nil {
return err
}
server.Send(&ExecTaskStreamingResponseMsg{
Exited: true,
Result: exitResultToProto(result),
})
return err
}
func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
err := b.impl.SignalTask(req.TaskId, req.Signal)
if err != nil {
return nil, err
}
resp := &proto.SignalTaskResponse{}
return resp, nil
}
func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
ch, err := b.impl.TaskEvents(srv.Context())
if err != nil {
return err
}
for {
event := <-ch
if event == nil {
break
}
pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
if err != nil {
return err
}
pbEvent := &proto.DriverTaskEvent{
TaskId: event.TaskID,
AllocId: event.AllocID,
TaskName: event.TaskName,
Timestamp: pbTimestamp,
Message: event.Message,
Annotations: event.Annotations,
}
if err = srv.Send(pbEvent); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("CreateNetwork RPC not supported by driver")
}
spec, created, err := nm.CreateNetwork(req.GetAllocId(), networkCreateRequestFromProto(req))
if err != nil {
return nil, err
}
return &proto.CreateNetworkResponse{
IsolationSpec: NetworkIsolationSpecToProto(spec),
Created: created,
}, nil
}
func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver")
}
err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec))
if err != nil {
return nil, err
}
return &proto.DestroyNetworkResponse{}, nil
}