test infrastructure for mock client RPCs (#10193)

This commit includes a new test client that allows overriding the RPC
protocols. Only the RPCs that are passed in are registered, which lets you
implement a mock RPC in the server tests. This commit includes an example of
this for the ClientCSI RPC server.
This commit is contained in:
Tim Gross 2021-03-19 10:52:43 -04:00
parent d97401f60e
commit 43622680fa
6 changed files with 166 additions and 46 deletions

View file

@ -308,8 +308,12 @@ var (
noServersErr = errors.New("no servers") noServersErr = errors.New("no servers")
) )
// NewClient is used to create a new client from the given configuration // NewClient is used to create a new client from the given configuration.
func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxies consulApi.SupportedProxiesAPI, consulService consulApi.ConsulServiceAPI) (*Client, error) { // `rpcs` is a map of RPC names to RPC structs that, if non-nil, will be
// registered via https://golang.org/pkg/net/rpc/#Server.RegisterName in place
// of the client's normal RPC handlers. This allows server tests to override
// the behavior of the client.
func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxies consulApi.SupportedProxiesAPI, consulService consulApi.ConsulServiceAPI, rpcs map[string]interface{}) (*Client, error) {
// Create the tls wrapper // Create the tls wrapper
var tlsWrap tlsutil.RegionWrapper var tlsWrap tlsutil.RegionWrapper
if cfg.TLSConfig.EnableRPC { if cfg.TLSConfig.EnableRPC {
@ -384,7 +388,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
}) })
// Setup the clients RPC server // Setup the clients RPC server
c.setupClientRpc() c.setupClientRpc(rpcs)
// Initialize the ACL state // Initialize the ACL state
if err := c.clientACLResolver.init(); err != nil { if err := c.clientACLResolver.init(); err != nil {

View file

@ -622,7 +622,7 @@ func TestClient_SaveRestoreState(t *testing.T) {
c1.config.PluginLoader = catalog.TestPluginLoaderWithOptions(t, "", c1.config.Options, nil) c1.config.PluginLoader = catalog.TestPluginLoaderWithOptions(t, "", c1.config.Options, nil)
c1.config.PluginSingletonLoader = singleton.NewSingletonLoader(logger, c1.config.PluginLoader) c1.config.PluginSingletonLoader = singleton.NewSingletonLoader(logger, c1.config.PluginLoader)
c2, err := NewClient(c1.config, consulCatalog, nil, mockService) c2, err := NewClient(c1.config, consulCatalog, nil, mockService, nil)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View file

@ -245,19 +245,24 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co
} }
// setupClientRpc is used to setup the Client's RPC endpoints // setupClientRpc is used to setup the Client's RPC endpoints
func (c *Client) setupClientRpc() { func (c *Client) setupClientRpc(rpcs map[string]interface{}) {
// Create the RPC Server
c.rpcServer = rpc.NewServer()
// Initialize the RPC handlers // Initialize the RPC handlers
if rpcs != nil {
// override RPCs
for name, rpc := range rpcs {
c.rpcServer.RegisterName(name, rpc)
}
} else {
c.endpoints.ClientStats = &ClientStats{c} c.endpoints.ClientStats = &ClientStats{c}
c.endpoints.CSI = &CSI{c} c.endpoints.CSI = &CSI{c}
c.endpoints.FileSystem = NewFileSystemEndpoint(c) c.endpoints.FileSystem = NewFileSystemEndpoint(c)
c.endpoints.Allocations = NewAllocationsEndpoint(c) c.endpoints.Allocations = NewAllocationsEndpoint(c)
c.endpoints.Agent = NewAgentEndpoint(c) c.endpoints.Agent = NewAgentEndpoint(c)
// Create the RPC Server
c.rpcServer = rpc.NewServer()
// Register the endpoints with the RPC server
c.setupClientRpcServer(c.rpcServer) c.setupClientRpcServer(c.rpcServer)
}
go c.rpcConnListener() go c.rpcConnListener()
} }

View file

@ -2,14 +2,18 @@ package client
import ( import (
"fmt" "fmt"
"net"
"net/rpc"
"time" "time"
"github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/config"
consulapi "github.com/hashicorp/nomad/client/consul" consulapi "github.com/hashicorp/nomad/client/consul"
"github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/client/fingerprint"
"github.com/hashicorp/nomad/client/servers"
agentconsul "github.com/hashicorp/nomad/command/agent/consul" agentconsul "github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/helper/pluginutils/catalog" "github.com/hashicorp/nomad/helper/pluginutils/catalog"
"github.com/hashicorp/nomad/helper/pluginutils/singleton" "github.com/hashicorp/nomad/helper/pluginutils/singleton"
"github.com/hashicorp/nomad/helper/pool"
"github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/testlog"
testing "github.com/mitchellh/go-testing-interface" testing "github.com/mitchellh/go-testing-interface"
) )
@ -21,6 +25,10 @@ import (
// and removed in the returned cleanup function. If they are overridden in the // and removed in the returned cleanup function. If they are overridden in the
// callback then the caller still must run the returned cleanup func. // callback then the caller still must run the returned cleanup func.
func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error) { func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error) {
return TestClientWithRPCs(t, cb, nil)
}
func TestClientWithRPCs(t testing.T, cb func(c *config.Config), rpcs map[string]interface{}) (*Client, func() error) {
conf, cleanup := config.TestClientConfig(t) conf, cleanup := config.TestClientConfig(t)
// Tighten the fingerprinter timeouts (must be done in client package // Tighten the fingerprinter timeouts (must be done in client package
@ -46,7 +54,7 @@ func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error)
} }
mockCatalog := agentconsul.NewMockCatalog(logger) mockCatalog := agentconsul.NewMockCatalog(logger)
mockService := consulapi.NewMockConsulServiceClient(t, logger) mockService := consulapi.NewMockConsulServiceClient(t, logger)
client, err := NewClient(conf, mockCatalog, nil, mockService) client, err := NewClient(conf, mockCatalog, nil, mockService, rpcs)
if err != nil { if err != nil {
cleanup() cleanup()
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -75,3 +83,51 @@ func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error)
} }
} }
} }
// TestRPCOnlyClient is a client that only pings to establish a connection
// with the server and then returns mock RPC responses for those interfaces
// passed in the `rpcs` parameter. Useful for testing client RPCs from the
// server. Returns the Client, a shutdown function, and any error.
func TestRPCOnlyClient(t testing.T, srvAddr net.Addr, rpcs map[string]interface{}) (*Client, func() error, error) {
var err error
conf, cleanup := config.TestClientConfig(t)
client := &Client{config: conf, logger: testlog.HCLogger(t)}
client.servers = servers.New(client.logger, client.shutdownCh, client)
client.configCopy = client.config.Copy()
client.rpcServer = rpc.NewServer()
for name, rpc := range rpcs {
client.rpcServer.RegisterName(name, rpc)
}
client.connPool = pool.NewPool(testlog.HCLogger(t), 10*time.Second, 10, nil)
cancelFunc := func() error {
ch := make(chan error)
go func() {
defer close(ch)
client.connPool.Shutdown()
client.shutdownGroup.Wait()
cleanup()
}()
select {
case <-ch:
return nil
case <-time.After(1 * time.Minute):
return fmt.Errorf("timed out while shutting down client")
}
}
go client.rpcConnListener()
_, err = client.SetServers([]string{srvAddr.String()})
if err != nil {
return nil, cancelFunc, fmt.Errorf("could not set servers: %v", err)
}
client.shutdownGroup.Go(client.registerAndHeartbeat)
return client, cancelFunc, nil
}

