rpc: authorize raft requests (#10925)

This commit is contained in:
Evan Culver 2021-08-27 00:04:32 +02:00 committed by GitHub
parent a758581ab6
commit 93f94ac24f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 365 additions and 33 deletions

3
.changelog/10925.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:security
rpc: authorize raft requests [CVE-2021-37219](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37219)
```

View File

@ -6,9 +6,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/raft"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/raft"
) )
// RaftLayer implements the raft.StreamLayer interface, // RaftLayer implements the raft.StreamLayer interface,

View File

@ -194,8 +194,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
s.handleConsulConn(conn) s.handleConsulConn(conn)
case pool.RPCRaft: case pool.RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1) s.handleRaftRPC(conn)
s.raftLayer.Handoff(conn)
case pool.RPCTLS: case pool.RPCTLS:
// Don't allow malicious client to create TLS-in-TLS for ever. // Don't allow malicious client to create TLS-in-TLS for ever.
@ -283,8 +282,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
s.handleConsulConn(tlsConn) s.handleConsulConn(tlsConn)
case pool.ALPN_RPCRaft: case pool.ALPN_RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1) s.handleRaftRPC(tlsConn)
s.raftLayer.Handoff(tlsConn)
case pool.ALPN_RPCMultiplexV2: case pool.ALPN_RPCMultiplexV2:
s.handleMultiplexV2(tlsConn) s.handleMultiplexV2(tlsConn)
@ -455,6 +453,20 @@ func (s *Server) handleSnapshotConn(conn net.Conn) {
}() }()
} }
func (s *Server) handleRaftRPC(conn net.Conn) {
if tlsConn, ok := conn.(*tls.Conn); ok {
err := s.tlsConfigurator.AuthorizeServerConn(s.config.Datacenter, tlsConn)
if err != nil {
s.rpcLogger().Warn(err.Error(), "from", conn.RemoteAddr(), "operation", "raft RPC")
conn.Close()
return
}
}
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(conn)
}
func (s *Server) handleALPN_WANGossipPacketStream(conn net.Conn) error { func (s *Server) handleALPN_WANGossipPacketStream(conn net.Conn) error {
defer conn.Close() defer conn.Close()

View File

@ -1,32 +1,43 @@
package consul package consul
import ( import (
"bufio"
"bytes" "bytes"
"crypto/x509"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"os" "os"
"path/filepath"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-msgpack/codec"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
tokenStore "github.com/hashicorp/consul/agent/token" tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil"
) )
func TestRPC_NoLeader_Fail(t *testing.T) { func TestRPC_NoLeader_Fail(t *testing.T) {
@ -681,10 +692,10 @@ func TestRPC_RPCMaxConnsPerClient(t *testing.T) {
magicByte pool.RPCType magicByte pool.RPCType
tlsEnabled bool tlsEnabled bool
}{ }{
{"RPC", pool.RPCMultiplexV2, false}, {"RPC v2", pool.RPCMultiplexV2, false},
{"RPC TLS", pool.RPCMultiplexV2, true}, {"RPC v2 TLS", pool.RPCMultiplexV2, true},
{"Raft", pool.RPCRaft, false}, {"RPC", pool.RPCConsul, false},
{"Raft TLS", pool.RPCRaft, true}, {"RPC TLS", pool.RPCConsul, true},
} }
for _, tc := range cases { for _, tc := range cases {
@ -1059,3 +1070,263 @@ func (r isReadRequest) IsRead() bool {
func (r isReadRequest) HasTimedOut(since time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool { func (r isReadRequest) HasTimedOut(since time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool {
return false return false
} }
func TestRPC_AuthorizeRaftRPC(t *testing.T) {
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "consul"})
require.NoError(t, err)
dir := testutil.TempDir(t, "certs")
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
require.NoError(t, err)
newCert := func(t *testing.T, caPEM, pk, node, name string) {
t.Helper()
signer, err := tlsutil.ParseSigner(pk)
require.NoError(t, err)
pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
})
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)
}
newCert(t, caPEM, pk, "srv1", "server.dc1.consul")
_, connectCApk, err := connect.GeneratePrivateKey()
require.NoError(t, err)
_, srv := testServerWithConfig(t, func(c *Config) {
c.TLSConfig.Domain = "consul." // consul. is the default value in agent/config
c.TLSConfig.CAFile = filepath.Join(dir, "ca.pem")
c.TLSConfig.CertFile = filepath.Join(dir, "srv1-server.dc1.consul.pem")
c.TLSConfig.KeyFile = filepath.Join(dir, "srv1-server.dc1.consul.key")
c.TLSConfig.VerifyIncoming = true
c.TLSConfig.VerifyServerHostname = true
// Enable Auto-Encrypt so that Conenct CA roots are added to the
// tlsutil.Configurator.
c.AutoEncryptAllowTLS = true
c.CAConfig = &structs.CAConfiguration{
ClusterID: connect.TestClusterID,
Provider: structs.ConsulCAProvider,
Config: map[string]interface{}{"PrivateKey": connectCApk},
}
})
defer srv.Shutdown()
// Wait for ConnectCA initiation to complete.
retry.Run(t, func(r *retry.R) {
_, root := srv.caManager.getCAProvider()
if root == nil {
r.Fatal("ConnectCA root is still nil")
}
})
useTLSByte := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := tlsutil.SpecificDC("dc1", c.OutgoingRPCWrapper())
tlsEnabled := func(_ raft.ServerAddress) bool {
return true
}
rl := NewRaftLayer(nil, nil, wrapper, tlsEnabled)
conn, err := rl.Dial(raft.ServerAddress(srv.Listener.Addr().String()), 100*time.Millisecond)
require.NoError(t, err)
return conn
}
useNativeTLS := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := c.OutgoingALPNRPCWrapper()
dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
rawConn, err := dialer.Dial("tcp", srv.Listener.Addr().String())
require.NoError(t, err)
tlsConn, err := wrapper("dc1", "srv1", pool.ALPN_RPCRaft, rawConn)
require.NoError(t, err)
return tlsConn
}
setupAgentTLSCert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
newCert(t, caPEM, pk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}
setupConnectCACert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
_, caRoot := srv.caManager.getCAProvider()
newCert(t, caRoot.RootCert, connectCApk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}
type testCase struct {
name string
conn func(t *testing.T, c *tlsutil.Configurator) net.Conn
setupCert func(t *testing.T) string
expectError bool
}
run := func(t *testing.T, tc testCase) {
certPath := tc.setupCert(t)
cfg := tlsutil.Config{
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
Domain: "consul",
}
c, err := tlsutil.NewConfigurator(cfg, hclog.New(nil))
require.NoError(t, err)
_, err = doRaftRPC(tc.conn(t, c), srv.config.NodeName)
if tc.expectError {
if !isConnectionClosedError(err) {
t.Fatalf("expected a connection closed error, got: %v", err)
}
return
}
require.NoError(t, err)
}
var testCases = []testCase{
{
name: "TLS byte with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useTLSByte,
},
{
name: "TLS byte with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "native TLS with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useNativeTLS,
},
{
name: "native TLS with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func doRaftRPC(conn net.Conn, leader string) (raft.AppendEntriesResponse, error) {
var resp raft.AppendEntriesResponse
var term uint64 = 0xc
a := raft.AppendEntriesRequest{
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
Term: 0,
Leader: []byte(leader),
PrevLogEntry: 0,
PrevLogTerm: term,
LeaderCommitIndex: 50,
}
if err := appendEntries(conn, a, &resp); err != nil {
return resp, err
}
return resp, nil
}
func appendEntries(conn net.Conn, req raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error {
w := bufio.NewWriter(conn)
enc := codec.NewEncoder(w, &codec.MsgpackHandle{})
const rpcAppendEntries = 0
if err := w.WriteByte(rpcAppendEntries); err != nil {
return fmt.Errorf("failed to write raft-RPC byte: %w", err)
}
if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to send append entries RPC: %w", err)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("failed to flush RPC: %w", err)
}
if err := decodeRaftRPCResponse(conn, resp); err != nil {
return fmt.Errorf("response error: %w", err)
}
return nil
}
// copied and modified from raft/net_transport.go
func decodeRaftRPCResponse(conn net.Conn, resp *raft.AppendEntriesResponse) error {
r := bufio.NewReader(conn)
dec := codec.NewDecoder(r, &codec.MsgpackHandle{})
var rpcError string
if err := dec.Decode(&rpcError); err != nil {
return fmt.Errorf("failed to decode response error: %w", err)
}
if err := dec.Decode(resp); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
if rpcError != "" {
return fmt.Errorf("rpc error: %v", rpcError)
}
return nil
}
func isConnectionClosedError(err error) bool {
switch {
case err == nil:
return false
case errors.Is(err, io.EOF):
return true
case strings.Contains(err.Error(), "connection reset by peer"):
return true
default:
return false
}
}

View File

@ -153,21 +153,20 @@ func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper {
// autoTLS stores configuration that is received from the auto-encrypt or // autoTLS stores configuration that is received from the auto-encrypt or
// auto-config features. // auto-config features.
type autoTLS struct { type autoTLS struct {
manualCAPems []string extraCAPems []string
connectCAPems []string connectCAPems []string
cert *tls.Certificate cert *tls.Certificate
verifyServerHostname bool verifyServerHostname bool
} }
func (a autoTLS) caPems() []string {
return append(a.manualCAPems, a.connectCAPems...)
}
// manual stores the TLS CA and cert received from Configurator.Update which // manual stores the TLS CA and cert received from Configurator.Update which
// generally comes from the agent configuration. // generally comes from the agent configuration.
type manual struct { type manual struct {
caPems []string caPems []string
cert *tls.Certificate cert *tls.Certificate
// caPool containing only the caPems. This CertPool should be used instead of
// the Configurator.caPool when only the Agent TLS CA is allowed.
caPool *x509.CertPool
} }
// Configurator provides tls.Config and net.Dial wrappers to enable TLS for // Configurator provides tls.Config and net.Dial wrappers to enable TLS for
@ -215,13 +214,6 @@ func NewConfigurator(config Config, logger hclog.Logger) (*Configurator, error)
return c, nil return c, nil
} }
// CAPems returns the currently loaded CAs in PEM format.
func (c *Configurator) CAPems() []string {
c.lock.RLock()
defer c.lock.RUnlock()
return append(c.manual.caPems, c.autoTLS.caPems()...)
}
// ManualCAPems returns the currently loaded CAs in PEM format. // ManualCAPems returns the currently loaded CAs in PEM format.
func (c *Configurator) ManualCAPems() []string { func (c *Configurator) ManualCAPems() []string {
c.lock.RLock() c.lock.RLock()
@ -244,17 +236,23 @@ func (c *Configurator) Update(config Config) error {
if err != nil { if err != nil {
return err return err
} }
pool, err := pool(append(pems, c.autoTLS.caPems()...)) caPool, err := newX509CertPool(pems, c.autoTLS.extraCAPems, c.autoTLS.connectCAPems)
if err != nil { if err != nil {
return err return err
} }
if err = validateConfig(config, pool, cert); err != nil { if err = validateConfig(config, caPool, cert); err != nil {
return err return err
} }
manualCAPool, err := newX509CertPool(pems)
if err != nil {
return err
}
c.base = &config c.base = &config
c.manual.cert = cert c.manual.cert = cert
c.manual.caPems = pems c.manual.caPems = pems
c.caPool = pool c.manual.caPool = manualCAPool
c.caPool = caPool
atomic.AddUint64(&c.version, 1) atomic.AddUint64(&c.version, 1)
c.log("Update") c.log("Update")
return nil return nil
@ -268,7 +266,7 @@ func (c *Configurator) UpdateAutoTLSCA(connectCAPems []string) error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
pool, err := pool(append(c.manual.caPems, append(c.autoTLS.manualCAPems, connectCAPems...)...)) pool, err := newX509CertPool(c.manual.caPems, c.autoTLS.extraCAPems, connectCAPems)
if err != nil { if err != nil {
return err return err
} }
@ -309,11 +307,11 @@ func (c *Configurator) UpdateAutoTLS(manualCAPems, connectCAPems []string, pub,
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
pool, err := pool(append(c.manual.caPems, append(manualCAPems, connectCAPems...)...)) pool, err := newX509CertPool(c.manual.caPems, manualCAPems, connectCAPems)
if err != nil { if err != nil {
return err return err
} }
c.autoTLS.manualCAPems = manualCAPems c.autoTLS.extraCAPems = manualCAPems
c.autoTLS.connectCAPems = connectCAPems c.autoTLS.connectCAPems = connectCAPems
c.autoTLS.cert = &cert c.autoTLS.cert = &cert
c.caPool = pool c.caPool = pool
@ -346,11 +344,21 @@ func (c *Configurator) Base() Config {
return *c.base return *c.base
} }
func pool(pems []string) (*x509.CertPool, error) { // newX509CertPool loads all the groups of PEM encoded certificates into a
// single x509.CertPool.
//
// The groups argument is a varargs of slices so that callers do not need to
// append slices together. In some cases append can modify the backing array
// of the first slice passed to append, which will often result in hard to
// find bugs. By accepting a varargs of slices we remove the need for the
// caller to append the groups, which should prevent any such bugs.
func newX509CertPool(groups ...[]string) (*x509.CertPool, error) {
pool := x509.NewCertPool() pool := x509.NewCertPool()
for _, pem := range pems { for _, group := range groups {
for _, pem := range group {
if !pool.AppendCertsFromPEM([]byte(pem)) { if !pool.AppendCertsFromPEM([]byte(pem)) {
return nil, fmt.Errorf("Couldn't parse PEM %s", pem) return nil, fmt.Errorf("failed to parse PEM %s", pem)
}
} }
} }
if len(pool.Subjects()) == 0 { if len(pool.Subjects()) == 0 {
@ -888,6 +896,43 @@ func (c *Configurator) wrapALPNTLSClient(dc, nodeName, alpnProto string, conn ne
return tlsConn, nil return tlsConn, nil
} }
// AuthorizeServerConn is used to validate that the connection is being established
// by a Consul server in the same datacenter.
//
// The identity of the connection is checked by verifying that the certificate
// presented is signed by the Agent TLS CA, and has a DNSName that matches the
// local ServerSNI name.
//
// Note this check is only performed if VerifyServerHostname is enabled, otherwise
// it does no authorization.
func (c *Configurator) AuthorizeServerConn(dc string, conn *tls.Conn) error {
if !c.VerifyServerHostname() {
return nil
}
c.lock.RLock()
caPool := c.manual.caPool
c.lock.RUnlock()
expected := c.ServerSNI(dc, "")
for _, chain := range conn.ConnectionState().VerifiedChains {
if len(chain) == 0 {
continue
}
clientCert := chain[0]
_, err := clientCert.Verify(x509.VerifyOptions{
DNSName: expected,
Roots: caPool,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
})
if err == nil {
return nil
}
c.logger.Debug("AuthorizeServerConn failed certificate validation", "error", err)
}
return fmt.Errorf("a TLS certificate with a CommonName of %v is required for this operation", expected)
}
// ParseCiphers parse ciphersuites from the comma-separated string into // ParseCiphers parse ciphersuites from the comma-separated string into
// recognized slice // recognized slice
func ParseCiphers(cipherStr string) ([]uint16, error) { func ParseCiphers(cipherStr string) ([]uint16, error) {

View File

@ -520,7 +520,7 @@ func TestConfigurator_ErrorPropagation(t *testing.T) {
require.NoError(t, err, info) require.NoError(t, err, info)
pems, err := LoadCAs(v.config.CAFile, v.config.CAPath) pems, err := LoadCAs(v.config.CAFile, v.config.CAPath)
require.NoError(t, err, info) require.NoError(t, err, info)
pool, err := pool(pems) pool, err := newX509CertPool(pems)
require.NoError(t, err, info) require.NoError(t, err, info)
err3 = validateConfig(v.config, pool, cert) err3 = validateConfig(v.config, pool, cert)
} }
@ -579,7 +579,7 @@ func TestConfigurator_LoadCAs(t *testing.T) {
} }
for i, v := range variants { for i, v := range variants {
pems, err1 := LoadCAs(v.cafile, v.capath) pems, err1 := LoadCAs(v.cafile, v.capath)
pool, err2 := pool(pems) pool, err2 := newX509CertPool(pems)
info := fmt.Sprintf("case %d", i) info := fmt.Sprintf("case %d", i)
if v.shouldErr { if v.shouldErr {
if err1 == nil && err2 == nil { if err1 == nil && err2 == nil {