open-nomad/devices/gpu/nvidia/device_test.go

115 lines
2.6 KiB
Go
Raw Normal View History

package nvidia
import (
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/devices/gpu/nvidia/nvml"
"github.com/hashicorp/nomad/plugins/device"
"github.com/stretchr/testify/require"
)
type MockNvmlClient struct {
FingerprintError error
FingerprintResponseReturned *nvml.FingerprintData
StatsError error
StatsResponseReturned []*nvml.StatsData
}
func (c *MockNvmlClient) GetFingerprintData() (*nvml.FingerprintData, error) {
return c.FingerprintResponseReturned, c.FingerprintError
}
func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
return c.StatsResponseReturned, c.StatsError
}
func TestReserve(t *testing.T) {
for _, testCase := range []struct {
Name string
ExpectedReservation *device.ContainerReservation
ExpectedError error
Device *NvidiaDevice
RequestedIDs []string
}{
{
Name: "All RequestedIDs are not managed by Device",
ExpectedReservation: nil,
ExpectedError: &reservationError{[]string{
"UUID1",
"UUID2",
"UUID3",
}},
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
logger: hclog.NewNullLogger(),
},
},
{
Name: "Some RequestedIDs are not managed by Device",
ExpectedReservation: nil,
ExpectedError: &reservationError{[]string{
"UUID1",
"UUID2",
}},
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID3": {},
},
logger: hclog.NewNullLogger(),
},
},
{
Name: "All RequestedIDs are managed by Device",
ExpectedReservation: &device.ContainerReservation{
Envs: map[string]string{
NvidiaVisibleDevices: "UUID1,UUID2,UUID3",
},
},
ExpectedError: nil,
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
},
},
{
Name: "No IDs requested",
ExpectedReservation: &device.ContainerReservation{},
ExpectedError: nil,
RequestedIDs: nil,
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
},
},
} {
actualReservation, actualError := testCase.Device.Reserve(testCase.RequestedIDs)
req := require.New(t)
req.Equal(testCase.ExpectedReservation, actualReservation)
req.Equal(testCase.ExpectedError, actualError)
}
}