diff --git a/client/client.go b/client/client.go index a4a85c031..757f32799 100644 --- a/client/client.go +++ b/client/client.go @@ -113,6 +113,11 @@ type Client struct { connPool *pool.ConnPool + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + // servers is the list of nomad servers servers *servers.Manager @@ -197,6 +202,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic consulService: consulService, start: time.Now(), connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), + tlsWrap: tlsWrap, streamingRpcs: structs.NewStreamingRpcRegistery(), logger: logger, allocs: make(map[string]*AllocRunner), @@ -263,7 +269,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic // Set the preconfigured list of static servers c.configLock.RLock() if len(c.configCopy.Servers) > 0 { - if err := c.SetServers(c.configCopy.Servers); err != nil { + if err := c.setServersImpl(c.configCopy.Servers, true); err != nil { logger.Printf("[WARN] client: None of the configured servers are valid: %v", err) } } @@ -389,6 +395,11 @@ func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error { tlsWrap = tw } + // Store the new tls wrapper. + c.tlsWrapLock.Lock() + c.tlsWrap = tlsWrap + c.tlsWrapLock.Unlock() + // Keep the client configuration up to date as we use configuration values to // decide on what type of connections to accept c.configLock.Lock() @@ -594,6 +605,16 @@ func (c *Client) GetServers() []string { // SetServers sets a new list of nomad servers to connect to. As long as one // server is resolvable no error is returned. func (c *Client) SetServers(in []string) error { + return c.setServersImpl(in, false) +} + +// setServersImpl sets a new list of nomad servers to connect to. If force is +// set, we add the server to the internal severlist even if the server could not +// be pinged. An error is returned if no endpoints were valid when non-forcing. +// +// Force should be used when setting the servers from the initial configuration +// since the server may be starting up in parallel and initial pings may fail. +func (c *Client) setServersImpl(in []string, force bool) error { var mu sync.Mutex var wg sync.WaitGroup var merr multierror.Error @@ -614,7 +635,12 @@ func (c *Client) SetServers(in []string) error { // Try to ping to check if it is a real server if err := c.Ping(addr); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err)) - return + + // If we are forcing the setting of the servers, inject it to + // the serverlist even if we can't ping immediately. + if !force { + return + } } mu.Lock() diff --git a/client/client_test.go b/client/client_test.go index a612ba71b..13d4debfb 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -905,10 +905,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err) } return true, nil + }, func(err error) { + t.Fatalf(err.Error()) }, - func(err error) { - t.Fatalf(err.Error()) - }, ) } @@ -931,10 +930,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err) } return true, nil + }, func(err error) { + t.Fatalf(err.Error()) }, - func(err error) { - t.Fatalf(err.Error()) - }, ) } } diff --git a/client/client_test.go.orig b/client/client_test.go.orig deleted file mode 100644 index 68664ebf1..000000000 --- a/client/client_test.go.orig +++ /dev/null @@ -1,1092 +0,0 @@ -package client - -import ( - "fmt" - "io/ioutil" - "log" - "os" - "path/filepath" - "testing" - "time" - - memdb "github.com/hashicorp/go-memdb" - "github.com/hashicorp/nomad/client/config" -<<<<<<< ours - "github.com/hashicorp/nomad/client/driver" - "github.com/hashicorp/nomad/client/fingerprint" -======= ->>>>>>> theirs - "github.com/hashicorp/nomad/command/agent/consul" - "github.com/hashicorp/nomad/helper/uuid" - "github.com/hashicorp/nomad/nomad" - "github.com/hashicorp/nomad/nomad/mock" - "github.com/hashicorp/nomad/nomad/structs" - nconfig "github.com/hashicorp/nomad/nomad/structs/config" - "github.com/hashicorp/nomad/testutil" - "github.com/mitchellh/hashstructure" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - ctestutil "github.com/hashicorp/nomad/client/testutil" -) - -func testACLServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string, *structs.ACLToken) { - server, token := nomad.TestACLServer(t, cb) - return server, server.GetConfig().RPCAddr.String(), token -} - -func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) { - server := nomad.TestServer(t, cb) - return server, server.GetConfig().RPCAddr.String() -} - -func TestClient_StartStop(t *testing.T) { - t.Parallel() - client := TestClient(t, nil) - if err := client.Shutdown(); err != nil { - t.Fatalf("err: %v", err) - } -} - -// Certain labels for metrics are dependant on client initial setup. This tests -// that the client has properly initialized before we assign values to labels -func TestClient_BaseLabels(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - client := TestClient(t, nil) - if err := client.Shutdown(); err != nil { - t.Fatalf("err: %v", err) - } - - // directly invoke this function, as otherwise this will fail on a CI build - // due to a race condition - client.emitStats() - - baseLabels := client.baseLabels - assert.NotEqual(0, len(baseLabels)) - - nodeID := client.Node().ID - for _, e := range baseLabels { - if e.Name == "node_id" { - assert.Equal(nodeID, e.Value) - } - } -} - -func TestClient_RPC(t *testing.T) { - t.Parallel() - s1, addr := testServer(t, nil) - defer s1.Shutdown() - - c1 := TestClient(t, func(c *config.Config) { - c.Servers = []string{addr} - }) - defer c1.Shutdown() - - // RPC should succeed - testutil.WaitForResult(func() (bool, error) { - var out struct{} - err := c1.RPC("Status.Ping", struct{}{}, &out) - return err == nil, err - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_RPC_Passthrough(t *testing.T) { - t.Parallel() - s1, _ := testServer(t, nil) - defer s1.Shutdown() - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - // RPC should succeed - testutil.WaitForResult(func() (bool, error) { - var out struct{} - err := c1.RPC("Status.Ping", struct{}{}, &out) - return err == nil, err - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_Fingerprint(t *testing.T) { - t.Parallel() -<<<<<<< ours - require := require.New(t) - - driver.CheckForMockDriver(t) - - c := testClient(t, nil) -======= - c := TestClient(t, nil) ->>>>>>> theirs - defer c.Shutdown() - - // Ensure default values are present - node := c.Node() - require.NotEqual("", node.Attributes["kernel.name"]) - require.NotEqual("", node.Attributes["cpu.arch"]) - require.NotEqual("", node.Attributes["driver.mock_driver"]) -} - -func TestClient_HasNodeChanged(t *testing.T) { - t.Parallel() - c := TestClient(t, nil) - defer c.Shutdown() - - node := c.config.Node - attrHash, err := hashstructure.Hash(node.Attributes, nil) - if err != nil { - c.logger.Printf("[DEBUG] client: unable to calculate node attributes hash: %v", err) - } - // Calculate node meta map hash - metaHash, err := hashstructure.Hash(node.Meta, nil) - if err != nil { - c.logger.Printf("[DEBUG] client: unable to calculate node meta hash: %v", err) - } - if changed, _, _ := c.hasNodeChanged(attrHash, metaHash); changed { - t.Fatalf("Unexpected hash change.") - } - - // Change node attribute - node.Attributes["arch"] = "xyz_86" - if changed, newAttrHash, _ := c.hasNodeChanged(attrHash, metaHash); !changed { - t.Fatalf("Expected hash change in attributes: %d vs %d", attrHash, newAttrHash) - } - - // Change node meta map - node.Meta["foo"] = "bar" - if changed, _, newMetaHash := c.hasNodeChanged(attrHash, metaHash); !changed { - t.Fatalf("Expected hash change in meta map: %d vs %d", metaHash, newMetaHash) - } -} - -func TestClient_Fingerprint_Periodic(t *testing.T) { - driver.CheckForMockDriver(t) - t.Parallel() -<<<<<<< ours - - // these constants are only defined when nomad_test is enabled, so these fail - // our linter without explicit disabling. - c1 := testClient(t, func(c *config.Config) { - c.Options = map[string]string{ - driver.ShutdownPeriodicAfter: "true", // nolint: varcheck - driver.ShutdownPeriodicDuration: "3", // nolint: varcheck -======= - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) - } - - // Weird spacing to test trimming. Whitelist all modules expect cpu. - c.Options["fingerprint.whitelist"] = " arch, consul,cpu,env_aws,env_gce,host,memory,network,storage,foo,bar " - }) - defer c.Shutdown() - - node := c.Node() - if node.Attributes["cpu.frequency"] == "" { - t.Fatalf("missing cpu fingerprint module") - } -} - -func TestClient_Fingerprint_InBlacklist(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) - } - - // Weird spacing to test trimming. Blacklist cpu. - c.Options["fingerprint.blacklist"] = " cpu " - }) - defer c.Shutdown() - - node := c.Node() - if node.Attributes["cpu.frequency"] != "" { - t.Fatalf("cpu fingerprint module loaded despite blacklisting") - } -} - -func TestClient_Fingerprint_OutOfWhitelist(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) - } - - c.Options["fingerprint.whitelist"] = "arch,consul,env_aws,env_gce,host,memory,network,storage,foo,bar" - }) - defer c.Shutdown() - - node := c.Node() - if node.Attributes["cpu.frequency"] != "" { - t.Fatalf("found cpu fingerprint module") - } -} - -func TestClient_Fingerprint_WhitelistBlacklistCombination(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) - } - - // With both white- and blacklist, should return the set difference of modules (arch, cpu) - c.Options["fingerprint.whitelist"] = "arch,memory,cpu" - c.Options["fingerprint.blacklist"] = "memory,nomad" - }) - defer c.Shutdown() - - node := c.Node() - // Check expected modules are present - if node.Attributes["cpu.frequency"] == "" { - t.Fatalf("missing cpu fingerprint module") - } - if node.Attributes["cpu.arch"] == "" { - t.Fatalf("missing arch fingerprint module") - } - // Check remainder _not_ present - if node.Attributes["memory.totalbytes"] != "" { - t.Fatalf("found memory fingerprint module") - } - if node.Attributes["nomad.version"] != "" { - t.Fatalf("found nomad fingerprint module") - } -} - -func TestClient_Drivers_InWhitelist(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) ->>>>>>> theirs - } - }) - defer c1.Shutdown() - - node := c1.config.Node - mockDriverName := "driver.mock_driver" - -<<<<<<< ours - // Ensure the mock driver is registered on the client - testutil.WaitForResult(func() (bool, error) { - mockDriverStatus := node.Attributes[mockDriverName] - if mockDriverStatus == "" { - return false, fmt.Errorf("mock driver attribute should be set on the client") -======= -func TestClient_Drivers_InBlacklist(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) ->>>>>>> theirs - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - -<<<<<<< ours - // Ensure that the client fingerprinter eventually removes this attribute - testutil.WaitForResult(func() (bool, error) { - mockDriverStatus := node.Attributes[mockDriverName] - if mockDriverStatus != "" { - return false, fmt.Errorf("mock driver attribute should not be set on the client") -======= - node := c.Node() - if node.Attributes["driver.raw_exec"] != "" { - t.Fatalf("raw_exec driver loaded despite blacklist") - } -} - -func TestClient_Drivers_OutOfWhitelist(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) ->>>>>>> theirs - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -<<<<<<< ours -======= - defer c.Shutdown() - - node := c.Node() - if node.Attributes["driver.exec"] != "" { - t.Fatalf("found exec driver") - } -} - -func TestClient_Drivers_WhitelistBlacklistCombination(t *testing.T) { - t.Parallel() - c := TestClient(t, func(c *config.Config) { - if c.Options == nil { - c.Options = make(map[string]string) - } - - // Expected output is set difference (raw_exec) - c.Options["driver.whitelist"] = "raw_exec,exec" - c.Options["driver.blacklist"] = "exec" - }) - defer c.Shutdown() - - node := c.Node() - // Check expected present - if node.Attributes["driver.raw_exec"] == "" { - t.Fatalf("missing raw_exec driver") - } - // Check expected absent - if node.Attributes["driver.exec"] != "" { - t.Fatalf("exec driver loaded despite blacklist") - } ->>>>>>> theirs -} - -// TestClient_MixedTLS asserts that when a server is running with TLS enabled -// it will reject any RPC connections from clients that lack TLS. See #2525 -func TestClient_MixedTLS(t *testing.T) { - t.Parallel() - const ( - cafile = "../helper/tlsutil/testdata/ca.pem" - foocert = "../helper/tlsutil/testdata/nomad-foo.pem" - fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" - ) - s1, addr := testServer(t, func(c *nomad.Config) { - c.TLSConfig = &nconfig.TLSConfig{ - EnableHTTP: true, - EnableRPC: true, - VerifyServerHostname: true, - CAFile: cafile, - CertFile: foocert, - KeyFile: fookey, - } - }) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.Servers = []string{addr} - }) - defer c1.Shutdown() - - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "global"}, - } - var out structs.SingleNodeResponse - testutil.AssertUntil(100*time.Millisecond, - func() (bool, error) { - err := c1.RPC("Node.GetNode", &req, &out) - if err == nil { - return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", out) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) -} - -// TestClient_BadTLS asserts that when a client and server are running with TLS -// enabled -- but their certificates are signed by different CAs -- they're -// unable to communicate. -func TestClient_BadTLS(t *testing.T) { - t.Parallel() - const ( - cafile = "../helper/tlsutil/testdata/ca.pem" - foocert = "../helper/tlsutil/testdata/nomad-foo.pem" - fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" - badca = "../helper/tlsutil/testdata/ca-bad.pem" - badcert = "../helper/tlsutil/testdata/nomad-bad.pem" - badkey = "../helper/tlsutil/testdata/nomad-bad-key.pem" - ) - s1, addr := testServer(t, func(c *nomad.Config) { - c.TLSConfig = &nconfig.TLSConfig{ - EnableHTTP: true, - EnableRPC: true, - VerifyServerHostname: true, - CAFile: cafile, - CertFile: foocert, - KeyFile: fookey, - } - }) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.Servers = []string{addr} - c.TLSConfig = &nconfig.TLSConfig{ - EnableHTTP: true, - EnableRPC: true, - VerifyServerHostname: true, - CAFile: badca, - CertFile: badcert, - KeyFile: badkey, - } - }) - defer c1.Shutdown() - - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "global"}, - } - var out structs.SingleNodeResponse - testutil.AssertUntil(100*time.Millisecond, - func() (bool, error) { - err := c1.RPC("Node.GetNode", &req, &out) - if err == nil { - return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", out) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) -} - -func TestClient_Register(t *testing.T) { - t.Parallel() - s1, _ := testServer(t, nil) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "global"}, - } - var out structs.SingleNodeResponse - - // Register should succeed - testutil.WaitForResult(func() (bool, error) { - err := s1.RPC("Node.GetNode", &req, &out) - if err != nil { - return false, err - } - if out.Node == nil { - return false, fmt.Errorf("missing reg") - } - return out.Node.ID == req.NodeID, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_Heartbeat(t *testing.T) { - t.Parallel() - s1, _ := testServer(t, func(c *nomad.Config) { - c.MinHeartbeatTTL = 50 * time.Millisecond - }) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "global"}, - } - var out structs.SingleNodeResponse - - // Register should succeed - testutil.WaitForResult(func() (bool, error) { - err := s1.RPC("Node.GetNode", &req, &out) - if err != nil { - return false, err - } - if out.Node == nil { - return false, fmt.Errorf("missing reg") - } - return out.Node.Status == structs.NodeStatusReady, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_UpdateAllocStatus(t *testing.T) { - t.Parallel() - s1, _ := testServer(t, nil) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - // Wait til the node is ready - waitTilNodeReady(c1, t) - - job := mock.Job() - alloc := mock.Alloc() - alloc.NodeID = c1.Node().ID - alloc.Job = job - alloc.JobID = job.ID - originalStatus := "foo" - alloc.ClientStatus = originalStatus - - // Insert at zero so they are pulled - state := s1.State() - if err := state.UpsertJob(0, job); err != nil { - t.Fatal(err) - } - if err := state.UpsertJobSummary(100, mock.JobSummary(alloc.JobID)); err != nil { - t.Fatal(err) - } - state.UpsertAllocs(101, []*structs.Allocation{alloc}) - - testutil.WaitForResult(func() (bool, error) { - ws := memdb.NewWatchSet() - out, err := state.AllocByID(ws, alloc.ID) - if err != nil { - return false, err - } - if out == nil { - return false, fmt.Errorf("no such alloc") - } - if out.ClientStatus == originalStatus { - return false, fmt.Errorf("Alloc client status not updated; got %v", out.ClientStatus) - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_WatchAllocs(t *testing.T) { - t.Parallel() - ctestutil.ExecCompatible(t) - s1, _ := testServer(t, nil) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - // Wait til the node is ready - waitTilNodeReady(c1, t) - - // Create mock allocations - job := mock.Job() - alloc1 := mock.Alloc() - alloc1.JobID = job.ID - alloc1.Job = job - alloc1.NodeID = c1.Node().ID - alloc2 := mock.Alloc() - alloc2.NodeID = c1.Node().ID - alloc2.JobID = job.ID - alloc2.Job = job - - state := s1.State() - if err := state.UpsertJob(100, job); err != nil { - t.Fatal(err) - } - if err := state.UpsertJobSummary(101, mock.JobSummary(alloc1.JobID)); err != nil { - t.Fatal(err) - } - err := state.UpsertAllocs(102, []*structs.Allocation{alloc1, alloc2}) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Both allocations should get registered - testutil.WaitForResult(func() (bool, error) { - c1.allocLock.RLock() - num := len(c1.allocs) - c1.allocLock.RUnlock() - return num == 2, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Delete one allocation - if err := state.DeleteEval(103, nil, []string{alloc1.ID}); err != nil { - t.Fatalf("err: %v", err) - } - - // Update the other allocation. Have to make a copy because the allocs are - // shared in memory in the test and the modify index would be updated in the - // alloc runner. - alloc2_2 := alloc2.Copy() - alloc2_2.DesiredStatus = structs.AllocDesiredStatusStop - if err := state.UpsertAllocs(104, []*structs.Allocation{alloc2_2}); err != nil { - t.Fatalf("err upserting stopped alloc: %v", err) - } - - // One allocation should get GC'd and removed - testutil.WaitForResult(func() (bool, error) { - c1.allocLock.RLock() - num := len(c1.allocs) - c1.allocLock.RUnlock() - return num == 1, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // One allocations should get updated - testutil.WaitForResult(func() (bool, error) { - c1.allocLock.RLock() - ar := c1.allocs[alloc2.ID] - c1.allocLock.RUnlock() - return ar.Alloc().DesiredStatus == structs.AllocDesiredStatusStop, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func waitTilNodeReady(client *Client, t *testing.T) { - testutil.WaitForResult(func() (bool, error) { - n := client.Node() - if n.Status != structs.NodeStatusReady { - return false, fmt.Errorf("node not registered") - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) -} - -func TestClient_SaveRestoreState(t *testing.T) { - t.Parallel() - ctestutil.ExecCompatible(t) - s1, _ := testServer(t, nil) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.DevMode = false - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - // Wait til the node is ready - waitTilNodeReady(c1, t) - - // Create mock allocations - job := mock.Job() - alloc1 := mock.Alloc() - alloc1.NodeID = c1.Node().ID - alloc1.Job = job - alloc1.JobID = job.ID - alloc1.Job.TaskGroups[0].Tasks[0].Driver = "mock_driver" - task := alloc1.Job.TaskGroups[0].Tasks[0] - task.Config["run_for"] = "10s" - - state := s1.State() - if err := state.UpsertJob(100, job); err != nil { - t.Fatal(err) - } - if err := state.UpsertJobSummary(101, mock.JobSummary(alloc1.JobID)); err != nil { - t.Fatal(err) - } - if err := state.UpsertAllocs(102, []*structs.Allocation{alloc1}); err != nil { - t.Fatalf("err: %v", err) - } - - // Allocations should get registered - testutil.WaitForResult(func() (bool, error) { - c1.allocLock.RLock() - ar := c1.allocs[alloc1.ID] - c1.allocLock.RUnlock() - if ar == nil { - return false, fmt.Errorf("nil alloc runner") - } - if ar.Alloc().ClientStatus != structs.AllocClientStatusRunning { - return false, fmt.Errorf("client status: got %v; want %v", ar.Alloc().ClientStatus, structs.AllocClientStatusRunning) - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Shutdown the client, saves state - if err := c1.Shutdown(); err != nil { - t.Fatalf("err: %v", err) - } - - // Create a new client - logger := log.New(c1.config.LogOutput, "", log.LstdFlags) - catalog := consul.NewMockCatalog(logger) - mockService := newMockConsulServiceClient() - mockService.logger = logger - c2, err := NewClient(c1.config, catalog, mockService, logger) - if err != nil { - t.Fatalf("err: %v", err) - } - defer c2.Shutdown() - - // Ensure the allocation is running - testutil.WaitForResult(func() (bool, error) { - c2.allocLock.RLock() - ar := c2.allocs[alloc1.ID] - c2.allocLock.RUnlock() - status := ar.Alloc().ClientStatus - alive := status == structs.AllocClientStatusRunning || status == structs.AllocClientStatusPending - if !alive { - return false, fmt.Errorf("incorrect client status: %#v", ar.Alloc()) - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Destroy all the allocations - for _, ar := range c2.getAllocRunners() { - ar.Destroy() - } - - for _, ar := range c2.getAllocRunners() { - <-ar.WaitCh() - } -} - -func TestClient_Init(t *testing.T) { - t.Parallel() - dir, err := ioutil.TempDir("", "nomad") - if err != nil { - t.Fatalf("err: %s", err) - } - defer os.RemoveAll(dir) - allocDir := filepath.Join(dir, "alloc") - - client := &Client{ - config: &config.Config{ - AllocDir: allocDir, - }, - logger: log.New(os.Stderr, "", log.LstdFlags), - } - if err := client.init(); err != nil { - t.Fatalf("err: %s", err) - } - - if _, err := os.Stat(allocDir); err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestClient_BlockedAllocations(t *testing.T) { - t.Parallel() - s1, _ := testServer(t, nil) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - c1 := TestClient(t, func(c *config.Config) { - c.RPCHandler = s1 - }) - defer c1.Shutdown() - - // Wait for the node to be ready - state := s1.State() - testutil.WaitForResult(func() (bool, error) { - ws := memdb.NewWatchSet() - out, err := state.NodeByID(ws, c1.Node().ID) - if err != nil { - return false, err - } - if out == nil || out.Status != structs.NodeStatusReady { - return false, fmt.Errorf("bad node: %#v", out) - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Add an allocation - alloc := mock.Alloc() - alloc.NodeID = c1.Node().ID - alloc.Job.TaskGroups[0].Tasks[0].Driver = "mock_driver" - alloc.Job.TaskGroups[0].Tasks[0].Config = map[string]interface{}{ - "kill_after": "1s", - "run_for": "100s", - "exit_code": 0, - "exit_signal": 0, - "exit_err": "", - } - - state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID)) - state.UpsertAllocs(100, []*structs.Allocation{alloc}) - - // Wait until the client downloads and starts the allocation - testutil.WaitForResult(func() (bool, error) { - ws := memdb.NewWatchSet() - out, err := state.AllocByID(ws, alloc.ID) - if err != nil { - return false, err - } - if out == nil || out.ClientStatus != structs.AllocClientStatusRunning { - return false, fmt.Errorf("bad alloc: %#v", out) - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Add a new chained alloc - alloc2 := alloc.Copy() - alloc2.ID = uuid.Generate() - alloc2.Job = alloc.Job - alloc2.JobID = alloc.JobID - alloc2.PreviousAllocation = alloc.ID - if err := state.UpsertAllocs(200, []*structs.Allocation{alloc2}); err != nil { - t.Fatalf("err: %v", err) - } - - // Enusre that the chained allocation is being tracked as blocked - testutil.WaitForResult(func() (bool, error) { - ar := c1.getAllocRunners()[alloc2.ID] - if ar == nil { - return false, fmt.Errorf("alloc 2's alloc runner does not exist") - } - if !ar.IsWaiting() { - return false, fmt.Errorf("alloc 2 is not blocked") - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Change the desired state of the parent alloc to stop - alloc1 := alloc.Copy() - alloc1.DesiredStatus = structs.AllocDesiredStatusStop - if err := state.UpsertAllocs(300, []*structs.Allocation{alloc1}); err != nil { - t.Fatalf("err: %v", err) - } - - // Ensure that there are no blocked allocations - testutil.WaitForResult(func() (bool, error) { - for id, ar := range c1.getAllocRunners() { - if ar.IsWaiting() { - return false, fmt.Errorf("%q still blocked", id) - } - if ar.IsMigrating() { - return false, fmt.Errorf("%q still migrating", id) - } - } - return true, nil - }, func(err error) { - t.Fatalf("err: %v", err) - }) - - // Destroy all the allocations - for _, ar := range c1.getAllocRunners() { - ar.Destroy() - } - - for _, ar := range c1.getAllocRunners() { - <-ar.WaitCh() - } -} - -func TestClient_ValidateMigrateToken_ValidToken(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - c := TestClient(t, func(c *config.Config) { - c.ACLEnabled = true - }) - defer c.Shutdown() - - alloc := mock.Alloc() - validToken, err := nomad.GenerateMigrateToken(alloc.ID, c.secretNodeID()) - assert.Nil(err) - - assert.Equal(c.ValidateMigrateToken(alloc.ID, validToken), true) -} - -func TestClient_ValidateMigrateToken_InvalidToken(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - c := TestClient(t, func(c *config.Config) { - c.ACLEnabled = true - }) - defer c.Shutdown() - - assert.Equal(c.ValidateMigrateToken("", ""), false) - - alloc := mock.Alloc() - assert.Equal(c.ValidateMigrateToken(alloc.ID, alloc.ID), false) - assert.Equal(c.ValidateMigrateToken(alloc.ID, ""), false) -} - -func TestClient_ValidateMigrateToken_ACLDisabled(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - c := TestClient(t, func(c *config.Config) {}) - defer c.Shutdown() - - assert.Equal(c.ValidateMigrateToken("", ""), true) -} - -func TestClient_ReloadTLS_UpgradePlaintextToTLS(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - s1, addr := testServer(t, func(c *nomad.Config) { - c.Region = "regionFoo" - }) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - const ( - cafile = "../helper/tlsutil/testdata/ca.pem" - foocert = "../helper/tlsutil/testdata/nomad-foo.pem" - fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" - ) - - c1 := testClient(t, func(c *config.Config) { - c.Servers = []string{addr} - }) - defer c1.Shutdown() - - // Registering a node over plaintext should succeed - { - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "regionFoo"}, - } - - testutil.WaitForResult(func() (bool, error) { - var out structs.SingleNodeResponse - err := c1.RPC("Node.GetNode", &req, &out) - if err != nil { - return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) - } - - newConfig := &nconfig.TLSConfig{ - EnableHTTP: true, - EnableRPC: true, - VerifyServerHostname: true, - CAFile: cafile, - CertFile: foocert, - KeyFile: fookey, - } - - err := c1.reloadTLSConnections(newConfig) - assert.Nil(err) - - // Registering a node over plaintext should fail after the node has upgraded - // to TLS - { - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "regionFoo"}, - } - testutil.WaitForResult(func() (bool, error) { - var out structs.SingleNodeResponse - err := c1.RPC("Node.GetNode", &req, &out) - if err == nil { - return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", err) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) - } -} - -func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - s1, addr := testServer(t, func(c *nomad.Config) { - c.Region = "regionFoo" - }) - defer s1.Shutdown() - testutil.WaitForLeader(t, s1.RPC) - - const ( - cafile = "../helper/tlsutil/testdata/ca.pem" - foocert = "../helper/tlsutil/testdata/nomad-foo.pem" - fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" - ) - - c1 := testClient(t, func(c *config.Config) { - c.Servers = []string{addr} - c.TLSConfig = &nconfig.TLSConfig{ - EnableHTTP: true, - EnableRPC: true, - VerifyServerHostname: true, - CAFile: cafile, - CertFile: foocert, - KeyFile: fookey, - } - }) - defer c1.Shutdown() - - // assert that when one node is running in encrypted mode, a RPC request to a - // node running in plaintext mode should fail - { - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "regionFoo"}, - } - testutil.WaitForResult(func() (bool, error) { - var out structs.SingleNodeResponse - err := c1.RPC("Node.GetNode", &req, &out) - if err == nil { - return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) - } - - newConfig := &nconfig.TLSConfig{} - - err := c1.reloadTLSConnections(newConfig) - assert.Nil(err) - - // assert that when both nodes are in plaintext mode, a RPC request should - // succeed - { - req := structs.NodeSpecificRequest{ - NodeID: c1.Node().ID, - QueryOptions: structs.QueryOptions{Region: "regionFoo"}, - } - testutil.WaitForResult(func() (bool, error) { - var out structs.SingleNodeResponse - err := c1.RPC("Node.GetNode", &req, &out) - if err != nil { - return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err) - } - return true, nil - }, - func(err error) { - t.Fatalf(err.Error()) - }, - ) - } -} diff --git a/client/driver/mock_driver.go b/client/driver/mock_driver.go index 07c262852..29d6a4a9d 100644 --- a/client/driver/mock_driver.go +++ b/client/driver/mock_driver.go @@ -1,5 +1,3 @@ -//+build nomad_test - package driver import ( @@ -34,11 +32,6 @@ const ( ShutdownPeriodicDuration = "test.shutdown_periodic_duration" ) -// Add the mock driver to the list of builtin drivers -func init() { - BuiltinDrivers["mock_driver"] = NewMockDriver -} - // MockDriverConfig is the driver configuration for the MockDriver type MockDriverConfig struct { diff --git a/client/driver/mock_driver_testing.go b/client/driver/mock_driver_testing.go new file mode 100644 index 000000000..1b1e861a8 --- /dev/null +++ b/client/driver/mock_driver_testing.go @@ -0,0 +1,8 @@ +//+build nomad_test + +package driver + +// Add the mock driver to the list of builtin drivers +func init() { + BuiltinDrivers["mock_driver"] = NewMockDriver +} diff --git a/client/rpc.go b/client/rpc.go index 1fe52288b..90a1eec47 100644 --- a/client/rpc.go +++ b/client/rpc.go @@ -151,23 +151,26 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co tcp.SetNoDelay(true) } - // TODO TLS // Check if TLS is enabled - //if p.tlsWrap != nil { - //// Switch the connection into TLS mode - //if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { - //conn.Close() - //return nil, err - //} + c.tlsWrapLock.RLock() + tlsWrap := c.tlsWrap + c.tlsWrapLock.RUnlock() - //// Wrap the connection in a TLS client - //tlsConn, err := p.tlsWrap(region, conn) - //if err != nil { - //conn.Close() - //return nil, err - //} - //conn = tlsConn - //} + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(c.Region(), conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } // Write the multiplex byte to set the mode if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { diff --git a/client/rpc_test.go b/client/rpc_test.go index 09984a3b6..c25033923 100644 --- a/client/rpc_test.go +++ b/client/rpc_test.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" + sconfig "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" ) @@ -45,5 +46,70 @@ func TestRpc_streamingRpcConn_badEndpoint(t *testing.T) { conn, err := c.streamingRpcConn(server, "Bogus") require.Nil(conn) require.NotNil(err) - require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") +} + +func TestRpc_streamingRpcConn_badEndpoint_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + s1 := nomad.TestServer(t, func(c *nomad.Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 1 + c.DevDisableBootstrap = true + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Region = "regionFoo" + c.Servers = []string{s1.GetConfig().RPCAddr.String()} + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer c.Shutdown() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s1.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, errors.New("no node") + } + + return node.Status == structs.NodeStatusReady, errors.New("wrong status") + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Get the server + server := c.servers.FindServer() + require.NotNil(server) + + conn, err := c.streamingRpcConn(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") } diff --git a/client/servers/manager.go b/client/servers/manager.go index 604b109a2..6dac0c7e4 100644 --- a/client/servers/manager.go +++ b/client/servers/manager.go @@ -98,18 +98,6 @@ func (s Servers) cycle() { s[numServers-1] = start } -// removeServerByKey performs an inline removal of the first matching server -func (s Servers) removeServerByKey(targetKey string) { - for i, srv := range s { - if targetKey == srv.String() { - copy(s[i:], s[i+1:]) - s[len(s)-1] = nil - s = s[:len(s)-1] - return - } - } -} - // shuffle shuffles the server list in place func (s Servers) shuffle() { for i := len(s) - 1; i > 0; i-- { diff --git a/helper/testlog/testlog.go b/helper/testlog/testlog.go index 7f6c6cb04..b72fcfb28 100644 --- a/helper/testlog/testlog.go +++ b/helper/testlog/testlog.go @@ -42,5 +42,5 @@ func WithPrefix(t LogPrinter, prefix string) *log.Logger { // NewLog logger with "TEST" prefix and the Lmicroseconds flag. func Logger(t LogPrinter) *log.Logger { - return WithPrefix(t, "TEST ") + return WithPrefix(t, "") } diff --git a/nomad/client_rpc_test.go b/nomad/client_rpc_test.go index fd6b35902..c64eecec0 100644 --- a/nomad/client_rpc_test.go +++ b/nomad/client_rpc_test.go @@ -278,5 +278,6 @@ func TestNodeStreamingRpc_badEndpoint(t *testing.T) { conn, err := NodeStreamingRpc(state.Session, "Bogus") require.Nil(conn) require.NotNil(err) - require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) } diff --git a/nomad/rpc.go b/nomad/rpc.go index 159c1f272..537e73d9d 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -172,7 +172,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCConte s.handleStreamingConn(conn) case pool.RpcMultiplexV2: - s.handleMultiplexV2(conn, ctx) + s.handleMultiplexV2(ctx, conn, rpcCtx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -286,11 +286,11 @@ func (s *Server) handleStreamingConn(conn net.Conn) { // handleMultiplexV2 is used to multiplex a single incoming connection // using the Yamux multiplexer. Version 2 handling allows a single connection to // switch streams between regulars RPCs and Streaming RPCs. -func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { +func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { defer func() { // Remove any potential mapping between a NodeID to this connection and // close the underlying connection. - s.removeNodeConn(ctx) + s.removeNodeConn(rpcCtx) conn.Close() }() @@ -303,11 +303,11 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { } // Update the context to store the yamux session - ctx.Session = server + rpcCtx.Session = server // Create the RPC server for this connection rpcServer := rpc.NewServer() - s.setupRpcServer(rpcServer, ctx) + s.setupRpcServer(rpcServer, rpcCtx) for { // Accept a new stream @@ -331,7 +331,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { // Determine which handler to use switch pool.RPCType(buf[0]) { case pool.RpcNomad: - go s.handleNomadConn(sub, rpcServer) + go s.handleNomadConn(ctx, sub, rpcServer) case pool.RpcStreaming: go s.handleStreamingConn(sub) @@ -476,7 +476,7 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err tcp.SetNoDelay(true) } - if err := s.streamingRpcImpl(conn, method); err != nil { + if err := s.streamingRpcImpl(conn, server.Region, method); err != nil { return nil, err } @@ -487,24 +487,27 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err // the handshake to establish a streaming RPC for the given method. If an error // is returned, the underlying connection has been closed. Otherwise it is // assumed that the connection has been hijacked by the RPC method. -func (s *Server) streamingRpcImpl(conn net.Conn, method string) error { - // TODO TLS +func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error { // Check if TLS is enabled - //if p.tlsWrap != nil { - //// Switch the connection into TLS mode - //if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { - //conn.Close() - //return nil, err - //} + s.tlsWrapLock.RLock() + tlsWrap := s.tlsWrap + s.tlsWrapLock.RUnlock() - //// Wrap the connection in a TLS client - //tlsConn, err := p.tlsWrap(region, conn) - //if err != nil { - //conn.Close() - //return nil, err - //} - //conn = tlsConn - //} + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(region, conn) + if err != nil { + conn.Close() + return err + } + conn = tlsConn + } // Write the multiplex byte to set the mode if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index eb85af57e..c876c6adb 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "net" "net/rpc" "os" @@ -201,7 +202,69 @@ func TestRPC_streamingRpcConn_badMethod(t *testing.T) { conn, err := s1.streamingRpc(server, "Bogus") require.Nil(conn) require.NotNil(err) - require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) +} + +func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + + s2 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + + s1.peerLock.RLock() + ok, parts := isNomadServer(s2.LocalMember()) + require.True(ok) + server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] + require.NotNil(server) + s1.peerLock.RUnlock() + + conn, err := s1.streamingRpc(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) } // COMPAT: Remove in 0.10 @@ -224,7 +287,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) { // Start the handler doneCh := make(chan struct{}) go func() { - s.handleConn(p2, &RPCContext{Conn: p2}) + s.handleConn(context.Background(), p2, &RPCContext{Conn: p2}) close(doneCh) }() @@ -257,8 +320,9 @@ func TestRPC_handleMultiplexV2(t *testing.T) { require.NotEmpty(l) // Make a streaming RPC - err = s.streamingRpcImpl(s2, "Bogus") + err = s.streamingRpcImpl(s2, s.Region(), "Bogus") require.NotNil(err) - require.Contains(err.Error(), "unknown rpc") + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) } diff --git a/nomad/server.go b/nomad/server.go index 5b68a85ac..a70c2f890 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -112,6 +112,11 @@ type Server struct { rpcListener net.Listener listenerCh chan struct{} + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + // rpcServer is the static RPC server that is used by the local agent. rpcServer *rpc.Server @@ -276,6 +281,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg consulCatalog: consulCatalog, connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), logger: logger, + tlsWrap: tlsWrap, rpcServer: rpc.NewServer(), streamingRpcs: structs.NewStreamingRpcRegistery(), nodeConns: make(map[string]*nodeConnState), @@ -435,6 +441,11 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { return err } + // Store the new tls wrapper. + s.tlsWrapLock.Lock() + s.tlsWrap = tlsWrap + s.tlsWrapLock.Unlock() + if s.rpcCancel == nil { err = fmt.Errorf("No existing RPC server to reset.") s.logger.Printf("[ERR] nomad: %s", err) diff --git a/nomad/structs/streaming_rpc.go b/nomad/structs/streaming_rpc.go index 949a31e23..6172c05e6 100644 --- a/nomad/structs/streaming_rpc.go +++ b/nomad/structs/streaming_rpc.go @@ -16,7 +16,7 @@ type StreamingRpcHeader struct { // StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and // routing to the requirested handler. type StreamingRpcAck struct { - // Error is used to return whether an error occured establishing the + // Error is used to return whether an error occurred establishing the // streaming RPC. This error occurs before entering the RPC handler. Error string } diff --git a/nomad/testing.go b/nomad/testing.go index 6111ea596..2859dfb63 100644 --- a/nomad/testing.go +++ b/nomad/testing.go @@ -38,7 +38,7 @@ func TestACLServer(t testing.T, cb func(*Config)) (*Server, *structs.ACLToken) { func TestServer(t testing.T, cb func(*Config)) *Server { // Setup the default settings config := DefaultConfig() - config.Build = "0.7.0+unittest" + config.Build = "0.8.0+unittest" config.DevMode = true nodeNum := atomic.AddUint32(&nodeNumber, 1) config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum) @@ -64,6 +64,11 @@ func TestServer(t testing.T, cb func(*Config)) *Server { // Squelch output when -v isn't specified config.LogOutput = testlog.NewWriter(t) + // Tighten the autopilot timing + config.AutopilotConfig.ServerStabilizationTime = 100 * time.Millisecond + config.ServerHealthInterval = 50 * time.Millisecond + config.AutopilotInterval = 100 * time.Millisecond + // Invoke the callback if any if cb != nil { cb(config)