112 lines
2.8 KiB
Go
112 lines
2.8 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package device
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/hashicorp/nomad/plugins/base"
|
|
)
|
|
|
|
type FingerprintFn func(context.Context) (<-chan *FingerprintResponse, error)
|
|
type ReserveFn func([]string) (*ContainerReservation, error)
|
|
type StatsFn func(context.Context, time.Duration) (<-chan *StatsResponse, error)
|
|
|
|
// MockDevicePlugin is used for testing.
|
|
// Each function can be set as a closure to make assertions about how data
|
|
// is passed through the base plugin layer.
|
|
type MockDevicePlugin struct {
|
|
*base.MockPlugin
|
|
FingerprintF FingerprintFn
|
|
ReserveF ReserveFn
|
|
StatsF StatsFn
|
|
}
|
|
|
|
func (p *MockDevicePlugin) Fingerprint(ctx context.Context) (<-chan *FingerprintResponse, error) {
|
|
return p.FingerprintF(ctx)
|
|
}
|
|
|
|
func (p *MockDevicePlugin) Reserve(devices []string) (*ContainerReservation, error) {
|
|
return p.ReserveF(devices)
|
|
}
|
|
|
|
func (p *MockDevicePlugin) Stats(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
|
|
return p.StatsF(ctx, interval)
|
|
}
|
|
|
|
// Below are static implementations of the device functions
|
|
|
|
// StaticFingerprinter fingerprints the passed devices just once
|
|
func StaticFingerprinter(devices []*DeviceGroup) FingerprintFn {
|
|
return func(_ context.Context) (<-chan *FingerprintResponse, error) {
|
|
outCh := make(chan *FingerprintResponse, 1)
|
|
outCh <- &FingerprintResponse{
|
|
Devices: devices,
|
|
}
|
|
return outCh, nil
|
|
}
|
|
}
|
|
|
|
// ErrorChFingerprinter returns an error fingerprinting over the channel
|
|
func ErrorChFingerprinter(err error) FingerprintFn {
|
|
return func(_ context.Context) (<-chan *FingerprintResponse, error) {
|
|
outCh := make(chan *FingerprintResponse, 1)
|
|
outCh <- &FingerprintResponse{
|
|
Error: err,
|
|
}
|
|
return outCh, nil
|
|
}
|
|
}
|
|
|
|
// StaticReserve returns the passed container reservation
|
|
func StaticReserve(out *ContainerReservation) ReserveFn {
|
|
return func(_ []string) (*ContainerReservation, error) {
|
|
return out, nil
|
|
}
|
|
}
|
|
|
|
// ErrorReserve returns the passed error
|
|
func ErrorReserve(err error) ReserveFn {
|
|
return func(_ []string) (*ContainerReservation, error) {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// StaticStats returns the passed statistics
|
|
func StaticStats(out []*DeviceGroupStats) StatsFn {
|
|
return func(ctx context.Context, intv time.Duration) (<-chan *StatsResponse, error) {
|
|
outCh := make(chan *StatsResponse, 1)
|
|
|
|
go func() {
|
|
ticker := time.NewTimer(0)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
ticker.Reset(intv)
|
|
}
|
|
|
|
outCh <- &StatsResponse{
|
|
Groups: out,
|
|
}
|
|
}
|
|
}()
|
|
|
|
return outCh, nil
|
|
}
|
|
}
|
|
|
|
// ErrorChStats returns an error collecting stats over the channel
|
|
func ErrorChStats(err error) StatsFn {
|
|
return func(_ context.Context, _ time.Duration) (<-chan *StatsResponse, error) {
|
|
outCh := make(chan *StatsResponse, 1)
|
|
outCh <- &StatsResponse{
|
|
Error: err,
|
|
}
|
|
return outCh, nil
|
|
}
|
|
}
|