Merge pull request #4868 from hashicorp/b-plugin-ctx

Plugin client's handle plugin dying
This commit is contained in:
Alex Dadgar 2018-11-13 10:26:53 -08:00 committed by GitHub
commit 17e8446484
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 165 additions and 110 deletions

View file

@ -1,6 +1,7 @@
package exec
import (
"context"
"fmt"
"os"
"path/filepath"
@ -21,7 +22,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View file

@ -1,6 +1,7 @@
package java
import (
"context"
"fmt"
"os"
"os/exec"
@ -23,7 +24,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View file

@ -16,7 +16,6 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
netctx "golang.org/x/net/context"
)
const (
@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) {
return capabilities, nil
}
func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) {
func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) {
ch := make(chan *drivers.Fingerprint)
go d.handleFingerprint(ctx, ch)
return ch, nil
@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
}
func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) {
func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) {
handle, ok := d.tasks.Get(taskID)
if !ok {
return nil, drivers.ErrTaskNotFound
@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
return nil, nil
}
func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) {
func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) {
return d.eventer.TaskEvents(ctx)
}

View file

@ -1,6 +1,7 @@
package qemu
import (
"context"
"errors"
"fmt"
"net"
@ -25,7 +26,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View file

@ -1,6 +1,7 @@
package rawexec
import (
"context"
"fmt"
"os"
"path/filepath"
@ -22,7 +23,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View file

@ -4,6 +4,7 @@ package rkt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
@ -36,7 +37,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
rktv1 "github.com/rkt/rkt/api/v1"
"golang.org/x/net/context"
)
const (

View file

@ -3,17 +3,16 @@
package rkt
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
"time"
"os"
"bytes"
"github.com/hashicorp/hcl2/hcl"
ctestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/helper/testlog"
@ -26,7 +25,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)
var _ drivers.DriverPlugin = (*Driver)(nil)

View file

@ -1,12 +1,12 @@
package eventer
import (
"context"
"sync"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/plugins/drivers"
"golang.org/x/net/context"
)
var (

View file

@ -12,10 +12,13 @@ import (
// gRPC to communicate to the remote plugin.
type BasePluginClient struct {
Client proto.BasePluginClient
// DoneCtx is closed when the plugin exits
DoneCtx context.Context
}
func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{})
presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{})
if err != nil {
return nil, err
}
@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
}
func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{})
presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{})
if err != nil {
return nil, err
}
@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error {
// Send the config
_, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{
_, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{
MsgpackConfig: data,
NomadConfig: config.toProto(),
})

View file

@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error
}
func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil
return &BasePluginClient{
Client: proto.NewBasePluginClient(c),
DoneCtx: ctx,
}, nil
}
// MsgpackHandle is a shared handle for encoding/decoding of structs

View file

@ -9,9 +9,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device/proto"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/plugins/shared"
)
// devicePluginClient implements the client side of a remote device plugin, using
@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleFingerprint(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_FingerprintClient,
out chan *FingerprintResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &FingerprintResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &FingerprintResponse{
f := &FingerprintResponse{
Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()),
}
select {
case <-ctx.Done():
return
case out <- f:
}
}
}
@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration)
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleStats(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_StatsClient,
out chan *StatsResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &StatsResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &StatsResponse{
s := &StatsResponse{
Groups: convertProtoDeviceGroupsStats(resp.GetGroups()),
}
}
}
// handleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown
func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
case <-ctx.Done():
err = ctx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
return
case out <- s:
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}

View file

@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker
doneCtx: ctx,
client: proto.NewDevicePluginClient(c),
BasePluginClient: &base.BasePluginClient{
Client: bproto.NewBasePluginClient(c),
Client: bproto.NewBasePluginClient(c),
DoneCtx: ctx,
},
}, nil
}

View file

