Verify TLS certificate on endpoints that are used between agents only (#11956)
This commit is contained in:
parent
f6217fe424
commit
c4cff5359f
|
@ -0,0 +1,3 @@
|
|||
```release-note:security
|
||||
server: validate mTLS certificate names on agent to agent endpoints
|
||||
```
|
|
@ -0,0 +1,80 @@
|
|||
rules:
|
||||
# Check potentially unauthenticated RPC endpoints
|
||||
- id: "rpc-potentially-unauthenticated"
|
||||
patterns:
|
||||
- pattern: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := $X.$Y.ResolveToken(...)
|
||||
...
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := $U.requestACLToken(...)
|
||||
...
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := $T.NamespaceValidator(...)
|
||||
...
|
||||
# Pattern used by endpoints called exclusively between agents
|
||||
# (server -> server or client -> server)
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := validateLocalClientTLSCertificate(...)
|
||||
...
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := validateLocalServerTLSCertificate(...)
|
||||
...
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
... := validateTLSCertificate(...)
|
||||
...
|
||||
# Pattern used by some Node endpoints.
|
||||
- pattern-not-inside: |
|
||||
if done, err := $A.$B.forward($METHOD, ...); done {
|
||||
return err
|
||||
}
|
||||
...
|
||||
return $A.deregister(...)
|
||||
...
|
||||
- metavariable-pattern:
|
||||
metavariable: $METHOD
|
||||
patterns:
|
||||
# Endpoints that are expected not to have authentication.
|
||||
- pattern-not: '"ACL.Bootstrap"'
|
||||
- pattern-not: '"ACL.ResolveToken"'
|
||||
- pattern-not: '"ACL.UpsertOneTimeToken"'
|
||||
- pattern-not: '"ACL.ExchangeOneTimeToken"'
|
||||
- pattern-not: '"CSIPlugin.Get"'
|
||||
- pattern-not: '"CSIPlugin.List"'
|
||||
- pattern-not: '"Status.Leader"'
|
||||
- pattern-not: '"Status.Peers"'
|
||||
- pattern-not: '"Status.Version"'
|
||||
message: "RPC method $METHOD appears to be unauthenticated"
|
||||
languages:
|
||||
- "go"
|
||||
severity: "WARNING"
|
||||
paths:
|
||||
include:
|
||||
- "*_endpoint.go"
|
|
@ -20,6 +20,9 @@ import (
|
|||
type Alloc struct {
|
||||
srv *Server
|
||||
logger log.Logger
|
||||
|
||||
// ctx provides context regarding the underlying connection
|
||||
ctx *RPCContext
|
||||
}
|
||||
|
||||
// List is used to list the allocations in the system
|
||||
|
@ -224,6 +227,11 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by a client if TLS is used.
|
||||
if err := validateLocalClientTLSCertificate(a.srv, a.ctx); err != nil {
|
||||
return fmt.Errorf("invalid client connection in region %s: %v", a.srv.Region(), err)
|
||||
}
|
||||
|
||||
allocs := make([]*structs.Allocation, len(args.AllocIDs))
|
||||
|
||||
// Setup the blocking query. We wait for at least one of the requested
|
||||
|
|
|
@ -17,6 +17,9 @@ import (
|
|||
type Deployment struct {
|
||||
srv *Server
|
||||
logger log.Logger
|
||||
|
||||
// ctx provides context regarding the underlying connection
|
||||
ctx *RPCContext
|
||||
}
|
||||
|
||||
// GetDeployment is used to request information about a specific deployment
|
||||
|
@ -506,6 +509,11 @@ func (d *Deployment) Reap(args *structs.DeploymentDeleteRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "deployment", "reap"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(d.srv, d.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", d.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Update via Raft
|
||||
_, index, err := d.srv.raftApply(structs.DeploymentDeleteRequestType, args)
|
||||
if err != nil {
|
||||
|
|
|
@ -24,6 +24,9 @@ const (
|
|||
type Eval struct {
|
||||
srv *Server
|
||||
logger log.Logger
|
||||
|
||||
// ctx provides context regarding the underlying connection
|
||||
ctx *RPCContext
|
||||
}
|
||||
|
||||
// GetEval is used to request information about a specific evaluation
|
||||
|
@ -87,6 +90,11 @@ func (e *Eval) Dequeue(args *structs.EvalDequeueRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "dequeue"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ensure there is at least one scheduler
|
||||
if len(args.Schedulers) == 0 {
|
||||
return fmt.Errorf("dequeue requires at least one scheduler type")
|
||||
|
@ -172,6 +180,11 @@ func (e *Eval) Ack(args *structs.EvalAckRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "ack"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ack the EvalID
|
||||
if err := e.srv.evalBroker.Ack(args.EvalID, args.Token); err != nil {
|
||||
return err
|
||||
|
@ -187,6 +200,11 @@ func (e *Eval) Nack(args *structs.EvalAckRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "nack"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Nack the EvalID
|
||||
if err := e.srv.evalBroker.Nack(args.EvalID, args.Token); err != nil {
|
||||
return err
|
||||
|
@ -202,6 +220,11 @@ func (e *Eval) Update(args *structs.EvalUpdateRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "update"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ensure there is only a single update with token
|
||||
if len(args.Evals) != 1 {
|
||||
return fmt.Errorf("only a single eval can be updated")
|
||||
|
@ -232,6 +255,11 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "create"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ensure there is only a single update with token
|
||||
if len(args.Evals) != 1 {
|
||||
return fmt.Errorf("only a single eval can be created")
|
||||
|
@ -277,6 +305,11 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "reblock"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ensure there is only a single update with token
|
||||
if len(args.Evals) != 1 {
|
||||
return fmt.Errorf("only a single eval can be reblocked")
|
||||
|
@ -319,6 +352,11 @@ func (e *Eval) Reap(args *structs.EvalDeleteRequest,
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "eval", "reap"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(e.srv, e.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", e.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Update via Raft
|
||||
_, index, err := e.srv.raftApply(structs.EvalDeleteRequestType, args)
|
||||
if err != nil {
|
||||
|
|
|
@ -114,8 +114,8 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
|
|||
reply.Warnings = structs.MergeMultierrorWarnings(warnings...)
|
||||
|
||||
// Check job submission permissions
|
||||
var aclObj *acl.ACL
|
||||
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
|
||||
aclObj, err := j.srv.ResolveToken(args.AuthToken)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if aclObj != nil {
|
||||
if !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilitySubmitJob) {
|
||||
|
@ -1879,9 +1879,8 @@ func (j *Job) Dispatch(args *structs.JobDispatchRequest, reply *structs.JobDispa
|
|||
defer metrics.MeasureSince([]string{"nomad", "job", "dispatch"}, time.Now())
|
||||
|
||||
// Check for submit-job permissions
|
||||
var aclObj *acl.ACL
|
||||
var err error
|
||||
if aclObj, err = j.srv.ResolveToken(args.AuthToken); err != nil {
|
||||
aclObj, err := j.srv.ResolveToken(args.AuthToken)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if aclObj != nil && !aclObj.AllowNsOp(args.RequestNamespace(), acl.NamespaceCapabilityDispatchJob) {
|
||||
return structs.ErrPermissionDenied
|
||||
|
|
|
@ -1103,6 +1103,11 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "client", "update_alloc"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by a client if TLS is used.
|
||||
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
|
||||
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
|
||||
}
|
||||
|
||||
// Ensure at least a single alloc
|
||||
if len(args.Alloc) == 0 {
|
||||
return fmt.Errorf("must update at least one allocation")
|
||||
|
@ -1920,6 +1925,11 @@ func (n *Node) EmitEvents(args *structs.EmitNodeEventsRequest, reply *structs.Em
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "client", "emit_events"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by a client if TLS is used.
|
||||
if err := validateLocalClientTLSCertificate(n.srv, n.ctx); err != nil {
|
||||
return fmt.Errorf("invalid client connection in region %s: %v", n.srv.Region(), err)
|
||||
}
|
||||
|
||||
if len(args.NodeEvents) == 0 {
|
||||
return fmt.Errorf("no node events given")
|
||||
}
|
||||
|
|
|
@ -14,6 +14,9 @@ import (
|
|||
type Plan struct {
|
||||
srv *Server
|
||||
logger log.Logger
|
||||
|
||||
// ctx provides context regarding the underlying connection
|
||||
ctx *RPCContext
|
||||
}
|
||||
|
||||
// Submit is used to submit a plan to the leader
|
||||
|
@ -23,6 +26,11 @@ func (p *Plan) Submit(args *structs.PlanRequest, reply *structs.PlanResponse) er
|
|||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "plan", "submit"}, time.Now())
|
||||
|
||||
// Ensure the connection was initiated by another server if TLS is used.
|
||||
if err := validateLocalServerTLSCertificate(p.srv, p.ctx); err != nil {
|
||||
return fmt.Errorf("invalid server connection in region %s: %v", p.srv.Region(), err)
|
||||
}
|
||||
|
||||
if args.Plan == nil {
|
||||
return fmt.Errorf("cannot submit nil plan")
|
||||
}
|
||||
|
|
64
nomad/rpc.go
64
nomad/rpc.go
|
@ -107,6 +107,38 @@ type RPCContext struct {
|
|||
NodeID string
|
||||
}
|
||||
|
||||
// Certificate returns the first certificate available in the chain.
|
||||
func (ctx *RPCContext) Certificate() *x509.Certificate {
|
||||
if ctx == nil || len(ctx.VerifiedChains) == 0 || len(ctx.VerifiedChains[0]) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ctx.VerifiedChains[0][0]
|
||||
}
|
||||
|
||||
// ValidateCertificateForName returns true if the RPC context certificate is valid
|
||||
// for the given domain name.
|
||||
func (ctx *RPCContext) ValidateCertificateForName(name string) error {
|
||||
if ctx == nil || !ctx.TLS {
|
||||
return nil
|
||||
}
|
||||
|
||||
cert := ctx.Certificate()
|
||||
if cert == nil {
|
||||
return errors.New("missing certificate information")
|
||||
}
|
||||
for _, dnsName := range cert.DNSNames {
|
||||
if dnsName == name {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if cert.Subject.CommonName == name {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("certificate not valid for %q", name)
|
||||
}
|
||||
|
||||
// listen is used to listen for incoming RPC connections
|
||||
func (r *rpcHandler) listen(ctx context.Context) {
|
||||
defer close(r.listenerCh)
|
||||
|
@ -838,30 +870,18 @@ func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// defensive conditions: these should have already been enforced by handleConn
|
||||
if rpcCtx == nil || !rpcCtx.TLS {
|
||||
return errors.New("non-TLS connection attempted")
|
||||
}
|
||||
if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 {
|
||||
// this should never happen, as rpcNameAndRegionValidate should have enforced it
|
||||
return errors.New("missing cert info")
|
||||
}
|
||||
|
||||
// check that `server.<region>.nomad` is present in cert
|
||||
expected := "server." + r.Region() + ".nomad"
|
||||
|
||||
cert := rpcCtx.VerifiedChains[0][0]
|
||||
for _, dnsName := range cert.DNSNames {
|
||||
if dnsName == expected {
|
||||
// Certificate is valid for the expected name
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if cert.Subject.CommonName == expected {
|
||||
// Certificate is valid for the expected name
|
||||
return nil
|
||||
err := rpcCtx.ValidateCertificateForName(expected)
|
||||
if err != nil {
|
||||
cert := rpcCtx.Certificate()
|
||||
if cert != nil {
|
||||
err = fmt.Errorf("request certificate is only valid for %s: %v", cert.DNSNames, err)
|
||||
}
|
||||
|
||||
r.logger.Warn("unauthorized raft connection", "remote_addr", rpcCtx.Conn.RemoteAddr(), "required_hostname", expected, "found", cert.DNSNames)
|
||||
return fmt.Errorf("certificate is invalid for expected role or region: %q", expected)
|
||||
return fmt.Errorf("unauthorized raft connection from %s: %v", rpcCtx.Conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
// Certificate is valid for the expected name
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1018,7 +1018,7 @@ func TestRPC_Limits_Streaming(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRPC_TLS_Enforcement(t *testing.T) {
|
||||
func TestRPC_TLS_Enforcement_Raft(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
|
@ -1026,27 +1026,349 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
|
|||
time.Sleep(1 * time.Second)
|
||||
}()
|
||||
|
||||
dir := tmpDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
tlsHelper := newTLSTestHelper(t)
|
||||
defer tlsHelper.cleanup()
|
||||
|
||||
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
|
||||
// When VerifyServerHostname is enabled:
|
||||
// Only local servers can connect to the Raft layer
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
canRaft bool
|
||||
}{
|
||||
{
|
||||
name: "local server",
|
||||
cn: "server.global.nomad",
|
||||
canRaft: true,
|
||||
},
|
||||
{
|
||||
name: "local client",
|
||||
cn: "client.global.nomad",
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other region server",
|
||||
cn: "server.other.nomad",
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other region client",
|
||||
cn: "client.other.nomad",
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "irrelevant cert",
|
||||
cn: "nomad.example.com",
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "globs",
|
||||
cn: "*.global.nomad",
|
||||
canRaft: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
certPath := tlsHelper.newCert(t, tc.cn)
|
||||
|
||||
cfg := &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
|
||||
CertFile: certPath + ".pem",
|
||||
KeyFile: certPath + ".key",
|
||||
}
|
||||
|
||||
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
|
||||
err := tlsHelper.raftRPC(t, tlsHelper.mtlsServer, cfg)
|
||||
|
||||
// the expected error depends on location of failure.
|
||||
// We expect "bad certificate" if connection fails during handshake,
|
||||
// or EOF when connection is closed after RaftRPC byte.
|
||||
if tc.canRaft {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Regexp(t, "(bad certificate|EOF)", err.Error())
|
||||
}
|
||||
})
|
||||
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
|
||||
err := tlsHelper.raftRPC(t, tlsHelper.nonVerifyServer, cfg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPC_TLS_Enforcement_RPC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
//TODO Avoid panics from logging during shutdown
|
||||
time.Sleep(1 * time.Second)
|
||||
}()
|
||||
|
||||
tlsHelper := newTLSTestHelper(t)
|
||||
defer tlsHelper.cleanup()
|
||||
|
||||
standardRPCs := map[string]interface{}{
|
||||
"Status.Ping": struct{}{},
|
||||
}
|
||||
|
||||
localServersOnlyRPCs := map[string]interface{}{
|
||||
"Eval.Update": &structs.EvalUpdateRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Ack": &structs.EvalAckRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Nack": &structs.EvalAckRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Dequeue": &structs.EvalDequeueRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Create": &structs.EvalUpdateRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Reblock": &structs.EvalUpdateRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Eval.Reap": &structs.EvalDeleteRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Plan.Submit": &structs.PlanRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Deployment.Reap": &structs.DeploymentDeleteRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
}
|
||||
|
||||
localClientsOnlyRPCs := map[string]interface{}{
|
||||
"Alloc.GetAllocs": &structs.AllocsGetRequest{
|
||||
QueryOptions: structs.QueryOptions{Region: "global"},
|
||||
},
|
||||
"Node.EmitEvents": &structs.EmitNodeEventsRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
"Node.UpdateAlloc": &structs.AllocUpdateRequest{
|
||||
WriteRequest: structs.WriteRequest{Region: "global"},
|
||||
},
|
||||
}
|
||||
|
||||
// When VerifyServerHostname is enabled:
|
||||
// All servers can make RPC requests
|
||||
// Only local clients can make RPC requests
|
||||
// Some endpoints can only be called server -> server
|
||||
// Some endpoints can only be called client -> server
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
rpcs map[string]interface{}
|
||||
canRPC bool
|
||||
}{
|
||||
// Local server.
|
||||
{
|
||||
name: "local server/standard rpc",
|
||||
cn: "server.global.nomad",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: true,
|
||||
},
|
||||
{
|
||||
name: "local server/servers only rpc",
|
||||
cn: "server.global.nomad",
|
||||
rpcs: localServersOnlyRPCs,
|
||||
canRPC: true,
|
||||
},
|
||||
{
|
||||
name: "local server/clients only rpc",
|
||||
cn: "server.global.nomad",
|
||||
rpcs: localClientsOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
// Local client.
|
||||
{
|
||||
name: "local client/standard rpc",
|
||||
cn: "client.global.nomad",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: true,
|
||||
},
|
||||
{
|
||||
name: "local client/servers only rpc",
|
||||
cn: "client.global.nomad",
|
||||
rpcs: localServersOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{
|
||||
name: "local client/clients only rpc",
|
||||
cn: "client.global.nomad",
|
||||
rpcs: localClientsOnlyRPCs,
|
||||
canRPC: true,
|
||||
},
|
||||
// Other region server.
|
||||
{
|
||||
name: "other region server/standard rpc",
|
||||
cn: "server.other.nomad",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: true,
|
||||
},
|
||||
{
|
||||
name: "other region server/servers only rpc",
|
||||
cn: "server.other.nomad",
|
||||
rpcs: localServersOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{
|
||||
name: "other region server/clients only rpc",
|
||||
cn: "server.other.nomad",
|
||||
rpcs: localClientsOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
// Other region client.
|
||||
{
|
||||
name: "other region client/standard rpc",
|
||||
cn: "client.other.nomad",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{
|
||||
name: "other region client/servers only rpc",
|
||||
cn: "client.other.nomad",
|
||||
rpcs: localServersOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{
|
||||
name: "other region client/clients only rpc",
|
||||
cn: "client.other.nomad",
|
||||
rpcs: localClientsOnlyRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
// Wrong certs.
|
||||
{
|
||||
name: "irrelevant cert",
|
||||
cn: "nomad.example.com",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{
|
||||
name: "globs",
|
||||
cn: "*.global.nomad",
|
||||
rpcs: standardRPCs,
|
||||
canRPC: false,
|
||||
},
|
||||
{},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
certPath := tlsHelper.newCert(t, tc.cn)
|
||||
|
||||
cfg := &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(tlsHelper.dir, "ca.pem"),
|
||||
CertFile: certPath + ".pem",
|
||||
KeyFile: certPath + ".key",
|
||||
}
|
||||
|
||||
for method, arg := range tc.rpcs {
|
||||
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=true", method), func(t *testing.T) {
|
||||
err := tlsHelper.nomadRPC(t, tlsHelper.mtlsServer, cfg, method, arg)
|
||||
|
||||
if tc.canRPC {
|
||||
if err != nil {
|
||||
require.NotContains(t, err, "certificate")
|
||||
}
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "certificate")
|
||||
}
|
||||
})
|
||||
t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) {
|
||||
err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg)
|
||||
if err != nil {
|
||||
require.NotContains(t, err, "certificate")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type tlsTestHelper struct {
|
||||
dir string
|
||||
nodeID int
|
||||
|
||||
mtlsServer *Server
|
||||
mtlsServerCleanup func()
|
||||
nonVerifyServer *Server
|
||||
nonVerifyServerCleanup func()
|
||||
|
||||
caPEM string
|
||||
pk string
|
||||
serverCert string
|
||||
}
|
||||
|
||||
func newTLSTestHelper(t *testing.T) tlsTestHelper {
|
||||
var err error
|
||||
|
||||
h := tlsTestHelper{
|
||||
dir: tmpDir(t),
|
||||
nodeID: 1,
|
||||
}
|
||||
|
||||
// Generate CA certificate and write it to disk.
|
||||
h.caPEM, h.pk, err = tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
|
||||
err = ioutil.WriteFile(filepath.Join(h.dir, "ca.pem"), []byte(h.caPEM), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeID := 1
|
||||
newCert := func(t *testing.T, name string) string {
|
||||
// Generate servers and their certificate.
|
||||
h.serverCert = h.newCert(t, "server.global.nomad")
|
||||
|
||||
h.mtlsServer, h.mtlsServerCleanup = TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(h.dir, "ca.pem"),
|
||||
CertFile: h.serverCert + ".pem",
|
||||
KeyFile: h.serverCert + ".key",
|
||||
}
|
||||
})
|
||||
|
||||
h.nonVerifyServer, h.nonVerifyServerCleanup = TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: false,
|
||||
CAFile: filepath.Join(h.dir, "ca.pem"),
|
||||
CertFile: h.serverCert + ".pem",
|
||||
KeyFile: h.serverCert + ".key",
|
||||
}
|
||||
})
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h tlsTestHelper) cleanup() {
|
||||
h.mtlsServerCleanup()
|
||||
h.nonVerifyServerCleanup()
|
||||
os.RemoveAll(h.dir)
|
||||
}
|
||||
|
||||
func (h tlsTestHelper) newCert(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
|
||||
node := fmt.Sprintf("node%d", nodeID)
|
||||
nodeID++
|
||||
signer, err := tlsutil.ParseSigner(pk)
|
||||
node := fmt.Sprintf("node%d", h.nodeID)
|
||||
h.nodeID++
|
||||
signer, err := tlsutil.ParseSigner(h.pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
|
||||
Signer: signer,
|
||||
CA: caPEM,
|
||||
CA: h.caPEM,
|
||||
Name: name,
|
||||
Days: 5,
|
||||
DNSNames: []string{node + "." + name, name, "localhost"},
|
||||
|
@ -1054,15 +1376,15 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
|
||||
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".pem"), []byte(pem), 0600)
|
||||
require.NoError(t, err)
|
||||
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
|
||||
err = ioutil.WriteFile(filepath.Join(h.dir, node+"-"+name+".key"), []byte(key), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
return filepath.Join(dir, node+"-"+name)
|
||||
}
|
||||
return filepath.Join(h.dir, node+"-"+name)
|
||||
}
|
||||
|
||||
connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
|
||||
func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
|
||||
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -1082,23 +1404,22 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
|
|||
require.NoError(t, tlsConn.Handshake())
|
||||
|
||||
return tlsConn
|
||||
}
|
||||
}
|
||||
|
||||
nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||
conn := connect(t, s, c)
|
||||
func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error {
|
||||
conn := h.connect(t, s, c)
|
||||
defer conn.Close()
|
||||
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
|
||||
require.NoError(t, err)
|
||||
|
||||
codec := pool.NewClientCodec(conn)
|
||||
|
||||
arg := struct{}{}
|
||||
var out struct{}
|
||||
return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out)
|
||||
}
|
||||
return msgpackrpc.CallWithCodec(codec, method, arg, &out)
|
||||
}
|
||||
|
||||
raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||
conn := connect(t, s, c)
|
||||
func (h tlsTestHelper) raftRPC(t *testing.T, s *Server, c *config.TLSConfig) error {
|
||||
conn := h.connect(t, s, c)
|
||||
defer conn.Close()
|
||||
|
||||
_, err := conn.Write([]byte{byte(pool.RpcRaft)})
|
||||
|
@ -1106,129 +1427,6 @@ func TestRPC_TLS_Enforcement(t *testing.T) {
|
|||
|
||||
_, err = doRaftRPC(conn, s.config.NodeName)
|
||||
return err
|
||||
}
|
||||
|
||||
// generate server cert
|
||||
serverCert := newCert(t, "server.global.nomad")
|
||||
|
||||
mtlsS, cleanup := TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: serverCert + ".pem",
|
||||
KeyFile: serverCert + ".key",
|
||||
}
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
nonVerifyS, cleanup := TestServer(t, func(c *Config) {
|
||||
c.TLSConfig = &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: false,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: serverCert + ".pem",
|
||||
KeyFile: serverCert + ".key",
|
||||
}
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
// When VerifyServerHostname is enabled:
|
||||
// Only all servers and local clients can make RPC requests
|
||||
// Only local servers can connect to the Raft layer
|
||||
cases := []struct {
|
||||
name string
|
||||
cn string
|
||||
canRPC bool
|
||||
canRaft bool
|
||||
}{
|
||||
{
|
||||
name: "local server",
|
||||
cn: "server.global.nomad",
|
||||
canRPC: true,
|
||||
canRaft: true,
|
||||
},
|
||||
{
|
||||
name: "local client",
|
||||
cn: "client.global.nomad",
|
||||
canRPC: true,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other region server",
|
||||
cn: "server.other.nomad",
|
||||
canRPC: true,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "other client server",
|
||||
cn: "client.other.nomad",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "irrelevant cert",
|
||||
cn: "nomad.example.com",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
{
|
||||
name: "globs",
|
||||
cn: "*.global.nomad",
|
||||
canRPC: false,
|
||||
canRaft: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
certPath := newCert(t, tc.cn)
|
||||
|
||||
cfg := &config.TLSConfig{
|
||||
EnableRPC: true,
|
||||
VerifyServerHostname: true,
|
||||
CAFile: filepath.Join(dir, "ca.pem"),
|
||||
CertFile: certPath + ".pem",
|
||||
KeyFile: certPath + ".key",
|
||||
}
|
||||
|
||||
t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) {
|
||||
err := nomadRPC(t, mtlsS, cfg)
|
||||
|
||||
if tc.canRPC {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bad certificate")
|
||||
}
|
||||
})
|
||||
t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) {
|
||||
err := nomadRPC(t, nonVerifyS, cfg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
|
||||
err := raftRPC(t, mtlsS, cfg)
|
||||
|
||||
// the expected error depends on location of failure.
|
||||
// We expect "bad certificate" if connection fails during handshake,
|
||||
// or EOF when connection is closed after RaftRPC byte.
|
||||
if tc.canRaft {
|
||||
require.NoError(t, err)
|
||||
} else if !tc.canRPC {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bad certificate")
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "EOF")
|
||||
}
|
||||
})
|
||||
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
|
||||
err := raftRPC(t, nonVerifyS, cfg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func doRaftRPC(conn net.Conn, leader string) (*raft.AppendEntriesResponse, error) {
|
||||
|
|
|
@ -265,9 +265,6 @@ type endpoints struct {
|
|||
Status *Status
|
||||
Node *Node
|
||||
Job *Job
|
||||
Eval *Eval
|
||||
Plan *Plan
|
||||
Alloc *Alloc
|
||||
CSIVolume *CSIVolume
|
||||
CSIPlugin *CSIPlugin
|
||||
Deployment *Deployment
|
||||
|
@ -1151,18 +1148,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||
if s.staticEndpoints.Status == nil {
|
||||
// Initialize the list just once
|
||||
s.staticEndpoints.ACL = &ACL{srv: s, logger: s.logger.Named("acl")}
|
||||
s.staticEndpoints.Alloc = &Alloc{srv: s, logger: s.logger.Named("alloc")}
|
||||
s.staticEndpoints.Eval = &Eval{srv: s, logger: s.logger.Named("eval")}
|
||||
s.staticEndpoints.Job = NewJobEndpoints(s)
|
||||
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")} // Add but don't register
|
||||
s.staticEndpoints.CSIVolume = &CSIVolume{srv: s, logger: s.logger.Named("csi_volume")}
|
||||
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
|
||||
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
|
||||
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
|
||||
s.staticEndpoints.Operator.register()
|
||||
|
||||
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
|
||||
s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")}
|
||||
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}
|
||||
s.staticEndpoints.Scaling = &Scaling{srv: s, logger: s.logger.Named("scaling")}
|
||||
s.staticEndpoints.Status = &Status{srv: s, logger: s.logger.Named("status")}
|
||||
|
@ -1171,6 +1163,13 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||
s.staticEndpoints.Namespace = &Namespace{srv: s}
|
||||
s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s)
|
||||
|
||||
// These endpoints are dynamic because they need access to the
|
||||
// RPCContext, but they also need to be called directly in some cases,
|
||||
// so store them into staticEndpoints for later access, but don't
|
||||
// register them as static.
|
||||
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
|
||||
s.staticEndpoints.Node = &Node{srv: s, logger: s.logger.Named("client")}
|
||||
|
||||
// Client endpoints
|
||||
s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")}
|
||||
s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")}
|
||||
|
@ -1191,15 +1190,11 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||
|
||||
// Register the static handlers
|
||||
server.Register(s.staticEndpoints.ACL)
|
||||
server.Register(s.staticEndpoints.Alloc)
|
||||
server.Register(s.staticEndpoints.Eval)
|
||||
server.Register(s.staticEndpoints.Job)
|
||||
server.Register(s.staticEndpoints.CSIVolume)
|
||||
server.Register(s.staticEndpoints.CSIPlugin)
|
||||
server.Register(s.staticEndpoints.Deployment)
|
||||
server.Register(s.staticEndpoints.Operator)
|
||||
server.Register(s.staticEndpoints.Periodic)
|
||||
server.Register(s.staticEndpoints.Plan)
|
||||
server.Register(s.staticEndpoints.Region)
|
||||
server.Register(s.staticEndpoints.Scaling)
|
||||
server.Register(s.staticEndpoints.Status)
|
||||
|
@ -1214,10 +1209,18 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
|||
server.Register(s.staticEndpoints.Namespace)
|
||||
|
||||
// Create new dynamic endpoints and add them to the RPC server.
|
||||
alloc := &Alloc{srv: s, ctx: ctx, logger: s.logger.Named("alloc")}
|
||||
deployment := &Deployment{srv: s, ctx: ctx, logger: s.logger.Named("deployment")}
|
||||
eval := &Eval{srv: s, ctx: ctx, logger: s.logger.Named("eval")}
|
||||
node := &Node{srv: s, ctx: ctx, logger: s.logger.Named("client")}
|
||||
plan := &Plan{srv: s, ctx: ctx, logger: s.logger.Named("plan")}
|
||||
|
||||
// Register the dynamic endpoints
|
||||
server.Register(alloc)
|
||||
server.Register(deployment)
|
||||
server.Register(eval)
|
||||
server.Register(node)
|
||||
server.Register(plan)
|
||||
}
|
||||
|
||||
// setupRaft is used to setup and initialize Raft
|
||||
|
|
|
@ -301,3 +301,27 @@ func getAlloc(state AllocGetter, allocID string) (*structs.Allocation, error) {
|
|||
|
||||
return alloc, nil
|
||||
}
|
||||
|
||||
// validateLocalClientTLSCertificate checks if the provided RPC connection was
|
||||
// initiated by a client in the same region as the target server.
|
||||
func validateLocalClientTLSCertificate(srv *Server, ctx *RPCContext) error {
|
||||
expected := fmt.Sprintf("client.%s.nomad", srv.Region())
|
||||
return validateTLSCertificate(srv, ctx, expected)
|
||||
}
|
||||
|
||||
// validateLocalServerTLSCertificate checks if the provided RPC connection was
|
||||
// initiated by a server in the same region as the target server.
|
||||
func validateLocalServerTLSCertificate(srv *Server, ctx *RPCContext) error {
|
||||
expected := fmt.Sprintf("server.%s.nomad", srv.Region())
|
||||
return validateTLSCertificate(srv, ctx, expected)
|
||||
}
|
||||
|
||||
// validateTLSCertificate checks if the RPC connection mTLS certificates are
|
||||
// valid for the given name.
|
||||
func validateTLSCertificate(srv *Server, ctx *RPCContext, name string) error {
|
||||
if srv.config.TLSConfig == nil || !srv.config.TLSConfig.VerifyServerHostname {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ctx.ValidateCertificateForName(name)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue