Move local cluster parameters to atomic values to fix some potential data races (#4036)
This commit is contained in:
parent
cb08fb92d2
commit
d4a431b298
|
@ -1,7 +1,6 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
|
@ -91,11 +90,9 @@ func (c *Core) Cluster(ctx context.Context) (*Cluster, error) {
|
|||
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
c.clusterParamsLock.Lock()
|
||||
c.localClusterCert = nil
|
||||
c.localClusterPrivateKey = nil
|
||||
c.localClusterParsedCert = nil
|
||||
c.clusterParamsLock.Unlock()
|
||||
c.localClusterCert.Store(([]byte)(nil))
|
||||
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
|
||||
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
|
||||
|
||||
c.requestForwardingConnectionLock.Lock()
|
||||
c.clearForwardingClients()
|
||||
|
@ -126,28 +123,26 @@ func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
|
|||
|
||||
}
|
||||
|
||||
// Prevent data races with the TLS parameters
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
c.localClusterPrivateKey = &ecdsa.PrivateKey{
|
||||
c.localClusterPrivateKey.Store(&ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P521(),
|
||||
X: adv.ClusterKeyParams.X,
|
||||
Y: adv.ClusterKeyParams.Y,
|
||||
},
|
||||
D: adv.ClusterKeyParams.D,
|
||||
}
|
||||
})
|
||||
|
||||
c.localClusterCert = adv.ClusterCert
|
||||
locCert := make([]byte, len(adv.ClusterCert))
|
||||
copy(locCert, adv.ClusterCert)
|
||||
c.localClusterCert.Store(locCert)
|
||||
|
||||
cert, err := x509.ParseCertificate(c.localClusterCert)
|
||||
cert, err := x509.ParseCertificate(adv.ClusterCert)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed parsing local cluster certificate", "error", err)
|
||||
return fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
c.localClusterParsedCert = cert
|
||||
c.localClusterParsedCert.Store(cert)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -210,7 +205,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
// If we're using HA, generate server-to-server parameters
|
||||
if c.ha != nil {
|
||||
// Create a private key
|
||||
if c.localClusterPrivateKey == nil {
|
||||
if c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey) == nil {
|
||||
c.logger.Trace("core: generating cluster private key")
|
||||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
if err != nil {
|
||||
|
@ -218,11 +213,11 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
c.localClusterPrivateKey = key
|
||||
c.localClusterPrivateKey.Store(key)
|
||||
}
|
||||
|
||||
// Create a certificate
|
||||
if c.localClusterCert == nil {
|
||||
if c.localClusterCert.Load().([]byte) == nil {
|
||||
c.logger.Trace("core: generating local cluster certificate")
|
||||
|
||||
host, err := uuid.GenerateUUID()
|
||||
|
@ -248,7 +243,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
IsCA: true,
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey).Public(), c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey))
|
||||
if err != nil {
|
||||
c.logger.Error("core: error generating self-signed cert", "error", err)
|
||||
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
|
||||
|
@ -260,8 +255,8 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
|
||||
}
|
||||
|
||||
c.localClusterCert = certBytes
|
||||
c.localClusterParsedCert = parsedCert
|
||||
c.localClusterCert.Store(certBytes)
|
||||
c.localClusterParsedCert.Store(parsedCert)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -349,24 +344,20 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
|
|||
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
switch {
|
||||
default:
|
||||
var localCert bytes.Buffer
|
||||
|
||||
c.clusterParamsLock.RLock()
|
||||
localCert.Write(c.localClusterCert)
|
||||
localSigner := c.localClusterPrivateKey
|
||||
parsedCert := c.localClusterParsedCert
|
||||
c.clusterParamsLock.RUnlock()
|
||||
|
||||
if localCert.Len() == 0 {
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, fmt.Errorf("got forwarding connection but no local cert")
|
||||
}
|
||||
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
//c.logger.Trace("core: performing cert name lookup", "hello_server_name", clientHello.ServerName, "local_cluster_cert_name", parsedCert.Subject.CommonName)
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert.Bytes()},
|
||||
PrivateKey: localSigner,
|
||||
Leaf: parsedCert,
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
@ -377,22 +368,19 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
|
|||
if len(requestInfo.AcceptableCAs) != 1 {
|
||||
return nil, fmt.Errorf("expected only a single acceptable CA")
|
||||
}
|
||||
var localCert bytes.Buffer
|
||||
|
||||
c.clusterParamsLock.RLock()
|
||||
localCert.Write(c.localClusterCert)
|
||||
localSigner := c.localClusterPrivateKey
|
||||
parsedCert := c.localClusterParsedCert
|
||||
c.clusterParamsLock.RUnlock()
|
||||
|
||||
if localCert.Len() == 0 {
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
||||
}
|
||||
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert.Bytes()},
|
||||
PrivateKey: localSigner,
|
||||
Leaf: parsedCert,
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -421,9 +409,7 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
|
|||
|
||||
switch {
|
||||
default:
|
||||
c.clusterParamsLock.RLock()
|
||||
parsedCert := c.localClusterParsedCert
|
||||
c.clusterParamsLock.RUnlock()
|
||||
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
|
||||
if parsedCert == nil {
|
||||
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
||||
|
@ -444,11 +430,10 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
|
|||
CipherSuites: c.clusterCipherSuites,
|
||||
}
|
||||
|
||||
var localCert bytes.Buffer
|
||||
c.clusterParamsLock.RLock()
|
||||
localCert.Write(c.localClusterCert)
|
||||
parsedCert := c.localClusterParsedCert
|
||||
c.clusterParamsLock.RUnlock()
|
||||
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
if parsedCert != nil {
|
||||
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
||||
|
|
|
@ -2,7 +2,6 @@ package vault
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
|
@ -307,11 +306,11 @@ type Core struct {
|
|||
clusterParamsLock sync.RWMutex
|
||||
// The private key stored in the barrier used for establishing
|
||||
// mutually-authenticated connections between Vault cluster members
|
||||
localClusterPrivateKey crypto.Signer
|
||||
localClusterPrivateKey *atomic.Value
|
||||
// The local cluster cert
|
||||
localClusterCert []byte
|
||||
localClusterCert *atomic.Value
|
||||
// The parsed form of the local cluster cert
|
||||
localClusterParsedCert *x509.Certificate
|
||||
localClusterParsedCert *atomic.Value
|
||||
// The TCP addresses we should use for clustering
|
||||
clusterListenerAddrs []*net.TCPAddr
|
||||
// The handler to use for request forwarding
|
||||
|
@ -497,10 +496,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
rpcServerActive: new(uint32),
|
||||
atomicPrimaryClusterAddrs: new(atomic.Value),
|
||||
atomicPrimaryFailoverAddrs: new(atomic.Value),
|
||||
localClusterPrivateKey: new(atomic.Value),
|
||||
localClusterCert: new(atomic.Value),
|
||||
localClusterParsedCert: new(atomic.Value),
|
||||
activeNodeReplicationState: new(uint32),
|
||||
}
|
||||
|
||||
atomic.StoreUint32(c.replicationState, uint32(consts.ReplicationDRDisabled|consts.ReplicationPerformanceDisabled))
|
||||
c.localClusterCert.Store(([]byte)(nil))
|
||||
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
|
||||
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
|
||||
|
||||
if conf.ClusterCipherSuites != "" {
|
||||
suites, err := tlsutil.ParseCiphers(conf.ClusterCipherSuites)
|
||||
|
@ -1816,11 +1821,9 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {
|
|||
|
||||
// Clear previous local cluster cert info so we generate new. Since the
|
||||
// UUID will have changed, standbys will know to look for new info
|
||||
c.clusterParamsLock.Lock()
|
||||
c.localClusterCert = nil
|
||||
c.localClusterParsedCert = nil
|
||||
c.localClusterPrivateKey = nil
|
||||
c.clusterParamsLock.Unlock()
|
||||
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
|
||||
c.localClusterCert.Store(([]byte)(nil))
|
||||
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
|
||||
|
||||
if err := c.setupCluster(ctx); err != nil {
|
||||
c.stateLock.Unlock()
|
||||
|
@ -2049,12 +2052,12 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
|
|||
go c.cleanLeaderPrefix(ctx, uuid, leaderLostCh)
|
||||
|
||||
var key *ecdsa.PrivateKey
|
||||
switch c.localClusterPrivateKey.(type) {
|
||||
switch c.localClusterPrivateKey.Load().(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
key = c.localClusterPrivateKey.(*ecdsa.PrivateKey)
|
||||
key = c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey)
|
||||
default:
|
||||
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey))
|
||||
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
|
||||
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey.Load()))
|
||||
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
|
||||
}
|
||||
|
||||
keyParams := &clusterKeyParams{
|
||||
|
@ -2064,10 +2067,13 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
|
|||
D: key.D,
|
||||
}
|
||||
|
||||
locCert := c.localClusterCert.Load().([]byte)
|
||||
localCert := make([]byte, len(locCert))
|
||||
copy(localCert, locCert)
|
||||
adv := &activeAdvertisement{
|
||||
RedirectAddr: c.redirectAddr,
|
||||
ClusterAddr: c.clusterAddr,
|
||||
ClusterCert: c.localClusterCert,
|
||||
ClusterCert: localCert,
|
||||
ClusterKeyParams: keyParams,
|
||||
}
|
||||
val, err := jsonutil.EncodeJSON(adv)
|
||||
|
|
Loading…
Reference in New Issue