Merge pull request #3877 from hashicorp/f-tls

Add TLS to streaming RPCs
This commit is contained in:
Alex Dadgar 2018-02-20 16:09:45 -08:00 committed by GitHub
commit aa6b3acfa7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 240 additions and 1166 deletions

View file

@ -113,6 +113,11 @@ type Client struct {
connPool *pool.ConnPool connPool *pool.ConnPool
// tlsWrap is used to wrap outbound connections using TLS. It should be
// accessed using the lock.
tlsWrap tlsutil.RegionWrapper
tlsWrapLock sync.RWMutex
// servers is the list of nomad servers // servers is the list of nomad servers
servers *servers.Manager servers *servers.Manager
@ -197,6 +202,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic
consulService: consulService, consulService: consulService,
start: time.Now(), start: time.Now(),
connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
tlsWrap: tlsWrap,
streamingRpcs: structs.NewStreamingRpcRegistery(), streamingRpcs: structs.NewStreamingRpcRegistery(),
logger: logger, logger: logger,
allocs: make(map[string]*AllocRunner), allocs: make(map[string]*AllocRunner),
@ -263,7 +269,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic
// Set the preconfigured list of static servers // Set the preconfigured list of static servers
c.configLock.RLock() c.configLock.RLock()
if len(c.configCopy.Servers) > 0 { if len(c.configCopy.Servers) > 0 {
if err := c.SetServers(c.configCopy.Servers); err != nil { if err := c.setServersImpl(c.configCopy.Servers, true); err != nil {
logger.Printf("[WARN] client: None of the configured servers are valid: %v", err) logger.Printf("[WARN] client: None of the configured servers are valid: %v", err)
} }
} }
@ -389,6 +395,11 @@ func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error {
tlsWrap = tw tlsWrap = tw
} }
// Store the new tls wrapper.
c.tlsWrapLock.Lock()
c.tlsWrap = tlsWrap
c.tlsWrapLock.Unlock()
// Keep the client configuration up to date as we use configuration values to // Keep the client configuration up to date as we use configuration values to
// decide on what type of connections to accept // decide on what type of connections to accept
c.configLock.Lock() c.configLock.Lock()
@ -594,6 +605,16 @@ func (c *Client) GetServers() []string {
// SetServers sets a new list of nomad servers to connect to. As long as one // SetServers sets a new list of nomad servers to connect to. As long as one
// server is resolvable no error is returned. // server is resolvable no error is returned.
func (c *Client) SetServers(in []string) error { func (c *Client) SetServers(in []string) error {
return c.setServersImpl(in, false)
}
// setServersImpl sets a new list of nomad servers to connect to. If force is
// set, we add the server to the internal severlist even if the server could not
// be pinged. An error is returned if no endpoints were valid when non-forcing.
//
// Force should be used when setting the servers from the initial configuration
// since the server may be starting up in parallel and initial pings may fail.
func (c *Client) setServersImpl(in []string, force bool) error {
var mu sync.Mutex var mu sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
var merr multierror.Error var merr multierror.Error
@ -614,8 +635,13 @@ func (c *Client) SetServers(in []string) error {
// Try to ping to check if it is a real server // Try to ping to check if it is a real server
if err := c.Ping(addr); err != nil { if err := c.Ping(addr); err != nil {
merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err)) merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err))
// If we are forcing the setting of the servers, inject it to
// the serverlist even if we can't ping immediately.
if !force {
return return
} }
}
mu.Lock() mu.Lock()
endpoints = append(endpoints, &servers.Server{Addr: addr}) endpoints = append(endpoints, &servers.Server{Addr: addr})

View file

