Make clusterListener an atomic.Value to avoid races with getGRPCDialer. (#7408)
This commit is contained in:
parent
8fdb5f62c4
commit
ed147b7ae7
|
@ -289,7 +289,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if c.clusterListener != nil {
|
||||
if c.getClusterListener() != nil {
|
||||
c.logger.Warn("cluster listener is already started")
|
||||
return nil
|
||||
}
|
||||
|
@ -301,15 +301,15 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||
|
||||
c.logger.Debug("starting cluster listeners")
|
||||
|
||||
c.clusterListener = cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))
|
||||
c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
|
||||
|
||||
err := c.clusterListener.Run(ctx)
|
||||
err := c.getClusterListener().Run(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.HasSuffix(c.ClusterAddr(), ":0") {
|
||||
// If we listened on port 0, record the port the OS gave us.
|
||||
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0]))
|
||||
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -318,18 +318,28 @@ func (c *Core) ClusterAddr() string {
|
|||
return c.clusterAddr.Load().(string)
|
||||
}
|
||||
|
||||
func (c *Core) getClusterListener() *cluster.Listener {
|
||||
cl := c.clusterListener.Load()
|
||||
if cl == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.(*cluster.Listener)
|
||||
}
|
||||
|
||||
// stopClusterListener stops any existing listeners during seal. It is
|
||||
// assumed that the state lock is held while this is run.
|
||||
func (c *Core) stopClusterListener() {
|
||||
if c.clusterListener == nil {
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener == nil {
|
||||
c.logger.Debug("clustering disabled, not stopping listeners")
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Info("stopping cluster listeners")
|
||||
|
||||
c.clusterListener.Stop()
|
||||
c.clusterListener = nil
|
||||
clusterListener.Stop()
|
||||
var nilCL *cluster.Listener
|
||||
c.clusterListener.Store(nilCL)
|
||||
|
||||
c.logger.Info("cluster listeners successfully shut down")
|
||||
}
|
||||
|
@ -350,11 +360,12 @@ func (c *Core) SetClusterHandler(handler http.Handler) {
|
|||
// NextProtos.
|
||||
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
if c.clusterListener == nil {
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener == nil {
|
||||
return nil, errors.New("clustering disabled")
|
||||
}
|
||||
|
||||
tlsConfig, err := c.clusterListener.TLSConfig(ctx)
|
||||
tlsConfig, err := clusterListener.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
|
|
|
@ -101,8 +101,8 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||
// Wait for core to become active
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
|
||||
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
addrs := cores[0].clusterListener.Addrs()
|
||||
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
addrs := cores[0].getClusterListener().Addrs()
|
||||
|
||||
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
|
||||
checkListenersFunc := func(expectFail bool) {
|
||||
|
@ -157,7 +157,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||
|
||||
// After this period it should be active again
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
cores[0].clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
checkListenersFunc(false)
|
||||
|
||||
err = cores[0].Core.Seal(cluster.RootToken)
|
||||
|
@ -384,12 +384,12 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
|
|||
// Wait for core to become active
|
||||
TestWaitActive(t, core.Core)
|
||||
|
||||
core.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
|
||||
core.getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{core.Core})
|
||||
|
||||
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
dialer := core.getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||
|
||||
netConn, err := dialer(core.clusterListener.Addrs()[0].String(), 0)
|
||||
netConn, err := dialer(core.getClusterListener().Addrs()[0].String(), 0)
|
||||
conn := netConn.(*tls.Conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -442,7 +442,7 @@ type Core struct {
|
|||
loadCaseSensitiveIdentityStore bool
|
||||
|
||||
// clusterListener starts up and manages connections on the cluster ports
|
||||
clusterListener *cluster.Listener
|
||||
clusterListener *atomic.Value
|
||||
|
||||
// Telemetry objects
|
||||
metricsHelper *metricsutil.MetricsHelper
|
||||
|
@ -615,6 +615,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
underlyingPhysical: conf.Physical,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: new(atomic.Value),
|
||||
clusterListener: new(atomic.Value),
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: new(uint32),
|
||||
|
@ -781,6 +782,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
uiStoragePrefix := systemBarrierPrefix + "ui"
|
||||
c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix))
|
||||
|
||||
c.clusterListener.Store((*cluster.Listener)(nil))
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -1515,7 +1518,7 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, shutdownRaft b
|
|||
// If the storage backend needs to be sealed
|
||||
if shutdownRaft {
|
||||
if raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend); ok {
|
||||
if err := raftStorage.TeardownCluster(c.clusterListener); err != nil {
|
||||
if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil {
|
||||
c.logger.Error("error stopping storage cluster", "error", err)
|
||||
return err
|
||||
}
|
||||
|
@ -1624,7 +1627,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
|
|||
c.auditBroker = NewAuditBroker(c.logger)
|
||||
}
|
||||
|
||||
if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) {
|
||||
if c.getClusterListener() != nil && (c.ha != nil || shouldStartClusterListener(c)) {
|
||||
if err := c.setupRaftActiveNode(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) {
|
|||
raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true))
|
||||
if err := raftStorage.SetupCluster(ctx, raft.SetupOpts{
|
||||
TLSKeyring: raftTLS,
|
||||
ClusterListener: c.clusterListener,
|
||||
ClusterListener: c.getClusterListener(),
|
||||
StartAsLeader: creating,
|
||||
}); err != nil {
|
||||
return err
|
||||
|
@ -133,7 +133,7 @@ func (c *Core) startRaftStorage(ctx context.Context) (retErr error) {
|
|||
defer func() {
|
||||
if retErr != nil {
|
||||
c.logger.Info("stopping raft server")
|
||||
if err := raftStorage.TeardownCluster(c.clusterListener); err != nil {
|
||||
if err := raftStorage.TeardownCluster(c.getClusterListener()); err != nil {
|
||||
c.logger.Error("failed to stop raft server", "error", err)
|
||||
}
|
||||
}
|
||||
|
@ -715,7 +715,7 @@ func (c *Core) joinRaftSendAnswer(ctx context.Context, leaderClient *api.Client,
|
|||
raftStorage.SetRestoreCallback(c.raftSnapshotRestoreCallback(true, true))
|
||||
err = raftStorage.SetupCluster(ctx, raft.SetupOpts{
|
||||
TLSKeyring: answerResp.Data.TLSKeyring,
|
||||
ClusterListener: c.clusterListener,
|
||||
ClusterListener: c.getClusterListener(),
|
||||
})
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to setup raft cluster: {{err}}", err)
|
||||
|
|
|
@ -189,7 +189,7 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
|||
c.clearForwardingClients()
|
||||
c.requestForwardingConnectionLock.Unlock()
|
||||
|
||||
if c.ha == nil || c.clusterListener == nil {
|
||||
if c.ha == nil || c.getClusterListener() == nil {
|
||||
c.logger.Debug("request forwarding not setup")
|
||||
return nil
|
||||
}
|
||||
|
@ -199,20 +199,20 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
|
||||
handler, err := NewRequestForwardingHandler(c, c.getClusterListener().Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.clusterListener.AddHandler(consts.RequestForwardingALPN, handler)
|
||||
c.getClusterListener().AddHandler(consts.RequestForwardingALPN, handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) stopForwarding() {
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.StopHandler(consts.RequestForwardingALPN)
|
||||
c.clusterListener.StopHandler(consts.PerfStandbyALPN)
|
||||
if c.getClusterListener() != nil {
|
||||
c.getClusterListener().StopHandler(consts.RequestForwardingALPN)
|
||||
c.getClusterListener().StopHandler(consts.PerfStandbyALPN)
|
||||
}
|
||||
c.removeAllPerfStandbySecondaries()
|
||||
}
|
||||
|
@ -247,8 +247,9 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||
return errors.New("no request forwarding cluster certificate found")
|
||||
}
|
||||
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener != nil {
|
||||
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
|
||||
core: c,
|
||||
})
|
||||
}
|
||||
|
@ -302,8 +303,9 @@ func (c *Core) clearForwardingClients() {
|
|||
c.rpcClientConnContext = nil
|
||||
c.rpcForwardingClient = nil
|
||||
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.RemoveClient(consts.RequestForwardingALPN)
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener != nil {
|
||||
clusterListener.RemoveClient(consts.RequestForwardingALPN)
|
||||
}
|
||||
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue