Move local cluster parameters to atomic values to fix some potential data races (#4036)

This commit is contained in:
Jeff Mitchell 2018-02-23 14:47:07 -05:00 committed by GitHub
parent cb08fb92d2
commit d4a431b298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 66 deletions

View File

@ -1,7 +1,6 @@
package vault
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
@ -91,11 +90,9 @@ func (c *Core) Cluster(ctx context.Context) (*Cluster, error) {
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
defer func() {
if retErr != nil {
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterPrivateKey = nil
c.localClusterParsedCert = nil
c.clusterParamsLock.Unlock()
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
c.requestForwardingConnectionLock.Lock()
c.clearForwardingClients()
@ -126,28 +123,26 @@ func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
}
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
c.localClusterPrivateKey = &ecdsa.PrivateKey{
c.localClusterPrivateKey.Store(&ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P521(),
X: adv.ClusterKeyParams.X,
Y: adv.ClusterKeyParams.Y,
},
D: adv.ClusterKeyParams.D,
}
})
c.localClusterCert = adv.ClusterCert
locCert := make([]byte, len(adv.ClusterCert))
copy(locCert, adv.ClusterCert)
c.localClusterCert.Store(locCert)
cert, err := x509.ParseCertificate(c.localClusterCert)
cert, err := x509.ParseCertificate(adv.ClusterCert)
if err != nil {
c.logger.Error("core: failed parsing local cluster certificate", "error", err)
return fmt.Errorf("error parsing local cluster certificate: %v", err)
}
c.localClusterParsedCert = cert
c.localClusterParsedCert.Store(cert)
return nil
}
@ -210,7 +205,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
// If we're using HA, generate server-to-server parameters
if c.ha != nil {
// Create a private key
if c.localClusterPrivateKey == nil {
if c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey) == nil {
c.logger.Trace("core: generating cluster private key")
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
@ -218,11 +213,11 @@ func (c *Core) setupCluster(ctx context.Context) error {
return err
}
c.localClusterPrivateKey = key
c.localClusterPrivateKey.Store(key)
}
// Create a certificate
if c.localClusterCert == nil {
if c.localClusterCert.Load().([]byte) == nil {
c.logger.Trace("core: generating local cluster certificate")
host, err := uuid.GenerateUUID()
@ -248,7 +243,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
IsCA: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey).Public(), c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey))
if err != nil {
c.logger.Error("core: error generating self-signed cert", "error", err)
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
@ -260,8 +255,8 @@ func (c *Core) setupCluster(ctx context.Context) error {
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
}
c.localClusterCert = certBytes
c.localClusterParsedCert = parsedCert
c.localClusterCert.Store(certBytes)
c.localClusterParsedCert.Store(parsedCert)
}
}
@ -349,24 +344,20 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
switch {
default:
var localCert bytes.Buffer
c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("got forwarding connection but no local cert")
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
//c.logger.Trace("core: performing cert name lookup", "hello_server_name", clientHello.ServerName, "local_cluster_cert_name", parsedCert.Subject.CommonName)
return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
@ -377,22 +368,19 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
if len(requestInfo.AcceptableCAs) != 1 {
return nil, fmt.Errorf("expected only a single acceptable CA")
}
var localCert bytes.Buffer
c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
@ -421,9 +409,7 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
switch {
default:
c.clusterParamsLock.RLock()
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
return nil, fmt.Errorf("forwarding connection client but no local cert")
@ -444,11 +430,10 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
CipherSuites: c.clusterCipherSuites,
}
var localCert bytes.Buffer
c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
currCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
if parsedCert != nil {
tlsConfig.ServerName = parsedCert.Subject.CommonName

View File

@ -2,7 +2,6 @@ package vault
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
@ -307,11 +306,11 @@ type Core struct {
clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members
localClusterPrivateKey crypto.Signer
localClusterPrivateKey *atomic.Value
// The local cluster cert
localClusterCert []byte
localClusterCert *atomic.Value
// The parsed form of the local cluster cert
localClusterParsedCert *x509.Certificate
localClusterParsedCert *atomic.Value
// The TCP addresses we should use for clustering
clusterListenerAddrs []*net.TCPAddr
// The handler to use for request forwarding
@ -497,10 +496,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
rpcServerActive: new(uint32),
atomicPrimaryClusterAddrs: new(atomic.Value),
atomicPrimaryFailoverAddrs: new(atomic.Value),
localClusterPrivateKey: new(atomic.Value),
localClusterCert: new(atomic.Value),
localClusterParsedCert: new(atomic.Value),
activeNodeReplicationState: new(uint32),
}
atomic.StoreUint32(c.replicationState, uint32(consts.ReplicationDRDisabled|consts.ReplicationPerformanceDisabled))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
if conf.ClusterCipherSuites != "" {
suites, err := tlsutil.ParseCiphers(conf.ClusterCipherSuites)
@ -1816,11 +1821,9 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {
// Clear previous local cluster cert info so we generate new. Since the
// UUID will have changed, standbys will know to look for new info
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterParsedCert = nil
c.localClusterPrivateKey = nil
c.clusterParamsLock.Unlock()
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
if err := c.setupCluster(ctx); err != nil {
c.stateLock.Unlock()
@ -2049,12 +2052,12 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
go c.cleanLeaderPrefix(ctx, uuid, leaderLostCh)
var key *ecdsa.PrivateKey
switch c.localClusterPrivateKey.(type) {
switch c.localClusterPrivateKey.Load().(type) {
case *ecdsa.PrivateKey:
key = c.localClusterPrivateKey.(*ecdsa.PrivateKey)
key = c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey)
default:
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey.Load()))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
}
keyParams := &clusterKeyParams{
@ -2064,10 +2067,13 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
D: key.D,
}
locCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(locCert))
copy(localCert, locCert)
adv := &activeAdvertisement{
RedirectAddr: c.redirectAddr,
ClusterAddr: c.clusterAddr,
ClusterCert: c.localClusterCert,
ClusterCert: localCert,
ClusterKeyParams: keyParams,
}
val, err := jsonutil.EncodeJSON(adv)