View file

@ -861,7 +861,8 @@ func (a *Agent) setupClient() error {
conf.StateDBFactory = state.GetStateDBFactory(conf.DevMode) conf.StateDBFactory = state.GetStateDBFactory(conf.DevMode)
} }
nomadClient, err := client.NewClient(conf, a.consulCatalog, a.consulProxies, a.consulService) nomadClient, err := client.NewClient(
conf, a.consulCatalog, a.consulProxies, a.consulService, nil)
if err != nil { if err != nil {
return fmt.Errorf("client setup failed: %v", err) return fmt.Errorf("client setup failed: %v", err)
} }

View file

@ -18,6 +18,42 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// MockClientCSI is a mock for the nomad.ClientCSI RPC server (see
// nomad/client_csi_endpoint.go). This can be used with a TestRPCOnlyClient to
// return specific plugin responses back to server RPCs for testing. Note that
// responses that have no bodies have no "Next*Response" field and will always
// return an empty response body.
type MockClientCSI struct {
NextValidateError error
NextAttachError error
NextAttachResponse *cstructs.ClientCSIControllerAttachVolumeResponse
NextDetachError error
NextNodeDetachError error
}
func newMockClientCSI() *MockClientCSI {
return &MockClientCSI{
NextAttachResponse: &cstructs.ClientCSIControllerAttachVolumeResponse{},
}
}
func (c *MockClientCSI) ControllerValidateVolume(req *cstructs.ClientCSIControllerValidateVolumeRequest, resp *cstructs.ClientCSIControllerValidateVolumeResponse) error {
return c.NextValidateError
}
func (c *MockClientCSI) ControllerAttachVolume(req *cstructs.ClientCSIControllerAttachVolumeRequest, resp *cstructs.ClientCSIControllerAttachVolumeResponse) error {
*resp = *c.NextAttachResponse
return c.NextAttachError
}
func (c *MockClientCSI) ControllerDetachVolume(req *cstructs.ClientCSIControllerDetachVolumeRequest, resp *cstructs.ClientCSIControllerDetachVolumeResponse) error {
return c.NextDetachError
}
func (c *MockClientCSI) NodeDetachVolume(req *cstructs.ClientCSINodeDetachVolumeRequest, resp *cstructs.ClientCSINodeDetachVolumeResponse) error {
return c.NextNodeDetachError
}
func TestClientCSIController_AttachVolume_Local(t *testing.T) { func TestClientCSIController_AttachVolume_Local(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) require := require.New(t)
@ -30,7 +66,7 @@ func TestClientCSIController_AttachVolume_Local(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -46,7 +82,7 @@ func TestClientCSIController_AttachVolume_Forwarded(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -62,7 +98,7 @@ func TestClientCSIController_DetachVolume_Local(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -78,7 +114,7 @@ func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -95,7 +131,7 @@ func TestClientCSIController_ValidateVolume_Local(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -112,7 +148,7 @@ func TestClientCSIController_ValidateVolume_Forwarded(t *testing.T) {
var resp structs.GenericResponse var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp)
require.NotNil(err) require.Error(err)
require.Contains(err.Error(), "no plugins registered for type") require.Contains(err.Error(), "no plugins registered for type")
} }
@ -163,9 +199,12 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) {
// returns a RPC client to the leader and a cleanup function. // returns a RPC client to the leader and a cleanup function.
func setupForward(t *testing.T) (rpc.ClientCodec, func()) { func setupForward(t *testing.T) (rpc.ClientCodec, func()) {
s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 })
s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 })
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
codec := rpcClient(t, s1) codec := rpcClient(t, s1)
c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { c1, cleanupC1 := client.TestClient(t, func(c *config.Config) {
@ -176,24 +215,22 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) {
select { select {
case <-c1.Ready(): case <-c1.Ready():
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
cleanupS1()
cleanupC1() cleanupC1()
cleanupS1()
cleanupS2()
t.Fatal("client timedout on initialize") t.Fatal("client timedout on initialize")
} }
waitForNodes(t, s1, 1, 1)
s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 })
TestJoin(t, s1, s2)
c2, cleanupC2 := client.TestClient(t, func(c *config.Config) { c2, cleanupC2 := client.TestClient(t, func(c *config.Config) {
c.Servers = []string{s2.config.RPCAddr.String()} c.Servers = []string{s2.config.RPCAddr.String()}
}) })
select { select {
case <-c2.Ready(): case <-c2.Ready():
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
cleanupS1()
cleanupC1() cleanupC1()
cleanupC2()
cleanupS1()
cleanupS2()
t.Fatal("client timedout on initialize") t.Fatal("client timedout on initialize")
} }
@ -224,10 +261,10 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) {
s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1) s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1)
cleanup := func() { cleanup := func() {
cleanupS1()
cleanupC1() cleanupC1()
cleanupS2()
cleanupC2() cleanupC2()
cleanupS2()
cleanupS1()
} }
return codec, cleanup return codec, cleanup
@ -235,23 +272,43 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) {
// sets up a single server with a client, and registers a plugin to the client. // sets up a single server with a client, and registers a plugin to the client.
func setupLocal(t *testing.T) (rpc.ClientCodec, func()) { func setupLocal(t *testing.T) (rpc.ClientCodec, func()) {
var err error
s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 })
testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s1.RPC)
codec := rpcClient(t, s1) codec := rpcClient(t, s1)
c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { mockCSI := newMockClientCSI()
c.Servers = []string{s1.config.RPCAddr.String()} mockCSI.NextValidateError = fmt.Errorf("no plugins registered for type")
}) mockCSI.NextAttachError = fmt.Errorf("no plugins registered for type")
mockCSI.NextDetachError = fmt.Errorf("no plugins registered for type")
// Wait for client initialization c1, cleanupC1 := client.TestClientWithRPCs(t,
select { func(c *config.Config) {
case <-c1.Ready(): c.Servers = []string{s1.config.RPCAddr.String()}
case <-time.After(10 * time.Second): },
cleanupS1() map[string]interface{}{"CSI": mockCSI},
)
if err != nil {
cleanupC1() cleanupC1()
t.Fatal("client timedout on initialize") cleanupS1()
require.NoError(t, err, "could not setup test client")
}
node1 := c1.Node()
node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions
req := &structs.NodeRegisterRequest{
Node: node1,
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp structs.NodeUpdateResponse
err = c1.RPC("Node.Register", req, &resp)
if err != nil {
cleanupC1()
cleanupS1()
require.NoError(t, err, "could not register client node")
} }
waitForNodes(t, s1, 1, 1) waitForNodes(t, s1, 1, 1)
@ -266,15 +323,12 @@ func setupLocal(t *testing.T) (rpc.ClientCodec, func()) {
} }
// update w/ plugin // update w/ plugin
node1 := c1.Node()
node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions
node1.CSIControllerPlugins = plugins node1.CSIControllerPlugins = plugins
s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1) s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1)
cleanup := func() { cleanup := func() {
cleanupS1()
cleanupC1() cleanupC1()
cleanupS1()
} }
return codec, cleanup return codec, cleanup