Merge pull request #3877 from hashicorp/f-tls
Add TLS to streaming RPCs
This commit is contained in:
commit
aa6b3acfa7
|
@ -113,6 +113,11 @@ type Client struct {
|
||||||
|
|
||||||
connPool *pool.ConnPool
|
connPool *pool.ConnPool
|
||||||
|
|
||||||
|
// tlsWrap is used to wrap outbound connections using TLS. It should be
|
||||||
|
// accessed using the lock.
|
||||||
|
tlsWrap tlsutil.RegionWrapper
|
||||||
|
tlsWrapLock sync.RWMutex
|
||||||
|
|
||||||
// servers is the list of nomad servers
|
// servers is the list of nomad servers
|
||||||
servers *servers.Manager
|
servers *servers.Manager
|
||||||
|
|
||||||
|
@ -197,6 +202,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic
|
||||||
consulService: consulService,
|
consulService: consulService,
|
||||||
start: time.Now(),
|
start: time.Now(),
|
||||||
connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
|
connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
|
||||||
|
tlsWrap: tlsWrap,
|
||||||
streamingRpcs: structs.NewStreamingRpcRegistery(),
|
streamingRpcs: structs.NewStreamingRpcRegistery(),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
allocs: make(map[string]*AllocRunner),
|
allocs: make(map[string]*AllocRunner),
|
||||||
|
@ -263,7 +269,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic
|
||||||
// Set the preconfigured list of static servers
|
// Set the preconfigured list of static servers
|
||||||
c.configLock.RLock()
|
c.configLock.RLock()
|
||||||
if len(c.configCopy.Servers) > 0 {
|
if len(c.configCopy.Servers) > 0 {
|
||||||
if err := c.SetServers(c.configCopy.Servers); err != nil {
|
if err := c.setServersImpl(c.configCopy.Servers, true); err != nil {
|
||||||
logger.Printf("[WARN] client: None of the configured servers are valid: %v", err)
|
logger.Printf("[WARN] client: None of the configured servers are valid: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -389,6 +395,11 @@ func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error {
|
||||||
tlsWrap = tw
|
tlsWrap = tw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the new tls wrapper.
|
||||||
|
c.tlsWrapLock.Lock()
|
||||||
|
c.tlsWrap = tlsWrap
|
||||||
|
c.tlsWrapLock.Unlock()
|
||||||
|
|
||||||
// Keep the client configuration up to date as we use configuration values to
|
// Keep the client configuration up to date as we use configuration values to
|
||||||
// decide on what type of connections to accept
|
// decide on what type of connections to accept
|
||||||
c.configLock.Lock()
|
c.configLock.Lock()
|
||||||
|
@ -594,6 +605,16 @@ func (c *Client) GetServers() []string {
|
||||||
// SetServers sets a new list of nomad servers to connect to. As long as one
|
// SetServers sets a new list of nomad servers to connect to. As long as one
|
||||||
// server is resolvable no error is returned.
|
// server is resolvable no error is returned.
|
||||||
func (c *Client) SetServers(in []string) error {
|
func (c *Client) SetServers(in []string) error {
|
||||||
|
return c.setServersImpl(in, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setServersImpl sets a new list of nomad servers to connect to. If force is
|
||||||
|
// set, we add the server to the internal severlist even if the server could not
|
||||||
|
// be pinged. An error is returned if no endpoints were valid when non-forcing.
|
||||||
|
//
|
||||||
|
// Force should be used when setting the servers from the initial configuration
|
||||||
|
// since the server may be starting up in parallel and initial pings may fail.
|
||||||
|
func (c *Client) setServersImpl(in []string, force bool) error {
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
var merr multierror.Error
|
var merr multierror.Error
|
||||||
|
@ -614,7 +635,12 @@ func (c *Client) SetServers(in []string) error {
|
||||||
// Try to ping to check if it is a real server
|
// Try to ping to check if it is a real server
|
||||||
if err := c.Ping(addr); err != nil {
|
if err := c.Ping(addr); err != nil {
|
||||||
merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err))
|
merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err))
|
||||||
return
|
|
||||||
|
// If we are forcing the setting of the servers, inject it to
|
||||||
|
// the serverlist even if we can't ping immediately.
|
||||||
|
if !force {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
|
@ -905,10 +905,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
|
||||||
return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err)
|
return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
}, func(err error) {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
},
|
},
|
||||||
func(err error) {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -931,10 +930,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
|
||||||
return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err)
|
return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
}, func(err error) {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
},
|
},
|
||||||
func(err error) {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,5 +1,3 @@
|
||||||
//+build nomad_test
|
|
||||||
|
|
||||||
package driver
|
package driver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -34,11 +32,6 @@ const (
|
||||||
ShutdownPeriodicDuration = "test.shutdown_periodic_duration"
|
ShutdownPeriodicDuration = "test.shutdown_periodic_duration"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Add the mock driver to the list of builtin drivers
|
|
||||||
func init() {
|
|
||||||
BuiltinDrivers["mock_driver"] = NewMockDriver
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockDriverConfig is the driver configuration for the MockDriver
|
// MockDriverConfig is the driver configuration for the MockDriver
|
||||||
type MockDriverConfig struct {
|
type MockDriverConfig struct {
|
||||||
|
|
||||||
|
|
8
client/driver/mock_driver_testing.go
Normal file
8
client/driver/mock_driver_testing.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
//+build nomad_test
|
||||||
|
|
||||||
|
package driver
|
||||||
|
|
||||||
|
// Add the mock driver to the list of builtin drivers
|
||||||
|
func init() {
|
||||||
|
BuiltinDrivers["mock_driver"] = NewMockDriver
|
||||||
|
}
|
|
@ -151,23 +151,26 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co
|
||||||
tcp.SetNoDelay(true)
|
tcp.SetNoDelay(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO TLS
|
|
||||||
// Check if TLS is enabled
|
// Check if TLS is enabled
|
||||||
//if p.tlsWrap != nil {
|
c.tlsWrapLock.RLock()
|
||||||
//// Switch the connection into TLS mode
|
tlsWrap := c.tlsWrap
|
||||||
//if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
|
c.tlsWrapLock.RUnlock()
|
||||||
//conn.Close()
|
|
||||||
//return nil, err
|
|
||||||
//}
|
|
||||||
|
|
||||||
//// Wrap the connection in a TLS client
|
if tlsWrap != nil {
|
||||||
//tlsConn, err := p.tlsWrap(region, conn)
|
// Switch the connection into TLS mode
|
||||||
//if err != nil {
|
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
|
||||||
//conn.Close()
|
conn.Close()
|
||||||
//return nil, err
|
return nil, err
|
||||||
//}
|
}
|
||||||
//conn = tlsConn
|
|
||||||
//}
|
// Wrap the connection in a TLS client
|
||||||
|
tlsConn, err := tlsWrap(c.Region(), conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn = tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
// Write the multiplex byte to set the mode
|
// Write the multiplex byte to set the mode
|
||||||
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
|
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/hashicorp/nomad/client/config"
|
"github.com/hashicorp/nomad/client/config"
|
||||||
"github.com/hashicorp/nomad/nomad"
|
"github.com/hashicorp/nomad/nomad"
|
||||||
"github.com/hashicorp/nomad/nomad/structs"
|
"github.com/hashicorp/nomad/nomad/structs"
|
||||||
|
sconfig "github.com/hashicorp/nomad/nomad/structs/config"
|
||||||
"github.com/hashicorp/nomad/testutil"
|
"github.com/hashicorp/nomad/testutil"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -45,5 +46,70 @@ func TestRpc_streamingRpcConn_badEndpoint(t *testing.T) {
|
||||||
conn, err := c.streamingRpcConn(server, "Bogus")
|
conn, err := c.streamingRpcConn(server, "Bogus")
|
||||||
require.Nil(conn)
|
require.Nil(conn)
|
||||||
require.NotNil(err)
|
require.NotNil(err)
|
||||||
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"")
|
require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRpc_streamingRpcConn_badEndpoint_TLS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cafile = "../helper/tlsutil/testdata/ca.pem"
|
||||||
|
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
|
||||||
|
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
|
||||||
|
)
|
||||||
|
|
||||||
|
s1 := nomad.TestServer(t, func(c *nomad.Config) {
|
||||||
|
c.Region = "regionFoo"
|
||||||
|
c.BootstrapExpect = 1
|
||||||
|
c.DevDisableBootstrap = true
|
||||||
|
c.TLSConfig = &sconfig.TLSConfig{
|
||||||
|
EnableHTTP: true,
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: cafile,
|
||||||
|
CertFile: foocert,
|
||||||
|
KeyFile: fookey,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer s1.Shutdown()
|
||||||
|
testutil.WaitForLeader(t, s1.RPC)
|
||||||
|
|
||||||
|
c := TestClient(t, func(c *config.Config) {
|
||||||
|
c.Region = "regionFoo"
|
||||||
|
c.Servers = []string{s1.GetConfig().RPCAddr.String()}
|
||||||
|
c.TLSConfig = &sconfig.TLSConfig{
|
||||||
|
EnableHTTP: true,
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: cafile,
|
||||||
|
CertFile: foocert,
|
||||||
|
KeyFile: fookey,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer c.Shutdown()
|
||||||
|
|
||||||
|
// Wait for the client to connect
|
||||||
|
testutil.WaitForResult(func() (bool, error) {
|
||||||
|
node, err := s1.State().NodeByID(nil, c.NodeID())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
return false, errors.New("no node")
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.Status == structs.NodeStatusReady, errors.New("wrong status")
|
||||||
|
}, func(err error) {
|
||||||
|
t.Fatalf("should have a clients")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get the server
|
||||||
|
server := c.servers.FindServer()
|
||||||
|
require.NotNil(server)
|
||||||
|
|
||||||
|
conn, err := c.streamingRpcConn(server, "Bogus")
|
||||||
|
require.Nil(conn)
|
||||||
|
require.NotNil(err)
|
||||||
|
require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"")
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,18 +98,6 @@ func (s Servers) cycle() {
|
||||||
s[numServers-1] = start
|
s[numServers-1] = start
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeServerByKey performs an inline removal of the first matching server
|
|
||||||
func (s Servers) removeServerByKey(targetKey string) {
|
|
||||||
for i, srv := range s {
|
|
||||||
if targetKey == srv.String() {
|
|
||||||
copy(s[i:], s[i+1:])
|
|
||||||
s[len(s)-1] = nil
|
|
||||||
s = s[:len(s)-1]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// shuffle shuffles the server list in place
|
// shuffle shuffles the server list in place
|
||||||
func (s Servers) shuffle() {
|
func (s Servers) shuffle() {
|
||||||
for i := len(s) - 1; i > 0; i-- {
|
for i := len(s) - 1; i > 0; i-- {
|
||||||
|
|
|
@ -42,5 +42,5 @@ func WithPrefix(t LogPrinter, prefix string) *log.Logger {
|
||||||
|
|
||||||
// NewLog logger with "TEST" prefix and the Lmicroseconds flag.
|
// NewLog logger with "TEST" prefix and the Lmicroseconds flag.
|
||||||
func Logger(t LogPrinter) *log.Logger {
|
func Logger(t LogPrinter) *log.Logger {
|
||||||
return WithPrefix(t, "TEST ")
|
return WithPrefix(t, "")
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,5 +278,6 @@ func TestNodeStreamingRpc_badEndpoint(t *testing.T) {
|
||||||
conn, err := NodeStreamingRpc(state.Session, "Bogus")
|
conn, err := NodeStreamingRpc(state.Session, "Bogus")
|
||||||
require.Nil(conn)
|
require.Nil(conn)
|
||||||
require.NotNil(err)
|
require.NotNil(err)
|
||||||
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"")
|
require.Contains(err.Error(), "Bogus")
|
||||||
|
require.True(structs.IsErrUnknownMethod(err))
|
||||||
}
|
}
|
||||||
|
|
49
nomad/rpc.go
49
nomad/rpc.go
|
@ -172,7 +172,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCConte
|
||||||
s.handleStreamingConn(conn)
|
s.handleStreamingConn(conn)
|
||||||
|
|
||||||
case pool.RpcMultiplexV2:
|
case pool.RpcMultiplexV2:
|
||||||
s.handleMultiplexV2(conn, ctx)
|
s.handleMultiplexV2(ctx, conn, rpcCtx)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0])
|
s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0])
|
||||||
|
@ -286,11 +286,11 @@ func (s *Server) handleStreamingConn(conn net.Conn) {
|
||||||
// handleMultiplexV2 is used to multiplex a single incoming connection
|
// handleMultiplexV2 is used to multiplex a single incoming connection
|
||||||
// using the Yamux multiplexer. Version 2 handling allows a single connection to
|
// using the Yamux multiplexer. Version 2 handling allows a single connection to
|
||||||
// switch streams between regulars RPCs and Streaming RPCs.
|
// switch streams between regulars RPCs and Streaming RPCs.
|
||||||
func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) {
|
func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Remove any potential mapping between a NodeID to this connection and
|
// Remove any potential mapping between a NodeID to this connection and
|
||||||
// close the underlying connection.
|
// close the underlying connection.
|
||||||
s.removeNodeConn(ctx)
|
s.removeNodeConn(rpcCtx)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -303,11 +303,11 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the context to store the yamux session
|
// Update the context to store the yamux session
|
||||||
ctx.Session = server
|
rpcCtx.Session = server
|
||||||
|
|
||||||
// Create the RPC server for this connection
|
// Create the RPC server for this connection
|
||||||
rpcServer := rpc.NewServer()
|
rpcServer := rpc.NewServer()
|
||||||
s.setupRpcServer(rpcServer, ctx)
|
s.setupRpcServer(rpcServer, rpcCtx)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Accept a new stream
|
// Accept a new stream
|
||||||
|
@ -331,7 +331,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) {
|
||||||
// Determine which handler to use
|
// Determine which handler to use
|
||||||
switch pool.RPCType(buf[0]) {
|
switch pool.RPCType(buf[0]) {
|
||||||
case pool.RpcNomad:
|
case pool.RpcNomad:
|
||||||
go s.handleNomadConn(sub, rpcServer)
|
go s.handleNomadConn(ctx, sub, rpcServer)
|
||||||
case pool.RpcStreaming:
|
case pool.RpcStreaming:
|
||||||
go s.handleStreamingConn(sub)
|
go s.handleStreamingConn(sub)
|
||||||
|
|
||||||
|
@ -476,7 +476,7 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err
|
||||||
tcp.SetNoDelay(true)
|
tcp.SetNoDelay(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.streamingRpcImpl(conn, method); err != nil {
|
if err := s.streamingRpcImpl(conn, server.Region, method); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -487,24 +487,27 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err
|
||||||
// the handshake to establish a streaming RPC for the given method. If an error
|
// the handshake to establish a streaming RPC for the given method. If an error
|
||||||
// is returned, the underlying connection has been closed. Otherwise it is
|
// is returned, the underlying connection has been closed. Otherwise it is
|
||||||
// assumed that the connection has been hijacked by the RPC method.
|
// assumed that the connection has been hijacked by the RPC method.
|
||||||
func (s *Server) streamingRpcImpl(conn net.Conn, method string) error {
|
func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error {
|
||||||
// TODO TLS
|
|
||||||
// Check if TLS is enabled
|
// Check if TLS is enabled
|
||||||
//if p.tlsWrap != nil {
|
s.tlsWrapLock.RLock()
|
||||||
//// Switch the connection into TLS mode
|
tlsWrap := s.tlsWrap
|
||||||
//if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
|
s.tlsWrapLock.RUnlock()
|
||||||
//conn.Close()
|
|
||||||
//return nil, err
|
|
||||||
//}
|
|
||||||
|
|
||||||
//// Wrap the connection in a TLS client
|
if tlsWrap != nil {
|
||||||
//tlsConn, err := p.tlsWrap(region, conn)
|
// Switch the connection into TLS mode
|
||||||
//if err != nil {
|
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
|
||||||
//conn.Close()
|
conn.Close()
|
||||||
//return nil, err
|
return err
|
||||||
//}
|
}
|
||||||
//conn = tlsConn
|
|
||||||
//}
|
// Wrap the connection in a TLS client
|
||||||
|
tlsConn, err := tlsWrap(region, conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn = tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
// Write the multiplex byte to set the mode
|
// Write the multiplex byte to set the mode
|
||||||
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
|
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package nomad
|
package nomad
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
"os"
|
"os"
|
||||||
|
@ -201,7 +202,69 @@ func TestRPC_streamingRpcConn_badMethod(t *testing.T) {
|
||||||
conn, err := s1.streamingRpc(server, "Bogus")
|
conn, err := s1.streamingRpc(server, "Bogus")
|
||||||
require.Nil(conn)
|
require.Nil(conn)
|
||||||
require.NotNil(err)
|
require.NotNil(err)
|
||||||
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"")
|
require.Contains(err.Error(), "Bogus")
|
||||||
|
require.True(structs.IsErrUnknownMethod(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require := require.New(t)
|
||||||
|
const (
|
||||||
|
cafile = "../helper/tlsutil/testdata/ca.pem"
|
||||||
|
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
|
||||||
|
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
|
||||||
|
)
|
||||||
|
dir := tmpDir(t)
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
s1 := TestServer(t, func(c *Config) {
|
||||||
|
c.Region = "regionFoo"
|
||||||
|
c.BootstrapExpect = 2
|
||||||
|
c.DevMode = false
|
||||||
|
c.DevDisableBootstrap = true
|
||||||
|
c.DataDir = path.Join(dir, "node1")
|
||||||
|
c.TLSConfig = &config.TLSConfig{
|
||||||
|
EnableHTTP: true,
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: cafile,
|
||||||
|
CertFile: foocert,
|
||||||
|
KeyFile: fookey,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer s1.Shutdown()
|
||||||
|
|
||||||
|
s2 := TestServer(t, func(c *Config) {
|
||||||
|
c.Region = "regionFoo"
|
||||||
|
c.BootstrapExpect = 2
|
||||||
|
c.DevMode = false
|
||||||
|
c.DevDisableBootstrap = true
|
||||||
|
c.DataDir = path.Join(dir, "node2")
|
||||||
|
c.TLSConfig = &config.TLSConfig{
|
||||||
|
EnableHTTP: true,
|
||||||
|
EnableRPC: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: cafile,
|
||||||
|
CertFile: foocert,
|
||||||
|
KeyFile: fookey,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
defer s2.Shutdown()
|
||||||
|
|
||||||
|
TestJoin(t, s1, s2)
|
||||||
|
testutil.WaitForLeader(t, s1.RPC)
|
||||||
|
|
||||||
|
s1.peerLock.RLock()
|
||||||
|
ok, parts := isNomadServer(s2.LocalMember())
|
||||||
|
require.True(ok)
|
||||||
|
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
|
||||||
|
require.NotNil(server)
|
||||||
|
s1.peerLock.RUnlock()
|
||||||
|
|
||||||
|
conn, err := s1.streamingRpc(server, "Bogus")
|
||||||
|
require.Nil(conn)
|
||||||
|
require.NotNil(err)
|
||||||
|
require.Contains(err.Error(), "Bogus")
|
||||||
|
require.True(structs.IsErrUnknownMethod(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// COMPAT: Remove in 0.10
|
// COMPAT: Remove in 0.10
|
||||||
|
@ -224,7 +287,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
|
||||||
// Start the handler
|
// Start the handler
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.handleConn(p2, &RPCContext{Conn: p2})
|
s.handleConn(context.Background(), p2, &RPCContext{Conn: p2})
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -257,8 +320,9 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
|
||||||
require.NotEmpty(l)
|
require.NotEmpty(l)
|
||||||
|
|
||||||
// Make a streaming RPC
|
// Make a streaming RPC
|
||||||
err = s.streamingRpcImpl(s2, "Bogus")
|
err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
|
||||||
require.NotNil(err)
|
require.NotNil(err)
|
||||||
require.Contains(err.Error(), "unknown rpc")
|
require.Contains(err.Error(), "Bogus")
|
||||||
|
require.True(structs.IsErrUnknownMethod(err))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,6 +112,11 @@ type Server struct {
|
||||||
rpcListener net.Listener
|
rpcListener net.Listener
|
||||||
listenerCh chan struct{}
|
listenerCh chan struct{}
|
||||||
|
|
||||||
|
// tlsWrap is used to wrap outbound connections using TLS. It should be
|
||||||
|
// accessed using the lock.
|
||||||
|
tlsWrap tlsutil.RegionWrapper
|
||||||
|
tlsWrapLock sync.RWMutex
|
||||||
|
|
||||||
// rpcServer is the static RPC server that is used by the local agent.
|
// rpcServer is the static RPC server that is used by the local agent.
|
||||||
rpcServer *rpc.Server
|
rpcServer *rpc.Server
|
||||||
|
|
||||||
|
@ -276,6 +281,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg
|
||||||
consulCatalog: consulCatalog,
|
consulCatalog: consulCatalog,
|
||||||
connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
|
connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
tlsWrap: tlsWrap,
|
||||||
rpcServer: rpc.NewServer(),
|
rpcServer: rpc.NewServer(),
|
||||||
streamingRpcs: structs.NewStreamingRpcRegistery(),
|
streamingRpcs: structs.NewStreamingRpcRegistery(),
|
||||||
nodeConns: make(map[string]*nodeConnState),
|
nodeConns: make(map[string]*nodeConnState),
|
||||||
|
@ -435,6 +441,11 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the new tls wrapper.
|
||||||
|
s.tlsWrapLock.Lock()
|
||||||
|
s.tlsWrap = tlsWrap
|
||||||
|
s.tlsWrapLock.Unlock()
|
||||||
|
|
||||||
if s.rpcCancel == nil {
|
if s.rpcCancel == nil {
|
||||||
err = fmt.Errorf("No existing RPC server to reset.")
|
err = fmt.Errorf("No existing RPC server to reset.")
|
||||||
s.logger.Printf("[ERR] nomad: %s", err)
|
s.logger.Printf("[ERR] nomad: %s", err)
|
||||||
|
|
|
@ -16,7 +16,7 @@ type StreamingRpcHeader struct {
|
||||||
// StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and
|
// StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and
|
||||||
// routing to the requirested handler.
|
// routing to the requirested handler.
|
||||||
type StreamingRpcAck struct {
|
type StreamingRpcAck struct {
|
||||||
// Error is used to return whether an error occured establishing the
|
// Error is used to return whether an error occurred establishing the
|
||||||
// streaming RPC. This error occurs before entering the RPC handler.
|
// streaming RPC. This error occurs before entering the RPC handler.
|
||||||
Error string
|
Error string
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func TestACLServer(t testing.T, cb func(*Config)) (*Server, *structs.ACLToken) {
|
||||||
func TestServer(t testing.T, cb func(*Config)) *Server {
|
func TestServer(t testing.T, cb func(*Config)) *Server {
|
||||||
// Setup the default settings
|
// Setup the default settings
|
||||||
config := DefaultConfig()
|
config := DefaultConfig()
|
||||||
config.Build = "0.7.0+unittest"
|
config.Build = "0.8.0+unittest"
|
||||||
config.DevMode = true
|
config.DevMode = true
|
||||||
nodeNum := atomic.AddUint32(&nodeNumber, 1)
|
nodeNum := atomic.AddUint32(&nodeNumber, 1)
|
||||||
config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum)
|
config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum)
|
||||||
|
@ -64,6 +64,11 @@ func TestServer(t testing.T, cb func(*Config)) *Server {
|
||||||
// Squelch output when -v isn't specified
|
// Squelch output when -v isn't specified
|
||||||
config.LogOutput = testlog.NewWriter(t)
|
config.LogOutput = testlog.NewWriter(t)
|
||||||
|
|
||||||
|
// Tighten the autopilot timing
|
||||||
|
config.AutopilotConfig.ServerStabilizationTime = 100 * time.Millisecond
|
||||||
|
config.ServerHealthInterval = 50 * time.Millisecond
|
||||||
|
config.AutopilotInterval = 100 * time.Millisecond
|
||||||
|
|
||||||
// Invoke the callback if any
|
// Invoke the callback if any
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
cb(config)
|
cb(config)
|
||||||
|
|
Loading…
Reference in a new issue