Make clusterListener an atomic.Value to avoid races with getGRPCDialer. (#7408)

This commit is contained in:
ncabatoff 2019-09-03 11:59:56 -04:00 committed by GitHub
parent 8fdb5f62c4
commit ed147b7ae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 30 deletions

View File

@ -289,7 +289,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
return nil
}
if c.clusterListener != nil {
if c.getClusterListener() != nil {
c.logger.Warn("cluster listener is already started")
return nil
}
@ -301,15 +301,15 @@ func (c *Core) startClusterListener(ctx context.Context) error {
c.logger.Debug("starting cluster listeners")
c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))
c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
err := c.clusterListener.Run(ctx)
err := c.getClusterListener().Run(ctx)
if err != nil {
return err
}
if strings.HasSuffix(c.ClusterAddr(), ":0") {
// If we listened on port 0, record the port the OS gave us.
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0]))
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0]))
}
return nil
}
@ -318,18 +318,28 @@ func (c *Core) ClusterAddr() string {
return c.clusterAddr.Load().(string)
}
func (c *Core) getClusterListener() *cluster.Listener {
cl := c.clusterListener.Load()
if cl == nil {
return nil
}
return cl.(*cluster.Listener)
}
// stopClusterListener stops any existing listeners during seal. It is
// assumed that the state lock is held while this is run.
func (c *Core) stopClusterListener() {
if c.clusterListener == nil {
clusterListener := c.getClusterListener()
if clusterListener == nil {
c.logger.Debug("clustering disabled, not stopping listeners")
return
}
c.logger.Info("stopping cluster listeners")
c.clusterListener.Stop()
c.clusterListener = nil
clusterListener.Stop()
var nilCL *cluster.Listener
c.clusterListener.Store(nilCL)
c.logger.Info("cluster listeners successfully shut down")
}
@ -350,11 +360,12 @@ func (c *Core) SetClusterHandler(handler http.Handler) {
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
if c.clusterListener == nil {
clusterListener := c.getClusterListener()
if clusterListener == nil {
return nil, errors.New("clustering disabled")
}
tlsConfig, err := c.clusterListener.TLSConfig(ctx)
tlsConfig, err := clusterListener.TLSConfig(ctx)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err

View File

@ -101,8 +101,8 @@ func TestCluster_ListenForRequests(t *testing.T) {
// Wait for core to become active
TestWaitActive(t, cores[0].Core)
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
addrs := cores[0].clusterListener.Addrs()
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
addrs := cores[0].getClusterListener().Addrs()
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
checkListenersFunc := func(expectFail bool) {
@ -157,7 +157,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
// After this period it should be active again
TestWaitActive(t, cores[0].Core)
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
checkListenersFunc(false)
err = cores[0].Core.Seal(cluster.RootToken)
@ -384,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
// Wait for core to become active
TestWaitActive(t, core.Core)
core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
core.getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0)
netConn, err := dialer(core.getClusterListener().Addrs()[0].String(), 0)
conn := netConn.(*tls.Conn)
if err != nil {
t.Fatal(err)

View File

@ -442,7 +442,7 @@ type Core struct {
loadCaseSensitiveIdentityStore bool
// clusterListener starts up and manages connections on the cluster ports
clusterListener *cluster.Listener
clusterListener *atomic.Value
// Telemetry objects
metricsHelper *metricsutil.MetricsHelper
@ -615,6 +615,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
underlyingPhysical: conf.Physical,
redirectAddr: conf.RedirectAddr,
clusterAddr: new(atomic.Value),
clusterListener: new(atomic.Value),
seal: conf.Seal,
router: NewRouter(),
sealed: new(uint32),
@ -781,6 +782,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
uiStoragePrefix := systemBarrierPrefix + "ui"
c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix))
c.clusterListener.Store((*cluster.Listener)(nil))
return c, nil
}
@ -1515,7 +1518,7 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, shutdownRaft b
// If the storage backend needs to be sealed
if shutdownRaft {
if raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend); ok {
if err := raftStorage.TeardownCluster(c.clusterListener); err != nil {
if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil {
c.logger.Error("error stopping storage cluster", "error", err)
return err
}
@ -1624,7 +1627,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
c.auditBroker = NewAuditBroker(c.logger)
}
if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) {
if c.getClusterListener() != nil && (c.ha != nil || shouldStartClusterListener(c)) {
if err := c.setupRaftActiveNode(ctx); err != nil {
return err
}

View File

@ -124,7 +124,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) {
raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true))
if err := raftStorage.SetupCluster(ctx, raft.SetupOpts{
TLSKeyring: raftTLS,
ClusterListener: c.clusterListener,
ClusterListener: c.getClusterListener(),
StartAsLeader: creating,
}); err != nil {
return err
@ -133,7 +133,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) {
defer func() {
if retErr != nil {
c.logger.Info("stopping raft server")
if err := raftStorage.TeardownCluster(c.clusterListener); err != nil {
if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil {
c.logger.Error("failed to stop raft server", "error", err)
}
}
@ -715,7 +715,7 @@ func (c *Core) joinRaftSendAnswer(ctx context.Context, leaderClient *api.Client,
raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true))
err = raftStorage.SetupCluster(ctx, raft.SetupOpts{
TLSKeyring: answerResp.Data.TLSKeyring,
ClusterListener: c.clusterListener,
ClusterListener: c.getClusterListener(),
})
if err != nil {
return errwrap.Wrapf("failed to setup raft cluster: {{err}}", err)

View File

@ -189,7 +189,7 @@ func (c *Core) startForwarding(ctx context.Context) error {
c.clearForwardingClients()
c.requestForwardingConnectionLock.Unlock()
if c.ha == nil || c.clusterListener == nil {
if c.ha == nil || c.getClusterListener() == nil {
c.logger.Debug("request forwarding not setup")
return nil
}
@ -199,20 +199,20 @@ func (c *Core) startForwarding(ctx context.Context) error {
return err
}
handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
handler, err := NewRequestForwardingHandler(c, c.getClusterListener().Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
if err != nil {
return err
}
c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler)
c.getClusterListener().AddHandler(consts.RequestForwardingALPN, handler)
return nil
}
func (c *Core) stopForwarding() {
if c.clusterListener != nil {
c.clusterListener.StopHandler(consts.RequestForwardingALPN)
c.clusterListener.StopHandler(consts.PerfStandbyALPN)
if c.getClusterListener() != nil {
c.getClusterListener().StopHandler(consts.RequestForwardingALPN)
c.getClusterListener().StopHandler(consts.PerfStandbyALPN)
}
c.removeAllPerfStandbySecondaries()
}
@ -247,8 +247,9 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
return errors.New("no request forwarding cluster certificate found")
}
if c.clusterListener != nil {
c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
clusterListener := c.getClusterListener()
if clusterListener != nil {
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
core: c,
})
}
@ -302,8 +303,9 @@ func (c *Core) clearForwardingClients() {
c.rpcClientConnContext = nil
c.rpcForwardingClient = nil
if c.clusterListener != nil {
c.clusterListener.RemoveClient(consts.RequestForwardingALPN)
clusterListener := c.getClusterListener()
if clusterListener != nil {
clusterListener.RemoveClient(consts.RequestForwardingALPN)
}
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
}