@ -1,18 +1,19 @@
package drivers
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/LK4D4/joincontext"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/drivers/proto"
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"golang.org/x/net/context"
)
var _ DriverPlugin = &driverPluginClient{}
@ -22,12 +23,15 @@ type driverPluginClient struct {
client proto.DriverClient
logger hclog.Logger
// doneCtx is closed when the plugin exits
doneCtx context.Context
}
func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
req := &proto.TaskConfigSchemaRequest{}
resp, err := d.client.TaskConfigSchema(context.Background(), req)
resp, err := d.client.TaskConfigSchema(d.doneCtx, req)
if err != nil {
return nil, err
}
@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
req := &proto.CapabilitiesRequest{}
resp, err := d.client.Capabilities(context.Background(), req)
resp, err := d.client.Capabilities(d.doneCtx, req)
if err != nil {
return nil, err
}
@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) {
req := &proto.FingerprintRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.Fingerprint(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *Fingerprint)
ch := make(chan *Fingerprint, 1)
go d.handleFingerprint(ctx, ch, stream)
return ch, nil
@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
defer close(ch)
for {
pb, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
select {
case <-ctx.Done():
case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}:
if err != io.EOF {
d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err)
ch <- &Fingerprint{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
f := &Fingerprint{
Attributes: pb.Attributes,
Health: healthStateFromProto(pb.Health),
@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
func (d *driverPluginClient) RecoverTask(h *TaskHandle) error {
req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)}
_, err := d.client.RecoverTask(context.Background(), req)
_, err := d.client.RecoverTask(d.doneCtx, req)
return err
}
@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
Task: taskConfigToProto(c),
}
resp, err := d.client.StartTask(context.Background(), req)
resp, err := d.client.StartTask(d.doneCtx, req)
if err != nil {
return nil, nil, err
}
@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
// the same task without issue.
func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) {
ch := make(chan *ExitResult)
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
go d.handleWaitTask(ctx, id, ch)
return ch, nil
}
@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign
Signal: signal,
}
_, err := d.client.StopTask(context.Background(), req)
_, err := d.client.StopTask(d.doneCtx, req)
return err
}
@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
Force: force,
}
_, err := d.client.DestroyTask(context.Background(), req)
_, err := d.client.DestroyTask(d.doneCtx, req)
return err
}
@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
req := &proto.InspectTaskRequest{TaskId: taskID}
resp, err := d.client.InspectTask(context.Background(), req)
resp, err := d.client.InspectTask(d.doneCtx, req)
if err != nil {
return nil, err
}
@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
req := &proto.TaskStatsRequest{TaskId: taskID}
resp, err := d.client.TaskStats(context.Background(), req)
resp, err := d.client.TaskStats(d.doneCtx, req)
if err != nil {
return nil, err
}
@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa
// tasks such as lifecycle events, terminal errors, etc.
func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) {
req := &proto.TaskEventsRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.TaskEvents(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *TaskEvent)
go d.handleTaskEvents(ch, stream)
ch := make(chan *TaskEvent, 1)
go d.handleTaskEvents(ctx, ch, stream)
return ch, nil
}
func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
defer close(ch)
for {
ev, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{Err: err}
break
if err != io.EOF {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
timestamp, _ := ptypes.Timestamp(ev.Timestamp)
event := &TaskEvent{
TaskID: ev.TaskId,
@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D
Message: ev.Message,
Timestamp: timestamp,
}
ch <- event
select {
case <-ctx.Done():
return
case ch <- event:
}
}
}
@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error {
TaskId: taskID,
Signal: signal,
}
_, err := d.client.SignalTask(context.Background(), req)
_, err := d.client.SignalTask(d.doneCtx, req)
return err
}
@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time.
Timeout: ptypes.DurationProto(timeout),
}
resp, err := d.client.ExecTask(context.Background(), req)
resp, err := d.client.ExecTask(d.doneCtx, req)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package drivers
import (
"context"
"fmt"
"path/filepath"
"sort"
@ -14,7 +15,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/msgpack"
"golang.org/x/net/context"
)
// DriverPlugin is the interface with drivers will implement. It is also

View file

@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err
func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &driverPluginClient{
BasePluginClient: &base.BasePluginClient{
Client: baseproto.NewBasePluginClient(c),
DoneCtx: ctx,
Client: baseproto.NewBasePluginClient(c),
},
client: proto.NewDriverClient(c),
logger: p.logger,
client: proto.NewDriverClient(c),
logger: p.logger,
doneCtx: ctx,
}, nil
}

View file

@ -2,6 +2,7 @@ package drivers
import (
"bytes"
"context"
"sync"
"testing"
"time"
@ -10,7 +11,6 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
"golang.org/x/net/context"
)
type testDriverState struct {

View file

@ -4,13 +4,12 @@ import (
"fmt"
"io"
"golang.org/x/net/context"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
context "golang.org/x/net/context"
)
type driverPluginServer struct {

View file

@ -1,16 +1,13 @@
package drivers
import (
"context"
"fmt"
"io/ioutil"
"path/filepath"
"runtime"
"time"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/client/allocdir"
@ -21,6 +18,8 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
)
type DriverHarness struct {

View file

@ -0,0 +1,61 @@
package shared
import (
"context"
"time"
"github.com/hashicorp/nomad/plugins/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown via the passeed pluginCtx. The
// parameters are:
// - err: the error returned from the streaming RPC
// - reqCtx: the context passed to the streaming request
// - pluginCtx: the plugins done ctx used to detect the plugin dying
//
// The return values are:
// - base.ErrPluginShutdown if the error is because the plugin shutdown
// - context.Canceled if the reqCtx is canceled
// - The original error
func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
case <-reqCtx.Done():
err = reqCtx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}