open-nomad/nomad/rpc_test.go
Luiz Aoqui 0e09b120e4
fix mTLS certificate check on agent to agent RPCs (#11998)
PR #11956 implemented a new mTLS RPC check to validate the role of the
certificate used in the request, but further testing revealed two flaws:

  1. client-only endpoints did not accept server certificates so the
     request would fail when forwarded from one server to another.
  2. the certificate was being checked after the request was forwarded,
     so the check would happen over the server certificate, not the
     actual source.

This commit checks for the desired mTLS level, where the client level
accepts both, a server or a client certificate. It also validates the
cercertificate before the request is forwarded.
2022-02-04 20:35:20 -05:00

1489 lines
38 KiB
Go

package nomad
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/rpc"
"os"
"path"
"path/filepath"
"testing"
"time"
"github.com/hashicorp/go-msgpack/codec"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pool"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/tlsutil"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
"github.com/hashicorp/nomad/testutil"
"github.com/hashicorp/raft"
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// rpcClient is a test helper method to return a ClientCodec to use to make rpc
// calls to the passed server.
func rpcClient(t *testing.T, s *Server) rpc.ClientCodec {
addr := s.config.RPCAddr
conn, err := net.DialTimeout("tcp", addr.String(), time.Second)
if err != nil {
t.Fatalf("err: %v", err)
}
// Write the Nomad RPC byte to set the mode
conn.Write([]byte{byte(pool.RpcNomad)})
return pool.NewClientCodec(conn)
}
func TestRPC_forwardLeader(t *testing.T) {
t.Parallel()
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
})
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
isLeader, remote := s1.getLeader()
if !isLeader && remote == nil {
t.Fatalf("missing leader")
}
if remote != nil {
var out struct{}
err := s1.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
if err != nil {
t.Fatalf("err: %v", err)
}
}
isLeader, remote = s2.getLeader()
if !isLeader && remote == nil {
t.Fatalf("missing leader")
}
if remote != nil {
var out struct{}
err := s2.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
if err != nil {
t.Fatalf("err: %v", err)
}
}
}
func TestRPC_WaitForConsistentReads(t *testing.T) {
t.Parallel()
s1, cleanupS2 := TestServer(t, func(c *Config) {
c.RPCHoldTimeout = 20 * time.Millisecond
})
defer cleanupS2()
testutil.WaitForLeader(t, s1.RPC)
isLeader, _ := s1.getLeader()
require.True(t, isLeader)
require.True(t, s1.isReadyForConsistentReads())
s1.resetConsistentReadReady()
require.False(t, s1.isReadyForConsistentReads())
codec := rpcClient(t, s1)
get := &structs.JobListRequest{
QueryOptions: structs.QueryOptions{
Region: "global",
Namespace: "default",
},
}
// check timeout while waiting for consistency
var resp structs.JobListResponse
err := msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp)
require.Error(t, err)
require.Contains(t, err.Error(), structs.ErrNotReadyForConsistentReads.Error())
// check we wait and block
go func() {
time.Sleep(5 * time.Millisecond)
s1.setConsistentReadReady()
}()
err = msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp)
require.NoError(t, err)
}
func TestRPC_forwardRegion(t *testing.T) {
t.Parallel()
s1, cleanupS1 := TestServer(t, nil)
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.Region = "global"
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
var out struct{}
err := s1.forwardRegion("global", "Status.Ping", struct{}{}, &out)
if err != nil {
t.Fatalf("err: %v", err)
}
err = s2.forwardRegion("global", "Status.Ping", struct{}{}, &out)
if err != nil {
t.Fatalf("err: %v", err)
}
}
func TestRPC_getServer(t *testing.T) {
t.Parallel()
s1, cleanupS1 := TestServer(t, nil)
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.Region = "global"
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
// Lookup by name
srv, err := s1.getServer("global", s2.serf.LocalMember().Name)
require.NoError(t, err)
require.Equal(t, srv.Name, s2.serf.LocalMember().Name)
// Lookup by id
srv, err = s2.getServer("global", s1.serf.LocalMember().Tags["id"])
require.NoError(t, err)
require.Equal(t, srv.Name, s1.serf.LocalMember().Name)
}
func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) {
t.Parallel()
assert := assert.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
RPCUpgradeMode: true,
}
})
defer cleanupS1()
codec := rpcClient(t, s1)
// Create the register request
node := mock.Node()
req := &structs.NodeRegisterRequest{
Node: node,
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
assert.Nil(err)
// Check that heartbeatTimers has the heartbeat ID
_, ok := s1.heartbeatTimers[node.ID]
assert.True(ok)
}
func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) {
t.Parallel()
assert := assert.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer cleanupS1()
codec := rpcClient(t, s1)
node := mock.Node()
req := &structs.NodeRegisterRequest{
Node: node,
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
assert.NotNil(err)
}
func TestRPC_streamingRpcConn_badMethod(t *testing.T) {
t.Parallel()
require := require.New(t)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
})
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
testutil.WaitForLeader(t, s2.RPC)
s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()
conn, err := s1.streamingRpc(server, "Bogus")
require.Nil(conn)
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
}
func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node2")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()
conn, err := s1.streamingRpc(server, "Bogus")
require.Nil(conn)
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
}
func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) {
t.Parallel()
require := require.New(t)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node1")
})
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node2")
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()
conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))
var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}
func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1, cleanupS1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer cleanupS1()
s2, cleanupS2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "node2")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer cleanupS2()
TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)
s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()
conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))
var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}
// COMPAT: Remove in 0.10
// This is a very low level test to assert that the V2 handling works. It is
// making manual RPC calls since no helpers exist at this point since we are
// only implementing support for v2 but not using it yet. In the future we can
// switch the conn pool to establishing v2 connections and we can deprecate this
// test.
func TestRPC_handleMultiplexV2(t *testing.T) {
t.Parallel()
require := require.New(t)
s, cleanupS := TestServer(t, nil)
defer cleanupS()
testutil.WaitForLeader(t, s.RPC)
p1, p2 := net.Pipe()
defer p1.Close()
defer p2.Close()
// Start the handler
doneCh := make(chan struct{})
go func() {
s.handleConn(context.Background(), p2, &RPCContext{Conn: p2})
close(doneCh)
}()
// Establish the MultiplexV2 connection
_, err := p1.Write([]byte{byte(pool.RpcMultiplexV2)})
require.Nil(err)
// Make two streams
conf := yamux.DefaultConfig()
conf.LogOutput = nil
conf.Logger = testlog.Logger(t)
session, err := yamux.Client(p1, conf)
require.Nil(err)
s1, err := session.Open()
require.Nil(err)
defer s1.Close()
s2, err := session.Open()
require.Nil(err)
defer s2.Close()
// Make an RPC
_, err = s1.Write([]byte{byte(pool.RpcNomad)})
require.Nil(err)
args := &structs.GenericRequest{}
var l string
err = msgpackrpc.CallWithCodec(pool.NewClientCodec(s1), "Status.Leader", args, &l)
require.Nil(err)
require.NotEmpty(l)
// Make a streaming RPC
_, err = s2.Write([]byte{byte(pool.RpcStreaming)})
require.Nil(err)
_, err = s.streamingRpcImpl(s2, "Bogus")
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
}
// TestRPC_TLS_in_TLS asserts that trying to nest TLS connections fails.
func TestRPC_TLS_in_TLS(t *testing.T) {
t.Parallel()
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
s, cleanup := TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer func() {
cleanup()
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
require.NoError(t, err)
// Client TLS verification isn't necessary for
// our assertions
tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true)
require.NoError(t, err)
outTLSConf, err := tlsConf.OutgoingTLSConfig()
require.NoError(t, err)
outTLSConf.InsecureSkipVerify = true
// Do initial handshake
tlsConn := tls.Client(conn, outTLSConf)
require.NoError(t, tlsConn.Handshake())
conn = tlsConn
// Try to create a nested TLS connection
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
require.NoError(t, err)
// Attempts at nested TLS connections should cause a disconnect
buf := []byte{0}
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, err := conn.Read(buf)
require.Zero(t, n)
require.Equal(t, io.EOF, err)
}
// TestRPC_Limits_OK asserts that all valid limits combinations
// (tls/timeout/conns) work.
//
// Invalid limits are tested in command/agent/agent_test.go
func TestRPC_Limits_OK(t *testing.T) {
t.Parallel()
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
maxConns = 10 // limit must be < this for testing
)
cases := []struct {
tls bool
timeout time.Duration
limit int
assertTimeout bool
assertLimit bool
}{
{
tls: false,
timeout: 5 * time.Second,
limit: 0,
assertTimeout: true,
assertLimit: false,
},
{
tls: true,
timeout: 5 * time.Second,
limit: 0,
assertTimeout: true,
assertLimit: false,
},
{
tls: false,
timeout: 0,
limit: 0,
assertTimeout: false,
assertLimit: false,
},
{
tls: true,
timeout: 0,
limit: 0,
assertTimeout: false,
assertLimit: false,
},
{
tls: false,
timeout: 0,
limit: 2,
assertTimeout: false,
assertLimit: true,
},
{
tls: true,
timeout: 0,
limit: 2,
assertTimeout: false,
assertLimit: true,
},
{
tls: false,
timeout: 5 * time.Second,
limit: 2,
assertTimeout: true,
assertLimit: true,
},
{
tls: true,
timeout: 5 * time.Second,
limit: 2,
assertTimeout: true,
assertLimit: true,
},
}
assertTimeout := func(t *testing.T, s *Server, useTLS bool, timeout time.Duration) {
// Increase timeout to detect timeouts
clientTimeout := timeout + time.Second
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
require.NoError(t, err)
defer conn.Close()
buf := []byte{0}
readDeadline := time.Now().Add(clientTimeout)
conn.SetReadDeadline(readDeadline)
n, err := conn.Read(buf)
require.Zero(t, n)
if timeout == 0 {
// Server should *not* have timed out.
// Now() should always be after the client read deadline, but
// isn't a sufficient assertion for correctness as slow tests
// may cause this to be true even if the server timed out.
now := time.Now()
require.Truef(t, now.After(readDeadline),
"Client read deadline (%s) should be in the past (before %s)", readDeadline, now)
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
return
}
// Server *should* have timed out (EOF)
require.Equal(t, io.EOF, err)
// Create a new connection to assert timeout doesn't
// apply after first byte.
conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
require.NoError(t, err)
defer conn.Close()
if useTLS {
_, err := conn.Write([]byte{byte(pool.RpcTLS)})
require.NoError(t, err)
// Client TLS verification isn't necessary for
// our assertions
tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true)
require.NoError(t, err)
outTLSConf, err := tlsConf.OutgoingTLSConfig()
require.NoError(t, err)
outTLSConf.InsecureSkipVerify = true
tlsConn := tls.Client(conn, outTLSConf)
require.NoError(t, tlsConn.Handshake())
conn = tlsConn
}
// Writing the Nomad RPC byte should be sufficient to
// disable the handshake timeout
n, err = conn.Write([]byte{byte(pool.RpcNomad)})
require.NoError(t, err)
require.Equal(t, 1, n)
// Read should timeout due to client timeout, not
// server's timeout
readDeadline = time.Now().Add(clientTimeout)
conn.SetReadDeadline(readDeadline)
n, err = conn.Read(buf)
require.Zero(t, n)
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
}
assertNoLimit := func(t *testing.T, addr string) {
var err error
// Create max connections
conns := make([]net.Conn, maxConns)
errCh := make(chan error, maxConns)
for i := 0; i < maxConns; i++ {
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conns[i].Close()
go func(i int) {
buf := []byte{0}
readDeadline := time.Now().Add(1 * time.Second)
conns[i].SetReadDeadline(readDeadline)
n, err := conns[i].Read(buf)
if n > 0 {
errCh <- fmt.Errorf("n > 0: %d", n)
return
}
errCh <- err
}(i)
}
// Now assert each error is a clientside read deadline error
deadline := time.After(10 * time.Second)
for i := 0; i < maxConns; i++ {
select {
case <-deadline:
t.Fatalf("timed out waiting for conn error %d/%d", i+1, maxConns)
case err := <-errCh:
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
}
}
}
assertLimit := func(t *testing.T, addr string, limit int) {
var err error
// Create limit connections
conns := make([]net.Conn, limit)
errCh := make(chan error, limit)
for i := range conns {
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conns[i].Close()
go func(i int) {
buf := []byte{0}
n, err := conns[i].Read(buf)
if n > 0 {
errCh <- fmt.Errorf("n > 0: %d", n)
return
}
errCh <- err
}(i)
}
// Assert a new connection is dropped
conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conn.Close()
buf := []byte{0}
deadline := time.Now().Add(6 * time.Second)
conn.SetReadDeadline(deadline)
n, err := conn.Read(buf)
require.Zero(t, n)
require.Equal(t, io.EOF, err)
// Assert existing connections are ok
ERRCHECK:
select {
case err := <-errCh:
t.Errorf("unexpected error from idle connection: (%T) %v", err, err)
goto ERRCHECK
default:
}
// Cleanup
for _, conn := range conns {
conn.Close()
}
for i := range conns {
select {
case err := <-errCh:
require.Contains(t, err.Error(), "use of closed network connection")
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for connection %d/%d to close", i, len(conns))
}
}
}
for i := range cases {
tc := cases[i]
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
t.Run(name, func(t *testing.T) {
t.Parallel()
if tc.limit >= maxConns {
t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", tc.limit, maxConns)
}
if tc.assertTimeout && tc.timeout == 0 {
t.Fatalf("test fixture failure: cannot assert timeout when no timeout set (0)")
}
s, cleanup := TestServer(t, func(c *Config) {
if tc.tls {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
}
c.RPCHandshakeTimeout = tc.timeout
c.RPCMaxConnsPerClient = tc.limit
})
defer func() {
cleanup()
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
assertTimeout(t, s, tc.tls, tc.timeout)
if tc.assertLimit {
// There's a race between assertTimeout(false) closing
// its connection and the HTTP server noticing and
// untracking it. Since there's no way to coordiante
// when this occurs, sleeping is the only way to avoid
// asserting limits before the timed out connection is
// untracked.
time.Sleep(1 * time.Second)
assertLimit(t, s.config.RPCAddr.String(), tc.limit)
} else {
assertNoLimit(t, s.config.RPCAddr.String())
}
})
}
}
// TestRPC_Limits_Streaming asserts that the streaming RPC limit is lower than
// the overall connection limit to prevent DOS via server-routed streaming API
// calls.
func TestRPC_Limits_Streaming(t *testing.T) {
t.Parallel()
s, cleanup := TestServer(t, func(c *Config) {
limits := config.DefaultLimits()
c.RPCMaxConnsPerClient = *limits.RPCMaxConnsPerClient
})
defer func() {
cleanup()
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
// Create a streaming connection
dialStreamer := func() net.Conn {
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
require.NoError(t, err)
_, err = conn.Write([]byte{byte(pool.RpcStreaming)})
require.NoError(t, err)
return conn
}
// Create up to the limit streaming connections
streamers := make([]net.Conn, s.config.RPCMaxConnsPerClient-config.LimitsNonStreamingConnsPerClient)
for i := range streamers {
streamers[i] = dialStreamer()
go func(i int) {
// Streamer should never die until test exits
buf := []byte{0}
_, err := streamers[i].Read(buf)
if ctx.Err() != nil {
// Error is expected when test finishes
return
}
t.Logf("connection %d died with error: (%T) %v", i, err, err)
// Send unexpected errors back
if err != nil {
select {
case errCh <- err:
case <-ctx.Done():
default:
// Only send first error
}
}
}(i)
}
defer func() {
cancel()
for _, conn := range streamers {
conn.Close()
}
}()
// Assert no streamer errors have occurred
select {
case err := <-errCh:
t.Fatalf("unexpected error from blocking streaming RPCs: (%T) %v", err, err)
case <-time.After(500 * time.Millisecond):
// Ok! No connections were rejected immediately.
}
// Assert subsequent streaming RPC are rejected
conn := dialStreamer()
t.Logf("expect connection to be rejected due to limit")
buf := []byte{0}
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
_, err := conn.Read(buf)
require.Equalf(t, io.EOF, err, "expected io.EOF but found: (%T) %v", err, err)
// Assert no streamer errors have occurred
select {
case err := <-errCh:
t.Fatalf("unexpected error from blocking streaming RPCs: %v", err)
default:
}
// Subsequent non-streaming RPC should be OK
conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
require.NoError(t, err)
_, err = conn.Write([]byte{byte(pool.RpcNomad)})
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
_, err = conn.Read(buf)
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
// Close 1 streamer and assert another is allowed
t.Logf("expect streaming connection 0 to exit with error")
streamers[0].Close()
<-errCh
// Assert that new connections are allowed.
// Due to the distributed nature here, server may not immediately recognize
// the connection closure, so first attempts may be rejections (i.e. EOF)
// but the first non-EOF request must be a read-deadline error
testutil.WaitForResult(func() (bool, error) {
conn = dialStreamer()
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
_, err = conn.Read(buf)
if err == io.EOF {
return false, fmt.Errorf("connection was rejected")
}
require.True(t, errors.Is(err, os.ErrDeadlineExceeded))
return true, nil
}, func(err error) {
require.NoError(t, err)
})
}
func TestRPC_TLS_Enforcement_Raft(t *testing.T) {
t.Parallel()
defer func() {
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
tlsHelper := newTLSTestHelper(t)
defer tlsHelper.cleanup()
// When VerifyServerHostname is enabled:
// Only local servers can connect to the Raft layer
cases := []struct {
name string
cn string
canRaft bool
}{
{
name: "local server",
cn: "server.global.nomad",
canRaft: true,
},
{
name: "local client",
cn: "client.global.nomad",
canRaft: false,
},
{
name: "other region server",
cn: "server.other.nomad",
canRaft: false,
},
{
name: "other region client",
cn: "client.other.nomad",
canRaft: false,
},
{
name: "irrelevant cert",
cn: "nomad.example.com",
canRaft: false,
},
{
name: "globs",
cn: "*.global.nomad",
canRaft: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := tlsHelper.newCert(t, tc.cn)
cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer1, cfg)
// the expected error depends on location of failure.
// We expect "bad certificate" if connection fails during handshake,
// or EOF when connection is closed after RaftRPC byte.
if tc.canRaft {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Regexp(t, "(bad certificate|EOF)", err.Error())
}
})
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
err := tlsHelper.raftRPC(t, tlsHelper.nonVerifyServer, cfg)
require.NoError(t, err)
})
})
}
}
func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
t.Parallel()
defer func() {
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
tlsHelper := newTLSTestHelper(t)
defer tlsHelper.cleanup()
standardRPCs := map[string]interface{}{
"Status.Ping": struct{}{},
}
localServersOnlyRPCs := map[string]interface{}{
"Eval.Update": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Ack": &structs.EvalAckRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Nack": &structs.EvalAckRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Dequeue": &structs.EvalDequeueRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Create": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Reblock": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Reap": &structs.EvalDeleteRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Plan.Submit": &structs.PlanRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Deployment.Reap": &structs.DeploymentDeleteRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
}
localClientsOnlyRPCs := map[string]interface{}{
"Alloc.GetAllocs": &structs.AllocsGetRequest{
QueryOptions: structs.QueryOptions{Region: "global"},
},
"Node.EmitEvents": &structs.EmitNodeEventsRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Node.UpdateAlloc": &structs.AllocUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
}
// When VerifyServerHostname is enabled:
// All servers can make RPC requests
// Only local clients can make RPC requests
// Some endpoints can only be called server -> server
// Some endpoints can only be called client -> server
cases := []struct {
name string
cn string
rpcs map[string]interface{}
canRPC bool
}{
// Local server.
{
name: "local server/standard rpc",
cn: "server.global.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "local server/servers only rpc",
cn: "server.global.nomad",
rpcs: localServersOnlyRPCs,
canRPC: true,
},
{
name: "local server/clients only rpc",
cn: "server.global.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: true,
},
// Local client.
{
name: "local client/standard rpc",
cn: "client.global.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "local client/servers only rpc",
cn: "client.global.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "local client/clients only rpc",
cn: "client.global.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: true,
},
// Other region server.
{
name: "other region server/standard rpc",
cn: "server.other.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "other region server/servers only rpc",
cn: "server.other.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "other region server/clients only rpc",
cn: "server.other.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
},
// Other region client.
{
name: "other region client/standard rpc",
cn: "client.other.nomad",
rpcs: standardRPCs,
canRPC: false,
},
{
name: "other region client/servers only rpc",
cn: "client.other.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "other region client/clients only rpc",
cn: "client.other.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
},
// Wrong certs.
{
name: "irrelevant cert",
cn: "nomad.example.com",
rpcs: standardRPCs,
canRPC: false,
},
{
name: "globs",
cn: "*.global.nomad",
rpcs: standardRPCs,
canRPC: false,
},
{},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := tlsHelper.newCert(t, tc.cn)
cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}
for method, arg := range tc.rpcs {
for _, srv := range []*Server{tlsHelper.mtlsServer1, tlsHelper.mtlsServer2} {
name := fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true leader=%v", method, srv.IsLeader())
t.Run(name, func(t *testing.T) {
err := tlsHelper.nomadRPC(t, srv, cfg, method, arg)
if tc.canRPC {
if err != nil {
require.NotContains(t, err, "certificate")
}
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "certificate")
}
})
}
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) {
err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg)
if err != nil {
require.NotContains(t, err, "certificate")
}
})
}
})
}
}
type tlsTestHelper struct {
dir string
nodeID int
mtlsServer1 *Server
mtlsServer1Cleanup func()
mtlsServer2 *Server
mtlsServer2Cleanup func()
nonVerifyServer *Server
nonVerifyServerCleanup func()
caPEM string
pk string
serverCert string
}
func newTLSTestHelper(t *testing.T) tlsTestHelper {
var err error
h := tlsTestHelper{
dir: tmpDir(t),
nodeID: 1,
}
// Generate CA certificate and write it to disk.
h.caPEM, h.pk, err = tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(h.dir, "ca.pem"), []byte(h.caPEM), 0600)
require.NoError(t, err)
// Generate servers and their certificate.
h.serverCert = h.newCert(t, "server.global.nomad")
h.mtlsServer1, h.mtlsServer1Cleanup = TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
h.mtlsServer2, h.mtlsServer2Cleanup = TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
TestJoin(t, h.mtlsServer1, h.mtlsServer2)
testutil.WaitForLeader(t, h.mtlsServer1.RPC)
testutil.WaitForLeader(t, h.mtlsServer2.RPC)
h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: false,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
return h
}
func (h tlsTestHelper) cleanup() {
h.mtlsServer1Cleanup()
h.mtlsServer2Cleanup()
h.nonVerifyServerCleanup()
os.RemoveAll(h.dir)
}
func (h tlsTestHelper) newCert(t *testing.T, name string) string {
t.Helper()
node := fmt.Sprintf("node%d", h.nodeID)
h.nodeID++
signer, err := tlsutil.ParseSigner(h.pk)
require.NoError(t, err)
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: h.caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
})
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)
return filepath.Join(h.dir, node+"-"+name)
}
func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
require.NoError(t, err)
// configure TLS
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
require.NoError(t, err)
// Client TLS verification isn't necessary for
// our assertions
tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true)
require.NoError(t, err)
outTLSConf, err := tlsConf.OutgoingTLSConfig()
require.NoError(t, err)
outTLSConf.InsecureSkipVerify = true
tlsConn := tls.Client(conn, outTLSConf)
require.NoError(t, tlsConn.Handshake())
return tlsConn
}
func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error {
conn := h.connect(t, s, c)
defer conn.Close()
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
require.NoError(t, err)
codec := pool.NewClientCodec(conn)
var out struct{}
return msgpackrpc.CallWithCodec(codec, method, arg, &out)
}
func (h tlsTestHelper) raftRPC(t *testing.T, s *Server, c *config.TLSConfig) error {
conn := h.connect(t, s, c)
defer conn.Close()
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
require.NoError(t, err)
_, err = doRaftRPC(conn, s.config.NodeName)
return err
}
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {
req := raft.AppendEntriesRequest{
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
Term: 0,
Leader: []byte(leader),
PrevLogEntry: 0,
PrevLogTerm: 0xc,
LeaderCommitIndex: 50,
}
enc := codec.NewEncoder(conn, &codec.MsgpackHandle{})
dec := codec.NewDecoder(conn, &codec.MsgpackHandle{})
const rpcAppendEntries = 0
if _, err := conn.Write([]byte{rpcAppendEntries}); err != nil {
return nil, fmt.Errorf("failed to write raft-RPC byte: %w", err)
}
if err := enc.Encode(req); err != nil {
return nil, fmt.Errorf("failed to send append entries RPC: %w", err)
}
var rpcError string
var resp raft.AppendEntriesResponse
if err := dec.Decode(&rpcError); err != nil {
return nil, fmt.Errorf("failed to decode response error: %w", err)
}
if rpcError != "" {
return nil, fmt.Errorf("rpc error: %v", rpcError)
}
if err := dec.Decode(&resp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &resp, nil
}