open-nomad/plugins/device/cmd/nvidia/device.go

145 lines
4.2 KiB
Go

package nvidia
import (
"context"
"fmt"
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device"
"github.com/hashicorp/nomad/plugins/device/cmd/nvidia/nvml"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
)
const (
// pluginName is the name of the plugin
pluginName = "nvidia-gpu"
// vendor is the vendor providing the devices
vendor = "nvidia"
// deviceType is the type of device being returned
deviceType = device.DeviceTypeGPU
// notAvailable value is returned to nomad server in case some properties were
// undetected by nvml driver
notAvailable = "N/A"
)
var (
// pluginInfo describes the plugin
pluginInfo = &base.PluginInfoResponse{
Type: base.PluginTypeDevice,
PluginApiVersion: "0.0.1", // XXX This should be an array and should be consts
PluginVersion: "0.1.0",
Name: pluginName,
}
// configSpec is the specification of the plugin's configuration
configSpec = hclspec.NewObject(map[string]*hclspec.Spec{
"ignored_gpu_ids": hclspec.NewDefault(
hclspec.NewAttr("ignored_gpu_ids", "list(string)", false),
hclspec.NewLiteral("[]"),
),
"fingerprint_period": hclspec.NewDefault(
hclspec.NewAttr("fingerprint_period", "string", false),
hclspec.NewLiteral("\"5s\""),
),
})
)
// Config contains configuration information for the plugin.
type Config struct {
IgnoredGPUIDs []string `codec:"ignored_gpu_ids"`
FingerprintPeriod string `codec:"fingerprint_period"`
}
// NvidiaDevice contains all plugin specific data
type NvidiaDevice struct {
// nvmlClient is used to get data from nvidia
nvmlClient nvml.NvmlClient
// nvmlClientInitializationError holds an error retrieved during
// nvmlClient initialization
nvmlClientInitializationError error
// ignoredGPUIDs is a set of UUIDs that would not be exposed to nomad
ignoredGPUIDs map[string]struct{}
// fingerprintPeriod is how often we should call nvml to get list of devices
fingerprintPeriod time.Duration
// devices is the set of detected eligible devices
devices map[string]struct{}
deviceLock sync.RWMutex
logger log.Logger
}
// NewNvidiaDevice returns a new nvidia device plugin.
func NewNvidiaDevice(log log.Logger) *NvidiaDevice {
nvmlClient, nvmlClientInitializationError := nvml.NewNvmlClient()
logger := log.Named(pluginName)
if nvmlClientInitializationError != nil {
logger.Error("unable to initialize Nvidia driver", "error", nvmlClientInitializationError)
}
return &NvidiaDevice{
logger: logger,
devices: make(map[string]struct{}),
ignoredGPUIDs: make(map[string]struct{}),
nvmlClient: nvmlClient,
nvmlClientInitializationError: nvmlClientInitializationError,
}
}
// PluginInfo returns information describing the plugin.
func (d *NvidiaDevice) PluginInfo() (*base.PluginInfoResponse, error) {
return pluginInfo, nil
}
// ConfigSchema returns the plugins configuration schema.
func (d *NvidiaDevice) ConfigSchema() (*hclspec.Spec, error) {
return configSpec, nil
}
// SetConfig is used to set the configuration of the plugin.
func (d *NvidiaDevice) SetConfig(data []byte) error {
var config Config
if err := base.MsgPackDecode(data, &config); err != nil {
return err
}
for _, ignoredGPUId := range config.IgnoredGPUIDs {
d.ignoredGPUIDs[ignoredGPUId] = struct{}{}
}
period, err := time.ParseDuration(config.FingerprintPeriod)
if err != nil {
return fmt.Errorf("failed to parse fingerprint period %q: %v", config.FingerprintPeriod, err)
}
d.fingerprintPeriod = period
return nil
}
// Fingerprint streams detected devices. If device changes are detected or the
// devices health changes, messages will be emitted.
func (d *NvidiaDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) {
outCh := make(chan *device.FingerprintResponse)
go d.fingerprint(ctx, outCh)
return outCh, nil
}
// Reserve returns information on how to mount the given devices.
func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation, error) {
return nil, nil
}
// Stats streams statistics for the detected devices.
func (d *NvidiaDevice) Stats(ctx context.Context) (<-chan *device.StatsResponse, error) {
return nil, nil
}