fix mTLS certificate check on agent to agent RPCs (#11998)

PR #11956 implemented a new mTLS RPC check to validate the role of the
certificate used in the request, but further testing revealed two flaws:

  1. client-only endpoints did not accept server certificates so the
     request would fail when forwarded from one server to another.
  2. the certificate was being checked after the request was forwarded,
     so the check would happen over the server certificate, not the
     actual source.

This commit checks for the desired mTLS level, where the client level
accepts both, a server or a client certificate. It also validates the
cercertificate before the request is forwarded.
This commit is contained in:
Luiz Aoqui 2022-02-04 20:35:20 -05:00 committed by GitHub
parent 5faf344152
commit 0e09b120e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 165 additions and 101 deletions

View File

@ -30,26 +30,11 @@ rules:
# Pattern used by endpoints called exclusively between agents
# (server -> server or client -> server)
- pattern-not-inside: |
... := validateTLSCertificateLevel(...)
...
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateLocalClientTLSCertificate(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateLocalServerTLSCertificate(...)
...
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {
return err
}
...
... := validateTLSCertificate(...)
...
# Pattern used by some Node endpoints.
- pattern-not-inside: |
if done, err := $A.$B.forward($METHOD, ...); done {

View File

@ -222,16 +222,18 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest,
// GetAllocs is used to lookup a set of allocations
func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest,
reply *structs.AllocsGetResponse) error {
// Ensure the connection was initiated by a client if TLS is used.
err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelClient)
if err != nil {
return err
}
if done, err := a.srv.forward("Alloc.GetAllocs", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err)
}
allocs := make([]*structs.Allocation, len(args.AllocIDs))
// Setup the blocking query. We wait for at least one of the requested

View File

@ -504,16 +504,18 @@ func (d *Deployment) Allocations(args *structs.DeploymentSpecificRequest, reply
// Reap is used to cleanup terminal deployments
func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(d.srv, d.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := d.srv.forward("Deployment.Reap", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err)
}
// Update via Raft
_, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args)
if err != nil {

View File

@ -85,16 +85,18 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest,
// Dequeue is used to dequeue a pending evaluation
func (e *Eval) Dequeue(args *structs.EvalDequeueRequest,
reply *structs.EvalDequeueResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Dequeue", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is at least one scheduler
if len(args.Schedulers) == 0 {
return fmt.Errorf("dequeue requires at least one scheduler type")
@ -175,16 +177,18 @@ func (e *Eval) getWaitIndex(namespace, job string, evalModifyIndex uint64) (uint
// Ack is used to acknowledge completion of a dequeued evaluation
func (e *Eval) Ack(args *structs.EvalAckRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Ack", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ack the EvalID
if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil {
return err
@ -195,16 +199,18 @@ func (e *Eval) Ack(args *structs.EvalAckRequest,
// Nack is used to negative acknowledge completion of a dequeued evaluation.
func (e *Eval) Nack(args *structs.EvalAckRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Nack", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Nack the EvalID
if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil {
return err
@ -215,16 +221,18 @@ func (e *Eval) Nack(args *structs.EvalAckRequest,
// Update is used to perform an update of an Eval if it is outstanding.
func (e *Eval) Update(args *structs.EvalUpdateRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Update", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be updated")
@ -250,16 +258,18 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest,
// Create is used to make a new evaluation
func (e *Eval) Create(args *structs.EvalUpdateRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Create", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be created")
@ -300,16 +310,17 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest,
// Reblock is used to reinsert an existing blocked evaluation into the blocked
// evaluation tracker.
func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Reblock", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Ensure there is only a single update with token
if len(args.Evals) != 1 {
return fmt.Errorf("only a single eval can be reblocked")
@ -347,16 +358,18 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe
// Reap is used to cleanup dead evaluations and allocations
func (e *Eval) Reap(args *structs.EvalDeleteRequest,
reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(e.srv, e.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := e.srv.forward("Eval.Reap", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
}
// Update via Raft
_, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args)
if err != nil {

View File

@ -1098,16 +1098,17 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest,
// UpdateAlloc is used to update the client status of an allocation
func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.GenericResponse) error {
// Ensure the connection was initiated by another client if TLS is used.
err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient)
if err != nil {
return err
}
if done, err := n.srv.forward("Node.UpdateAlloc", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
}
// Ensure at least a single alloc
if len(args.Alloc) == 0 {
return fmt.Errorf("must update at least one allocation")
@ -1920,16 +1921,17 @@ func taskUsesConnect(task *structs.Task) bool {
}
func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.EmitNodeEventsResponse) error {
// Ensure the connection was initiated by another client if TLS is used.
err := validateTLSCertificateLevel(n.srv, n.ctx, tlsCertificateLevelClient)
if err != nil {
return err
}
if done, err := n.srv.forward("Node.EmitEvents", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now())
// Ensure the connection was initiated by a client if TLS is used.
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
}
if len(args.NodeEvents) == 0 {
return fmt.Errorf("no node events given")
}

View File

@ -21,16 +21,17 @@ type Plan struct {
// Submit is used to submit a plan to the leader
func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) error {
// Ensure the connection was initiated by another server if TLS is used.
err := validateTLSCertificateLevel(p.srv, p.ctx, tlsCertificateLevelServer)
if err != nil {
return err
}
if done, err := p.srv.forward("Plan.Submit", args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now())
// Ensure the connection was initiated by another server if TLS is used.
if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err)
}
if args.Plan == nil {
return fmt.Errorf("cannot submit nil plan")
}

View File

@ -127,16 +127,16 @@ func (ctx *RPCContext) ValidateCertificateForName(name string) error {
if cert == nil {
return errors.New("missing certificate information")
}
for _, dnsName := range cert.DNSNames {
if dnsName == name {
validNames := []string{cert.Subject.CommonName}
validNames = append(validNames, cert.DNSNames...)
for _, valid := range validNames {
if name == valid {
return nil
}
}
if cert.Subject.CommonName == name {
return nil
}
return fmt.Errorf("certificate not valid for %q", name)
return fmt.Errorf("invalid certificate, %s not in %s", name, strings.Join(validNames, ","))
}
// listen is used to listen for incoming RPC connections

View File

@ -1081,7 +1081,7 @@ func TestRPC_TLS_Enforcement_Raft(t *testing.T) {
}
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg)
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer1, cfg)
// the expected error depends on location of failure.
// We expect "bad certificate" if connection fails during handshake,
@ -1186,7 +1186,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
name: "local server/clients only rpc",
cn: "server.global.nomad",
rpcs: localClientsOnlyRPCs,
canRPC: false,
canRPC: true,
},
// Local client.
{
@ -1274,18 +1274,22 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
}
for method, arg := range tc.rpcs {
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) {
err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg)
for _, srv := range []*Server{tlsHelper.mtlsServer1, tlsHelper.mtlsServer2} {
name := fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true leader=%v", method, srv.IsLeader())
t.Run(name, func(t *testing.T) {
err := tlsHelper.nomadRPC(t, srv, cfg, method, arg)
if tc.canRPC {
if err != nil {
require.NotContains(t, err, "certificate")
if tc.canRPC {
if err != nil {
require.NotContains(t, err, "certificate")
}
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "certificate")
}
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "certificate")
}
})
})
}
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) {
err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg)
if err != nil {
@ -1301,8 +1305,10 @@ type tlsTestHelper struct {
dir string
nodeID int
mtlsServer *Server
mtlsServerCleanup func()
mtlsServer1 *Server
mtlsServer1Cleanup func()
mtlsServer2 *Server
mtlsServer2Cleanup func()
nonVerifyServer *Server
nonVerifyServerCleanup func()
@ -1329,7 +1335,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper {
// Generate servers and their certificate.
h.serverCert = h.newCert(t, "server.global.nomad")
h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) {
h.mtlsServer1, h.mtlsServer1Cleanup = TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
@ -1338,6 +1345,19 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper {
KeyFile: h.serverCert + ".key",
}
})
h.mtlsServer2, h.mtlsServer2Cleanup = TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(h.dir, "ca.pem"),
CertFile: h.serverCert + ".pem",
KeyFile: h.serverCert + ".key",
}
})
TestJoin(t, h.mtlsServer1, h.mtlsServer2)
testutil.WaitForLeader(t, h.mtlsServer1.RPC)
testutil.WaitForLeader(t, h.mtlsServer2.RPC)
h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
@ -1353,7 +1373,8 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper {
}
func (h tlsTestHelper) cleanup() {
h.mtlsServerCleanup()
h.mtlsServer1Cleanup()
h.mtlsServer2Cleanup()
h.nonVerifyServerCleanup()
os.RemoveAll(h.dir)
}

View File

@ -302,18 +302,56 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) {
return alloc, nil
}
// tlsCertificateLevel represents a role level for mTLS certificates.
type tlsCertificateLevel int8
const (
tlsCertificateLevelServer tlsCertificateLevel = iota
tlsCertificateLevelClient
)
// validateTLSCertificateLevel checks if the provided RPC connection was
// initiated with a certificate that matches the given TLS role level.
//
// - tlsCertificateLevelServer requires a server certificate.
// - tlsCertificateLevelServer requires a client or server certificate.
func validateTLSCertificateLevel(srv *Server, ctx *RPCContext, lvl tlsCertificateLevel) error {
switch lvl {
case tlsCertificateLevelClient:
err := validateLocalClientTLSCertificate(srv, ctx)
if err != nil {
return validateLocalServerTLSCertificate(srv, ctx)
}
return nil
case tlsCertificateLevelServer:
return validateLocalServerTLSCertificate(srv, ctx)
}
return fmt.Errorf("invalid TLS certificate level %v", lvl)
}
// validateLocalClientTLSCertificate checks if the provided RPC connection was
// initiated by a client in the same region as the target server.
func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("client.%s.nomad", srv.Region())
return validateTLSCertificate(srv, ctx, expected)
err := validateTLSCertificate(srv, ctx, expected)
if err != nil {
return fmt.Errorf("invalid client connection in region %s: %v", srv.Region(), err)
}
return nil
}
// validateLocalServerTLSCertificate checks if the provided RPC connection was
// initiated by a server in the same region as the target server.
func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error {
expected := fmt.Sprintf("server.%s.nomad", srv.Region())
return validateTLSCertificate(srv, ctx, expected)
err := validateTLSCertificate(srv, ctx, expected)
if err != nil {
return fmt.Errorf("invalid server connection in region %s: %v", srv.Region(), err)
}
return nil
}
// validateTLSCertificate checks if the RPC connection mTLS certificates are