141 lines
3.1 KiB
Go
141 lines
3.1 KiB
Go
package nvidia
|
|
|
|
import (
|
|
"testing"
|
|
|
|
hclog "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/nomad/devices/gpu/nvidia/nvml"
|
|
"github.com/hashicorp/nomad/plugins/device"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type MockNvmlClient struct {
|
|
FingerprintError error
|
|
FingerprintResponseReturned *nvml.FingerprintData
|
|
|
|
StatsError error
|
|
StatsResponseReturned []*nvml.StatsData
|
|
}
|
|
|
|
func (c *MockNvmlClient) GetFingerprintData() (*nvml.FingerprintData, error) {
|
|
return c.FingerprintResponseReturned, c.FingerprintError
|
|
}
|
|
|
|
func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
|
|
return c.StatsResponseReturned, c.StatsError
|
|
}
|
|
|
|
func TestReserve(t *testing.T) {
|
|
cases := []struct {
|
|
Name string
|
|
ExpectedReservation *device.ContainerReservation
|
|
ExpectedError error
|
|
Device *NvidiaDevice
|
|
RequestedIDs []string
|
|
}{
|
|
{
|
|
Name: "All RequestedIDs are not managed by Device",
|
|
ExpectedReservation: nil,
|
|
ExpectedError: &reservationError{[]string{
|
|
"UUID1",
|
|
"UUID2",
|
|
"UUID3",
|
|
}},
|
|
RequestedIDs: []string{
|
|
"UUID1",
|
|
"UUID2",
|
|
"UUID3",
|
|
},
|
|
Device: &NvidiaDevice{
|
|
logger: hclog.NewNullLogger(),
|
|
enabled: true,
|
|
},
|
|
},
|
|
{
|
|
Name: "Some RequestedIDs are not managed by Device",
|
|
ExpectedReservation: nil,
|
|
ExpectedError: &reservationError{[]string{
|
|
"UUID1",
|
|
"UUID2",
|
|
}},
|
|
RequestedIDs: []string{
|
|
"UUID1",
|
|
"UUID2",
|
|
"UUID3",
|
|
},
|
|
Device: &NvidiaDevice{
|
|
devices: map[string]struct{}{
|
|
"UUID3": {},
|
|
},
|
|
logger: hclog.NewNullLogger(),
|
|
enabled: true,
|
|
},
|
|
},
|
|
{
|
|
Name: "All RequestedIDs are managed by Device",
|
|
ExpectedReservation: &device.ContainerReservation{
|
|
Envs: map[string]string{
|
|
NvidiaVisibleDevices: "UUID1,UUID2,UUID3",
|
|
},
|
|
},
|
|
ExpectedError: nil,
|
|
RequestedIDs: []string{
|
|
"UUID1",
|
|
"UUID2",
|
|
"UUID3",
|
|
},
|
|
Device: &NvidiaDevice{
|
|
devices: map[string]struct{}{
|
|
"UUID1": {},
|
|
"UUID2": {},
|
|
"UUID3": {},
|
|
},
|
|
logger: hclog.NewNullLogger(),
|
|
enabled: true,
|
|
},
|
|
},
|
|
{
|
|
Name: "No IDs requested",
|
|
ExpectedReservation: &device.ContainerReservation{},
|
|
ExpectedError: nil,
|
|
RequestedIDs: nil,
|
|
Device: &NvidiaDevice{
|
|
devices: map[string]struct{}{
|
|
"UUID1": {},
|
|
"UUID2": {},
|
|
"UUID3": {},
|
|
},
|
|
logger: hclog.NewNullLogger(),
|
|
enabled: true,
|
|
},
|
|
},
|
|
{
|
|
Name: "Device is disabled",
|
|
ExpectedReservation: nil,
|
|
ExpectedError: device.ErrPluginDisabled,
|
|
RequestedIDs: []string{
|
|
"UUID1",
|
|
"UUID2",
|
|
"UUID3",
|
|
},
|
|
Device: &NvidiaDevice{
|
|
devices: map[string]struct{}{
|
|
"UUID1": {},
|
|
"UUID2": {},
|
|
"UUID3": {},
|
|
},
|
|
logger: hclog.NewNullLogger(),
|
|
enabled: false,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.Name, func(t *testing.T) {
|
|
actualReservation, actualError := c.Device.Reserve(c.RequestedIDs)
|
|
require.Equal(t, c.ExpectedReservation, actualReservation)
|
|
require.Equal(t, c.ExpectedError, actualError)
|
|
})
|
|
}
|
|
}
|