open-nomad/plugins/drivers/server.go
Alex Dadgar 693f244cce Plugin client's handle plugin dying
This PR plumbs the plugins done ctx through the base and driver plugin
clients (device already had it). Further, it adds generic handling of
gRPC stream errors.
2018-11-12 17:09:27 -08:00

291 lines
6.7 KiB
Go

package drivers
import (
"fmt"
"io"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
context "golang.org/x/net/context"
)
type driverPluginServer struct {
broker *plugin.GRPCBroker
impl DriverPlugin
logger hclog.Logger
}
func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
spec, err := b.impl.TaskConfigSchema()
if err != nil {
return nil, err
}
resp := &proto.TaskConfigSchemaResponse{
Spec: spec,
}
return resp, nil
}
func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
caps, err := b.impl.Capabilities()
if err != nil {
return nil, err
}
resp := &proto.CapabilitiesResponse{
Capabilities: &proto.DriverCapabilities{
SendSignals: caps.SendSignals,
Exec: caps.Exec,
},
}
switch caps.FSIsolation {
case cstructs.FSIsolationNone:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
case cstructs.FSIsolationChroot:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
case cstructs.FSIsolationImage:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
default:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
}
return resp, nil
}
func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
ctx := srv.Context()
ch, err := b.impl.Fingerprint(ctx)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case f, ok := <-ch:
if !ok {
return nil
}
resp := &proto.FingerprintResponse{
Attributes: f.Attributes,
Health: healthStateToProto(f.Health),
HealthDescription: f.HealthDescription,
}
if err := srv.Send(resp); err != nil {
return err
}
}
}
}
func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
if err != nil {
return nil, err
}
return &proto.RecoverTaskResponse{}, nil
}
func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
if err != nil {
return nil, err
}
var pbNet *proto.NetworkOverride
if net != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: net.IP,
AutoAdvertise: net.AutoAdvertise,
}
for k, v := range net.PortMap {
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.StartTaskResponse{
Handle: taskHandleToProto(handle),
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
ch, err := b.impl.WaitTask(ctx, req.TaskId)
if err != nil {
return nil, err
}
var ok bool
var result *ExitResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case result, ok = <-ch:
if !ok {
return &proto.WaitTaskResponse{
Err: "channel closed",
}, nil
}
}
var errStr string
if result.Err != nil {
errStr = result.Err.Error()
}
resp := &proto.WaitTaskResponse{
Err: errStr,
Result: &proto.ExitResult{
ExitCode: int32(result.ExitCode),
Signal: int32(result.Signal),
OomKilled: result.OOMKilled,
},
}
return resp, nil
}
func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
if err != nil {
return nil, err
}
return &proto.StopTaskResponse{}, nil
}
func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
err := b.impl.DestroyTask(req.TaskId, req.Force)
if err != nil {
return nil, err
}
return &proto.DestroyTaskResponse{}, nil
}
func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
status, err := b.impl.InspectTask(req.TaskId)
if err != nil {
return nil, err
}
protoStatus, err := taskStatusToProto(status)
if err != nil {
return nil, err
}
var pbNet *proto.NetworkOverride
if status.NetworkOverride != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: status.NetworkOverride.IP,
AutoAdvertise: status.NetworkOverride.AutoAdvertise,
}
for k, v := range status.NetworkOverride.PortMap {
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.InspectTaskResponse{
Task: protoStatus,
Driver: &proto.TaskDriverStatus{
Attributes: status.DriverAttributes,
},
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) TaskStats(ctx context.Context, req *proto.TaskStatsRequest) (*proto.TaskStatsResponse, error) {
stats, err := b.impl.TaskStats(req.TaskId)
if err != nil {
return nil, err
}
pb, err := taskStatsToProto(stats)
if err != nil {
return nil, fmt.Errorf("failed to encode task stats: %v", err)
}
resp := &proto.TaskStatsResponse{
Stats: pb,
}
return resp, nil
}
func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
if err != nil {
return nil, err
}
resp := &proto.ExecTaskResponse{
Stdout: result.Stdout,
Stderr: result.Stderr,
Result: exitResultToProto(result.ExitResult),
}
return resp, nil
}
func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
err := b.impl.SignalTask(req.TaskId, req.Signal)
if err != nil {
return nil, err
}
resp := &proto.SignalTaskResponse{}
return resp, nil
}
func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
ch, err := b.impl.TaskEvents(srv.Context())
if err != nil {
return err
}
for {
event := <-ch
if event == nil {
break
}
pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
if err != nil {
return err
}
pbEvent := &proto.DriverTaskEvent{
TaskId: event.TaskID,
Timestamp: pbTimestamp,
Message: event.Message,
Annotations: event.Annotations,
}
if err = srv.Send(pbEvent); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}