Making Nomad TLS configs region aware

This commit is contained in:
Diptanu Choudhury 2016-11-01 11:55:29 -07:00
parent c1f9d3ed61
commit 1a8fa8c8d5
4 changed files with 44 additions and 11 deletions

View File

@ -166,7 +166,7 @@ var (
// NewClient is used to create a new client from the given configuration
func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logger) (*Client, error) {
// Create the tls wrapper
var tlsWrap tlsutil.Wrapper
var tlsWrap tlsutil.RegionWrapper
if cfg.TLSConfig.EnableRPC {
tw, err := cfg.TLSConfiguration().OutgoingTLSWrapper()
if err != nil {

View File

@ -9,6 +9,22 @@ import (
"time"
)
// RegionSpecificWrapper is used to invoke a static Region and turns a
// RegionWrapper into a Wrapper type.
func RegionSpecificWrapper(region string, tlsWrap RegionWrapper) Wrapper {
if tlsWrap == nil {
return nil
}
return func(conn net.Conn) (net.Conn, error) {
return tlsWrap(region, conn)
}
}
// RegionWrapper is a function that is used to wrap a non-TLS connection and
// returns an appropriate TLS connection or error. This takes a Region as an
// argument.
type RegionWrapper func(region string, conn net.Conn) (net.Conn, error)
// Wrapper wraps a connection and enables TLS on it.
type Wrapper func(conn net.Conn) (net.Conn, error)
@ -102,6 +118,11 @@ func (c *Config) OutgoingTLSConfig() (*tls.Config, error) {
tlsConfig.ServerName = c.ServerName
tlsConfig.InsecureSkipVerify = false
}
if c.VerifyServerHostname {
// ServerName is filled in dynamically based on the target DC
tlsConfig.ServerName = "VerifyServerHostname"
tlsConfig.InsecureSkipVerify = false
}
// Ensure we have a CA if VerifyOutgoing is set
if c.VerifyOutgoing && c.CAFile == "" {
@ -128,7 +149,7 @@ func (c *Config) OutgoingTLSConfig() (*tls.Config, error) {
// OutgoingTLSWrapper returns a a Wrapper based on the OutgoingTLS
// configuration. If hostname verification is on, the wrapper
// will properly generate the dynamic server name for verification.
func (c *Config) OutgoingTLSWrapper() (Wrapper, error) {
func (c *Config) OutgoingTLSWrapper() (RegionWrapper, error) {
// Get the TLS config
tlsConfig, err := c.OutgoingTLSConfig()
if err != nil {
@ -140,10 +161,21 @@ func (c *Config) OutgoingTLSWrapper() (Wrapper, error) {
return nil, nil
}
wrapper := func(conn net.Conn) (net.Conn, error) {
return WrapTLSClient(conn, tlsConfig)
// Generate the wrapper based on hostname verification
if c.VerifyServerHostname {
wrapper := func(region string, conn net.Conn) (net.Conn, error) {
conf := *tlsConfig
conf.ServerName = "server." + region + ".nomad"
return WrapTLSClient(conn, &conf)
}
return wrapper, nil
} else {
wrapper := func(dc string, c net.Conn) (net.Conn, error) {
return WrapTLSClient(c, tlsConfig)
}
return wrapper, nil
}
return wrapper, nil
}
// Wrap a net.Conn into a client tls connection, performing any

View File

@ -129,7 +129,7 @@ type ConnPool struct {
limiter map[string]chan struct{}
// TLS wrapper
tlsWrap tlsutil.Wrapper
tlsWrap tlsutil.RegionWrapper
// Used to indicate the pool is shutdown
shutdown bool
@ -141,7 +141,7 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool {
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
pool := &ConnPool{
logOutput: logOutput,
maxTime: maxTime,
@ -261,7 +261,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn,
}
// Wrap the connection in a TLS client
tlsConn, err := p.tlsWrap(conn)
tlsConn, err := p.tlsWrap(region, conn)
if err != nil {
conn.Close()
return nil, err

View File

@ -188,7 +188,7 @@ func NewServer(config *Config, consulSyncer *consul.Syncer, logger *log.Logger)
}
// Configure TLS
var tlsWrap tlsutil.Wrapper
var tlsWrap tlsutil.RegionWrapper
var incomingTLS *tls.Config
if config.TLSConfig.EnableRPC {
tlsConf := config.tlsConfig()
@ -594,7 +594,7 @@ func (s *Server) setupVaultClient() error {
}
// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error {
// Create endpoints
s.endpoints.Status = &Status{s}
s.endpoints.Node = &Node{srv: s}
@ -640,7 +640,8 @@ func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
}
s.raftLayer = NewRaftLayer(s.rpcAdvertise, tlsWrap)
wrapper := tlsutil.RegionSpecificWrapper(s.config.Region, tlsWrap)
s.raftLayer = NewRaftLayer(s.rpcAdvertise, wrapper)
return nil
}