From 62e14c280d79aee268202478f6323d340e70488b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Jun 2019 10:00:03 -0700 Subject: [PATCH] storage/raft: fix races in tests (#6996) * storage/raft: fix races in tests * Fix another test race --- helper/testhelpers/testhelpers.go | 3 ++- physical/raft/fsm.go | 8 +++++++- physical/raft/raft.go | 22 +++++++++++++--------- vault/cluster.go | 12 ++++++------ vault/core.go | 13 +++++++------ vault/ha.go | 4 ++-- vault/init.go | 2 +- vault/raft.go | 9 +++++---- vault/request_forwarding.go | 1 + vault/request_forwarding_rpc.go | 7 ++++--- 10 files changed, 48 insertions(+), 33 deletions(-) diff --git a/helper/testhelpers/testhelpers.go b/helper/testhelpers/testhelpers.go index 62b48c163..be54594c8 100644 --- a/helper/testhelpers/testhelpers.go +++ b/helper/testhelpers/testhelpers.go @@ -12,6 +12,7 @@ import ( "os" "reflect" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-hclog" @@ -759,7 +760,7 @@ func RaftClusterJoinNodes(t testing.T, cluster *vault.TestCluster) { leaderCore := cluster.Cores[0] leaderAPI := leaderCore.Client.Address() - vault.UpdateClusterAddrForTests = true + atomic.StoreUint32(&vault.UpdateClusterAddrForTests, 1) // Seal the leader so we can install an address provider { diff --git a/physical/raft/fsm.go b/physical/raft/fsm.go index e67be47f7..45595447c 100644 --- a/physical/raft/fsm.go +++ b/physical/raft/fsm.go @@ -364,6 +364,7 @@ func (f *FSM) Apply(log *raft.Log) interface{} { command := &LogData{} err := proto.Unmarshal(log.Data, command) if err != nil { + f.logger.Error("error proto unmarshaling log data", "error", err) panic("error proto unmarshaling log data") } @@ -380,7 +381,8 @@ func (f *FSM) Apply(log *raft.Log) interface{} { Index: log.Index, }) if err != nil { - panic("failed to store data") + f.logger.Error("unable to marshal latest index", "error", err) + panic("unable to marshal latest index") } } @@ -418,6 +420,7 @@ func (f *FSM) Apply(log *raft.Log) interface{} { return nil }) if err != nil { + f.logger.Error("failed to store data", "error", err) panic("failed to store data") } @@ -575,6 +578,7 @@ func (f *FSM) StoreConfiguration(index uint64, configuration raft.Configuration) var err error indexBytes, err = proto.Marshal(latestIndex) if err != nil { + f.logger.Error("unable to marshal latest index", "error", err) panic(fmt.Sprintf("unable to marshal latest index: %v", err)) } } @@ -582,6 +586,7 @@ func (f *FSM) StoreConfiguration(index uint64, configuration raft.Configuration) protoConfig := raftConfigurationToProtoConfiguration(index, configuration) configBytes, err := proto.Marshal(protoConfig) if err != nil { + f.logger.Error("unable to marshal config", "error", err) panic(fmt.Sprintf("unable to marshal config: %v", err)) } @@ -604,6 +609,7 @@ func (f *FSM) StoreConfiguration(index uint64, configuration raft.Configuration) return nil }) if err != nil { + f.logger.Error("unable to store latest configuration", "error", err) panic(fmt.Sprintf("unable to store latest configuration: %v", err)) } } diff --git a/physical/raft/raft.go b/physical/raft/raft.go index 5db6877b6..af8cc6143 100644 --- a/physical/raft/raft.go +++ b/physical/raft/raft.go @@ -111,7 +111,7 @@ func EnsurePath(path string, dir bool) error { func NewRaftBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Create the FSM. var err error - fsm, err := NewFSM(conf, logger) + fsm, err := NewFSM(conf, logger.Named("fsm")) if err != nil { return nil, fmt.Errorf("failed to create fsm: %v", err) } @@ -379,8 +379,8 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, raftTLSKeyring *RaftTLSK case raftTLSKeyring == nil && clusterListener == nil: // If we don't have a provided network we use an in-memory one. // This allows us to bootstrap a node without bringing up a cluster - // network. This will be true during bootstrap and dev modes. - _, b.raftTransport = raft.NewInmemTransport(raft.ServerAddress(b.localID)) + // network. This will be true during bootstrap, tests and dev modes. + _, b.raftTransport = raft.NewInmemTransportWithTimeout(raft.ServerAddress(b.localID), time.Second) case raftTLSKeyring == nil: return errors.New("no keyring provided") case clusterListener == nil: @@ -819,11 +819,11 @@ type RaftLock struct { // monitorLeadership waits until we receive an update on the raftNotifyCh and // closes the leaderLost channel. -func (l *RaftLock) monitorLeadership(stopCh <-chan struct{}) <-chan struct{} { +func (l *RaftLock) monitorLeadership(stopCh <-chan struct{}, leaderNotifyCh <-chan bool) <-chan struct{} { leaderLost := make(chan struct{}) go func() { select { - case <-l.b.raftNotifyCh: + case <-leaderNotifyCh: close(leaderLost) case <-stopCh: } @@ -834,8 +834,12 @@ func (l *RaftLock) monitorLeadership(stopCh <-chan struct{}) <-chan struct{} { // Lock blocks until we become leader or are shutdown. It returns a channel that // is closed when we detect a loss of leadership. func (l *RaftLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { - // Check to see if we are already leader. l.b.l.RLock() + + // Cache the notifyCh locally + leaderNotifyCh := l.b.raftNotifyCh + + // Check to see if we are already leader. if l.b.raft.State() == raft.Leader { err := l.b.applyLog(context.Background(), &LogData{ Operations: []*LogOperation{ @@ -851,13 +855,13 @@ func (l *RaftLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { return nil, err } - return l.monitorLeadership(stopCh), nil + return l.monitorLeadership(stopCh, leaderNotifyCh), nil } l.b.l.RUnlock() for { select { - case isLeader := <-l.b.raftNotifyCh: + case isLeader := <-leaderNotifyCh: if isLeader { // We are leader, set the key l.b.l.RLock() @@ -875,7 +879,7 @@ func (l *RaftLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { return nil, err } - return l.monitorLeadership(stopCh), nil + return l.monitorLeadership(stopCh, leaderNotifyCh), nil } case <-stopCh: return nil, nil diff --git a/vault/cluster.go b/vault/cluster.go index 627b5f553..3e5a30203 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -284,7 +284,7 @@ func (c *Core) setupCluster(ctx context.Context) error { // only starts cluster listeners. Once the listener is started handlers/clients // can start being registered to it. func (c *Core) startClusterListener(ctx context.Context) error { - if c.clusterAddr == "" { + if c.ClusterAddr() == "" { c.logger.Info("clustering disabled, not starting listeners") return nil } @@ -307,15 +307,15 @@ func (c *Core) startClusterListener(ctx context.Context) error { if err != nil { return err } - if strings.HasSuffix(c.clusterAddr, ":0") { + if strings.HasSuffix(c.ClusterAddr(), ":0") { // If we listened on port 0, record the port the OS gave us. - c.clusterAddr = fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0]) + c.clusterAddr.Store(fmt.Sprintf("https://%s", c.clusterListener.Addrs()[0])) } return nil } func (c *Core) ClusterAddr() string { - return c.clusterAddr + return c.clusterAddr.Load().(string) } // stopClusterListener stops any existing listeners during seal. It is @@ -336,8 +336,8 @@ func (c *Core) stopClusterListener() { func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) { c.clusterListenerAddrs = addrs - if c.clusterAddr == "" && len(addrs) == 1 { - c.clusterAddr = fmt.Sprintf("https://%s", addrs[0].String()) + if c.ClusterAddr() == "" && len(addrs) == 1 { + c.clusterAddr.Store(fmt.Sprintf("https://%s", addrs[0].String())) } } diff --git a/vault/core.go b/vault/core.go index 758e43f4b..88b2a5101 100644 --- a/vault/core.go +++ b/vault/core.go @@ -173,7 +173,7 @@ type Core struct { redirectAddr string // clusterAddr is the address we use for clustering - clusterAddr string + clusterAddr *atomic.Value // physical backend is the un-trusted backend with durable data physical physical.Backend @@ -609,7 +609,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { physical: conf.Physical, underlyingPhysical: conf.Physical, redirectAddr: conf.RedirectAddr, - clusterAddr: conf.ClusterAddr, + clusterAddr: new(atomic.Value), seal: conf.Seal, router: NewRouter(), sealed: new(uint32), @@ -654,7 +654,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil)) c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil)) - + c.clusterAddr.Store(conf.ClusterAddr) c.activeContextCancelFunc.Store((context.CancelFunc)(nil)) if conf.ClusterCipherSuites != "" { @@ -1617,11 +1617,12 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c } if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) { + c.startPeriodicRaftTLSRotate(ctx) + if err := c.startForwarding(ctx); err != nil { return err } - c.startPeriodicRaftTLSRotate(ctx) } c.clusterParamsLock.Lock() @@ -1704,10 +1705,10 @@ func (c *Core) preSeal() error { } var result error - c.stopPeriodicRaftTLSRotate() - c.stopForwarding() + c.stopPeriodicRaftTLSRotate() + c.clusterParamsLock.Lock() if err := stopReplication(c); err != nil { result = multierror.Append(result, errwrap.Wrapf("error stopping replication: {{err}}", err)) diff --git a/vault/ha.go b/vault/ha.go index 73d0d94c1..020a1543f 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -91,7 +91,7 @@ func (c *Core) Leader() (isLeader bool, leaderAddr, clusterAddr string, err erro // Check if we are the leader if !c.standby { c.stateLock.RUnlock() - return true, c.redirectAddr, c.clusterAddr, nil + return true, c.redirectAddr, c.ClusterAddr(), nil } // Initialize a lock @@ -877,7 +877,7 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <- copy(localCert, locCert) adv := &activeAdvertisement{ RedirectAddr: c.redirectAddr, - ClusterAddr: c.clusterAddr, + ClusterAddr: c.ClusterAddr(), ClusterCert: localCert, ClusterKeyParams: keyParams, } diff --git a/vault/init.go b/vault/init.go index beaba20d7..dc083d6bc 100644 --- a/vault/init.go +++ b/vault/init.go @@ -146,7 +146,7 @@ func (c *Core) Initialize(ctx context.Context, initParams *InitParams) (*InitRes // If we have clustered storage, set it up now if raftStorage, ok := c.underlyingPhysical.(*raft.RaftBackend); ok { - parsedClusterAddr, err := url.Parse(c.clusterAddr) + parsedClusterAddr, err := url.Parse(c.ClusterAddr()) if err != nil { return nil, errwrap.Wrapf("error parsing cluster address: {{err}}", err) } diff --git a/vault/raft.go b/vault/raft.go index 8c15cafc8..9307f8c15 100644 --- a/vault/raft.go +++ b/vault/raft.go @@ -11,6 +11,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" proto "github.com/golang/protobuf/proto" @@ -145,7 +146,7 @@ func (c *Core) startPeriodicRaftTLSRotate(ctx context.Context) error { return err } for _, server := range raftConfig.Servers { - if !server.Leader { + if server.NodeID != raftStorage.NodeID() { followerStates.update(server.NodeID, 0) } } @@ -586,7 +587,7 @@ func (c *Core) JoinRaftCluster(ctx context.Context, leaderAddr string, tlsConfig } // This is used in tests to override the cluster address -var UpdateClusterAddrForTests bool +var UpdateClusterAddrForTests uint32 func (c *Core) joinRaftSendAnswer(ctx context.Context, leaderClient *api.Client, challenge *physical.EncryptedBlobInfo, sealAccess seal.Access) error { if challenge == nil { @@ -607,12 +608,12 @@ func (c *Core) joinRaftSendAnswer(ctx context.Context, leaderClient *api.Client, return errwrap.Wrapf("error decrypting challenge: {{err}}", err) } - parsedClusterAddr, err := url.Parse(c.clusterAddr) + parsedClusterAddr, err := url.Parse(c.ClusterAddr()) if err != nil { return errwrap.Wrapf("error parsing cluster address: {{err}}", err) } clusterAddr := parsedClusterAddr.Host - if UpdateClusterAddrForTests && strings.HasSuffix(clusterAddr, ":0") { + if atomic.LoadUint32(&UpdateClusterAddrForTests) == 1 && strings.HasSuffix(clusterAddr, ":0") { // We are testing and have an address provider, so just create a random // addr, it will be overwritten later. var err error diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index eaea4e04a..eb8f021d1 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -60,6 +60,7 @@ func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots ch perfStandbySlots: perfStandbySlots, perfStandbyRepCluster: perfStandbyRepCluster, perfStandbyCache: perfStandbyCache, + raftFollowerStates: c.raftFollowerStates, }) } diff --git a/vault/request_forwarding_rpc.go b/vault/request_forwarding_rpc.go index 117e7fb4a..4004fdb18 100644 --- a/vault/request_forwarding_rpc.go +++ b/vault/request_forwarding_rpc.go @@ -20,6 +20,7 @@ type forwardedRequestRPCServer struct { perfStandbySlots chan struct{} perfStandbyRepCluster *replication.Cluster perfStandbyCache *cache.Cache + raftFollowerStates *raftFollowerStates } func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *forwarding.Request) (*forwarding.Response, error) { @@ -74,8 +75,8 @@ func (s *forwardedRequestRPCServer) Echo(ctx context.Context, in *EchoRequest) ( s.core.clusterPeerClusterAddrsCache.Set(in.ClusterAddr, nil, 0) } - if in.RaftAppliedIndex > 0 && len(in.RaftNodeID) > 0 && s.core.raftFollowerStates != nil { - s.core.raftFollowerStates.update(in.RaftNodeID, in.RaftAppliedIndex) + if in.RaftAppliedIndex > 0 && len(in.RaftNodeID) > 0 && s.raftFollowerStates != nil { + s.raftFollowerStates.update(in.RaftNodeID, in.RaftAppliedIndex) } reply := &EchoReply{ @@ -106,7 +107,7 @@ func (c *forwardingClient) startHeartbeat() { go func() { tick := func() { c.core.stateLock.RLock() - clusterAddr := c.core.clusterAddr + clusterAddr := c.core.ClusterAddr() c.core.stateLock.RUnlock() req := &EchoRequest{