From ed147b7ae7c057dde8daa139cc05989dab30917b Mon Sep 17 00:00:00 2001 From: ncabatoff Date: Tue, 3 Sep 2019 11:59:56 -0400 Subject: [PATCH] Make clusterListener an atomic.Value to avoid races with getGRPCDialer. (#7408) --- vault/cluster.go | 29 ++++++++++++++++++++--------- vault/cluster_test.go | 10 +++++----- vault/core.go | 9 ++++++--- vault/raft.go | 6 +++--- vault/request_forwarding.go | 22 ++++++++++++---------- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/vault/cluster.go b/vault/cluster.go index 3e5a30203..983afdbb8 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -289,7 +289,7 @@ func (c *Core) startClusterListener(ctx context.Context) error { return nil } - if c.clusterListener != nil { + if c.getClusterListener() != nil { c.logger.Warn("cluster listener is already started") return nil } @@ -301,15 +301,15 @@ func (c *Core) startClusterListener(ctx context.Context) error { c.logger.Debug("starting cluster listeners") - c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")) + c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))) - err := c.clusterListener.Run(ctx) + err := c.getClusterListener().Run(ctx) if err != nil { return err } if strings.HasSuffix(c.ClusterAddr(), ":0") { // If we listened on port 0, record the port the OS gave us. - c.clusterAddr.Store(fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0])) + c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0])) } return nil } @@ -318,18 +318,28 @@ func (c *Core) ClusterAddr() string { return c.clusterAddr.Load().(string) } +func (c *Core) getClusterListener() *cluster.Listener { + cl := c.clusterListener.Load() + if cl == nil { + return nil + } + return cl.(*cluster.Listener) +} + // stopClusterListener stops any existing listeners during seal. It is // assumed that the state lock is held while this is run. func (c *Core) stopClusterListener() { - if c.clusterListener == nil { + clusterListener := c.getClusterListener() + if clusterListener == nil { c.logger.Debug("clustering disabled, not stopping listeners") return } c.logger.Info("stopping cluster listeners") - c.clusterListener.Stop() - c.clusterListener = nil + clusterListener.Stop() + var nilCL *cluster.Listener + c.clusterListener.Store(nilCL) c.logger.Info("cluster listeners successfully shut down") } @@ -350,11 +360,12 @@ func (c *Core) SetClusterHandler(handler http.Handler) { // NextProtos. func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) { return func(addr string, timeout time.Duration) (net.Conn, error) { - if c.clusterListener == nil { + clusterListener := c.getClusterListener() + if clusterListener == nil { return nil, errors.New("clustering disabled") } - tlsConfig, err := c.clusterListener.TLSConfig(ctx) + tlsConfig, err := clusterListener.TLSConfig(ctx) if err != nil { c.logger.Error("failed to get tls configuration", "error", err) return nil, err diff --git a/vault/cluster_test.go b/vault/cluster_test.go index 5344a729a..81ff81b95 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -101,8 +101,8 @@ func TestCluster_ListenForRequests(t *testing.T) { // Wait for core to become active TestWaitActive(t, cores[0].Core) - cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) - addrs := cores[0].clusterListener.Addrs() + cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) + addrs := cores[0].getClusterListener().Addrs() // Use this to have a valid config after sealing since ClusterTLSConfig returns nil checkListenersFunc := func(expectFail bool) { @@ -157,7 +157,7 @@ func TestCluster_ListenForRequests(t *testing.T) { // After this period it should be active again TestWaitActive(t, cores[0].Core) - cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) + cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) checkListenersFunc(false) err = cores[0].Core.Seal(cluster.RootToken) @@ -384,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) { // Wait for core to become active TestWaitActive(t, core.Core) - core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core}) + core.getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core}) parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate) dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) - netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0) + netConn, err := dialer(core.getClusterListener().Addrs()[0].String(), 0) conn := netConn.(*tls.Conn) if err != nil { t.Fatal(err) diff --git a/vault/core.go b/vault/core.go index 7f8cf76e7..8da65417c 100644 --- a/vault/core.go +++ b/vault/core.go @@ -442,7 +442,7 @@ type Core struct { loadCaseSensitiveIdentityStore bool // clusterListener starts up and manages connections on the cluster ports - clusterListener *cluster.Listener + clusterListener *atomic.Value // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -615,6 +615,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { underlyingPhysical: conf.Physical, redirectAddr: conf.RedirectAddr, clusterAddr: new(atomic.Value), + clusterListener: new(atomic.Value), seal: conf.Seal, router: NewRouter(), sealed: new(uint32), @@ -781,6 +782,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { uiStoragePrefix := systemBarrierPrefix + "ui" c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix)) + c.clusterListener.Store((*cluster.Listener)(nil)) + return c, nil } @@ -1515,7 +1518,7 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, shutdownRaft b // If the storage backend needs to be sealed if shutdownRaft { if raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend); ok { - if err := raftStorage.TeardownCluster(c.clusterListener); err != nil { + if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil { c.logger.Error("error stopping storage cluster", "error", err) return err } @@ -1624,7 +1627,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c c.auditBroker = NewAuditBroker(c.logger) } - if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) { + if c.getClusterListener() != nil && (c.ha != nil || shouldStartClusterListener(c)) { if err := c.setupRaftActiveNode(ctx); err != nil { return err } diff --git a/vault/raft.go b/vault/raft.go index 0ac3c1edd..31d622501 100644 --- a/vault/raft.go +++ b/vault/raft.go @@ -124,7 +124,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) { raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true)) if err := raftStorage.SetupCluster(ctx, raft.SetupOpts{ TLSKeyring: raftTLS, - ClusterListener: c.clusterListener, + ClusterListener: c.getClusterListener(), StartAsLeader: creating, }); err != nil { return err @@ -133,7 +133,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) { defer func() { if retErr != nil { c.logger.Info("stopping raft server") - if err := raftStorage.TeardownCluster(c.clusterListener); err != nil { + if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil { c.logger.Error("failed to stop raft server", "error", err) } } @@ -715,7 +715,7 @@ func (c *Core) joinRaftSendAnswer(ctx context.Context, leaderClient *api.Client, raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true)) err = raftStorage.SetupCluster(ctx, raft.SetupOpts{ TLSKeyring: answerResp.Data.TLSKeyring, - ClusterListener: c.clusterListener, + ClusterListener: c.getClusterListener(), }) if err != nil { return errwrap.Wrapf("failed to setup raft cluster: {{err}}", err) diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 6c8889179..26e3fa5da 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -189,7 +189,7 @@ func (c *Core) startForwarding(ctx context.Context) error { c.clearForwardingClients() c.requestForwardingConnectionLock.Unlock() - if c.ha == nil || c.clusterListener == nil { + if c.ha == nil || c.getClusterListener() == nil { c.logger.Debug("request forwarding not setup") return nil } @@ -199,20 +199,20 @@ func (c *Core) startForwarding(ctx context.Context) error { return err } - handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache) + handler, err := NewRequestForwardingHandler(c, c.getClusterListener().Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache) if err != nil { return err } - c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler) + c.getClusterListener().AddHandler(consts.RequestForwardingALPN, handler) return nil } func (c *Core) stopForwarding() { - if c.clusterListener != nil { - c.clusterListener.StopHandler(consts.RequestForwardingALPN) - c.clusterListener.StopHandler(consts.PerfStandbyALPN) + if c.getClusterListener() != nil { + c.getClusterListener().StopHandler(consts.RequestForwardingALPN) + c.getClusterListener().StopHandler(consts.PerfStandbyALPN) } c.removeAllPerfStandbySecondaries() } @@ -247,8 +247,9 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd return errors.New("no request forwarding cluster certificate found") } - if c.clusterListener != nil { - c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{ + clusterListener := c.getClusterListener() + if clusterListener != nil { + clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{ core: c, }) } @@ -302,8 +303,9 @@ func (c *Core) clearForwardingClients() { c.rpcClientConnContext = nil c.rpcForwardingClient = nil - if c.clusterListener != nil { - c.clusterListener.RemoveClient(consts.RequestForwardingALPN) + clusterListener := c.getClusterListener() + if clusterListener != nil { + clusterListener.RemoveClient(consts.RequestForwardingALPN) } c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil)) }