Verify TLS certificate on endpoints that are used between agents only (#11956)

This commit is contained in:
Luiz Aoqui 2022-02-02 15:03:18 -05:00 committed by GitHub
parent f6217fe424
commit c4cff5359f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 576 additions and 177 deletions

3
.changelog/11956.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:security
server: validate mTLS certificate names on agent to agent endpoints
```

80
.semgrep/rpc_endpoint.yml Normal file
View File

@ -0,0 +1,80 @@
rules:
# Check potentially unauthenticated RPC endpoints
- id: "rpc-potentially-unauthenticated"
patterns:
- pattern: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := $X.$Y.ResolveToken(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := $U.requestACLToken(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := $T.NamespaceValidator(...)
...
# Pattern used by endpoints called exclusively between agents
# (server -> server or client -> server)
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateLocalClientTLSCertificate(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateLocalServerTLSCertificate(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateTLSCertificate(...)
...
# Pattern used by some Node endpoints.
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
return $A.deregister(...)
...
- metavariable-pattern:
metavariable: $METHOD
patterns:
# Endpoints that are expected not to have authentication.
- pattern-not: '"ACL.Bootstrap"'
- pattern-not: '"ACL.ResolveToken"'
- pattern-not: '"ACL.UpsertOneTimeToken"'
- pattern-not: '"ACL.ExchangeOneTimeToken"'
- pattern-not: '"CSIPlugin.Get"'
- pattern-not: '"CSIPlugin.List"'
- pattern-not: '"Status.Leader"'
- pattern-not: '"Status.Peers"'
- pattern-not: '"Status.Version"'
message: "RPC method $METHOD appears to be unauthenticated"
languages:
- "go"
severity: "WARNING"
paths:
include:
- "*_endpoint.go"

View File

@ -20,6 +20,9 @@ import (
type Alloc struct {
srv *Server
logger log.Logger
// ctx provides context regarding the underlying connection
ctx *RPCContext
}
// List is used to list the allocations in the system
@ -224,6 +227,11 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest,
}
defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err)
}
allocs := make([]*structs.Allocation, len(args.AllocIDs))
// Setup the blocking query. We wait for at least one of the requested

View File

@ -17,6 +17,9 @@ import (
type Deployment struct {
srv *Server
logger log.Logger
// ctx provides context regarding the underlying connection
ctx *RPCContext
}
// GetDeployment is used to request information about a specific deployment
@ -506,6 +509,11 @@ func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest,
}
defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err)
}
// Update via Raft
_, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args)
if err != nil {

View File

@ -24,6 +24,9 @@ const (
type Eval struct {
srv *Server
logger log.Logger
// ctx provides context regarding the underlying connection
ctx *RPCContext
}
// GetEval is used to request information about a specific evaluation
@ -87,6 +90,11 @@ func (e *Eval) Dequeue(args *structs.EvalDequeueRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is at least one scheduler
if len(args.Schedulers) == 0 {
return fmt.Errorf("dequeue requires at least one scheduler type")
@ -172,6 +180,11 @@ func (e *Eval) Ack(args *structs.EvalAckRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ack the EvalID
if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil {
return err
@ -187,6 +200,11 @@ func (e *Eval) Nack(args *structs.EvalAckRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Nack the EvalID
if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil {
return err
@ -202,6 +220,11 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be updated")
@ -232,6 +255,11 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be created")
@ -277,6 +305,11 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe
}
defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be reblocked")
@ -319,6 +352,11 @@ func (e *Eval) Reap(args *structs.EvalDeleteRequest,
}
defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Update via Raft
_, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args)
if err != nil {

View File

@ -114,8 +114,8 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
reply.Warnings = structs.MergeMultierrorWarnings(warnings...)
// Check job submission permissions
var aclObj *acl.ACL
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
aclObj, err := j.srv.ResolveToken(args.AuthToken)
if err != nil {
return err
} else if aclObj != nil {
if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilitySubmitJob) {
@ -1879,9 +1879,8 @@ func (j *Job) Dispatch(args *structs.JobDispatchRequest, reply *structs.JobDispa
defer metrics.MeasureSince([]string{"nomad", "job", "dispatch"}, time.Now())
// Check for submit-job permissions
var aclObj *acl.ACL
var err error
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
aclObj, err := j.srv.ResolveToken(args.AuthToken)
if err != nil {
return err
} else if aclObj != nil && !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityDispatchJob) {
return structs.ErrPermissionDenied

View File

@ -1103,6 +1103,11 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene
}
defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
}
// Ensure at least a single alloc
if len(args.Alloc) == 0 {
return fmt.Errorf("must update at least one allocation")
@ -1920,6 +1925,11 @@ func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.Em
}
defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
}
if len(args.NodeEvents) == 0 {
return fmt.Errorf("no node events given")
}

View File

@ -14,6 +14,9 @@ import (
type Plan struct {
srv *Server
logger log.Logger
// ctx provides context regarding the underlying connection
ctx *RPCContext
}
// Submit is used to submit a plan to the leader
@ -23,6 +26,11 @@ func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) er
}
defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err)
}
if args.Plan == nil {
return fmt.Errorf("cannot submit nil plan")
}

View File

@ -107,6 +107,38 @@ type RPCContext struct {
NodeID string
}
// Certificate returns the first certificate available in the chain.
func (ctx *RPCContext) Certificate() *x509.Certificate {
if ctx == nil || len(ctx.VerifiedChains) == 0 || len(ctx.VerifiedChains[0]) == 0 {
return nil
}
return ctx.VerifiedChains[0][0]
}
// ValidateCertificateForName returns true if the RPC context certificate is valid
// for the given domain name.
func (ctx *RPCContext) ValidateCertificateForName(name string) error {
if ctx == nil || !ctx.TLS {
return nil
}
cert := ctx.Certificate()
if cert == nil {
return errors.New("missing certificate information")
}
for _, dnsName := range cert.DNSNames {
if dnsName == name {
return nil
}
}
if cert.Subject.CommonName == name {
return nil
}
return fmt.Errorf("certificate not valid for %q", name)
}
// listen is used to listen for incoming RPC connections
func (r *rpcHandler) listen(ctx context.Context) {
defer close(r.listenerCh)
@ -838,30 +870,18 @@ func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error {
return nil
}
// defensive conditions: these should have already been enforced by handleConn
if rpcCtx == nil || !rpcCtx.TLS {
return errors.New("non-TLS connection attempted")
}
if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 {
// this should never happen, as rpcNameAndRegionValidate should have enforced it
return errors.New("missing cert info")
}
// check that `server.<region>.nomad` is present in cert
expected := "server." + r.Region() + ".nomad"
cert := rpcCtx.VerifiedChains[0][0]
for _, dnsName := range cert.DNSNames {
if dnsName == expected {
// Certificate is valid for the expected name
return nil
}
}
if cert.Subject.CommonName == expected {
// Certificate is valid for the expected name
return nil
err := rpcCtx.ValidateCertificateForName(expected)
if err != nil {
cert := rpcCtx.Certificate()
if cert != nil {
err = fmt.Errorf("request certificate is only valid for %s: %v", cert.DNSNames, err)
}
r.logger.Warn("unauthorized raft connection", "remote_addr", rpcCtx.Conn.RemoteAddr(), "required_hostname", expected, "found", cert.DNSNames)
return fmt.Errorf("certificate is invalid for expected role or region: %q", expected)
return fmt.Errorf("unauthorized raft connection from %s: %v", rpcCtx.Conn.RemoteAddr(), err)
}
// Certificate is valid for the expected name
return nil
}

View File

@ -1018,7 +1018,7 @@ func TestRPC_Limits_Streaming(t *testing.T) {
})
}
func TestRPC_TLS_Enforcement(t *testing.T) {
func TestRPC_TLS_Enforcement_Raft(t *testing.T) {
t.Parallel()
defer func() {
@ -1026,27 +1026,349 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
time.Sleep(1 * time.Second)
}()
dir := tmpDir(t)
defer os.RemoveAll(dir)
tlsHelper := newTLSTestHelper(t)
defer tlsHelper.cleanup()
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
// When VerifyServerHostname is enabled:
// Only local servers can connect to the Raft layer
cases := []struct {
name string
cn string
canRaft bool
}{
{
name: "local server",
cn: "server.global.nomad",
canRaft: true,
},
{
name: "local client",
cn: "client.global.nomad",
canRaft: false,
},
{
name: "other region server",
cn: "server.other.nomad",
canRaft: false,
},
{
name: "other region client",
cn: "client.other.nomad",
canRaft: false,
},
{
name: "irrelevant cert",
cn: "nomad.example.com",
canRaft: false,
},
{
name: "globs",
cn: "*.global.nomad",
canRaft: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := tlsHelper.newCert(t, tc.cn)
cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg)
// the expected error depends on location of failure.
// We expect "bad certificate" if connection fails during handshake,
// or EOF when connection is closed after RaftRPC byte.
if tc.canRaft {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Regexp(t, "(bad certificate|EOF)", err.Error())
}
})
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
err := tlsHelper.raftRPC(t, tlsHelper.nonVerifyServer, cfg)
require.NoError(t, err)
})
})
}
}
func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
t.Parallel()
defer func() {
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()
tlsHelper := newTLSTestHelper(t)
defer tlsHelper.cleanup()
standardRPCs := map[string]interface{}{
"Status.Ping": struct{}{},
}
localServersOnlyRPCs := map[string]interface{}{
"Eval.Update": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Ack": &structs.EvalAckRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Nack": &structs.EvalAckRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Dequeue": &structs.EvalDequeueRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Create": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Reblock": &structs.EvalUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Eval.Reap": &structs.EvalDeleteRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Plan.Submit": &structs.PlanRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Deployment.Reap": &structs.DeploymentDeleteRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
}
localClientsOnlyRPCs := map[string]interface{}{
"Alloc.GetAllocs": &structs.AllocsGetRequest{
QueryOptions: structs.QueryOptions{Region: "global"},
},
"Node.EmitEvents": &structs.EmitNodeEventsRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
"Node.UpdateAlloc": &structs.AllocUpdateRequest{
WriteRequest: structs.WriteRequest{Region: "global"},
},
}
// When VerifyServerHostname is enabled:
// All servers can make RPC requests
// Only local clients can make RPC requests
// Some endpoints can only be called server -> server
// Some endpoints can only be called client -> server
cases := []struct {
name string
cn string
rpcs map[string]interface{}
canRPC bool
}{
// Local server.
{
name: "local server/standard rpc",
cn: "server.global.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "local server/servers only rpc",
cn: "server.global.nomad",
rpcs: localServersOnlyRPCs,
canRPC: true,
},
{
name: "local server/clients only rpc",
cn: "server.global.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
},
// Local client.
{
name: "local client/standard rpc",
cn: "client.global.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "local client/servers only rpc",
cn: "client.global.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "local client/clients only rpc",
cn: "client.global.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: true,
},
// Other region server.
{
name: "other region server/standard rpc",
cn: "server.other.nomad",
rpcs: standardRPCs,
canRPC: true,
},
{
name: "other region server/servers only rpc",
cn: "server.other.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "other region server/clients only rpc",
cn: "server.other.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
},
// Other region client.
{
name: "other region client/standard rpc",
cn: "client.other.nomad",
rpcs: standardRPCs,
canRPC: false,
},
{
name: "other region client/servers only rpc",
cn: "client.other.nomad",
rpcs: localServersOnlyRPCs,
canRPC: false,
},
{
name: "other region client/clients only rpc",
cn: "client.other.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
},
// Wrong certs.
{
name: "irrelevant cert",
cn: "nomad.example.com",
rpcs: standardRPCs,
canRPC: false,
},
{
name: "globs",
cn: "*.global.nomad",
rpcs: standardRPCs,
canRPC: false,
},
{},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := tlsHelper.newCert(t, tc.cn)
cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}
for method, arg := range tc.rpcs {
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) {
err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg)
if tc.canRPC {
if err != nil {
require.NotContains(t, err, "certificate")
}
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "certificate")
}
})
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) {
err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg)
if err != nil {
require.NotContains(t, err, "certificate")
}
})
}
})
}
}
type tlsTestHelper struct {
dir string
nodeID int
mtlsServer *Server
mtlsServerCleanup func()
nonVerifyServer *Server
nonVerifyServerCleanup func()
caPEM string
pk string
serverCert string
}
func newTLSTestHelper(t *testing.T) tlsTestHelper {
var err error
h := tlsTestHelper{
dir: tmpDir(t),
nodeID: 1,
}
// Generate CA certificate and write it to disk.
h.caPEM, h.pk, err = tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
err = ioutil.WriteFile(filepath.Join(h.dir, "ca.pem"), []byte(h.caPEM), 0600)
require.NoError(t, err)
nodeID := 1
newCert := func(t *testing.T, name string) string {
// Generate servers and their certificate.
h.serverCert = h.newCert(t, "server.global.nomad")
h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: false,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
return h
}
func (h tlsTestHelper) cleanup() {
h.mtlsServerCleanup()
h.nonVerifyServerCleanup()
os.RemoveAll(h.dir)
}
func (h tlsTestHelper) newCert(t *testing.T, name string) string {
t.Helper()
node := fmt.Sprintf("node%d", nodeID)
nodeID++
signer, err := tlsutil.ParseSigner(pk)
node := fmt.Sprintf("node%d", h.nodeID)
h.nodeID++
signer, err := tlsutil.ParseSigner(h.pk)
require.NoError(t, err)
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: caPEM,
CA: h.caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
@ -1054,15 +1376,15 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
})
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)
return filepath.Join(dir, node+"-"+name)
}
return filepath.Join(h.dir, node+"-"+name)
}
connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
require.NoError(t, err)
@ -1082,23 +1404,22 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
require.NoError(t, tlsConn.Handshake())
return tlsConn
}
}
nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
conn := connect(t, s, c)
func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error {
conn := h.connect(t, s, c)
defer conn.Close()
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
require.NoError(t, err)
codec := pool.NewClientCodec(conn)
arg := struct{}{}
var out struct{}
return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out)
}
return msgpackrpc.CallWithCodec(codec, method, arg, &out)
}
raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
conn := connect(t, s, c)
func (h tlsTestHelper) raftRPC(t *testing.T, s *Server, c *config.TLSConfig) error {
conn := h.connect(t, s, c)
defer conn.Close()
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
@ -1106,129 +1427,6 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
_, err = doRaftRPC(conn, s.config.NodeName)
return err
}
// generate server cert
serverCert := newCert(t, "server.global.nomad")
mtlsS, cleanup := TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: serverCert + ".pem",
KeyFile: serverCert + ".key",
}
})
defer cleanup()
nonVerifyS, cleanup := TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: false,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: serverCert + ".pem",
KeyFile: serverCert + ".key",
}
})
defer cleanup()
// When VerifyServerHostname is enabled:
// Only all servers and local clients can make RPC requests
// Only local servers can connect to the Raft layer
cases := []struct {
name string
cn string
canRPC bool
canRaft bool
}{
{
name: "local server",
cn: "server.global.nomad",
canRPC: true,
canRaft: true,
},
{
name: "local client",
cn: "client.global.nomad",
canRPC: true,
canRaft: false,
},
{
name: "other region server",
cn: "server.other.nomad",
canRPC: true,
canRaft: false,
},
{
name: "other client server",
cn: "client.other.nomad",
canRPC: false,
canRaft: false,
},
{
name: "irrelevant cert",
cn: "nomad.example.com",
canRPC: false,
canRaft: false,
},
{
name: "globs",
cn: "*.global.nomad",
canRPC: false,
canRaft: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := newCert(t, tc.cn)
cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}
t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) {
err := nomadRPC(t, mtlsS, cfg)
if tc.canRPC {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "bad certificate")
}
})
t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) {
err := nomadRPC(t, nonVerifyS, cfg)
require.NoError(t, err)
})
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
err := raftRPC(t, mtlsS, cfg)
// the expected error depends on location of failure.
// We expect "bad certificate" if connection fails during handshake,
// or EOF when connection is closed after RaftRPC byte.
if tc.canRaft {
require.NoError(t, err)
} else if !tc.canRPC {
require.Error(t, err)
require.Contains(t, err.Error(), "bad certificate")
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "EOF")
}
})
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
err := raftRPC(t, nonVerifyS, cfg)
require.NoError(t, err)
})
})
}
}
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {

View File

@ -265,9 +265,6 @@ type endpoints struct {
Status *Status
Node *Node
Job *Job
Eval *Eval
Plan *Plan
Alloc *Alloc
CSIVolume *CSIVolume
CSIPlugin *CSIPlugin
Deployment *Deployment
@ -1151,18 +1148,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
if s.staticEndpoints.Status == nil {
// Initialize the list just once
s.staticEndpoints.ACL = &ACL{srv: s, logger: s.logger.Named("acl")}
s.staticEndpoints.Alloc = &Alloc{srv: s, logger: s.logger.Named("alloc")}
s.staticEndpoints.Eval = &Eval{srv: s, logger: s.logger.Named("eval")}
s.staticEndpoints.Job = NewJobEndpoints(s)
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")} // Add but don't register
s.staticEndpoints.CSIVolume = &CSIVolume{srv: s, logger: s.logger.Named("csi_volume")}
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
s.staticEndpoints.Operator.register()
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")}
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}
s.staticEndpoints.Scaling = &Scaling{srv: s, logger: s.logger.Named("scaling")}
s.staticEndpoints.Status = &Status{srv: s, logger: s.logger.Named("status")}
@ -1171,6 +1163,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
s.staticEndpoints.Namespace = &Namespace{srv: s}
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s)
// These endpoints are dynamic because they need access to the
// RPCContext, but they also need to be called directly in some cases,
// so store them into staticEndpoints for later access, but don't
// register them as static.
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")}
// Client endpoints
s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")}
s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")}
@ -1191,15 +1190,11 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
// Register the static handlers
server.Register(s.staticEndpoints.ACL)
server.Register(s.staticEndpoints.Alloc)
server.Register(s.staticEndpoints.Eval)
server.Register(s.staticEndpoints.Job)
server.Register(s.staticEndpoints.CSIVolume)
server.Register(s.staticEndpoints.CSIPlugin)
server.Register(s.staticEndpoints.Deployment)
server.Register(s.staticEndpoints.Operator)
server.Register(s.staticEndpoints.Periodic)
server.Register(s.staticEndpoints.Plan)
server.Register(s.staticEndpoints.Region)
server.Register(s.staticEndpoints.Scaling)
server.Register(s.staticEndpoints.Status)
@ -1214,10 +1209,18 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
server.Register(s.staticEndpoints.Namespace)
// Create new dynamic endpoints and add them to the RPC server.
alloc := &Alloc{srv: s, ctx: ctx, logger: s.logger.Named("alloc")}
deployment := &Deployment{srv: s, ctx: ctx, logger: s.logger.Named("deployment")}
eval := &Eval{srv: s, ctx: ctx, logger: s.logger.Named("eval")}
node := &Node{srv: s, ctx: ctx, logger: s.logger.Named("client")}
plan := &Plan{srv: s, ctx: ctx, logger: s.logger.Named("plan")}
// Register the dynamic endpoints
server.Register(alloc)
server.Register(deployment)
server.Register(eval)
server.Register(node)
server.Register(plan)
}
// setupRaft is used to setup and initialize Raft

View File

@ -301,3 +301,27 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) {
return alloc, nil
}
// validateLocalClientTLSCertificate checks if the provided RPC connection was
// initiated by a client in the same region as the target server.
func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("client.%s.nomad", srv.Region())
return validateTLSCertificate(srv, ctx, expected)
}
// validateLocalServerTLSCertificate checks if the provided RPC connection was
// initiated by a server in the same region as the target server.
func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("server.%s.nomad", srv.Region())
return validateTLSCertificate(srv, ctx, expected)
}
// validateTLSCertificate checks if the RPC connection mTLS certificates are
// valid for the given name.
func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error {
if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname {
return nil
}
return ctx.ValidateCertificateForName(name)
}