open-nomad/plugins/shared/cmd/launcher/command/device.go
2023-04-10 15:36:59 +00:00

369 lines
8.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package command
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"time"
hclog "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
"github.com/hashicorp/hcl/v2/hcldec"
"github.com/hashicorp/nomad/helper/pluginutils/hclspecutils"
"github.com/hashicorp/nomad/helper/pluginutils/hclutils"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device"
"github.com/kr/pretty"
"github.com/mitchellh/cli"
"github.com/zclconf/go-cty/cty/msgpack"
)
func DeviceCommandFactory(meta Meta) cli.CommandFactory {
return func() (cli.Command, error) {
return &Device{Meta: meta}, nil
}
}
type Device struct {
Meta
// dev is the plugin device
dev device.DevicePlugin
// spec is the returned and parsed spec.
spec hcldec.Spec
}
func (c *Device) Help() string {
helpText := `
Usage: nomad-plugin-launcher device <device-binary> <config_file>
Device launches the given device binary and provides a REPL for interacting
with it.
General Options:
` + generalOptionsUsage() + `
Device Options:
-trace
Enable trace level log output.
`
return strings.TrimSpace(helpText)
}
func (c *Device) Synopsis() string {
return "REPL for interacting with device plugins"
}
func (c *Device) Run(args []string) int {
var trace bool
cmdFlags := c.FlagSet("device")
cmdFlags.Usage = func() { c.Ui.Output(c.Help()) }
cmdFlags.BoolVar(&trace, "trace", false, "")
if err := cmdFlags.Parse(args); err != nil {
c.logger.Error("failed to parse flags:", "error", err)
return 1
}
if trace {
c.logger.SetLevel(hclog.Trace)
} else if c.verbose {
c.logger.SetLevel(hclog.Debug)
}
args = cmdFlags.Args()
numArgs := len(args)
if numArgs < 1 {
c.logger.Error("expected at least 1 args (device binary)", "args", args)
return 1
} else if numArgs > 2 {
c.logger.Error("expected at most 2 args (device binary and config file)", "args", args)
return 1
}
binary := args[0]
var config []byte
if numArgs == 2 {
var err error
config, err = os.ReadFile(args[1])
if err != nil {
c.logger.Error("failed to read config file", "error", err)
return 1
}
}
// Get the plugin
dev, cleanup, err := c.getDevicePlugin(binary)
if err != nil {
c.logger.Error("failed to launch device plugin", "error", err)
return 1
}
defer cleanup()
c.dev = dev
spec, err := c.getSpec()
if err != nil {
c.logger.Error("failed to get config spec", "error", err)
return 1
}
c.spec = spec
if err := c.setConfig(spec, device.ApiVersion010, config, nil); err != nil {
c.logger.Error("failed to set config", "error", err)
return 1
}
if err := c.startRepl(); err != nil {
c.logger.Error("error interacting with plugin", "error", err)
return 1
}
return 0
}
func (c *Device) getDevicePlugin(binary string) (device.DevicePlugin, func(), error) {
// Launch the plugin
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: base.Handshake,
Plugins: map[string]plugin.Plugin{
base.PluginTypeBase: &base.PluginBase{},
base.PluginTypeDevice: &device.PluginDevice{},
},
Cmd: exec.Command(binary),
AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC},
Logger: c.logger,
})
// Connect via RPC
rpcClient, err := client.Client()
if err != nil {
client.Kill()
return nil, nil, err
}
// Request the plugin
raw, err := rpcClient.Dispense(base.PluginTypeDevice)
if err != nil {
client.Kill()
return nil, nil, err
}
// We should have a KV store now! This feels like a normal interface
// implementation but is in fact over an RPC connection.
dev := raw.(device.DevicePlugin)
return dev, func() { client.Kill() }, nil
}
func (c *Device) getSpec() (hcldec.Spec, error) {
// Get the schema so we can parse the config
spec, err := c.dev.ConfigSchema()
if err != nil {
return nil, fmt.Errorf("failed to get config schema: %v", err)
}
// Convert the schema
schema, diag := hclspecutils.Convert(spec)
if diag.HasErrors() {
errStr := "failed to convert HCL schema: "
for _, err := range diag.Errs() {
errStr = fmt.Sprintf("%s\n* %s", errStr, err.Error())
}
return nil, errors.New(errStr)
}
return schema, nil
}
func (c *Device) setConfig(spec hcldec.Spec, apiVersion string, config []byte, nmdCfg *base.AgentConfig) error {
// Parse the config into hcl
configVal, err := hclConfigToInterface(config)
if err != nil {
return err
}
val, diag, diagErrs := hclutils.ParseHclInterface(configVal, spec, nil)
if diag.HasErrors() {
return multierror.Append(errors.New("failed to parse config: "), diagErrs...)
}
cdata, err := msgpack.Marshal(val, val.Type())
if err != nil {
return err
}
req := &base.Config{
PluginConfig: cdata,
AgentConfig: nmdCfg,
ApiVersion: apiVersion,
}
if err := c.dev.SetConfig(req); err != nil {
return err
}
return nil
}
func hclConfigToInterface(config []byte) (interface{}, error) {
if len(config) == 0 {
return map[string]interface{}{}, nil
}
// Parse as we do in the jobspec parser
root, err := hcl.Parse(string(config))
if err != nil {
return nil, fmt.Errorf("failed to hcl parse the config: %v", err)
}
// Top-level item should be a list
list, ok := root.Node.(*ast.ObjectList)
if !ok {
return nil, fmt.Errorf("root should be an object")
}
var m map[string]interface{}
if err := hcl.DecodeObject(&m, list.Items[0]); err != nil {
return nil, fmt.Errorf("failed to decode object: %v", err)
}
return m["config"], nil
}
func (c *Device) startRepl() error {
// Start the output goroutine
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fingerprint := make(chan context.Context)
stats := make(chan context.Context)
reserve := make(chan []string)
go c.replOutput(ctx, fingerprint, stats, reserve)
c.Ui.Output("> Availabile commands are: exit(), fingerprint(), stop_fingerprint(), stats(), stop_stats(), reserve(id1, id2, ...)")
var fingerprintCtx, statsCtx context.Context
var fingerprintCancel, statsCancel context.CancelFunc
for {
in, err := c.Ui.Ask("> ")
if err != nil {
if fingerprintCancel != nil {
fingerprintCancel()
}
if statsCancel != nil {
statsCancel()
}
return err
}
switch {
case in == "exit()":
if fingerprintCancel != nil {
fingerprintCancel()
}
if statsCancel != nil {
statsCancel()
}
return nil
case in == "fingerprint()":
if fingerprintCtx != nil {
continue
}
fingerprintCtx, fingerprintCancel = context.WithCancel(ctx)
fingerprint <- fingerprintCtx
case in == "stop_fingerprint()":
if fingerprintCtx == nil {
continue
}
fingerprintCancel()
fingerprintCtx = nil
case in == "stats()":
if statsCtx != nil {
continue
}
statsCtx, statsCancel = context.WithCancel(ctx)
stats <- statsCtx
case in == "stop_stats()":
if statsCtx == nil {
continue
}
statsCancel()
statsCtx = nil
case strings.HasPrefix(in, "reserve(") && strings.HasSuffix(in, ")"):
listString := strings.TrimSuffix(strings.TrimPrefix(in, "reserve("), ")")
ids := strings.Split(strings.TrimSpace(listString), ",")
reserve <- ids
default:
c.Ui.Error(fmt.Sprintf("> Unknown command %q", in))
}
}
}
func (c *Device) replOutput(ctx context.Context, startFingerprint, startStats <-chan context.Context, reserve <-chan []string) {
var fingerprint <-chan *device.FingerprintResponse
var stats <-chan *device.StatsResponse
for {
select {
case <-ctx.Done():
return
case ctx := <-startFingerprint:
var err error
fingerprint, err = c.dev.Fingerprint(ctx)
if err != nil {
c.Ui.Error(fmt.Sprintf("fingerprint: %s", err))
os.Exit(1)
}
case resp, ok := <-fingerprint:
if !ok {
c.Ui.Output("> fingerprint: fingerprint output closed")
fingerprint = nil
continue
}
if resp == nil {
c.Ui.Warn("> fingerprint: received nil result")
os.Exit(1)
}
c.Ui.Output(fmt.Sprintf("> fingerprint: % #v", pretty.Formatter(resp)))
case ctx := <-startStats:
var err error
stats, err = c.dev.Stats(ctx, 1*time.Second)
if err != nil {
c.Ui.Error(fmt.Sprintf("stats: %s", err))
os.Exit(1)
}
case resp, ok := <-stats:
if !ok {
c.Ui.Output("> stats: stats output closed")
stats = nil
continue
}
if resp == nil {
c.Ui.Warn("> stats: received nil result")
os.Exit(1)
}
c.Ui.Output(fmt.Sprintf("> stats: % #v", pretty.Formatter(resp)))
case ids := <-reserve:
resp, err := c.dev.Reserve(ids)
if err != nil {
c.Ui.Warn(fmt.Sprintf("> reserve(%s): %v", strings.Join(ids, ", "), err))
} else {
c.Ui.Output(fmt.Sprintf("> reserve(%s): % #v", strings.Join(ids, ", "), pretty.Formatter(resp)))
}
}
}
}