Merge pull request #4868 from hashicorp/b-plugin-ctx
Plugin client's handle plugin dying
This commit is contained in:
commit
17e8446484
|
@ -1,6 +1,7 @@
|
|||
package exec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -21,7 +22,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package java
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
@ -23,7 +24,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/drivers"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
netctx "golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) {
|
|||
return capabilities, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) {
|
||||
func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) {
|
||||
ch := make(chan *drivers.Fingerprint)
|
||||
go d.handleFingerprint(ctx, ch)
|
||||
return ch, nil
|
||||
|
@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
|
|||
|
||||
}
|
||||
|
||||
func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) {
|
||||
func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) {
|
||||
handle, ok := d.tasks.Get(taskID)
|
||||
if !ok {
|
||||
return nil, drivers.ErrTaskNotFound
|
||||
|
@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) {
|
||||
func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) {
|
||||
return d.eventer.TaskEvents(ctx)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package qemu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -25,7 +26,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package rawexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -22,7 +23,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -4,6 +4,7 @@ package rkt
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -36,7 +37,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/plugins/shared/loader"
|
||||
rktv1 "github.com/rkt/rkt/api/v1"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -3,17 +3,16 @@
|
|||
package rkt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"os"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/hashicorp/hcl2/hcl"
|
||||
ctestutil "github.com/hashicorp/nomad/client/testutil"
|
||||
"github.com/hashicorp/nomad/helper/testlog"
|
||||
|
@ -26,7 +25,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/hashicorp/nomad/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var _ drivers.DriverPlugin = (*Driver)(nil)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
package eventer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/nomad/plugins/drivers"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -12,10 +12,13 @@ import (
|
|||
// gRPC to communicate to the remote plugin.
|
||||
type BasePluginClient struct {
|
||||
Client proto.BasePluginClient
|
||||
|
||||
// DoneCtx is closed when the plugin exits
|
||||
DoneCtx context.Context
|
||||
}
|
||||
|
||||
func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
|
||||
presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{})
|
||||
presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
|
|||
}
|
||||
|
||||
func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
|
||||
presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{})
|
||||
presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
|
|||
|
||||
func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error {
|
||||
// Send the config
|
||||
_, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{
|
||||
_, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{
|
||||
MsgpackConfig: data,
|
||||
NomadConfig: config.toProto(),
|
||||
})
|
||||
|
|
|
@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error
|
|||
}
|
||||
|
||||
func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
|
||||
return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil
|
||||
return &BasePluginClient{
|
||||
Client: proto.NewBasePluginClient(c),
|
||||
DoneCtx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MsgpackHandle is a shared handle for encoding/decoding of structs
|
||||
|
|
|
@ -9,9 +9,7 @@ import (
|
|||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/hashicorp/nomad/plugins/base"
|
||||
"github.com/hashicorp/nomad/plugins/device/proto"
|
||||
netctx "golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
)
|
||||
|
||||
// devicePluginClient implements the client side of a remote device plugin, using
|
||||
|
@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri
|
|||
// the gRPC stream to a channel. Exits either when context is cancelled or the
|
||||
// stream has an error.
|
||||
func (d *devicePluginClient) handleFingerprint(
|
||||
ctx netctx.Context,
|
||||
ctx context.Context,
|
||||
stream proto.DevicePlugin_FingerprintClient,
|
||||
out chan *FingerprintResponse) {
|
||||
|
||||
defer close(out)
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
out <- &FingerprintResponse{
|
||||
Error: d.handleStreamErr(err, ctx),
|
||||
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
|
||||
}
|
||||
}
|
||||
|
||||
// End the stream
|
||||
close(out)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the response
|
||||
out <- &FingerprintResponse{
|
||||
f := &FingerprintResponse{
|
||||
Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()),
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- f:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration)
|
|||
// the gRPC stream to a channel. Exits either when context is cancelled or the
|
||||
// stream has an error.
|
||||
func (d *devicePluginClient) handleStats(
|
||||
ctx netctx.Context,
|
||||
ctx context.Context,
|
||||
stream proto.DevicePlugin_StatsClient,
|
||||
out chan *StatsResponse) {
|
||||
|
||||
defer close(out)
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
out <- &StatsResponse{
|
||||
Error: d.handleStreamErr(err, ctx),
|
||||
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
|
||||
}
|
||||
}
|
||||
|
||||
// End the stream
|
||||
close(out)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the response
|
||||
out <- &StatsResponse{
|
||||
s := &StatsResponse{
|
||||
Groups: convertProtoDeviceGroupsStats(resp.GetGroups()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamErr is used to handle a non io.EOF error in a stream. It handles
|
||||
// detecting if the plugin has shutdown
|
||||
func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine if the error is because the plugin shutdown
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
|
||||
// Potentially wait a little before returning an error so we can detect
|
||||
// the exit
|
||||
select {
|
||||
case <-d.doneCtx.Done():
|
||||
err = base.ErrPluginShutdown
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
|
||||
// There is no guarantee that the select will choose the
|
||||
// doneCtx first so we have to double check
|
||||
select {
|
||||
case <-d.doneCtx.Done():
|
||||
err = base.ErrPluginShutdown
|
||||
default:
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
// Its okay to wait a while since the connection isn't available and
|
||||
// on local host it is likely shutting down. It is not expected for
|
||||
// this to ever reach even close to 3 seconds.
|
||||
return
|
||||
case out <- s:
|
||||
}
|
||||
|
||||
// It is an error we don't know how to handle, so return it
|
||||
return err
|
||||
}
|
||||
|
||||
// Context was cancelled
|
||||
if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled {
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker
|
|||
doneCtx: ctx,
|
||||
client: proto.NewDevicePluginClient(c),
|
||||
BasePluginClient: &base.BasePluginClient{
|
||||
Client: bproto.NewBasePluginClient(c),
|
||||
Client: bproto.NewBasePluginClient(c),
|
||||
DoneCtx: ctx,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
package drivers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/LK4D4/joincontext"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
cstructs "github.com/hashicorp/nomad/client/structs"
|
||||
"github.com/hashicorp/nomad/plugins/base"
|
||||
"github.com/hashicorp/nomad/plugins/drivers/proto"
|
||||
"github.com/hashicorp/nomad/plugins/shared"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var _ DriverPlugin = &driverPluginClient{}
|
||||
|
@ -22,12 +23,15 @@ type driverPluginClient struct {
|
|||
|
||||
client proto.DriverClient
|
||||
logger hclog.Logger
|
||||
|
||||
// doneCtx is closed when the plugin exits
|
||||
doneCtx context.Context
|
||||
}
|
||||
|
||||
func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
|
||||
req := &proto.TaskConfigSchemaRequest{}
|
||||
|
||||
resp, err := d.client.TaskConfigSchema(context.Background(), req)
|
||||
resp, err := d.client.TaskConfigSchema(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
|
|||
func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
|
||||
req := &proto.CapabilitiesRequest{}
|
||||
|
||||
resp, err := d.client.Capabilities(context.Background(), req)
|
||||
resp, err := d.client.Capabilities(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
|
|||
func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) {
|
||||
req := &proto.FingerprintRequest{}
|
||||
|
||||
// Join the passed context and the shutdown context
|
||||
ctx, _ = joincontext.Join(ctx, d.doneCtx)
|
||||
|
||||
stream, err := d.client.Fingerprint(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan *Fingerprint)
|
||||
ch := make(chan *Fingerprint, 1)
|
||||
go d.handleFingerprint(ctx, ch, stream)
|
||||
|
||||
return ch, nil
|
||||
|
@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
|
|||
defer close(ch)
|
||||
for {
|
||||
pb, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}:
|
||||
if err != io.EOF {
|
||||
d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err)
|
||||
ch <- &Fingerprint{
|
||||
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
|
||||
}
|
||||
}
|
||||
|
||||
// End the stream
|
||||
return
|
||||
}
|
||||
|
||||
f := &Fingerprint{
|
||||
Attributes: pb.Attributes,
|
||||
Health: healthStateFromProto(pb.Health),
|
||||
|
@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
|
|||
func (d *driverPluginClient) RecoverTask(h *TaskHandle) error {
|
||||
req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)}
|
||||
|
||||
_, err := d.client.RecoverTask(context.Background(), req)
|
||||
_, err := d.client.RecoverTask(d.doneCtx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
|
|||
Task: taskConfigToProto(c),
|
||||
}
|
||||
|
||||
resp, err := d.client.StartTask(context.Background(), req)
|
||||
resp, err := d.client.StartTask(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
|
|||
// the same task without issue.
|
||||
func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) {
|
||||
ch := make(chan *ExitResult)
|
||||
|
||||
// Join the passed context and the shutdown context
|
||||
ctx, _ = joincontext.Join(ctx, d.doneCtx)
|
||||
|
||||
go d.handleWaitTask(ctx, id, ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign
|
|||
Signal: signal,
|
||||
}
|
||||
|
||||
_, err := d.client.StopTask(context.Background(), req)
|
||||
_, err := d.client.StopTask(d.doneCtx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
|
|||
Force: force,
|
||||
}
|
||||
|
||||
_, err := d.client.DestroyTask(context.Background(), req)
|
||||
_, err := d.client.DestroyTask(d.doneCtx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
|
|||
func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
|
||||
req := &proto.InspectTaskRequest{TaskId: taskID}
|
||||
|
||||
resp, err := d.client.InspectTask(context.Background(), req)
|
||||
resp, err := d.client.InspectTask(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
|
|||
func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
|
||||
req := &proto.TaskStatsRequest{TaskId: taskID}
|
||||
|
||||
resp, err := d.client.TaskStats(context.Background(), req)
|
||||
resp, err := d.client.TaskStats(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa
|
|||
// tasks such as lifecycle events, terminal errors, etc.
|
||||
func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) {
|
||||
req := &proto.TaskEventsRequest{}
|
||||
|
||||
// Join the passed context and the shutdown context
|
||||
ctx, _ = joincontext.Join(ctx, d.doneCtx)
|
||||
|
||||
stream, err := d.client.TaskEvents(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := make(chan *TaskEvent)
|
||||
go d.handleTaskEvents(ch, stream)
|
||||
ch := make(chan *TaskEvent, 1)
|
||||
go d.handleTaskEvents(ctx, ch, stream)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
|
||||
func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
|
||||
defer close(ch)
|
||||
for {
|
||||
ev, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
|
||||
ch <- &TaskEvent{Err: err}
|
||||
break
|
||||
if err != io.EOF {
|
||||
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
|
||||
ch <- &TaskEvent{
|
||||
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
|
||||
}
|
||||
}
|
||||
|
||||
// End the stream
|
||||
return
|
||||
}
|
||||
|
||||
timestamp, _ := ptypes.Timestamp(ev.Timestamp)
|
||||
event := &TaskEvent{
|
||||
TaskID: ev.TaskId,
|
||||
|
@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D
|
|||
Message: ev.Message,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
ch <- event
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ch <- event:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error {
|
|||
TaskId: taskID,
|
||||
Signal: signal,
|
||||
}
|
||||
_, err := d.client.SignalTask(context.Background(), req)
|
||||
_, err := d.client.SignalTask(d.doneCtx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time.
|
|||
Timeout: ptypes.DurationProto(timeout),
|
||||
}
|
||||
|
||||
resp, err := d.client.ExecTask(context.Background(), req)
|
||||
resp, err := d.client.ExecTask(d.doneCtx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package drivers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
@ -14,7 +15,6 @@ import (
|
|||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/zclconf/go-cty/cty"
|
||||
"github.com/zclconf/go-cty/cty/msgpack"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// DriverPlugin is the interface with drivers will implement. It is also
|
||||
|
|
|
@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err
|
|||
func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
|
||||
return &driverPluginClient{
|
||||
BasePluginClient: &base.BasePluginClient{
|
||||
Client: baseproto.NewBasePluginClient(c),
|
||||
DoneCtx: ctx,
|
||||
Client: baseproto.NewBasePluginClient(c),
|
||||
},
|
||||
client: proto.NewDriverClient(c),
|
||||
logger: p.logger,
|
||||
client: proto.NewDriverClient(c),
|
||||
logger: p.logger,
|
||||
doneCtx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package drivers
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -10,7 +11,6 @@ import (
|
|||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ugorji/go/codec"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type testDriverState struct {
|
||||
|
|
|
@ -4,13 +4,12 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
cstructs "github.com/hashicorp/nomad/client/structs"
|
||||
"github.com/hashicorp/nomad/plugins/drivers/proto"
|
||||
context "golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type driverPluginServer struct {
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
package drivers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/nomad/client/allocdir"
|
||||
|
@ -21,6 +18,8 @@ import (
|
|||
"github.com/hashicorp/nomad/helper/uuid"
|
||||
"github.com/hashicorp/nomad/plugins/base"
|
||||
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type DriverHarness struct {
|
||||
|
|
61
plugins/shared/grpc_utils.go
Normal file
61
plugins/shared/grpc_utils.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/nomad/plugins/base"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles
|
||||
// detecting if the plugin has shutdown via the passeed pluginCtx. The
|
||||
// parameters are:
|
||||
// - err: the error returned from the streaming RPC
|
||||
// - reqCtx: the context passed to the streaming request
|
||||
// - pluginCtx: the plugins done ctx used to detect the plugin dying
|
||||
//
|
||||
// The return values are:
|
||||
// - base.ErrPluginShutdown if the error is because the plugin shutdown
|
||||
// - context.Canceled if the reqCtx is canceled
|
||||
// - The original error
|
||||
func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine if the error is because the plugin shutdown
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
|
||||
// Potentially wait a little before returning an error so we can detect
|
||||
// the exit
|
||||
select {
|
||||
case <-pluginCtx.Done():
|
||||
err = base.ErrPluginShutdown
|
||||
case <-reqCtx.Done():
|
||||
err = reqCtx.Err()
|
||||
|
||||
// There is no guarantee that the select will choose the
|
||||
// doneCtx first so we have to double check
|
||||
select {
|
||||
case <-pluginCtx.Done():
|
||||
err = base.ErrPluginShutdown
|
||||
default:
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
// Its okay to wait a while since the connection isn't available and
|
||||
// on local host it is likely shutting down. It is not expected for
|
||||
// this to ever reach even close to 3 seconds.
|
||||
}
|
||||
|
||||
// It is an error we don't know how to handle, so return it
|
||||
return err
|
||||
}
|
||||
|
||||
// Context was cancelled
|
||||
if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled {
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
Loading…
Reference in a new issue