2018-08-13 17:29:29 +00:00
|
|
|
package device
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
pb "github.com/golang/protobuf/proto"
|
|
|
|
plugin "github.com/hashicorp/go-plugin"
|
2018-10-13 18:43:06 +00:00
|
|
|
"github.com/hashicorp/nomad/helper"
|
2018-08-13 17:29:29 +00:00
|
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
|
|
"github.com/hashicorp/nomad/plugins/base"
|
|
|
|
"github.com/hashicorp/nomad/plugins/shared/hclspec"
|
2018-09-28 17:09:01 +00:00
|
|
|
psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
|
2018-08-13 17:29:29 +00:00
|
|
|
"github.com/hashicorp/nomad/testutil"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/zclconf/go-cty/cty"
|
|
|
|
"github.com/zclconf/go-cty/cty/msgpack"
|
|
|
|
"google.golang.org/grpc/status"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestDevicePlugin_PluginInfo(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
2018-12-18 00:40:58 +00:00
|
|
|
var (
|
|
|
|
apiVersions = []string{"v0.1.0", "v0.2.0"}
|
|
|
|
)
|
|
|
|
|
2018-08-13 17:29:29 +00:00
|
|
|
const (
|
|
|
|
pluginVersion = "v0.2.1"
|
2018-11-13 02:39:11 +00:00
|
|
|
pluginName = "mock_device"
|
2018-08-13 17:29:29 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
knownType := func() (*base.PluginInfoResponse, error) {
|
|
|
|
info := &base.PluginInfoResponse{
|
2018-12-18 00:40:58 +00:00
|
|
|
Type: base.PluginTypeDevice,
|
|
|
|
PluginApiVersions: apiVersions,
|
|
|
|
PluginVersion: pluginVersion,
|
|
|
|
Name: pluginName,
|
2018-08-13 17:29:29 +00:00
|
|
|
}
|
|
|
|
return info, nil
|
|
|
|
}
|
|
|
|
unknownType := func() (*base.PluginInfoResponse, error) {
|
|
|
|
info := &base.PluginInfoResponse{
|
2018-12-18 00:40:58 +00:00
|
|
|
Type: "bad",
|
|
|
|
PluginApiVersions: apiVersions,
|
|
|
|
PluginVersion: pluginVersion,
|
|
|
|
Name: pluginName,
|
2018-08-13 17:29:29 +00:00
|
|
|
}
|
|
|
|
return info, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
MockPlugin: &base.MockPlugin{
|
|
|
|
PluginInfoF: knownType,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := impl.PluginInfo()
|
|
|
|
require.NoError(err)
|
2018-12-18 00:40:58 +00:00
|
|
|
require.Equal(apiVersions, resp.PluginApiVersions)
|
2018-08-13 17:29:29 +00:00
|
|
|
require.Equal(pluginVersion, resp.PluginVersion)
|
|
|
|
require.Equal(pluginName, resp.Name)
|
|
|
|
require.Equal(base.PluginTypeDevice, resp.Type)
|
|
|
|
|
|
|
|
// Swap the implementation to return an unknown type
|
|
|
|
mock.PluginInfoF = unknownType
|
|
|
|
_, err = impl.PluginInfo()
|
|
|
|
require.Error(err)
|
|
|
|
require.Contains(err.Error(), "unknown type")
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_ConfigSchema(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
MockPlugin: &base.MockPlugin{
|
|
|
|
ConfigSchemaF: func() (*hclspec.Spec, error) {
|
|
|
|
return base.TestSpec, nil
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
specOut, err := impl.ConfigSchema()
|
|
|
|
require.NoError(err)
|
|
|
|
require.True(pb.Equal(base.TestSpec, specOut))
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_SetConfig(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
var receivedData []byte
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
MockPlugin: &base.MockPlugin{
|
2018-11-09 03:36:25 +00:00
|
|
|
PluginInfoF: func() (*base.PluginInfoResponse, error) {
|
|
|
|
return &base.PluginInfoResponse{
|
2018-12-18 00:40:58 +00:00
|
|
|
Type: base.PluginTypeDevice,
|
|
|
|
PluginApiVersions: []string{"v0.0.1"},
|
|
|
|
PluginVersion: "v0.0.1",
|
|
|
|
Name: "mock_device",
|
2018-11-09 03:36:25 +00:00
|
|
|
}, nil
|
|
|
|
},
|
2018-08-13 17:29:29 +00:00
|
|
|
ConfigSchemaF: func() (*hclspec.Spec, error) {
|
|
|
|
return base.TestSpec, nil
|
|
|
|
},
|
2018-12-18 00:40:58 +00:00
|
|
|
SetConfigF: func(cfg *base.Config) error {
|
|
|
|
receivedData = cfg.PluginConfig
|
2018-08-13 17:29:29 +00:00
|
|
|
return nil
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
config := cty.ObjectVal(map[string]cty.Value{
|
|
|
|
"foo": cty.StringVal("v1"),
|
|
|
|
"bar": cty.NumberIntVal(1337),
|
|
|
|
"baz": cty.BoolVal(true),
|
|
|
|
})
|
|
|
|
cdata, err := msgpack.Marshal(config, config.Type())
|
|
|
|
require.NoError(err)
|
2018-12-18 00:40:58 +00:00
|
|
|
require.NoError(impl.SetConfig(&base.Config{PluginConfig: cdata}))
|
2018-08-13 17:29:29 +00:00
|
|
|
require.Equal(cdata, receivedData)
|
|
|
|
|
|
|
|
// Decode the value back
|
|
|
|
var actual base.TestConfig
|
|
|
|
require.NoError(structs.Decode(receivedData, &actual))
|
|
|
|
require.Equal("v1", actual.Foo)
|
|
|
|
require.EqualValues(1337, actual.Bar)
|
|
|
|
require.True(actual.Baz)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Fingerprint(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
devices1 := []*DeviceGroup{
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "foo",
|
2018-10-13 18:43:06 +00:00
|
|
|
Attributes: map[string]*psstructs.Attribute{
|
|
|
|
"memory": {
|
|
|
|
Int: helper.Int64ToPtr(4),
|
|
|
|
Unit: "GiB",
|
|
|
|
},
|
|
|
|
},
|
2018-08-13 17:29:29 +00:00
|
|
|
},
|
|
|
|
}
|
|
|
|
devices2 := []*DeviceGroup{
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "foo",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "bar",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
|
|
|
|
outCh := make(chan *FingerprintResponse, 1)
|
|
|
|
go func() {
|
|
|
|
// Send two messages
|
|
|
|
for _, devs := range [][]*DeviceGroup{devices1, devices2} {
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
return
|
|
|
|
case outCh <- &FingerprintResponse{Devices: devs}:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
// Get the stream
|
|
|
|
stream, err := impl.Fingerprint(ctx)
|
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
var first *FingerprintResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case first = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
require.NoError(first.Error)
|
|
|
|
require.EqualValues(devices1, first.Devices)
|
|
|
|
|
|
|
|
// Get the second message
|
|
|
|
var second *FingerprintResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case second = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
require.NoError(second.Error)
|
|
|
|
require.EqualValues(devices2, second.Devices)
|
|
|
|
|
|
|
|
select {
|
|
|
|
case _, ok := <-stream:
|
|
|
|
require.False(ok)
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("stream should be closed")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
ferr := fmt.Errorf("mock fingerprinting failed")
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
|
|
|
|
outCh := make(chan *FingerprintResponse, 1)
|
|
|
|
go func() {
|
|
|
|
// Send the error
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
return
|
|
|
|
case outCh <- &FingerprintResponse{Error: ferr}:
|
|
|
|
}
|
|
|
|
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
// Get the stream
|
|
|
|
stream, err := impl.Fingerprint(ctx)
|
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
var first *FingerprintResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case first = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
errStatus := status.Convert(ferr)
|
|
|
|
require.EqualError(first.Error, errStatus.Err().Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
|
|
|
|
outCh := make(chan *FingerprintResponse, 1)
|
|
|
|
go func() {
|
|
|
|
<-ctx.Done()
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
|
|
|
// Get the stream
|
|
|
|
stream, err := impl.Fingerprint(ctx)
|
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
select {
|
|
|
|
case <-time.After(testutil.Timeout(10 * time.Millisecond)):
|
|
|
|
case _ = <-stream:
|
|
|
|
t.Fatal("bad value")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Cancel the context
|
|
|
|
cancel()
|
|
|
|
|
|
|
|
// Make sure we are done
|
|
|
|
select {
|
|
|
|
case <-time.After(100 * time.Millisecond):
|
|
|
|
t.Fatalf("timeout")
|
|
|
|
case v := <-stream:
|
|
|
|
require.Error(v.Error)
|
|
|
|
require.EqualError(v.Error, context.Canceled.Error())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Reserve(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
reservation := &ContainerReservation{
|
|
|
|
Envs: map[string]string{
|
|
|
|
"foo": "bar",
|
|
|
|
},
|
|
|
|
Mounts: []*Mount{
|
|
|
|
{
|
|
|
|
TaskPath: "foo",
|
|
|
|
HostPath: "bar",
|
|
|
|
ReadOnly: true,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
Devices: []*DeviceSpec{
|
|
|
|
{
|
|
|
|
TaskPath: "foo",
|
|
|
|
HostPath: "bar",
|
|
|
|
CgroupPerms: "rx",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
var received []string
|
|
|
|
mock := &MockDevicePlugin{
|
|
|
|
ReserveF: func(devices []string) (*ContainerReservation, error) {
|
|
|
|
received = devices
|
|
|
|
return reservation, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
req := []string{"a", "b"}
|
|
|
|
containerRes, err := impl.Reserve(req)
|
|
|
|
require.NoError(err)
|
|
|
|
require.EqualValues(req, received)
|
|
|
|
require.EqualValues(reservation, containerRes)
|
|
|
|
}
|
2018-08-27 23:11:07 +00:00
|
|
|
|
|
|
|
func TestDevicePlugin_Stats(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
devices1 := []*DeviceGroupStats{
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "foo",
|
|
|
|
InstanceStats: map[string]*DeviceStats{
|
2018-08-28 17:32:03 +00:00
|
|
|
"1": {
|
2018-11-12 00:36:20 +00:00
|
|
|
Summary: &psstructs.StatValue{
|
2018-11-13 15:06:44 +00:00
|
|
|
IntNumeratorVal: helper.Int64ToPtr(10),
|
|
|
|
IntDenominatorVal: helper.Int64ToPtr(20),
|
2018-08-27 23:11:07 +00:00
|
|
|
Unit: "MB",
|
|
|
|
Desc: "Unit test",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
devices2 := []*DeviceGroupStats{
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "foo",
|
|
|
|
InstanceStats: map[string]*DeviceStats{
|
2018-08-28 17:32:03 +00:00
|
|
|
"1": {
|
2018-11-12 00:36:20 +00:00
|
|
|
Summary: &psstructs.StatValue{
|
2018-11-13 15:06:44 +00:00
|
|
|
FloatNumeratorVal: helper.Float64ToPtr(10.0),
|
|
|
|
FloatDenominatorVal: helper.Float64ToPtr(20.0),
|
2018-08-27 23:11:07 +00:00
|
|
|
Unit: "MB",
|
|
|
|
Desc: "Unit test",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "bar",
|
|
|
|
InstanceStats: map[string]*DeviceStats{
|
2018-08-28 17:32:03 +00:00
|
|
|
"1": {
|
2018-11-12 00:36:20 +00:00
|
|
|
Summary: &psstructs.StatValue{
|
2018-11-13 15:06:44 +00:00
|
|
|
StringVal: helper.StringToPtr("foo"),
|
2018-08-27 23:11:07 +00:00
|
|
|
Unit: "MB",
|
|
|
|
Desc: "Unit test",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Vendor: "nvidia",
|
|
|
|
Type: DeviceTypeGPU,
|
|
|
|
Name: "baz",
|
|
|
|
InstanceStats: map[string]*DeviceStats{
|
2018-08-28 17:32:03 +00:00
|
|
|
"1": {
|
2018-11-12 00:36:20 +00:00
|
|
|
Summary: &psstructs.StatValue{
|
2018-11-13 15:06:44 +00:00
|
|
|
BoolVal: helper.BoolToPtr(true),
|
2018-08-27 23:11:07 +00:00
|
|
|
Unit: "MB",
|
|
|
|
Desc: "Unit test",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
2018-09-28 17:09:01 +00:00
|
|
|
StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
|
2018-08-27 23:11:07 +00:00
|
|
|
outCh := make(chan *StatsResponse, 1)
|
|
|
|
go func() {
|
|
|
|
// Send two messages
|
|
|
|
for _, devs := range [][]*DeviceGroupStats{devices1, devices2} {
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
return
|
|
|
|
case outCh <- &StatsResponse{Groups: devs}:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
// Get the stream
|
2018-09-28 17:09:01 +00:00
|
|
|
stream, err := impl.Stats(ctx, time.Millisecond)
|
2018-08-27 23:11:07 +00:00
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
var first *StatsResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case first = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
require.NoError(first.Error)
|
|
|
|
require.EqualValues(devices1, first.Groups)
|
|
|
|
|
|
|
|
// Get the second message
|
|
|
|
var second *StatsResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case second = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
require.NoError(second.Error)
|
|
|
|
require.EqualValues(devices2, second.Groups)
|
|
|
|
|
|
|
|
select {
|
|
|
|
case _, ok := <-stream:
|
|
|
|
require.False(ok)
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("stream should be closed")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Stats_StreamErr(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
ferr := fmt.Errorf("mock stats failed")
|
|
|
|
mock := &MockDevicePlugin{
|
2018-09-28 17:09:01 +00:00
|
|
|
StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
|
2018-08-27 23:11:07 +00:00
|
|
|
outCh := make(chan *StatsResponse, 1)
|
|
|
|
go func() {
|
|
|
|
// Send the error
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
return
|
|
|
|
case outCh <- &StatsResponse{Error: ferr}:
|
|
|
|
}
|
|
|
|
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
// Get the stream
|
2018-09-28 17:09:01 +00:00
|
|
|
stream, err := impl.Stats(ctx, time.Millisecond)
|
2018-08-27 23:11:07 +00:00
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
var first *StatsResponse
|
|
|
|
select {
|
|
|
|
case <-time.After(1 * time.Second):
|
|
|
|
t.Fatal("timeout")
|
|
|
|
case first = <-stream:
|
|
|
|
}
|
|
|
|
|
|
|
|
errStatus := status.Convert(ferr)
|
|
|
|
require.EqualError(first.Error, errStatus.Err().Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDevicePlugin_Stats_CancelCtx(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
require := require.New(t)
|
|
|
|
|
|
|
|
mock := &MockDevicePlugin{
|
2018-09-28 17:09:01 +00:00
|
|
|
StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
|
2018-08-27 23:11:07 +00:00
|
|
|
outCh := make(chan *StatsResponse, 1)
|
|
|
|
go func() {
|
|
|
|
<-ctx.Done()
|
|
|
|
close(outCh)
|
|
|
|
return
|
|
|
|
}()
|
|
|
|
return outCh, nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
|
|
|
|
base.PluginTypeBase: &base.PluginBase{Impl: mock},
|
|
|
|
base.PluginTypeDevice: &PluginDevice{Impl: mock},
|
|
|
|
})
|
|
|
|
defer server.Stop()
|
|
|
|
defer client.Close()
|
|
|
|
|
|
|
|
raw, err := client.Dispense(base.PluginTypeDevice)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("err: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
impl, ok := raw.(DevicePlugin)
|
|
|
|
if !ok {
|
|
|
|
t.Fatalf("bad: %#v", raw)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a context
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
|
|
|
// Get the stream
|
2018-09-28 17:09:01 +00:00
|
|
|
stream, err := impl.Stats(ctx, time.Millisecond)
|
2018-08-27 23:11:07 +00:00
|
|
|
require.NoError(err)
|
|
|
|
|
|
|
|
// Get the first message
|
|
|
|
select {
|
|
|
|
case <-time.After(testutil.Timeout(10 * time.Millisecond)):
|
|
|
|
case _ = <-stream:
|
|
|
|
t.Fatal("bad value")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Cancel the context
|
|
|
|
cancel()
|
|
|
|
|
|
|
|
// Make sure we are done
|
|
|
|
select {
|
|
|
|
case <-time.After(100 * time.Millisecond):
|
|
|
|
t.Fatalf("timeout")
|
|
|
|
case v := <-stream:
|
|
|
|
require.Error(v.Error)
|
|
|
|
require.EqualError(v.Error, context.Canceled.Error())
|
|
|
|
}
|
|
|
|
}
|