From 0e09b120e49e3c5914ce8496c28013e8db018270 Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Fri, 4 Feb 2022 20:35:20 -0500 Subject: [PATCH] fix mTLS certificate check on agent to agent RPCs (#11998) PR #11956 implemented a new mTLS RPC check to validate the role of the certificate used in the request, but further testing revealed two flaws: 1. client-only endpoints did not accept server certificates so the request would fail when forwarded from one server to another. 2. the certificate was being checked after the request was forwarded, so the check would happen over the server certificate, not the actual source. This commit checks for the desired mTLS level, where the client level accepts both, a server or a client certificate. It also validates the cercertificate before the request is forwarded. --- .semgrep/rpc_endpoint.yml | 19 +-------- nomad/alloc_endpoint.go | 12 +++--- nomad/deployment_endpoint.go | 12 +++--- nomad/eval_endpoint.go | 83 +++++++++++++++++++++--------------- nomad/node_endpoint.go | 22 +++++----- nomad/plan_endpoint.go | 11 ++--- nomad/rpc.go | 12 +++--- nomad/rpc_test.go | 53 ++++++++++++++++------- nomad/util.go | 42 +++++++++++++++++- 9 files changed, 165 insertions(+), 101 deletions(-) diff --git a/.semgrep/rpc_endpoint.yml b/.semgrep/rpc_endpoint.yml index 2277a6b19..9f22f67a2 100644 --- a/.semgrep/rpc_endpoint.yml +++ b/.semgrep/rpc_endpoint.yml @@ -30,26 +30,11 @@ rules: # Pattern used by endpoints called exclusively between agents # (server -> server or client -> server) - pattern-not-inside: | + ... := validateTLSCertificateLevel(...) + ... if done, err := $A.$B.forward($METHOD, ...); done { return err } - ... - ... := validateLocalClientTLSCertificate(...) - ... - - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateLocalServerTLSCertificate(...) - ... - - pattern-not-inside: | - if done, err := $A.$B.forward($METHOD, ...); done { - return err - } - ... - ... := validateTLSCertificate(...) - ... # Pattern used by some Node endpoints. - pattern-not-inside: | if done, err := $A.$B.forward($METHOD, ...); done { diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index 6c8231c6b..92abee62f 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -222,16 +222,18 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest, // GetAllocs is used to lookup a set of allocations func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, reply *structs.AllocsGetResponse) error { + + // Ensure the connection was initiated by a client if TLS is used. + err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := a.srv.forward("Alloc.GetAllocs", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err) - } - allocs := make([]*structs.Allocation, len(args.AllocIDs)) // Setup the blocking query. We wait for at least one of the requested diff --git a/nomad/deployment_endpoint.go b/nomad/deployment_endpoint.go index 0bc073768..2c18de98d 100644 --- a/nomad/deployment_endpoint.go +++ b/nomad/deployment_endpoint.go @@ -504,16 +504,18 @@ func (d *Deployment) Allocations(args *structs.DeploymentSpecificRequest, reply // Reap is used to cleanup terminal deployments func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(d.srv, d.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := d.srv.forward("Deployment.Reap", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err) - } - // Update via Raft _, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args) if err != nil { diff --git a/nomad/eval_endpoint.go b/nomad/eval_endpoint.go index 18b83c45d..8a48e27c1 100644 --- a/nomad/eval_endpoint.go +++ b/nomad/eval_endpoint.go @@ -85,16 +85,18 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest, // Dequeue is used to dequeue a pending evaluation func (e *Eval) Dequeue(args *structs.EvalDequeueRequest, reply *structs.EvalDequeueResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Dequeue", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is at least one scheduler if len(args.Schedulers) == 0 { return fmt.Errorf("dequeue requires at least one scheduler type") @@ -175,16 +177,18 @@ func (e *Eval) getWaitIndex(namespace, job string, evalModifyIndex uint64) (uint // Ack is used to acknowledge completion of a dequeued evaluation func (e *Eval) Ack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Ack", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ack the EvalID if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil { return err @@ -195,16 +199,18 @@ func (e *Eval) Ack(args *structs.EvalAckRequest, // Nack is used to negative acknowledge completion of a dequeued evaluation. func (e *Eval) Nack(args *structs.EvalAckRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Nack", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Nack the EvalID if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil { return err @@ -215,16 +221,18 @@ func (e *Eval) Nack(args *structs.EvalAckRequest, // Update is used to perform an update of an Eval if it is outstanding. func (e *Eval) Update(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Update", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be updated") @@ -250,16 +258,18 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest, // Create is used to make a new evaluation func (e *Eval) Create(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Create", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be created") @@ -300,16 +310,17 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest, // Reblock is used to reinsert an existing blocked evaluation into the blocked // evaluation tracker. func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Reblock", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be reblocked") @@ -347,16 +358,18 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe // Reap is used to cleanup dead evaluations and allocations func (e *Eval) Reap(args *structs.EvalDeleteRequest, reply *structs.GenericResponse) error { + + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := e.srv.forward("Eval.Reap", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) - } - // Update via Raft _, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args) if err != nil { diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 8ed43b2dc..d5a1725b4 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1098,16 +1098,17 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, // UpdateAlloc is used to update the client status of an allocation func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.GenericResponse) error { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := n.srv.forward("Node.UpdateAlloc", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) - } - // Ensure at least a single alloc if len(args.Alloc) == 0 { return fmt.Errorf("must update at least one allocation") @@ -1920,16 +1921,17 @@ func taskUsesConnect(task *structs.Task) bool { } func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.EmitNodeEventsResponse) error { + // Ensure the connection was initiated by another client if TLS is used. + err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient) + if err != nil { + return err + } + if done, err := n.srv.forward("Node.EmitEvents", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now()) - // Ensure the connection was initiated by a client if TLS is used. - if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { - return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) - } - if len(args.NodeEvents) == 0 { return fmt.Errorf("no node events given") } diff --git a/nomad/plan_endpoint.go b/nomad/plan_endpoint.go index a6cd8dbef..4979270e4 100644 --- a/nomad/plan_endpoint.go +++ b/nomad/plan_endpoint.go @@ -21,16 +21,17 @@ type Plan struct { // Submit is used to submit a plan to the leader func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) error { + // Ensure the connection was initiated by another server if TLS is used. + err := validateTLSCertificateLevel(p.srv, p.ctx, tlsCertificateLevelServer) + if err != nil { + return err + } + if done, err := p.srv.forward("Plan.Submit", args, args, reply); done { return err } defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now()) - // Ensure the connection was initiated by another server if TLS is used. - if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil { - return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err) - } - if args.Plan == nil { return fmt.Errorf("cannot submit nil plan") } diff --git a/nomad/rpc.go b/nomad/rpc.go index 37446d53e..96db49b64 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -127,16 +127,16 @@ func (ctx *RPCContext) ValidateCertificateForName(name string) error { if cert == nil { return errors.New("missing certificate information") } - for _, dnsName := range cert.DNSNames { - if dnsName == name { + + validNames := []string{cert.Subject.CommonName} + validNames = append(validNames, cert.DNSNames...) + for _, valid := range validNames { + if name == valid { return nil } } - if cert.Subject.CommonName == name { - return nil - } - return fmt.Errorf("certificate not valid for %q", name) + return fmt.Errorf("invalid certificate, %s not in %s", name, strings.Join(validNames, ",")) } // listen is used to listen for incoming RPC connections diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 07f2d9492..bd738f279 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1081,7 +1081,7 @@ func TestRPC_TLS_Enforcement_Raft(t *testing.T) { } t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) { - err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg) + err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer1, cfg) // the expected error depends on location of failure. // We expect "bad certificate" if connection fails during handshake, @@ -1186,7 +1186,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { name: "local server/clients only rpc", cn: "server.global.nomad", rpcs: localClientsOnlyRPCs, - canRPC: false, + canRPC: true, }, // Local client. { @@ -1274,18 +1274,22 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { } for method, arg := range tc.rpcs { - t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) { - err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg) + for _, srv := range []*Server{tlsHelper.mtlsServer1, tlsHelper.mtlsServer2} { + name := fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true leader=%v", method, srv.IsLeader()) + t.Run(name, func(t *testing.T) { + err := tlsHelper.nomadRPC(t, srv, cfg, method, arg) - if tc.canRPC { - if err != nil { - require.NotContains(t, err, "certificate") + if tc.canRPC { + if err != nil { + require.NotContains(t, err, "certificate") + } + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "certificate") } - } else { - require.Error(t, err) - require.Contains(t, err.Error(), "certificate") - } - }) + }) + } + t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) { err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg) if err != nil { @@ -1301,8 +1305,10 @@ type tlsTestHelper struct { dir string nodeID int - mtlsServer *Server - mtlsServerCleanup func() + mtlsServer1 *Server + mtlsServer1Cleanup func() + mtlsServer2 *Server + mtlsServer2Cleanup func() nonVerifyServer *Server nonVerifyServerCleanup func() @@ -1329,7 +1335,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { // Generate servers and their certificate. h.serverCert = h.newCert(t, "server.global.nomad") - h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) { + h.mtlsServer1, h.mtlsServer1Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 c.TLSConfig = &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, @@ -1338,6 +1345,19 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { KeyFile: h.serverCert + ".key", } }) + h.mtlsServer2, h.mtlsServer2Cleanup = TestServer(t, func(c *Config) { + c.BootstrapExpect = 2 + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(h.dir, "ca.pem"), + CertFile: h.serverCert + ".pem", + KeyFile: h.serverCert + ".key", + } + }) + TestJoin(t, h.mtlsServer1, h.mtlsServer2) + testutil.WaitForLeader(t, h.mtlsServer1.RPC) + testutil.WaitForLeader(t, h.mtlsServer2.RPC) h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) { c.TLSConfig = &config.TLSConfig{ @@ -1353,7 +1373,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { } func (h tlsTestHelper) cleanup() { - h.mtlsServerCleanup() + h.mtlsServer1Cleanup() + h.mtlsServer2Cleanup() h.nonVerifyServerCleanup() os.RemoveAll(h.dir) } diff --git a/nomad/util.go b/nomad/util.go index daa6999f8..210a202d9 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -302,18 +302,56 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) { return alloc, nil } +// tlsCertificateLevel represents a role level for mTLS certificates. +type tlsCertificateLevel int8 + +const ( + tlsCertificateLevelServer tlsCertificateLevel = iota + tlsCertificateLevelClient +) + +// validateTLSCertificateLevel checks if the provided RPC connection was +// initiated with a certificate that matches the given TLS role level. +// +// - tlsCertificateLevelServer requires a server certificate. +// - tlsCertificateLevelServer requires a client or server certificate. +func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error { + switch lvl { + case tlsCertificateLevelClient: + err := validateLocalClientTLSCertificate(srv, ctx) + if err != nil { + return validateLocalServerTLSCertificate(srv, ctx) + } + return nil + case tlsCertificateLevelServer: + return validateLocalServerTLSCertificate(srv, ctx) + } + + return fmt.Errorf("invalid TLS certificate level %v", lvl) +} + // validateLocalClientTLSCertificate checks if the provided RPC connection was // initiated by a client in the same region as the target server. func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("client.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err) + } + return nil } // validateLocalServerTLSCertificate checks if the provided RPC connection was // initiated by a server in the same region as the target server. func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error { expected := fmt.Sprintf("server.%s.nomad", srv.Region()) - return validateTLSCertificate(srv, ctx, expected) + + err := validateTLSCertificate(srv, ctx, expected) + if err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err) + } + return nil } // validateTLSCertificate checks if the RPC connection mTLS certificates are