From c4cff5359fb35a597fbbca15ad793b26463ff059 Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Wed, 2 Feb 2022 15:03:18 -0500 Subject: [PATCH] Verify TLS certificate on endpoints that are used between agents only (#11956) --- .changelog/11956.txt | 3 + .semgrep/rpc_endpoint.yml | 80 ++++++ nomad/alloc_endpoint.go | 8 + nomad/deployment_endpoint.go | 8 + nomad/eval_endpoint.go | 38 +++ nomad/job_endpoint.go | 9 +- nomad/node_endpoint.go | 10 + nomad/plan_endpoint.go | 8 + nomad/rpc.go | 62 +++-- nomad/rpc_test.go | 476 +++++++++++++++++++++++++---------- nomad/server.go | 27 +- nomad/util.go | 24 ++ 12 files changed, 576 insertions(+), 177 deletions(-) create mode 100644 .changelog/11956.txt create mode 100644 .semgrep/rpc_endpoint.yml diff --git a/.changelog/11956.txt b/.changelog/11956.txt new file mode 100644 index 000000000..6174cb824 --- /dev/null +++ b/.changelog/11956.txt @@ -0,0 +1,3 @@ +```release-note:security +server: validate mTLS certificate names on agent to agent endpoints +``` diff --git a/.semgrep/rpc_endpoint.yml b/.semgrep/rpc_endpoint.yml new file mode 100644 index 000000000..2277a6b19 --- /dev/null +++ b/.semgrep/rpc_endpoint.yml @@ -0,0 +1,80 @@ +rules: + # Check potentially unauthenticated RPC endpoints + - id: "rpc-potentially-unauthenticated" + patterns: + - pattern: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := $X.$Y.ResolveToken(...) + ... + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := $U.requestACLToken(...) + ... + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := $T.NamespaceValidator(...) + ... + # Pattern used by endpoints called exclusively between agents + # (server -> server or client -> server) + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := validateLocalClientTLSCertificate(...) + ... + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := validateLocalServerTLSCertificate(...) + ... + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + ... := validateTLSCertificate(...) + ... + # Pattern used by some Node endpoints. + - pattern-not-inside: | + if done, err := $A.$B.forward($METHOD, ...); done { + return err + } + ... + return $A.deregister(...) + ... + - metavariable-pattern: + metavariable: $METHOD + patterns: + # Endpoints that are expected not to have authentication. + - pattern-not: '"ACL.Bootstrap"' + - pattern-not: '"ACL.ResolveToken"' + - pattern-not: '"ACL.UpsertOneTimeToken"' + - pattern-not: '"ACL.ExchangeOneTimeToken"' + - pattern-not: '"CSIPlugin.Get"' + - pattern-not: '"CSIPlugin.List"' + - pattern-not: '"Status.Leader"' + - pattern-not: '"Status.Peers"' + - pattern-not: '"Status.Version"' + message: "RPC method $METHOD appears to be unauthenticated" + languages: + - "go" + severity: "WARNING" + paths: + include: + - "*_endpoint.go" diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index 0b44175ad..6c8231c6b 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -20,6 +20,9 @@ import ( type Alloc struct { srv *Server logger log.Logger + + // ctx provides context regarding the underlying connection + ctx *RPCContext } // List is used to list the allocations in the system @@ -224,6 +227,11 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, } defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) + // Ensure the connection was initiated by a client if TLS is used. + if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err) + } + allocs := make([]*structs.Allocation, len(args.AllocIDs)) // Setup the blocking query. We wait for at least one of the requested diff --git a/nomad/deployment_endpoint.go b/nomad/deployment_endpoint.go index d6e06cc93..0bc073768 100644 --- a/nomad/deployment_endpoint.go +++ b/nomad/deployment_endpoint.go @@ -17,6 +17,9 @@ import ( type Deployment struct { srv *Server logger log.Logger + + // ctx provides context regarding the underlying connection + ctx *RPCContext } // GetDeployment is used to request information about a specific deployment @@ -506,6 +509,11 @@ func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest, } defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err) + } + // Update via Raft _, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args) if err != nil { diff --git a/nomad/eval_endpoint.go b/nomad/eval_endpoint.go index 5fe4d3658..18b83c45d 100644 --- a/nomad/eval_endpoint.go +++ b/nomad/eval_endpoint.go @@ -24,6 +24,9 @@ const ( type Eval struct { srv *Server logger log.Logger + + // ctx provides context regarding the underlying connection + ctx *RPCContext } // GetEval is used to request information about a specific evaluation @@ -87,6 +90,11 @@ func (e *Eval) Dequeue(args *structs.EvalDequeueRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Ensure there is at least one scheduler if len(args.Schedulers) == 0 { return fmt.Errorf("dequeue requires at least one scheduler type") @@ -172,6 +180,11 @@ func (e *Eval) Ack(args *structs.EvalAckRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Ack the EvalID if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil { return err @@ -187,6 +200,11 @@ func (e *Eval) Nack(args *structs.EvalAckRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Nack the EvalID if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil { return err @@ -202,6 +220,11 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be updated") @@ -232,6 +255,11 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be created") @@ -277,6 +305,11 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe } defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Ensure there is only a single update with token if len(args.Evals) != 1 { return fmt.Errorf("only a single eval can be reblocked") @@ -319,6 +352,11 @@ func (e *Eval) Reap(args *structs.EvalDeleteRequest, } defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err) + } + // Update via Raft _, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args) if err != nil { diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 097c48309..564fe4bf3 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -114,8 +114,8 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis reply.Warnings = structs.MergeMultierrorWarnings(warnings...) // Check job submission permissions - var aclObj *acl.ACL - if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil { + aclObj, err := j.srv.ResolveToken(args.AuthToken) + if err != nil { return err } else if aclObj != nil { if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilitySubmitJob) { @@ -1879,9 +1879,8 @@ func (j *Job) Dispatch(args *structs.JobDispatchRequest, reply *structs.JobDispa defer metrics.MeasureSince([]string{"nomad", "job", "dispatch"}, time.Now()) // Check for submit-job permissions - var aclObj *acl.ACL - var err error - if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil { + aclObj, err := j.srv.ResolveToken(args.AuthToken) + if err != nil { return err } else if aclObj != nil && !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityDispatchJob) { return structs.ErrPermissionDenied diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 75e236947..8ed43b2dc 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1103,6 +1103,11 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene } defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now()) + // Ensure the connection was initiated by a client if TLS is used. + if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) + } + // Ensure at least a single alloc if len(args.Alloc) == 0 { return fmt.Errorf("must update at least one allocation") @@ -1920,6 +1925,11 @@ func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.Em } defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now()) + // Ensure the connection was initiated by a client if TLS is used. + if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil { + return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err) + } + if len(args.NodeEvents) == 0 { return fmt.Errorf("no node events given") } diff --git a/nomad/plan_endpoint.go b/nomad/plan_endpoint.go index 9d2ea30ed..a6cd8dbef 100644 --- a/nomad/plan_endpoint.go +++ b/nomad/plan_endpoint.go @@ -14,6 +14,9 @@ import ( type Plan struct { srv *Server logger log.Logger + + // ctx provides context regarding the underlying connection + ctx *RPCContext } // Submit is used to submit a plan to the leader @@ -23,6 +26,11 @@ func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) er } defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now()) + // Ensure the connection was initiated by another server if TLS is used. + if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil { + return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err) + } + if args.Plan == nil { return fmt.Errorf("cannot submit nil plan") } diff --git a/nomad/rpc.go b/nomad/rpc.go index 869c5bacb..37446d53e 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -107,6 +107,38 @@ type RPCContext struct { NodeID string } +// Certificate returns the first certificate available in the chain. +func (ctx *RPCContext) Certificate() *x509.Certificate { + if ctx == nil || len(ctx.VerifiedChains) == 0 || len(ctx.VerifiedChains[0]) == 0 { + return nil + } + + return ctx.VerifiedChains[0][0] +} + +// ValidateCertificateForName returns true if the RPC context certificate is valid +// for the given domain name. +func (ctx *RPCContext) ValidateCertificateForName(name string) error { + if ctx == nil || !ctx.TLS { + return nil + } + + cert := ctx.Certificate() + if cert == nil { + return errors.New("missing certificate information") + } + for _, dnsName := range cert.DNSNames { + if dnsName == name { + return nil + } + } + if cert.Subject.CommonName == name { + return nil + } + + return fmt.Errorf("certificate not valid for %q", name) +} + // listen is used to listen for incoming RPC connections func (r *rpcHandler) listen(ctx context.Context) { defer close(r.listenerCh) @@ -838,30 +870,18 @@ func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error { return nil } - // defensive conditions: these should have already been enforced by handleConn - if rpcCtx == nil || !rpcCtx.TLS { - return errors.New("non-TLS connection attempted") - } - if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 { - // this should never happen, as rpcNameAndRegionValidate should have enforced it - return errors.New("missing cert info") - } - // check that `server..nomad` is present in cert expected := "server." + r.Region() + ".nomad" - - cert := rpcCtx.VerifiedChains[0][0] - for _, dnsName := range cert.DNSNames { - if dnsName == expected { - // Certificate is valid for the expected name - return nil + err := rpcCtx.ValidateCertificateForName(expected) + if err != nil { + cert := rpcCtx.Certificate() + if cert != nil { + err = fmt.Errorf("request certificate is only valid for %s: %v", cert.DNSNames, err) } - } - if cert.Subject.CommonName == expected { - // Certificate is valid for the expected name - return nil + + return fmt.Errorf("unauthorized raft connection from %s: %v", rpcCtx.Conn.RemoteAddr(), err) } - r.logger.Warn("unauthorized raft connection", "remote_addr", rpcCtx.Conn.RemoteAddr(), "required_hostname", expected, "found", cert.DNSNames) - return fmt.Errorf("certificate is invalid for expected role or region: %q", expected) + // Certificate is valid for the expected name + return nil } diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 8b3ea0f6a..07f2d9492 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1018,7 +1018,7 @@ func TestRPC_Limits_Streaming(t *testing.T) { }) } -func TestRPC_TLS_Enforcement(t *testing.T) { +func TestRPC_TLS_Enforcement_Raft(t *testing.T) { t.Parallel() defer func() { @@ -1026,211 +1026,409 @@ func TestRPC_TLS_Enforcement(t *testing.T) { time.Sleep(1 * time.Second) }() - dir := tmpDir(t) - defer os.RemoveAll(dir) - - caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"}) - require.NoError(t, err) - - err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600) - require.NoError(t, err) - - nodeID := 1 - newCert := func(t *testing.T, name string) string { - t.Helper() - - node := fmt.Sprintf("node%d", nodeID) - nodeID++ - signer, err := tlsutil.ParseSigner(pk) - require.NoError(t, err) - - pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{ - Signer: signer, - CA: caPEM, - Name: name, - Days: 5, - DNSNames: []string{node + "." + name, name, "localhost"}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - }) - require.NoError(t, err) - - err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600) - require.NoError(t, err) - err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600) - require.NoError(t, err) - - return filepath.Join(dir, node+"-"+name) - } - - connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn { - conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) - require.NoError(t, err) - - // configure TLS - _, err = conn.Write([]byte{byte(pool.RpcTLS)}) - require.NoError(t, err) - - // Client TLS verification isn't necessary for - // our assertions - tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true) - require.NoError(t, err) - outTLSConf, err := tlsConf.OutgoingTLSConfig() - require.NoError(t, err) - outTLSConf.InsecureSkipVerify = true - - tlsConn := tls.Client(conn, outTLSConf) - require.NoError(t, tlsConn.Handshake()) - - return tlsConn - } - - nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error { - conn := connect(t, s, c) - defer conn.Close() - _, err := conn.Write([]byte{byte(pool.RpcNomad)}) - require.NoError(t, err) - - codec := pool.NewClientCodec(conn) - - arg := struct{}{} - var out struct{} - return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out) - } - - raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error { - conn := connect(t, s, c) - defer conn.Close() - - _, err := conn.Write([]byte{byte(pool.RpcRaft)}) - require.NoError(t, err) - - _, err = doRaftRPC(conn, s.config.NodeName) - return err - } - - // generate server cert - serverCert := newCert(t, "server.global.nomad") - - mtlsS, cleanup := TestServer(t, func(c *Config) { - c.TLSConfig = &config.TLSConfig{ - EnableRPC: true, - VerifyServerHostname: true, - CAFile: filepath.Join(dir, "ca.pem"), - CertFile: serverCert + ".pem", - KeyFile: serverCert + ".key", - } - }) - defer cleanup() - - nonVerifyS, cleanup := TestServer(t, func(c *Config) { - c.TLSConfig = &config.TLSConfig{ - EnableRPC: true, - VerifyServerHostname: false, - CAFile: filepath.Join(dir, "ca.pem"), - CertFile: serverCert + ".pem", - KeyFile: serverCert + ".key", - } - }) - defer cleanup() + tlsHelper := newTLSTestHelper(t) + defer tlsHelper.cleanup() // When VerifyServerHostname is enabled: - // Only all servers and local clients can make RPC requests // Only local servers can connect to the Raft layer cases := []struct { name string cn string - canRPC bool canRaft bool }{ { name: "local server", cn: "server.global.nomad", - canRPC: true, canRaft: true, }, { name: "local client", cn: "client.global.nomad", - canRPC: true, canRaft: false, }, { name: "other region server", cn: "server.other.nomad", - canRPC: true, canRaft: false, }, { - name: "other client server", + name: "other region client", cn: "client.other.nomad", - canRPC: false, canRaft: false, }, { name: "irrelevant cert", cn: "nomad.example.com", - canRPC: false, canRaft: false, }, { name: "globs", cn: "*.global.nomad", - canRPC: false, canRaft: false, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - certPath := newCert(t, tc.cn) + certPath := tlsHelper.newCert(t, tc.cn) cfg := &config.TLSConfig{ EnableRPC: true, VerifyServerHostname: true, - CAFile: filepath.Join(dir, "ca.pem"), + CAFile: filepath.Join(tlsHelper.dir, "ca.pem"), CertFile: certPath + ".pem", KeyFile: certPath + ".key", } - t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) { - err := nomadRPC(t, mtlsS, cfg) - - if tc.canRPC { - require.NoError(t, err) - } else { - require.Error(t, err) - require.Contains(t, err.Error(), "bad certificate") - } - }) - t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) { - err := nomadRPC(t, nonVerifyS, cfg) - require.NoError(t, err) - }) - t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) { - err := raftRPC(t, mtlsS, cfg) + err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg) // the expected error depends on location of failure. // We expect "bad certificate" if connection fails during handshake, // or EOF when connection is closed after RaftRPC byte. if tc.canRaft { require.NoError(t, err) - } else if !tc.canRPC { - require.Error(t, err) - require.Contains(t, err.Error(), "bad certificate") } else { require.Error(t, err) - require.Contains(t, err.Error(), "EOF") + require.Regexp(t, "(bad certificate|EOF)", err.Error()) } }) t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) { - err := raftRPC(t, nonVerifyS, cfg) + err := tlsHelper.raftRPC(t, tlsHelper.nonVerifyServer, cfg) require.NoError(t, err) }) }) } } +func TestRPC_TLS_Enforcement_RPC(t *testing.T) { + t.Parallel() + + defer func() { + //TODO Avoid panics from logging during shutdown + time.Sleep(1 * time.Second) + }() + + tlsHelper := newTLSTestHelper(t) + defer tlsHelper.cleanup() + + standardRPCs := map[string]interface{}{ + "Status.Ping": struct{}{}, + } + + localServersOnlyRPCs := map[string]interface{}{ + "Eval.Update": &structs.EvalUpdateRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Ack": &structs.EvalAckRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Nack": &structs.EvalAckRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Dequeue": &structs.EvalDequeueRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Create": &structs.EvalUpdateRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Reblock": &structs.EvalUpdateRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Eval.Reap": &structs.EvalDeleteRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Plan.Submit": &structs.PlanRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Deployment.Reap": &structs.DeploymentDeleteRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + } + + localClientsOnlyRPCs := map[string]interface{}{ + "Alloc.GetAllocs": &structs.AllocsGetRequest{ + QueryOptions: structs.QueryOptions{Region: "global"}, + }, + "Node.EmitEvents": &structs.EmitNodeEventsRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + "Node.UpdateAlloc": &structs.AllocUpdateRequest{ + WriteRequest: structs.WriteRequest{Region: "global"}, + }, + } + + // When VerifyServerHostname is enabled: + // All servers can make RPC requests + // Only local clients can make RPC requests + // Some endpoints can only be called server -> server + // Some endpoints can only be called client -> server + cases := []struct { + name string + cn string + rpcs map[string]interface{} + canRPC bool + }{ + // Local server. + { + name: "local server/standard rpc", + cn: "server.global.nomad", + rpcs: standardRPCs, + canRPC: true, + }, + { + name: "local server/servers only rpc", + cn: "server.global.nomad", + rpcs: localServersOnlyRPCs, + canRPC: true, + }, + { + name: "local server/clients only rpc", + cn: "server.global.nomad", + rpcs: localClientsOnlyRPCs, + canRPC: false, + }, + // Local client. + { + name: "local client/standard rpc", + cn: "client.global.nomad", + rpcs: standardRPCs, + canRPC: true, + }, + { + name: "local client/servers only rpc", + cn: "client.global.nomad", + rpcs: localServersOnlyRPCs, + canRPC: false, + }, + { + name: "local client/clients only rpc", + cn: "client.global.nomad", + rpcs: localClientsOnlyRPCs, + canRPC: true, + }, + // Other region server. + { + name: "other region server/standard rpc", + cn: "server.other.nomad", + rpcs: standardRPCs, + canRPC: true, + }, + { + name: "other region server/servers only rpc", + cn: "server.other.nomad", + rpcs: localServersOnlyRPCs, + canRPC: false, + }, + { + name: "other region server/clients only rpc", + cn: "server.other.nomad", + rpcs: localClientsOnlyRPCs, + canRPC: false, + }, + // Other region client. + { + name: "other region client/standard rpc", + cn: "client.other.nomad", + rpcs: standardRPCs, + canRPC: false, + }, + { + name: "other region client/servers only rpc", + cn: "client.other.nomad", + rpcs: localServersOnlyRPCs, + canRPC: false, + }, + { + name: "other region client/clients only rpc", + cn: "client.other.nomad", + rpcs: localClientsOnlyRPCs, + canRPC: false, + }, + // Wrong certs. + { + name: "irrelevant cert", + cn: "nomad.example.com", + rpcs: standardRPCs, + canRPC: false, + }, + { + name: "globs", + cn: "*.global.nomad", + rpcs: standardRPCs, + canRPC: false, + }, + {}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + certPath := tlsHelper.newCert(t, tc.cn) + + cfg := &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(tlsHelper.dir, "ca.pem"), + CertFile: certPath + ".pem", + KeyFile: certPath + ".key", + } + + for method, arg := range tc.rpcs { + t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) { + err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg) + + if tc.canRPC { + if err != nil { + require.NotContains(t, err, "certificate") + } + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "certificate") + } + }) + t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) { + err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg) + if err != nil { + require.NotContains(t, err, "certificate") + } + }) + } + }) + } +} + +type tlsTestHelper struct { + dir string + nodeID int + + mtlsServer *Server + mtlsServerCleanup func() + nonVerifyServer *Server + nonVerifyServerCleanup func() + + caPEM string + pk string + serverCert string +} + +func newTLSTestHelper(t *testing.T) tlsTestHelper { + var err error + + h := tlsTestHelper{ + dir: tmpDir(t), + nodeID: 1, + } + + // Generate CA certificate and write it to disk. + h.caPEM, h.pk, err = tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"}) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(h.dir, "ca.pem"), []byte(h.caPEM), 0600) + require.NoError(t, err) + + // Generate servers and their certificate. + h.serverCert = h.newCert(t, "server.global.nomad") + + h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(h.dir, "ca.pem"), + CertFile: h.serverCert + ".pem", + KeyFile: h.serverCert + ".key", + } + }) + + h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: false, + CAFile: filepath.Join(h.dir, "ca.pem"), + CertFile: h.serverCert + ".pem", + KeyFile: h.serverCert + ".key", + } + }) + + return h +} + +func (h tlsTestHelper) cleanup() { + h.mtlsServerCleanup() + h.nonVerifyServerCleanup() + os.RemoveAll(h.dir) +} + +func (h tlsTestHelper) newCert(t *testing.T, name string) string { + t.Helper() + + node := fmt.Sprintf("node%d", h.nodeID) + h.nodeID++ + signer, err := tlsutil.ParseSigner(h.pk) + require.NoError(t, err) + + pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{ + Signer: signer, + CA: h.caPEM, + Name: name, + Days: 5, + DNSNames: []string{node + "." + name, name, "localhost"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + }) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".pem"), []byte(pem), 0600) + require.NoError(t, err) + err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".key"), []byte(key), 0600) + require.NoError(t, err) + + return filepath.Join(h.dir, node+"-"+name) +} + +func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn { + conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) + require.NoError(t, err) + + // configure TLS + _, err = conn.Write([]byte{byte(pool.RpcTLS)}) + require.NoError(t, err) + + // Client TLS verification isn't necessary for + // our assertions + tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true) + require.NoError(t, err) + outTLSConf, err := tlsConf.OutgoingTLSConfig() + require.NoError(t, err) + outTLSConf.InsecureSkipVerify = true + + tlsConn := tls.Client(conn, outTLSConf) + require.NoError(t, tlsConn.Handshake()) + + return tlsConn +} + +func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error { + conn := h.connect(t, s, c) + defer conn.Close() + _, err := conn.Write([]byte{byte(pool.RpcNomad)}) + require.NoError(t, err) + + codec := pool.NewClientCodec(conn) + + var out struct{} + return msgpackrpc.CallWithCodec(codec, method, arg, &out) +} + +func (h tlsTestHelper) raftRPC(t *testing.T, s *Server, c *config.TLSConfig) error { + conn := h.connect(t, s, c) + defer conn.Close() + + _, err := conn.Write([]byte{byte(pool.RpcRaft)}) + require.NoError(t, err) + + _, err = doRaftRPC(conn, s.config.NodeName) + return err +} + func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) { req := raft.AppendEntriesRequest{ RPCHeader: raft.RPCHeader{ProtocolVersion: 3}, diff --git a/nomad/server.go b/nomad/server.go index 1ae341af8..1dfbd6ef9 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -265,9 +265,6 @@ type endpoints struct { Status *Status Node *Node Job *Job - Eval *Eval - Plan *Plan - Alloc *Alloc CSIVolume *CSIVolume CSIPlugin *CSIPlugin Deployment *Deployment @@ -1151,18 +1148,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { if s.staticEndpoints.Status == nil { // Initialize the list just once s.staticEndpoints.ACL = &ACL{srv: s, logger: s.logger.Named("acl")} - s.staticEndpoints.Alloc = &Alloc{srv: s, logger: s.logger.Named("alloc")} - s.staticEndpoints.Eval = &Eval{srv: s, logger: s.logger.Named("eval")} s.staticEndpoints.Job = NewJobEndpoints(s) - s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")} // Add but don't register s.staticEndpoints.CSIVolume = &CSIVolume{srv: s, logger: s.logger.Named("csi_volume")} s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")} - s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")} s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")} s.staticEndpoints.Operator.register() s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")} - s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")} s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")} s.staticEndpoints.Scaling = &Scaling{srv: s, logger: s.logger.Named("scaling")} s.staticEndpoints.Status = &Status{srv: s, logger: s.logger.Named("status")} @@ -1171,6 +1163,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { s.staticEndpoints.Namespace = &Namespace{srv: s} s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s) + // These endpoints are dynamic because they need access to the + // RPCContext, but they also need to be called directly in some cases, + // so store them into staticEndpoints for later access, but don't + // register them as static. + s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")} + s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")} + // Client endpoints s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")} s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")} @@ -1191,15 +1190,11 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { // Register the static handlers server.Register(s.staticEndpoints.ACL) - server.Register(s.staticEndpoints.Alloc) - server.Register(s.staticEndpoints.Eval) server.Register(s.staticEndpoints.Job) server.Register(s.staticEndpoints.CSIVolume) server.Register(s.staticEndpoints.CSIPlugin) - server.Register(s.staticEndpoints.Deployment) server.Register(s.staticEndpoints.Operator) server.Register(s.staticEndpoints.Periodic) - server.Register(s.staticEndpoints.Plan) server.Register(s.staticEndpoints.Region) server.Register(s.staticEndpoints.Scaling) server.Register(s.staticEndpoints.Status) @@ -1214,10 +1209,18 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { server.Register(s.staticEndpoints.Namespace) // Create new dynamic endpoints and add them to the RPC server. + alloc := &Alloc{srv: s, ctx: ctx, logger: s.logger.Named("alloc")} + deployment := &Deployment{srv: s, ctx: ctx, logger: s.logger.Named("deployment")} + eval := &Eval{srv: s, ctx: ctx, logger: s.logger.Named("eval")} node := &Node{srv: s, ctx: ctx, logger: s.logger.Named("client")} + plan := &Plan{srv: s, ctx: ctx, logger: s.logger.Named("plan")} // Register the dynamic endpoints + server.Register(alloc) + server.Register(deployment) + server.Register(eval) server.Register(node) + server.Register(plan) } // setupRaft is used to setup and initialize Raft diff --git a/nomad/util.go b/nomad/util.go index 2a5f5b0da..daa6999f8 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -301,3 +301,27 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) { return alloc, nil } + +// validateLocalClientTLSCertificate checks if the provided RPC connection was +// initiated by a client in the same region as the target server. +func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error { + expected := fmt.Sprintf("client.%s.nomad", srv.Region()) + return validateTLSCertificate(srv, ctx, expected) +} + +// validateLocalServerTLSCertificate checks if the provided RPC connection was +// initiated by a server in the same region as the target server. +func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error { + expected := fmt.Sprintf("server.%s.nomad", srv.Region()) + return validateTLSCertificate(srv, ctx, expected) +} + +// validateTLSCertificate checks if the RPC connection mTLS certificates are +// valid for the given name. +func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error { + if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname { + return nil + } + + return ctx.ValidateCertificateForName(name) +}