diff --git a/client/client.go b/client/client.go index 1dca496ed..6a647c2bf 100644 --- a/client/client.go +++ b/client/client.go @@ -30,6 +30,7 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" + nconfig "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/hashstructure" "github.com/shirou/gopsutil/host" @@ -364,6 +365,25 @@ func (c *Client) init() error { return nil } +// ReloadTLSConnectoins allows a client to reload RPC connections if the +// client's TLS configuration changes from plaintext to TLS +func (c *Client) ReloadTLSConnections(newConfig *nconfig.TLSConfig) error { + c.configLock.Lock() + defer c.configLock.Unlock() + + c.config.TLSConfig = newConfig + + if c.config.TLSConfig.EnableRPC { + tw, err := c.config.TLSConfiguration().OutgoingTLSWrapper() + if err != nil { + return err + } + c.connPool.ReloadTLS(tw) + } + + return nil +} + // Leave is used to prepare the client to leave the cluster func (c *Client) Leave() error { // TODO diff --git a/client/client_test.go b/client/client_test.go index 3492557f7..4550ebcdc 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1001,3 +1001,108 @@ func TestClient_ValidateMigrateToken_ACLDisabled(t *testing.T) { assert.Equal(c.ValidateMigrateToken("", ""), true) } + +func TestClient_ReloadTLS_UpgradePlaintextToTLS(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + s1, addr := testServer(t, func(c *nomad.Config) { + c.Region = "dc1" + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + c1 := testClient(t, func(c *config.Config) { + c.Servers = []string{addr} + }) + defer c1.Shutdown() + + newConfig := &nconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + + err := c1.ReloadTLSConnections(newConfig) + assert.Nil(err) + + req := structs.NodeSpecificRequest{ + NodeID: c1.Node().ID, + QueryOptions: structs.QueryOptions{Region: "dc1"}, + } + var out structs.SingleNodeResponse + testutil.AssertUntil(100*time.Millisecond, + func() (bool, error) { + err := c1.RPC("Node.GetNode", &req, &out) + if err == nil { + return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", err) + } + return true, nil + }, + func(err error) { + t.Fatalf(err.Error()) + }, + ) +} + +func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + s1, addr := testServer(t, func(c *nomad.Config) { + c.Region = "dc1" + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + c1 := testClient(t, func(c *config.Config) { + c.Servers = []string{addr} + c.TLSConfig = &nconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer c1.Shutdown() + + newConfig := &nconfig.TLSConfig{} + + err := c1.ReloadTLSConnections(newConfig) + assert.Nil(err) + + req := structs.NodeSpecificRequest{ + NodeID: c1.Node().ID, + QueryOptions: structs.QueryOptions{Region: "dc1"}, + } + var out structs.SingleNodeResponse + testutil.AssertUntil(100*time.Millisecond, + func() (bool, error) { + err := c1.RPC("Node.GetNode", &req, &out) + if err != nil { + return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", err) + } + return true, nil + }, + func(err error) { + t.Fatalf(err.Error()) + }, + ) +} diff --git a/command/agent/agent.go b/command/agent/agent.go index de05400fa..788085a50 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -730,29 +730,71 @@ func (a *Agent) Stats() map[string]map[string]string { return stats } +// ShouldReload determines if we should reload the configuration and agent +// connections. If the TLS Configuration has not changed, we shouldn't reload. +func (a *Agent) ShouldReload(newConfig *Config) (bool, func(*Config) error) { + if a.config.TLSConfig.Equals(newConfig.TLSConfig) { + return false, nil + } + + return true, a.Reload +} + // Reload handles configuration changes for the agent. Provides a method that // is easier to unit test, as this action is invoked via SIGHUP. func (a *Agent) Reload(newConfig *Config) error { a.configLock.Lock() defer a.configLock.Unlock() - if newConfig.TLSConfig != nil { + if newConfig == nil || newConfig.TLSConfig == nil { + return fmt.Errorf("cannot reload agent with nil configuration") + } - // TODO(chelseakomlo) In a later PR, we will introduce the ability to reload - // TLS configuration if the agent is not running with TLS enabled. - if a.config.TLSConfig != nil { - // Reload the certificates on the keyloader and on success store the - // updated TLS config. It is important to reuse the same keyloader - // as this allows us to dynamically reload configurations not only - // on the Agent but on the Server and Client too (they are - // referencing the same keyloader). - keyloader := a.config.TLSConfig.GetKeyLoader() - _, err := keyloader.LoadKeyPair(newConfig.TLSConfig.CertFile, newConfig.TLSConfig.KeyFile) - if err != nil { - return err - } - a.config.TLSConfig = newConfig.TLSConfig - a.config.TLSConfig.KeyLoader = keyloader + // This is just a TLS configuration reload, we don't need to refresh + // existing network connections + if !a.config.TLSConfig.IsEmpty() && !newConfig.TLSConfig.IsEmpty() { + + // Reload the certificates on the keyloader and on success store the + // updated TLS config. It is important to reuse the same keyloader + // as this allows us to dynamically reload configurations not only + // on the Agent but on the Server and Client too (they are + // referencing the same keyloader). + keyloader := a.config.TLSConfig.GetKeyLoader() + _, err := keyloader.LoadKeyPair(newConfig.TLSConfig.CertFile, newConfig.TLSConfig.KeyFile) + if err != nil { + return err + } + a.config.TLSConfig = newConfig.TLSConfig + a.config.TLSConfig.KeyLoader = keyloader + return nil + } + + // Completely reload the agent's TLS configuration (moving from non-TLS to + // TLS, or vice versa) + // This does not handle errors in loading the new TLS configuration + a.config.TLSConfig = newConfig.TLSConfig.Copy() + + if newConfig.TLSConfig.IsEmpty() { + a.logger.Println("[WARN] Downgrading agent's existing TLS configuration to plaintext") + } else { + a.logger.Println("[INFO] Upgrading from plaintext configuration to TLS") + } + + // Reload the TLS configuration for the client or server, depending on how + // the agent is configured to run. + if s := a.Server(); s != nil { + err := s.ReloadTLSConnections(a.config.TLSConfig) + if err != nil { + a.logger.Printf("[WARN] agent: Issue reloading the server's TLS Configuration, consider a full system restart: %v", err.Error()) + return err + } + } + if c := a.Client(); c != nil { + + err := c.ReloadTLSConnections(a.config.TLSConfig) + if err != nil { + a.logger.Printf("[ERR] agent: Issue reloading the client's TLS Configuration, consider a full system restart: %v", err.Error()) + return err } } diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 5dc8d8ce7..023060401 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -746,3 +746,197 @@ func Test_GetConfig(t *testing.T) { actualAgentConfig := agent.GetConfig() assert.Equal(actualAgentConfig, agentConfig) } + +func TestServer_Reload_TLS_WithNilConfiguration(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + logger := log.New(ioutil.Discard, "", 0) + + agent := &Agent{ + logger: logger, + config: &Config{}, + } + + err := agent.Reload(nil) + assert.NotNil(err) + assert.Equal(err.Error(), "cannot reload agent with nil configuration") +} + +func TestServer_Reload_TLS_UpgradeToTLS(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + + logger := log.New(ioutil.Discard, "", 0) + + agentConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{}, + } + + agent := &Agent{ + logger: logger, + config: agentConfig, + } + + newConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + err := agent.Reload(newConfig) + assert.Nil(err) + + assert.Equal(agent.config.TLSConfig.CAFile, newConfig.TLSConfig.CAFile) + assert.Equal(agent.config.TLSConfig.CertFile, newConfig.TLSConfig.CertFile) + assert.Equal(agent.config.TLSConfig.KeyFile, newConfig.TLSConfig.KeyFile) +} + +func TestServer_Reload_TLS_DowngradeFromTLS(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + + logger := log.New(ioutil.Discard, "", 0) + + agentConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + agent := &Agent{ + logger: logger, + config: agentConfig, + } + + newConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{}, + } + + assert.False(agentConfig.TLSConfig.IsEmpty()) + + err := agent.Reload(newConfig) + assert.Nil(err) + + assert.True(agentConfig.TLSConfig.IsEmpty()) +} + +func TestServer_ShouldReload_ReturnFalseForNoChanges(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + + logger := log.New(ioutil.Discard, "", 0) + + agentConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + sameAgentConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + agent := &Agent{ + logger: logger, + config: agentConfig, + } + + shouldReload, reloadFunc := agent.ShouldReload(sameAgentConfig) + assert.False(shouldReload) + assert.Nil(reloadFunc) +} + +func TestServer_ShouldReload_ReturnTrueForConfigChanges(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + foocert2 = "any_cert_path" + fookey2 = "any_key_path" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + + logger := log.New(ioutil.Discard, "", 0) + + agentConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + newConfig := &Config{ + TLSConfig: &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert2, + KeyFile: fookey2, + }, + } + + agent := &Agent{ + logger: logger, + config: agentConfig, + } + + shouldReload, reloadFunc := agent.ShouldReload(newConfig) + assert.True(shouldReload) + assert.NotNil(reloadFunc) +} diff --git a/command/agent/command.go b/command/agent/command.go index c1be2286e..78d7c838b 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -598,6 +598,24 @@ WAIT: } } +func (c *Command) reloadHTTPServerOnConfigChange(newConfig *Config) error { + c.agent.logger.Println("[INFO] agent: Reloading HTTP server with new TLS configuration") + err := c.httpServer.Shutdown() + if err != nil { + return err + } + + // Wait some time to ensure a clean shutdown + time.Sleep(5 * time.Second) + http, err := NewHTTPServer(c.agent, c.agent.config) + if err != nil { + return err + } + c.httpServer = http + + return nil +} + // handleReload is invoked when we should reload our configs, e.g. SIGHUP func (c *Command) handleReload() { c.Ui.Output("Reloading configuration...") @@ -620,10 +638,29 @@ func (c *Command) handleReload() { newConf.LogLevel = c.agent.GetConfig().LogLevel } - // Reloads configuration for an agent running in both client and server mode - err := c.agent.Reload(newConf) - if err != nil { - c.agent.logger.Printf("[ERR] agent: failed to reload the config: %v", err) + shouldReload, reloadFunc := c.agent.ShouldReload(newConf) + if shouldReload && reloadFunc != nil { + // Reloads configuration for an agent running in both client and server mode + err := reloadFunc(newConf) + if err != nil { + c.agent.logger.Printf("[ERR] agent: failed to reload the config: %v", err) + } + + err = c.httpServer.Shutdown() + if err != nil { + c.agent.logger.Printf("[ERR] agent: failed to stop HTTP server: %v", err) + return + } + + // Wait some time to ensure a clean shutdown + time.Sleep(5 * time.Second) + http, err := NewHTTPServer(c.agent, c.agent.config) + if err != nil { + c.agent.logger.Printf("[ERR] agent: failed to reload http server: %v", err) + return + } + c.agent.logger.Println("[INFO] agent: successfully restarted the HTTP server") + c.httpServer = http } if s := c.agent.Server(); s != nil { diff --git a/command/agent/http.go b/command/agent/http.go index 4146cffd4..2c8fa27ff 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -126,11 +126,12 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { } // Shutdown is used to shutdown the HTTP server -func (s *HTTPServer) Shutdown() { +func (s *HTTPServer) Shutdown() error { if s != nil { s.logger.Printf("[DEBUG] http: Shutting down http server") - s.listener.Close() + return s.listener.Close() } + return nil } // registerHandlers is used to attach our handlers to the mux diff --git a/command/agent/http_test.go b/command/agent/http_test.go index 5d4004c18..b145d4357 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -673,3 +673,74 @@ func encodeReq(obj interface{}) io.ReadCloser { enc.Encode(obj) return ioutil.NopCloser(buf) } + +func TestHTTP_VerifyHTTPSClientUpgrade_AfterConfigReload(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + newConfig := &Config{ + TLSConfig: &config.TLSConfig{ + EnableHTTP: true, + VerifyHTTPSClient: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + }, + } + + s := makeHTTPServer(t, func(c *Config) { + c.TLSConfig = newConfig.TLSConfig + }) + defer s.Shutdown() + + // HTTP plaintext request should succeed + reqURL := fmt.Sprintf("http://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP) + + // First test with a plaintext request + transport := &http.Transport{} + client := &http.Client{Transport: transport} + _, err := http.NewRequest("GET", reqURL, nil) + assert.Nil(err) + + // Next, reload the TLS configuration + err = s.Agent.Reload(newConfig) + assert.Nil(err) + + // PASS: Requests that specify a valid hostname, CA cert, and client + // certificate succeed. + tlsConf := &tls.Config{ + ServerName: "client.regionFoo.nomad", + RootCAs: x509.NewCertPool(), + GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + c, err := tls.LoadX509KeyPair(foocert, fookey) + if err != nil { + return nil, err + } + return &c, nil + }, + } + + // HTTPS request should succeed + httpsReqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP) + + cacertBytes, err := ioutil.ReadFile(cafile) + assert.Nil(err) + tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes) + + transport = &http.Transport{TLSClientConfig: tlsConf} + client = &http.Client{Transport: transport} + req, err := http.NewRequest("GET", httpsReqURL, nil) + assert.Nil(err) + + resp, err := client.Do(req) + assert.Nil(err) + + resp.Body.Close() + assert.Equal(resp.StatusCode, 200) +} diff --git a/nomad/pool.go b/nomad/pool.go index 320e7a320..017621c99 100644 --- a/nomad/pool.go +++ b/nomad/pool.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/yamux" ) @@ -175,6 +175,19 @@ func (p *ConnPool) Shutdown() error { return nil } +// ReloadTLS reloads TLS configuration on the fly +func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { + p.Lock() + defer p.Unlock() + + oldPool := p.pool + for _, conn := range oldPool { + conn.Close() + } + p.pool = make(map[string]*Conn) + p.tlsWrap = tlsWrap +} + // Acquire is used to get a connection that is // pooled or to return a new connection func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { diff --git a/nomad/raft_rpc.go b/nomad/raft_rpc.go index 31ac4d0a7..e769b9998 100644 --- a/nomad/raft_rpc.go +++ b/nomad/raft_rpc.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "fmt" "net" "sync" @@ -43,12 +44,14 @@ func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer { // Handoff is used to hand off a connection to the // RaftLayer. This allows it to be Accept()'ed -func (l *RaftLayer) Handoff(c net.Conn) error { +func (l *RaftLayer) Handoff(c net.Conn, ctx context.Context) error { select { case l.connCh <- c: return nil case <-l.closeCh: return fmt.Errorf("Raft RPC layer closed") + case <-ctx.Done(): + return fmt.Errorf("[INFO] nomad.rpc: Closing server RPC connection") } } @@ -110,3 +113,16 @@ func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net } return conn, err } + +// ReloadTLS will re-initialize the TLS wrapper on the fly +func (l *RaftLayer) ReloadTLS(tlsWrap tlsutil.Wrapper) { + l.closeLock.Lock() + defer l.closeLock.Unlock() + + if !l.closed { + close(l.closeCh) + } + + l.tlsWrap = tlsWrap + l.closeCh = make(chan struct{}) +} diff --git a/nomad/rpc.go b/nomad/rpc.go index 828ee0c94..8deedd194 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -68,8 +68,15 @@ func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { } // listen is used to listen for incoming RPC connections -func (s *Server) listen() { +func (s *Server) listen(ctx context.Context) { for { + select { + case <-ctx.Done(): + s.logger.Println("[INFO] nomad.rpc: Closing server RPC connection") + return + default: + } + // Accept a connection conn, err := s.rpcListener.Accept() if err != nil { @@ -80,14 +87,14 @@ func (s *Server) listen() { continue } - go s.handleConn(conn, false) + go s.handleConn(conn, false, ctx) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(conn net.Conn, isTLS bool) { +func (s *Server) handleConn(conn net.Conn, isTLS bool, ctx context.Context) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -110,14 +117,14 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // Switch on the byte switch RPCType(buf[0]) { case rpcNomad: - s.handleNomadConn(conn) + s.handleNomadConn(conn, ctx) case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) - s.raftLayer.Handoff(conn) + s.raftLayer.Handoff(conn, ctx) case rpcMultiplex: - s.handleMultiplex(conn) + s.handleMultiplex(conn, ctx) case rpcTLS: if s.rpcTLS == nil { @@ -126,7 +133,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(conn, true) + s.handleConn(conn, true, ctx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -137,7 +144,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(conn net.Conn) { +func (s *Server) handleMultiplex(conn net.Conn, ctx context.Context) { defer conn.Close() conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput @@ -150,16 +157,19 @@ func (s *Server) handleMultiplex(conn net.Conn) { } return } - go s.handleNomadConn(sub) + go s.handleNomadConn(sub, ctx) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(conn net.Conn) { +func (s *Server) handleNomadConn(conn net.Conn, ctx context.Context) { defer conn.Close() rpcCodec := NewServerCodec(conn) for { select { + case <-ctx.Done(): + s.logger.Println("[INFO] nomad.rpc: Closing server RPC connection") + return case <-s.shutdownCh: return default: diff --git a/nomad/server.go b/nomad/server.go index 7648c7436..dad73148a 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "crypto/tls" "errors" "fmt" @@ -26,6 +27,7 @@ import ( "github.com/hashicorp/nomad/nomad/deploymentwatcher" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/raft" raftboltdb "github.com/hashicorp/raft-boltdb" "github.com/hashicorp/serf/serf" @@ -109,7 +111,8 @@ type Server struct { rpcAdvertise net.Addr // rpcTLS is the TLS config for incoming TLS requests - rpcTLS *tls.Config + rpcTLS *tls.Config + rpcCancel context.CancelFunc // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. @@ -329,7 +332,9 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg go s.serfEventHandler() // Start the RPC listeners - go s.listen() + ctx, cancel := context.WithCancel(context.Background()) + s.rpcCancel = cancel + go s.listen(ctx) // Emit metrics for the eval broker go evalBroker.EmitStats(time.Second, s.shutdownCh) @@ -353,6 +358,62 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg return s, nil } +// ReloadTLSConnections updates a server's TLS configuration and reloads RPC +// connections +func (s *Server) ReloadTLSConnections(newTLSConfig *config.TLSConfig) error { + s.logger.Printf("[INFO] nomad: reloading server connections due to configuration changes") + + s.config.TLSConfig = newTLSConfig + + var tlsWrap tlsutil.RegionWrapper + var incomingTLS *tls.Config + if s.config.TLSConfig.EnableRPC { + tlsConf := s.config.tlsConfig() + tw, err := tlsConf.OutgoingTLSWrapper() + if err != nil { + return err + } + tlsWrap = tw + + itls, err := tlsConf.IncomingTLSConfig() + if err != nil { + return err + } + incomingTLS = itls + } + + if s.rpcCancel == nil { + s.logger.Printf("[ERR] nomad: No TLS Context to reset") + return fmt.Errorf("Unable to reset tls context") + } + + s.rpcTLS = incomingTLS + + s.rpcCancel() + s.connPool.ReloadTLS(tlsWrap) + + // reinitialize our rpc listener + s.rpcListener.Close() + time.Sleep(500 * time.Millisecond) + list, err := net.ListenTCP("tcp", s.config.RPCAddr) + if err != nil || list == nil { + s.logger.Printf("[ERR] nomad: No TLS listener to reload") + return err + } + s.rpcListener = list + + // reinitialize the cancel context + ctx, cancel := context.WithCancel(context.Background()) + s.rpcCancel = cancel + go s.listen(ctx) + + wrapper := tlsutil.RegionSpecificWrapper(s.config.Region, tlsWrap) + s.raftLayer.ReloadTLS(wrapper) + + s.logger.Printf("[INFO] nomad: finished reloading server connections") + return nil +} + // Shutdown is used to shutdown the server func (s *Server) Shutdown() error { s.logger.Printf("[INFO] nomad: shutting down server") @@ -497,9 +558,10 @@ func (s *Server) Leave() error { return nil } -// Reload handles a config reload. Not all config fields can handle a reload. -func (s *Server) Reload(config *Config) error { - if config == nil { +// Reload handles a config reload specific to server-only configuration. Not +// all config fields can handle a reload. +func (s *Server) Reload(newConfig *Config) error { + if newConfig == nil { return fmt.Errorf("Reload given a nil config") } @@ -507,7 +569,7 @@ func (s *Server) Reload(config *Config) error { // Handle the Vault reload. Vault should never be nil but just guard. if s.vault != nil { - if err := s.vault.SetConfig(config.VaultConfig); err != nil { + if err := s.vault.SetConfig(newConfig.VaultConfig); err != nil { multierror.Append(&mErr, err) } } diff --git a/nomad/server_test.go b/nomad/server_test.go index 04175a290..bc7eee0a1 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -13,12 +13,14 @@ import ( "time" "github.com/hashicorp/consul/lib/freeport" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/assert" ) var ( @@ -276,3 +278,100 @@ func TestServer_Reload_Vault(t *testing.T) { t.Fatalf("Vault client should be running") } } + +// Tests that the server will successfully reload its network connections, +// upgrading from plaintext to TLS if the server's TLS configuration changes. +func TestServer_Reload_TLSConnections_PlaintextToTLS(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := testServer(t, func(c *Config) { + c.DataDir = path.Join(dir, "nodeA") + }) + defer s1.Shutdown() + + // assert that the server started in plaintext mode + assert.Equal(s1.config.TLSConfig.CertFile, "") + + newTLSConfig := &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + + err := s1.ReloadTLSConnections(newTLSConfig) + assert.Nil(err) + + assert.True(s1.config.TLSConfig.Equals(newTLSConfig)) + + time.Sleep(10 * time.Second) + codec := rpcClient(t, s1) + + node := mock.Node() + req := &structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + var resp structs.GenericResponse + err = msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) + assert.NotNil(err) +} + +// Tests that the server will successfully reload its network connections, +// downgrading from TLS to plaintext if the server's TLS configuration changes. +func TestServer_Reload_TLSConnections_TLSToPlaintext(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := testServer(t, func(c *Config) { + c.DataDir = path.Join(dir, "nodeB") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + + newTLSConfig := &config.TLSConfig{} + + err := s1.ReloadTLSConnections(newTLSConfig) + assert.Nil(err) + assert.True(s1.config.TLSConfig.Equals(newTLSConfig)) + + time.Sleep(10 * time.Second) + + codec := rpcClient(t, s1) + + node := mock.Node() + req := &structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + var resp structs.GenericResponse + err = msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) + assert.Nil(err) +} diff --git a/nomad/structs/config/tls.go b/nomad/structs/config/tls.go index 651781685..c93946e34 100644 --- a/nomad/structs/config/tls.go +++ b/nomad/structs/config/tls.go @@ -141,6 +141,20 @@ func (t *TLSConfig) Copy() *TLSConfig { return new } +func (t *TLSConfig) IsEmpty() bool { + if t == nil { + return true + } + + return t.EnableHTTP == false && + t.EnableRPC == false && + t.VerifyServerHostname == false && + t.CAFile == "" && + t.CertFile == "" && + t.KeyFile == "" && + t.VerifyHTTPSClient == false +} + // Merge is used to merge two TLS configs together func (t *TLSConfig) Merge(b *TLSConfig) *TLSConfig { result := t.Copy() @@ -171,3 +185,22 @@ func (t *TLSConfig) Merge(b *TLSConfig) *TLSConfig { } return result } + +// Equals compares the fields of two TLS configuration objects, returning a +// boolean indicating if they are the same. +func (t *TLSConfig) Equals(newConfig *TLSConfig) bool { + if t == nil && newConfig == nil { + return true + } + + if t != nil && newConfig == nil { + return false + } + + return t.EnableRPC == newConfig.EnableRPC && + t.CAFile == newConfig.CAFile && + t.CertFile == newConfig.CertFile && + t.KeyFile == newConfig.KeyFile && + t.RPCUpgradeMode == newConfig.RPCUpgradeMode && + t.VerifyHTTPSClient == newConfig.VerifyHTTPSClient +}