@ -905,8 +905,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err) return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err)
} }
return true, nil return true, nil
}, }, func(err error) {
func(err error) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
}, },
) )
@ -931,8 +930,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err) return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err)
} }
return true, nil return true, nil
}, }, func(err error) {
func(err error) {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
}, },
) )

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,3 @@
//+build nomad_test
package driver package driver
import ( import (
@ -34,11 +32,6 @@ const (
ShutdownPeriodicDuration = "test.shutdown_periodic_duration" ShutdownPeriodicDuration = "test.shutdown_periodic_duration"
) )
// Add the mock driver to the list of builtin drivers
func init() {
BuiltinDrivers["mock_driver"] = NewMockDriver
}
// MockDriverConfig is the driver configuration for the MockDriver // MockDriverConfig is the driver configuration for the MockDriver
type MockDriverConfig struct { type MockDriverConfig struct {

View file

@ -0,0 +1,8 @@
//+build nomad_test
package driver
// Add the mock driver to the list of builtin drivers
func init() {
BuiltinDrivers["mock_driver"] = NewMockDriver
}

View file

@ -151,23 +151,26 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co
tcp.SetNoDelay(true) tcp.SetNoDelay(true)
} }
// TODO TLS
// Check if TLS is enabled // Check if TLS is enabled
//if p.tlsWrap != nil { c.tlsWrapLock.RLock()
//// Switch the connection into TLS mode tlsWrap := c.tlsWrap
//if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { c.tlsWrapLock.RUnlock()
//conn.Close()
//return nil, err
//}
//// Wrap the connection in a TLS client if tlsWrap != nil {
//tlsConn, err := p.tlsWrap(region, conn) // Switch the connection into TLS mode
//if err != nil { if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
//conn.Close() conn.Close()
//return nil, err return nil, err
//} }
//conn = tlsConn
//} // Wrap the connection in a TLS client
tlsConn, err := tlsWrap(c.Region(), conn)
if err != nil {
conn.Close()
return nil, err
}
conn = tlsConn
}
// Write the multiplex byte to set the mode // Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {

View file

