open-nomad/devices/gpu/nvidia/device_test.go

141 lines
3.1 KiB
Go

package nvidia
import (
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/devices/gpu/nvidia/nvml"
"github.com/hashicorp/nomad/plugins/device"
"github.com/stretchr/testify/require"
)
type MockNvmlClient struct {
FingerprintError error
FingerprintResponseReturned *nvml.FingerprintData
StatsError error
StatsResponseReturned []*nvml.StatsData
}
func (c *MockNvmlClient) GetFingerprintData() (*nvml.FingerprintData, error) {
return c.FingerprintResponseReturned, c.FingerprintError
}
func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
return c.StatsResponseReturned, c.StatsError
}
func TestReserve(t *testing.T) {
cases := []struct {
Name string
ExpectedReservation *device.ContainerReservation
ExpectedError error
Device *NvidiaDevice
RequestedIDs []string
}{
{
Name: "All RequestedIDs are not managed by Device",
ExpectedReservation: nil,
ExpectedError: &reservationError{[]string{
"UUID1",
"UUID2",
"UUID3",
}},
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Name: "Some RequestedIDs are not managed by Device",
ExpectedReservation: nil,
ExpectedError: &reservationError{[]string{
"UUID1",
"UUID2",
}},
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Name: "All RequestedIDs are managed by Device",
ExpectedReservation: &device.ContainerReservation{
Envs: map[string]string{
NvidiaVisibleDevices: "UUID1,UUID2,UUID3",
},
},
ExpectedError: nil,
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Name: "No IDs requested",
ExpectedReservation: &device.ContainerReservation{},
ExpectedError: nil,
RequestedIDs: nil,
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Name: "Device is disabled",
ExpectedReservation: nil,
ExpectedError: device.ErrPluginDisabled,
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: false,
},
},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
actualReservation, actualError := c.Device.Reserve(c.RequestedIDs)
require.Equal(t, c.ExpectedReservation, actualReservation)
require.Equal(t, c.ExpectedError, actualError)
})
}
}