open-nomad/plugins/drivers/server.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

423 lines
10 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package drivers
import (
"context"
2018-10-16 02:37:58 +00:00
"fmt"
"io"
fix integer bounds checks (#11815) * driver: fix integer conversion error The shared executor incorrectly parsed the user's group into int32 and then cast to uint32 without bounds checking. This is harmless because an out-of-bounds gid will throw an error later, but it triggers security and code quality scans. Parse directly to uint32 so that we get correct error handling. * helper: fix integer conversion error The autopilot flags helper incorrectly parses a uint64 to a uint which is machine specific size. Although we don't have 32-bit builds, this sets off security and code quality scaans. Parse to the machine sized uint. * driver: restrict bounds of port map The plugin server doesn't constrain the maximum integer for port maps. This could result in a user-visible misconfiguration, but it also triggers security and code quality scans. Restrict the bounds before casting to int32 and return an error. * cpuset: restrict upper bounds of cpuset values Our cpuset configuration expects values in the range of uint16 to match the expectations set by the kernel, but we don't constrain the values before downcasting. An underflow could lead to allocations failing on the client rather than being caught earlier. This also make security and code quality scanners happy. * http: fix integer downcast for per_page parameter The parser for the `per_page` query parameter downcasts to int32 without bounds checking. This could result in underflow and nonsensical paging, but there's no server-side consequences for this. Fixing this will silence some security and code quality scanners though.
2022-01-25 16:16:48 +00:00
"math"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-plugin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
)
type driverPluginServer struct {
broker *plugin.GRPCBroker
impl DriverPlugin
}
func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
spec, err := b.impl.TaskConfigSchema()
if err != nil {
return nil, err
}
resp := &proto.TaskConfigSchemaResponse{
Spec: spec,
}
return resp, nil
}
func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
caps, err := b.impl.Capabilities()
if err != nil {
return nil, err
}
resp := &proto.CapabilitiesResponse{
Capabilities: &proto.DriverCapabilities{
SendSignals: caps.SendSignals,
Exec: caps.Exec,
MustCreateNetwork: caps.MustInitiateNetwork,
NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{},
RemoteTasks: caps.RemoteTasks,
},
}
switch caps.FSIsolation {
2019-01-04 21:11:25 +00:00
case FSIsolationNone:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
2019-01-04 21:11:25 +00:00
case FSIsolationChroot:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
2019-01-04 21:11:25 +00:00
case FSIsolationImage:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
default:
resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
}
for _, mode := range caps.NetIsolationModes {
resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode))
}
return resp, nil
}
func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
ctx := srv.Context()
ch, err := b.impl.Fingerprint(ctx)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case f, ok := <-ch:
if !ok {
return nil
}
resp := &proto.FingerprintResponse{
Attributes: dstructs.ConvertStructAttributeMap(f.Attributes),
Health: healthStateToProto(f.Health),
HealthDescription: f.HealthDescription,
}
if err := srv.Send(resp); err != nil {
return err
}
}
}
}
func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
if err != nil {
return nil, err
}
return &proto.RecoverTaskResponse{}, nil
}
func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return nil, st.Err()
}
return nil, err
}
var pbNet *proto.NetworkOverride
if net != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: net.IP,
AutoAdvertise: net.AutoAdvertise,
}
for k, v := range net.PortMap {
fix integer bounds checks (#11815) * driver: fix integer conversion error The shared executor incorrectly parsed the user's group into int32 and then cast to uint32 without bounds checking. This is harmless because an out-of-bounds gid will throw an error later, but it triggers security and code quality scans. Parse directly to uint32 so that we get correct error handling. * helper: fix integer conversion error The autopilot flags helper incorrectly parses a uint64 to a uint which is machine specific size. Although we don't have 32-bit builds, this sets off security and code quality scaans. Parse to the machine sized uint. * driver: restrict bounds of port map The plugin server doesn't constrain the maximum integer for port maps. This could result in a user-visible misconfiguration, but it also triggers security and code quality scans. Restrict the bounds before casting to int32 and return an error. * cpuset: restrict upper bounds of cpuset values Our cpuset configuration expects values in the range of uint16 to match the expectations set by the kernel, but we don't constrain the values before downcasting. An underflow could lead to allocations failing on the client rather than being caught earlier. This also make security and code quality scanners happy. * http: fix integer downcast for per_page parameter The parser for the `per_page` query parameter downcasts to int32 without bounds checking. This could result in underflow and nonsensical paging, but there's no server-side consequences for this. Fixing this will silence some security and code quality scanners though.
2022-01-25 16:16:48 +00:00
if v > math.MaxInt32 {
return nil, fmt.Errorf("port map out of bounds")
}
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.StartTaskResponse{
Handle: taskHandleToProto(handle),
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
ch, err := b.impl.WaitTask(ctx, req.TaskId)
if err != nil {
return nil, err
}
2018-10-16 02:37:58 +00:00
var ok bool
var result *ExitResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case result, ok = <-ch:
if !ok {
return &proto.WaitTaskResponse{
Err: "channel closed",
}, nil
}
}
var errStr string
if result.Err != nil {
errStr = result.Err.Error()
}
resp := &proto.WaitTaskResponse{
Err: errStr,
Result: &proto.ExitResult{
ExitCode: int32(result.ExitCode),
Signal: int32(result.Signal),
OomKilled: result.OOMKilled,
},
}
return resp, nil
}
func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
if err != nil {
return nil, err
}
return &proto.StopTaskResponse{}, nil
}
func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
err := b.impl.DestroyTask(req.TaskId, req.Force)
if err != nil {
return nil, err
}
return &proto.DestroyTaskResponse{}, nil
}
func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
status, err := b.impl.InspectTask(req.TaskId)
if err != nil {
return nil, err
}
protoStatus, err := taskStatusToProto(status)
if err != nil {
return nil, err
}
var pbNet *proto.NetworkOverride
if status.NetworkOverride != nil {
pbNet = &proto.NetworkOverride{
PortMap: map[string]int32{},
Addr: status.NetworkOverride.IP,
AutoAdvertise: status.NetworkOverride.AutoAdvertise,
}
for k, v := range status.NetworkOverride.PortMap {
pbNet.PortMap[k] = int32(v)
}
}
resp := &proto.InspectTaskResponse{
Task: protoStatus,
Driver: &proto.TaskDriverStatus{
Attributes: status.DriverAttributes,
},
NetworkOverride: pbNet,
}
return resp, nil
}
func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
interval, err := ptypes.Duration(req.CollectionInterval)
if err != nil {
return fmt.Errorf("failed to parse collection interval: %v", err)
}
ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
if err != nil {
if rec, ok := err.(structs.Recoverable); ok {
st := status.New(codes.FailedPrecondition, rec.Error())
st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
if err != nil {
// If this error, it will always error
panic(err)
}
return st.Err()
}
return err
}
for stats := range ch {
pb, err := TaskStatsToProto(stats)
if err != nil {
return fmt.Errorf("failed to encode task stats: %v", err)
}
if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
timeout, err := ptypes.Duration(req.Timeout)
if err != nil {
return nil, err
}
result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
if err != nil {
return nil, err
}
resp := &proto.ExecTaskResponse{
Stdout: result.Stdout,
Stderr: result.Stderr,
Result: exitResultToProto(result.ExitResult),
}
return resp, nil
}
func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
msg, err := server.Recv()
if err != nil {
return fmt.Errorf("failed to receive initial message: %v", err)
}
if msg.Setup == nil {
return fmt.Errorf("first message should always be setup")
}
if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
return impl.ExecTaskStreamingRaw(server.Context(),
msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
server)
}
d, ok := b.impl.(ExecTaskStreamingDriver)
if !ok {
return fmt.Errorf("driver does not support exec")
}
execOpts, errCh := StreamToExecOptions(server.Context(),
msg.Setup.Command, msg.Setup.Tty,
server)
result, err := d.ExecTaskStreaming(server.Context(),
msg.Setup.TaskId, execOpts)
execOpts.Stdout.Close()
execOpts.Stderr.Close()
if err != nil {
return err
}
// wait for copy to be done
select {
case err = <-errCh:
case <-server.Context().Done():
err = fmt.Errorf("exec timed out: %v", server.Context().Err())
}
if err != nil {
return err
}
server.Send(&ExecTaskStreamingResponseMsg{
Exited: true,
Result: exitResultToProto(result),
})
return err
}
func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
err := b.impl.SignalTask(req.TaskId, req.Signal)
if err != nil {
return nil, err
}
resp := &proto.SignalTaskResponse{}
return resp, nil
}
func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
ch, err := b.impl.TaskEvents(srv.Context())
if err != nil {
return err
}
for {
event := <-ch
if event == nil {
break
}
pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
if err != nil {
return err
}
pbEvent := &proto.DriverTaskEvent{
TaskId: event.TaskID,
AllocId: event.AllocID,
TaskName: event.TaskName,
Timestamp: pbTimestamp,
Message: event.Message,
Annotations: event.Annotations,
}
if err = srv.Send(pbEvent); err == io.EOF {
break
} else if err != nil {
return err
}
}
return nil
}
func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("CreateNetwork RPC not supported by driver")
}
spec, created, err := nm.CreateNetwork(req.GetAllocId(), networkCreateRequestFromProto(req))
if err != nil {
return nil, err
}
return &proto.CreateNetworkResponse{
IsolationSpec: NetworkIsolationSpecToProto(spec),
Created: created,
}, nil
}
func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) {
nm, ok := b.impl.(DriverNetworkManager)
if !ok {
return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver")
}
err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec))
if err != nil {
return nil, err
}
return &proto.DestroyNetworkResponse{}, nil
}