@ -7,6 +7,7 @@ import (
"github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
sconfig "github.com/hashicorp/nomad/nomad/structs/config"
"github.com/hashicorp/nomad/testutil" "github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -45,5 +46,70 @@ func TestRpc_streamingRpcConn_badEndpoint(t *testing.T) {
conn, err := c.streamingRpcConn(server, "Bogus") conn, err := c.streamingRpcConn(server, "Bogus")
require.Nil(conn) require.Nil(conn)
require.NotNil(err) require.NotNil(err)
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"")
}
func TestRpc_streamingRpcConn_badEndpoint_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
s1 := nomad.TestServer(t, func(c *nomad.Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 1
c.DevDisableBootstrap = true
c.TLSConfig = &sconfig.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s1.Shutdown()
testutil.WaitForLeader(t, s1.RPC)
c := TestClient(t, func(c *config.Config) {
c.Region = "regionFoo"
c.Servers = []string{s1.GetConfig().RPCAddr.String()}
c.TLSConfig = &sconfig.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer c.Shutdown()
// Wait for the client to connect
testutil.WaitForResult(func() (bool, error) {
node, err := s1.State().NodeByID(nil, c.NodeID())
if err != nil {
return false, err
}
if node == nil {
return false, errors.New("no node")
}
return node.Status == structs.NodeStatusReady, errors.New("wrong status")
}, func(err error) {
t.Fatalf("should have a clients")
})
// Get the server
server := c.servers.FindServer()
require.NotNil(server)
conn, err := c.streamingRpcConn(server, "Bogus")
require.Nil(conn)
require.NotNil(err)
require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"")
} }

View file

@ -98,18 +98,6 @@ func (s Servers) cycle() {
s[numServers-1] = start s[numServers-1] = start
} }
// removeServerByKey performs an inline removal of the first matching server
func (s Servers) removeServerByKey(targetKey string) {
for i, srv := range s {
if targetKey == srv.String() {
copy(s[i:], s[i+1:])
s[len(s)-1] = nil
s = s[:len(s)-1]
return
}
}
}
// shuffle shuffles the server list in place // shuffle shuffles the server list in place
func (s Servers) shuffle() { func (s Servers) shuffle() {
for i := len(s) - 1; i > 0; i-- { for i := len(s) - 1; i > 0; i-- {

View file

@ -42,5 +42,5 @@ func WithPrefix(t LogPrinter, prefix string) *log.Logger {
// NewLog logger with "TEST" prefix and the Lmicroseconds flag. // NewLog logger with "TEST" prefix and the Lmicroseconds flag.
func Logger(t LogPrinter) *log.Logger { func Logger(t LogPrinter) *log.Logger {
return WithPrefix(t, "TEST ") return WithPrefix(t, "")
} }

View file

@ -278,5 +278,6 @@ func TestNodeStreamingRpc_badEndpoint(t *testing.T) {
conn, err := NodeStreamingRpc(state.Session, "Bogus") conn, err := NodeStreamingRpc(state.Session, "Bogus")
require.Nil(conn) require.Nil(conn)
require.NotNil(err) require.NotNil(err)
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
} }

View file

@ -172,7 +172,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCConte
s.handleStreamingConn(conn) s.handleStreamingConn(conn)
case pool.RpcMultiplexV2: case pool.RpcMultiplexV2:
s.handleMultiplexV2(conn, ctx) s.handleMultiplexV2(ctx, conn, rpcCtx)
default: default:
s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0])
@ -286,11 +286,11 @@ func (s *Server) handleStreamingConn(conn net.Conn) {
// handleMultiplexV2 is used to multiplex a single incoming connection // handleMultiplexV2 is used to multiplex a single incoming connection
// using the Yamux multiplexer. Version 2 handling allows a single connection to // using the Yamux multiplexer. Version 2 handling allows a single connection to
// switch streams between regulars RPCs and Streaming RPCs. // switch streams between regulars RPCs and Streaming RPCs.
func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) {
defer func() { defer func() {
// Remove any potential mapping between a NodeID to this connection and // Remove any potential mapping between a NodeID to this connection and
// close the underlying connection. // close the underlying connection.
s.removeNodeConn(ctx) s.removeNodeConn(rpcCtx)
conn.Close() conn.Close()
}() }()
@ -303,11 +303,11 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) {
} }
// Update the context to store the yamux session // Update the context to store the yamux session
ctx.Session = server rpcCtx.Session = server
// Create the RPC server for this connection // Create the RPC server for this connection
rpcServer := rpc.NewServer() rpcServer := rpc.NewServer()
s.setupRpcServer(rpcServer, ctx) s.setupRpcServer(rpcServer, rpcCtx)
for { for {
// Accept a new stream // Accept a new stream
@ -331,7 +331,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) {
// Determine which handler to use // Determine which handler to use
switch pool.RPCType(buf[0]) { switch pool.RPCType(buf[0]) {
case pool.RpcNomad: case pool.RpcNomad:
go s.handleNomadConn(sub, rpcServer) go s.handleNomadConn(ctx, sub, rpcServer)
case pool.RpcStreaming: case pool.RpcStreaming:
go s.handleStreamingConn(sub) go s.handleStreamingConn(sub)
@ -476,7 +476,7 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err
tcp.SetNoDelay(true) tcp.SetNoDelay(true)
} }
if err := s.streamingRpcImpl(conn, method); err != nil { if err := s.streamingRpcImpl(conn, server.Region, method); err != nil {
return nil, err return nil, err
} }
@ -487,24 +487,27 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err
// the handshake to establish a streaming RPC for the given method. If an error // the handshake to establish a streaming RPC for the given method. If an error
// is returned, the underlying connection has been closed. Otherwise it is // is returned, the underlying connection has been closed. Otherwise it is
// assumed that the connection has been hijacked by the RPC method. // assumed that the connection has been hijacked by the RPC method.
func (s *Server) streamingRpcImpl(conn net.Conn, method string) error { func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error {
// TODO TLS
// Check if TLS is enabled // Check if TLS is enabled
//if p.tlsWrap != nil { s.tlsWrapLock.RLock()
//// Switch the connection into TLS mode tlsWrap := s.tlsWrap
//if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { s.tlsWrapLock.RUnlock()
//conn.Close()
//return nil, err
//}
//// Wrap the connection in a TLS client if tlsWrap != nil {
//tlsConn, err := p.tlsWrap(region, conn) // Switch the connection into TLS mode
//if err != nil { if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
//conn.Close() conn.Close()
//return nil, err return err
//} }
//conn = tlsConn
//} // Wrap the connection in a TLS client
tlsConn, err := tlsWrap(region, conn)
if err != nil {
conn.Close()
return err
}
conn = tlsConn
}
// Write the multiplex byte to set the mode // Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {

View file

@ -1,6 +1,7 @@
package nomad package nomad
import ( import (
"context"
"net" "net"
"net/rpc" "net/rpc"
"os" "os"
@ -201,7 +202,69 @@ func TestRPC_streamingRpcConn_badMethod(t *testing.T) {
conn, err := s1.streamingRpc(server, "Bogus") conn, err := s1.streamingRpc(server, "Bogus")
require.Nil(conn) require.Nil(conn)
require.NotNil(err) require.NotNil(err)
require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
}
func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s1.Shutdown()
s2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node2")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s2.Shutdown()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()
conn, err := s1.streamingRpc(server, "Bogus")
require.Nil(conn)
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
} }
// COMPAT: Remove in 0.10 // COMPAT: Remove in 0.10
@ -224,7 +287,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
// Start the handler // Start the handler
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go func() { go func() {
s.handleConn(p2, &RPCContext{Conn: p2}) s.handleConn(context.Background(), p2, &RPCContext{Conn: p2})
close(doneCh) close(doneCh)
}() }()
@ -257,8 +320,9 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
require.NotEmpty(l) require.NotEmpty(l)
// Make a streaming RPC // Make a streaming RPC
err = s.streamingRpcImpl(s2, "Bogus") err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
require.NotNil(err) require.NotNil(err)
require.Contains(err.Error(), "unknown rpc") require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
} }

View file

@ -112,6 +112,11 @@ type Server struct {
rpcListener net.Listener rpcListener net.Listener
listenerCh chan struct{} listenerCh chan struct{}
// tlsWrap is used to wrap outbound connections using TLS. It should be
// accessed using the lock.
tlsWrap tlsutil.RegionWrapper
tlsWrapLock sync.RWMutex
// rpcServer is the static RPC server that is used by the local agent. // rpcServer is the static RPC server that is used by the local agent.
rpcServer *rpc.Server rpcServer *rpc.Server
@ -276,6 +281,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg
consulCatalog: consulCatalog, consulCatalog: consulCatalog,
connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
logger: logger, logger: logger,
tlsWrap: tlsWrap,
rpcServer: rpc.NewServer(), rpcServer: rpc.NewServer(),
streamingRpcs: structs.NewStreamingRpcRegistery(), streamingRpcs: structs.NewStreamingRpcRegistery(),
nodeConns: make(map[string]*nodeConnState), nodeConns: make(map[string]*nodeConnState),
@ -435,6 +441,11 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error {
return err return err
} }
// Store the new tls wrapper.
s.tlsWrapLock.Lock()
s.tlsWrap = tlsWrap
s.tlsWrapLock.Unlock()
if s.rpcCancel == nil { if s.rpcCancel == nil {
err = fmt.Errorf("No existing RPC server to reset.") err = fmt.Errorf("No existing RPC server to reset.")
s.logger.Printf("[ERR] nomad: %s", err) s.logger.Printf("[ERR] nomad: %s", err)

View file

@ -16,7 +16,7 @@ type StreamingRpcHeader struct {
// StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and // StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and
// routing to the requirested handler. // routing to the requirested handler.
type StreamingRpcAck struct { type StreamingRpcAck struct {
// Error is used to return whether an error occured establishing the // Error is used to return whether an error occurred establishing the
// streaming RPC. This error occurs before entering the RPC handler. // streaming RPC. This error occurs before entering the RPC handler.
Error string Error string
} }

View file

@ -38,7 +38,7 @@ func TestACLServer(t testing.T, cb func(*Config)) (*Server, *structs.ACLToken) {
func TestServer(t testing.T, cb func(*Config)) *Server { func TestServer(t testing.T, cb func(*Config)) *Server {
// Setup the default settings // Setup the default settings
config := DefaultConfig() config := DefaultConfig()
config.Build = "0.7.0+unittest" config.Build = "0.8.0+unittest"
config.DevMode = true config.DevMode = true
nodeNum := atomic.AddUint32(&nodeNumber, 1) nodeNum := atomic.AddUint32(&nodeNumber, 1)
config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum) config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum)
@ -64,6 +64,11 @@ func TestServer(t testing.T, cb func(*Config)) *Server {
// Squelch output when -v isn't specified // Squelch output when -v isn't specified
config.LogOutput = testlog.NewWriter(t) config.LogOutput = testlog.NewWriter(t)
// Tighten the autopilot timing
config.AutopilotConfig.ServerStabilizationTime = 100 * time.Millisecond
config.ServerHealthInterval = 50 * time.Millisecond
config.AutopilotInterval = 100 * time.Millisecond
// Invoke the callback if any // Invoke the callback if any
if cb != nil { if cb != nil {
cb(config) cb(config)