369 lines
8.7 KiB
Go
369 lines
8.7 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package command
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
|
|
hclog "github.com/hashicorp/go-hclog"
|
|
multierror "github.com/hashicorp/go-multierror"
|
|
plugin "github.com/hashicorp/go-plugin"
|
|
"github.com/hashicorp/hcl"
|
|
"github.com/hashicorp/hcl/hcl/ast"
|
|
"github.com/hashicorp/hcl/v2/hcldec"
|
|
"github.com/hashicorp/nomad/helper/pluginutils/hclspecutils"
|
|
"github.com/hashicorp/nomad/helper/pluginutils/hclutils"
|
|
"github.com/hashicorp/nomad/plugins/base"
|
|
"github.com/hashicorp/nomad/plugins/device"
|
|
"github.com/kr/pretty"
|
|
"github.com/mitchellh/cli"
|
|
"github.com/zclconf/go-cty/cty/msgpack"
|
|
)
|
|
|
|
func DeviceCommandFactory(meta Meta) cli.CommandFactory {
|
|
return func() (cli.Command, error) {
|
|
return &Device{Meta: meta}, nil
|
|
}
|
|
}
|
|
|
|
type Device struct {
|
|
Meta
|
|
|
|
// dev is the plugin device
|
|
dev device.DevicePlugin
|
|
|
|
// spec is the returned and parsed spec.
|
|
spec hcldec.Spec
|
|
}
|
|
|
|
func (c *Device) Help() string {
|
|
helpText := `
|
|
Usage: nomad-plugin-launcher device <device-binary> <config_file>
|
|
|
|
Device launches the given device binary and provides a REPL for interacting
|
|
with it.
|
|
|
|
General Options:
|
|
|
|
` + generalOptionsUsage() + `
|
|
|
|
Device Options:
|
|
|
|
-trace
|
|
Enable trace level log output.
|
|
`
|
|
|
|
return strings.TrimSpace(helpText)
|
|
}
|
|
|
|
func (c *Device) Synopsis() string {
|
|
return "REPL for interacting with device plugins"
|
|
}
|
|
|
|
func (c *Device) Run(args []string) int {
|
|
var trace bool
|
|
cmdFlags := c.FlagSet("device")
|
|
cmdFlags.Usage = func() { c.Ui.Output(c.Help()) }
|
|
cmdFlags.BoolVar(&trace, "trace", false, "")
|
|
|
|
if err := cmdFlags.Parse(args); err != nil {
|
|
c.logger.Error("failed to parse flags:", "error", err)
|
|
return 1
|
|
}
|
|
if trace {
|
|
c.logger.SetLevel(hclog.Trace)
|
|
} else if c.verbose {
|
|
c.logger.SetLevel(hclog.Debug)
|
|
}
|
|
|
|
args = cmdFlags.Args()
|
|
numArgs := len(args)
|
|
if numArgs < 1 {
|
|
c.logger.Error("expected at least 1 args (device binary)", "args", args)
|
|
return 1
|
|
} else if numArgs > 2 {
|
|
c.logger.Error("expected at most 2 args (device binary and config file)", "args", args)
|
|
return 1
|
|
}
|
|
|
|
binary := args[0]
|
|
var config []byte
|
|
if numArgs == 2 {
|
|
var err error
|
|
config, err = os.ReadFile(args[1])
|
|
if err != nil {
|
|
c.logger.Error("failed to read config file", "error", err)
|
|
return 1
|
|
}
|
|
}
|
|
|
|
// Get the plugin
|
|
dev, cleanup, err := c.getDevicePlugin(binary)
|
|
if err != nil {
|
|
c.logger.Error("failed to launch device plugin", "error", err)
|
|
return 1
|
|
}
|
|
defer cleanup()
|
|
c.dev = dev
|
|
|
|
spec, err := c.getSpec()
|
|
if err != nil {
|
|
c.logger.Error("failed to get config spec", "error", err)
|
|
return 1
|
|
}
|
|
c.spec = spec
|
|
|
|
if err := c.setConfig(spec, device.ApiVersion010, config, nil); err != nil {
|
|
c.logger.Error("failed to set config", "error", err)
|
|
return 1
|
|
}
|
|
|
|
if err := c.startRepl(); err != nil {
|
|
c.logger.Error("error interacting with plugin", "error", err)
|
|
return 1
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
func (c *Device) getDevicePlugin(binary string) (device.DevicePlugin, func(), error) {
|
|
// Launch the plugin
|
|
client := plugin.NewClient(&plugin.ClientConfig{
|
|
HandshakeConfig: base.Handshake,
|
|
Plugins: map[string]plugin.Plugin{
|
|
base.PluginTypeBase: &base.PluginBase{},
|
|
base.PluginTypeDevice: &device.PluginDevice{},
|
|
},
|
|
Cmd: exec.Command(binary),
|
|
AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC},
|
|
Logger: c.logger,
|
|
})
|
|
|
|
// Connect via RPC
|
|
rpcClient, err := client.Client()
|
|
if err != nil {
|
|
client.Kill()
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Request the plugin
|
|
raw, err := rpcClient.Dispense(base.PluginTypeDevice)
|
|
if err != nil {
|
|
client.Kill()
|
|
return nil, nil, err
|
|
}
|
|
|
|
// We should have a KV store now! This feels like a normal interface
|
|
// implementation but is in fact over an RPC connection.
|
|
dev := raw.(device.DevicePlugin)
|
|
return dev, func() { client.Kill() }, nil
|
|
}
|
|
|
|
func (c *Device) getSpec() (hcldec.Spec, error) {
|
|
// Get the schema so we can parse the config
|
|
spec, err := c.dev.ConfigSchema()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get config schema: %v", err)
|
|
}
|
|
|
|
// Convert the schema
|
|
schema, diag := hclspecutils.Convert(spec)
|
|
if diag.HasErrors() {
|
|
errStr := "failed to convert HCL schema: "
|
|
for _, err := range diag.Errs() {
|
|
errStr = fmt.Sprintf("%s\n* %s", errStr, err.Error())
|
|
}
|
|
return nil, errors.New(errStr)
|
|
}
|
|
|
|
return schema, nil
|
|
}
|
|
|
|
func (c *Device) setConfig(spec hcldec.Spec, apiVersion string, config []byte, nmdCfg *base.AgentConfig) error {
|
|
// Parse the config into hcl
|
|
configVal, err := hclConfigToInterface(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
val, diag, diagErrs := hclutils.ParseHclInterface(configVal, spec, nil)
|
|
if diag.HasErrors() {
|
|
return multierror.Append(errors.New("failed to parse config: "), diagErrs...)
|
|
}
|
|
|
|
cdata, err := msgpack.Marshal(val, val.Type())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req := &base.Config{
|
|
PluginConfig: cdata,
|
|
AgentConfig: nmdCfg,
|
|
ApiVersion: apiVersion,
|
|
}
|
|
|
|
if err := c.dev.SetConfig(req); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func hclConfigToInterface(config []byte) (interface{}, error) {
|
|
if len(config) == 0 {
|
|
return map[string]interface{}{}, nil
|
|
}
|
|
|
|
// Parse as we do in the jobspec parser
|
|
root, err := hcl.Parse(string(config))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hcl parse the config: %v", err)
|
|
}
|
|
|
|
// Top-level item should be a list
|
|
list, ok := root.Node.(*ast.ObjectList)
|
|
if !ok {
|
|
return nil, fmt.Errorf("root should be an object")
|
|
}
|
|
|
|
var m map[string]interface{}
|
|
if err := hcl.DecodeObject(&m, list.Items[0]); err != nil {
|
|
return nil, fmt.Errorf("failed to decode object: %v", err)
|
|
}
|
|
|
|
return m["config"], nil
|
|
}
|
|
|
|
func (c *Device) startRepl() error {
|
|
// Start the output goroutine
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
fingerprint := make(chan context.Context)
|
|
stats := make(chan context.Context)
|
|
reserve := make(chan []string)
|
|
go c.replOutput(ctx, fingerprint, stats, reserve)
|
|
|
|
c.Ui.Output("> Availabile commands are: exit(), fingerprint(), stop_fingerprint(), stats(), stop_stats(), reserve(id1, id2, ...)")
|
|
var fingerprintCtx, statsCtx context.Context
|
|
var fingerprintCancel, statsCancel context.CancelFunc
|
|
|
|
for {
|
|
in, err := c.Ui.Ask("> ")
|
|
if err != nil {
|
|
if fingerprintCancel != nil {
|
|
fingerprintCancel()
|
|
}
|
|
if statsCancel != nil {
|
|
statsCancel()
|
|
}
|
|
return err
|
|
}
|
|
|
|
switch {
|
|
case in == "exit()":
|
|
if fingerprintCancel != nil {
|
|
fingerprintCancel()
|
|
}
|
|
if statsCancel != nil {
|
|
statsCancel()
|
|
}
|
|
return nil
|
|
case in == "fingerprint()":
|
|
if fingerprintCtx != nil {
|
|
continue
|
|
}
|
|
fingerprintCtx, fingerprintCancel = context.WithCancel(ctx)
|
|
fingerprint <- fingerprintCtx
|
|
case in == "stop_fingerprint()":
|
|
if fingerprintCtx == nil {
|
|
continue
|
|
}
|
|
fingerprintCancel()
|
|
fingerprintCtx = nil
|
|
case in == "stats()":
|
|
if statsCtx != nil {
|
|
continue
|
|
}
|
|
statsCtx, statsCancel = context.WithCancel(ctx)
|
|
stats <- statsCtx
|
|
case in == "stop_stats()":
|
|
if statsCtx == nil {
|
|
continue
|
|
}
|
|
statsCancel()
|
|
statsCtx = nil
|
|
case strings.HasPrefix(in, "reserve(") && strings.HasSuffix(in, ")"):
|
|
listString := strings.TrimSuffix(strings.TrimPrefix(in, "reserve("), ")")
|
|
ids := strings.Split(strings.TrimSpace(listString), ",")
|
|
reserve <- ids
|
|
default:
|
|
c.Ui.Error(fmt.Sprintf("> Unknown command %q", in))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Device) replOutput(ctx context.Context, startFingerprint, startStats <-chan context.Context, reserve <-chan []string) {
|
|
var fingerprint <-chan *device.FingerprintResponse
|
|
var stats <-chan *device.StatsResponse
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case ctx := <-startFingerprint:
|
|
var err error
|
|
fingerprint, err = c.dev.Fingerprint(ctx)
|
|
if err != nil {
|
|
c.Ui.Error(fmt.Sprintf("fingerprint: %s", err))
|
|
os.Exit(1)
|
|
}
|
|
case resp, ok := <-fingerprint:
|
|
if !ok {
|
|
c.Ui.Output("> fingerprint: fingerprint output closed")
|
|
fingerprint = nil
|
|
continue
|
|
}
|
|
|
|
if resp == nil {
|
|
c.Ui.Warn("> fingerprint: received nil result")
|
|
os.Exit(1)
|
|
}
|
|
|
|
c.Ui.Output(fmt.Sprintf("> fingerprint: % #v", pretty.Formatter(resp)))
|
|
case ctx := <-startStats:
|
|
var err error
|
|
stats, err = c.dev.Stats(ctx, 1*time.Second)
|
|
if err != nil {
|
|
c.Ui.Error(fmt.Sprintf("stats: %s", err))
|
|
os.Exit(1)
|
|
}
|
|
case resp, ok := <-stats:
|
|
if !ok {
|
|
c.Ui.Output("> stats: stats output closed")
|
|
stats = nil
|
|
continue
|
|
}
|
|
|
|
if resp == nil {
|
|
c.Ui.Warn("> stats: received nil result")
|
|
os.Exit(1)
|
|
}
|
|
|
|
c.Ui.Output(fmt.Sprintf("> stats: % #v", pretty.Formatter(resp)))
|
|
case ids := <-reserve:
|
|
resp, err := c.dev.Reserve(ids)
|
|
if err != nil {
|
|
c.Ui.Warn(fmt.Sprintf("> reserve(%s): %v", strings.Join(ids, ", "), err))
|
|
} else {
|
|
c.Ui.Output(fmt.Sprintf("> reserve(%s): % #v", strings.Join(ids, ", "), pretty.Formatter(resp)))
|
|
}
|
|
}
|
|
}
|
|
}
|