From b3189a566acd4b481d2920866f4b6230af9461f7 Mon Sep 17 00:00:00 2001 From: Frank Schroeder Date: Tue, 27 Jun 2017 15:25:25 +0200 Subject: [PATCH] rpc: refactor sessionTimers and fix racy tests The sessionTimers map was secured by a lock which wasn't used properly in the tests. This lead to data races and failing tests when accessing the length or the members of the map. This patch adds a separate SessionTimers struct which is safe for concurrent use and which ecapsulates the behavior of the sessionTimers map. --- agent/consul/server.go | 4 +- agent/consul/session_endpoint_test.go | 6 +- agent/consul/session_timers.go | 82 +++++++++++++++++++ agent/consul/session_timers_test.go | 105 +++++++++++++++++++++++++ agent/consul/session_ttl.go | 54 +++---------- agent/consul/session_ttl_test.go | 100 +++++++++++------------ agent/consul/snapshot_endpoint_test.go | 8 +- 7 files changed, 253 insertions(+), 106 deletions(-) create mode 100644 agent/consul/session_timers.go create mode 100644 agent/consul/session_timers_test.go diff --git a/agent/consul/server.go b/agent/consul/server.go index d6f00dd3d..e5f9ac1dd 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -172,8 +172,7 @@ type Server struct { // sessionTimers track the expiration time of each Session that has // a TTL. On expiration, a SessionDestroy event will occur, and // destroy the session via standard session destroy processing - sessionTimers map[string]*time.Timer - sessionTimersLock sync.Mutex + sessionTimers *SessionTimers // statsFetcher is used by autopilot to check the status of the other // Consul servers. @@ -296,6 +295,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) { rpcServer: rpc.NewServer(), rpcTLS: incomingTLS, reassertLeaderCh: make(chan chan error), + sessionTimers: NewSessionTimers(), tombstoneGC: gc, shutdownCh: shutdownCh, } diff --git a/agent/consul/session_endpoint_test.go b/agent/consul/session_endpoint_test.go index d10f5c0ec..776e95a5f 100644 --- a/agent/consul/session_endpoint_test.go +++ b/agent/consul/session_endpoint_test.go @@ -514,7 +514,7 @@ func TestSession_ApplyTimers(t *testing.T) { } // Check the session map - if _, ok := s1.sessionTimers[out]; !ok { + if s1.sessionTimers.Get(out) == nil { t.Fatalf("missing session timer") } @@ -526,7 +526,7 @@ func TestSession_ApplyTimers(t *testing.T) { } // Check the session map - if _, ok := s1.sessionTimers[out]; ok { + if s1.sessionTimers.Get(out) != nil { t.Fatalf("session timer exists") } } @@ -564,7 +564,7 @@ func TestSession_Renew(t *testing.T) { } // Verify the timer map is setup - if len(s1.sessionTimers) != 5 { + if s1.sessionTimers.Len() != 5 { t.Fatalf("missing session timers") } diff --git a/agent/consul/session_timers.go b/agent/consul/session_timers.go new file mode 100644 index 000000000..227d43e85 --- /dev/null +++ b/agent/consul/session_timers.go @@ -0,0 +1,82 @@ +package consul + +import ( + "sync" + "time" +) + +// SessionTimers provides a map of named timers which +// is safe for concurrent use. +type SessionTimers struct { + sync.RWMutex + m map[string]*time.Timer +} + +func NewSessionTimers() *SessionTimers { + return &SessionTimers{m: make(map[string]*time.Timer)} +} + +// Get returns the timer with the given id or nil. +func (t *SessionTimers) Get(id string) *time.Timer { + t.RLock() + defer t.RUnlock() + return t.m[id] +} + +// Set stores the timer under given id. If tm is nil the timer +// witht the given id is removed. +func (t *SessionTimers) Set(id string, tm *time.Timer) { + t.Lock() + defer t.Unlock() + if tm == nil { + // todo(fs): shouldn't we call Stop() here? + delete(t.m, id) + } else { + t.m[id] = tm + } +} + +// Del removes the timer with the given id. +func (t *SessionTimers) Del(id string) { + t.Set(id, nil) +} + +// Len returns the number of registered timers. +func (t *SessionTimers) Len() int { + t.RLock() + defer t.RUnlock() + return len(t.m) +} + +// ResetOrCreate sets the ttl of the timer with the given id or creates a new +// one if it does not exist. +func (t *SessionTimers) ResetOrCreate(id string, ttl time.Duration, afterFunc func()) { + t.Lock() + defer t.Unlock() + + if tm := t.m[id]; tm != nil { + tm.Reset(ttl) + return + } + t.m[id] = time.AfterFunc(ttl, afterFunc) +} + +// Stop stops the timer with the given id and removes it. +func (t *SessionTimers) Stop(id string) { + t.Lock() + defer t.Unlock() + if tm := t.m[id]; tm != nil { + tm.Stop() + delete(t.m, id) + } +} + +// StopAll stops and removes all registered timers. +func (t *SessionTimers) StopAll() { + t.Lock() + defer t.Unlock() + for _, tm := range t.m { + tm.Stop() + } + t.m = make(map[string]*time.Timer) +} diff --git a/agent/consul/session_timers_test.go b/agent/consul/session_timers_test.go new file mode 100644 index 000000000..8f49763f7 --- /dev/null +++ b/agent/consul/session_timers_test.go @@ -0,0 +1,105 @@ +package consul + +import ( + "testing" + "time" +) + +func TestSessionTimers(t *testing.T) { + m := NewSessionTimers() + ch := make(chan int) + newTm := func(d time.Duration) *time.Timer { + return time.AfterFunc(d, func() { ch <- 1 }) + } + + waitForTimer := func() { + select { + case <-ch: + return + case <-time.After(100 * time.Millisecond): + t.Fatal("timer did not fire") + } + } + + // check that non-existent id returns nil + if got, want := m.Get("foo"), (*time.Timer)(nil); got != want { + t.Fatalf("got %v want %v", got, want) + } + + // add a timer and look it up and delete via Set(id, nil) + tm := newTm(time.Millisecond) + m.Set("foo", tm) + if got, want := m.Len(), 1; got != want { + t.Fatalf("got len %d want %d", got, want) + } + if got, want := m.Get("foo"), tm; got != want { + t.Fatalf("got %v want %v", got, want) + } + m.Set("foo", nil) + if got, want := m.Get("foo"), (*time.Timer)(nil); got != want { + t.Fatalf("got %v want %v", got, want) + } + waitForTimer() + + // same thing via Del(id) + tm = newTm(time.Millisecond) + m.Set("foo", tm) + if got, want := m.Get("foo"), tm; got != want { + t.Fatalf("got %v want %v", got, want) + } + m.Del("foo") + if got, want := m.Len(), 0; got != want { + t.Fatalf("got len %d want %d", got, want) + } + waitForTimer() + + // create timer via ResetOrCreate + m.ResetOrCreate("foo", time.Millisecond, func() { ch <- 1 }) + if got, want := m.Len(), 1; got != want { + t.Fatalf("got len %d want %d", got, want) + } + waitForTimer() + + // timer is still there + if got, want := m.Len(), 1; got != want { + t.Fatalf("got len %d want %d", got, want) + } + + // reset the timer and check that it fires again + m.ResetOrCreate("foo", time.Millisecond, nil) + waitForTimer() + + // reset the timer with a long ttl and then stop it + m.ResetOrCreate("foo", 20*time.Millisecond, func() { ch <- 1 }) + m.Stop("foo") + select { + case <-ch: + t.Fatal("timer fired although it shouldn't") + case <-time.After(100 * time.Millisecond): + // want + } + + // stopping a stopped timer should not break + m.Stop("foo") + + // stop should also remove the timer + if got, want := m.Len(), 0; got != want { + t.Fatalf("got len %d want %d", got, want) + } + + // create two timers and stop and then stop all + m.ResetOrCreate("foo1", 20*time.Millisecond, func() { ch <- 1 }) + m.ResetOrCreate("foo2", 30*time.Millisecond, func() { ch <- 2 }) + m.StopAll() + select { + case x := <-ch: + t.Fatalf("timer %d fired although it shouldn't", x) + case <-time.After(100 * time.Millisecond): + // want + } + + // stopall should remove all timers + if got, want := m.Len(), 0; got != want { + t.Fatalf("got len %d want %d", got, want) + } +} diff --git a/agent/consul/session_ttl.go b/agent/consul/session_ttl.go index 2f4819ac3..35441d954 100644 --- a/agent/consul/session_ttl.go +++ b/agent/consul/session_ttl.go @@ -66,49 +66,28 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { return nil } - // Reset the session timer - s.sessionTimersLock.Lock() - defer s.sessionTimersLock.Unlock() - s.resetSessionTimerLocked(id, ttl) + s.createSessionTimer(session.ID, ttl) return nil } -// resetSessionTimerLocked is used to reset a session timer -// assuming the sessionTimerLock is already held -func (s *Server) resetSessionTimerLocked(id string, ttl time.Duration) { - // Ensure a timer map exists - if s.sessionTimers == nil { - s.sessionTimers = make(map[string]*time.Timer) - } - +func (s *Server) createSessionTimer(id string, ttl time.Duration) { + // Reset the session timer // Adjust the given TTL by the TTL multiplier. This is done // to give a client a grace period and to compensate for network // and processing delays. The contract is that a session is not expired // before the TTL, but there is no explicit promise about the upper // bound so this is allowable. ttl = ttl * structs.SessionTTLMultiplier - - // Renew the session timer if it exists - if timer, ok := s.sessionTimers[id]; ok { - timer.Reset(ttl) - return - } - - // Create a new timer to track expiration of this ssession - timer := time.AfterFunc(ttl, func() { - s.invalidateSession(id) - }) - s.sessionTimers[id] = timer + s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) }) } // invalidateSession is invoked when a session TTL is reached and we // need to invalidate the session. func (s *Server) invalidateSession(id string) { defer metrics.MeasureSince([]string{"consul", "session_ttl", "invalidate"}, time.Now()) + // Clear the session timer - s.sessionTimersLock.Lock() - delete(s.sessionTimers, id) - s.sessionTimersLock.Unlock() + s.sessionTimers.Del(id) // Create a session destroy request args := structs.SessionRequest{ @@ -137,26 +116,14 @@ func (s *Server) invalidateSession(id string) { // a single session. This is used when a session is destroyed // explicitly and no longer needed. func (s *Server) clearSessionTimer(id string) error { - s.sessionTimersLock.Lock() - defer s.sessionTimersLock.Unlock() - - if timer, ok := s.sessionTimers[id]; ok { - timer.Stop() - delete(s.sessionTimers, id) - } + s.sessionTimers.Stop(id) return nil } // clearAllSessionTimers is used when a leader is stepping // down and we no longer need to track any session timers. func (s *Server) clearAllSessionTimers() error { - s.sessionTimersLock.Lock() - defer s.sessionTimersLock.Unlock() - - for _, t := range s.sessionTimers { - t.Stop() - } - s.sessionTimers = nil + s.sessionTimers.StopAll() return nil } @@ -166,10 +133,7 @@ func (s *Server) sessionStats() { for { select { case <-time.After(5 * time.Second): - s.sessionTimersLock.Lock() - num := len(s.sessionTimers) - s.sessionTimersLock.Unlock() - metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(num)) + metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(s.sessionTimers.Len())) case <-s.shutdownCh: return diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index 5cf75129d..86e9855a6 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -39,8 +39,7 @@ func TestInitializeSessionTimers(t *testing.T) { } // Check that we have a timer - _, ok := s1.sessionTimers[session.ID] - if !ok { + if s1.sessionTimers.Get(session.ID) == nil { t.Fatalf("missing session timer") } } @@ -79,8 +78,7 @@ func TestResetSessionTimer_Fault(t *testing.T) { } // Check that we have a timer - _, ok := s1.sessionTimers[session.ID] - if !ok { + if s1.sessionTimers.Get(session.ID) == nil { t.Fatalf("missing session timer") } } @@ -113,8 +111,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) { } // Check that we have a timer - _, ok := s1.sessionTimers[session.ID] - if ok { + if s1.sessionTimers.Get(session.ID) != nil { t.Fatalf("should not have session timer") } } @@ -145,17 +142,13 @@ func TestResetSessionTimerLocked(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - s1.sessionTimersLock.Lock() - s1.resetSessionTimerLocked("foo", 5*time.Millisecond) - s1.sessionTimersLock.Unlock() - - if _, ok := s1.sessionTimers["foo"]; !ok { + s1.createSessionTimer("foo", 5*time.Millisecond) + if s1.sessionTimers.Get("foo") == nil { t.Fatalf("missing timer") } time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier) - - if _, ok := s1.sessionTimers["foo"]; ok { + if s1.sessionTimers.Get("foo") != nil { t.Fatalf("timer should be gone") } } @@ -165,39 +158,46 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + ttl := 100 * time.Millisecond - s1.sessionTimersLock.Lock() - s1.resetSessionTimerLocked("foo", 5*time.Millisecond) - s1.sessionTimersLock.Unlock() - - if _, ok := s1.sessionTimers["foo"]; !ok { + // create the timer + s1.createSessionTimer("foo", ttl) + if s1.sessionTimers.Get("foo") == nil { t.Fatalf("missing timer") } - time.Sleep(5 * time.Millisecond) + // wait until it is "expired" but at this point + // the session still exists. + time.Sleep(ttl) + if s1.sessionTimers.Get("foo") == nil { + t.Fatal("missing timer") + } - // Renew the session - s1.sessionTimersLock.Lock() - renew := time.Now() - s1.resetSessionTimerLocked("foo", 5*time.Millisecond) - s1.sessionTimersLock.Unlock() + // renew the session which will reset the TTL to 2*ttl + // since that is the current SessionTTLMultiplier + s1.createSessionTimer("foo", ttl) // Watch for invalidation - for time.Now().Sub(renew) < 20*time.Millisecond { - s1.sessionTimersLock.Lock() - _, ok := s1.sessionTimers["foo"] - s1.sessionTimersLock.Unlock() - if !ok { - end := time.Now() - if end.Sub(renew) < 5*time.Millisecond { - t.Fatalf("early invalidate") - } - return + renew := time.Now() + deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl) + for { + now := time.Now() + if now.After(deadline) { + t.Fatal("should have expired by now") } - time.Sleep(time.Millisecond) + + // timer still exists + if s1.sessionTimers.Get("foo") != nil { + time.Sleep(time.Millisecond) + continue + } + + // timer gone + if now.Sub(renew) < ttl { + t.Fatalf("early invalidate") + } + break } - t.Fatalf("should have expired") } func TestInvalidateSession(t *testing.T) { @@ -239,16 +239,14 @@ func TestClearSessionTimer(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.sessionTimersLock.Lock() - s1.resetSessionTimerLocked("foo", 5*time.Millisecond) - s1.sessionTimersLock.Unlock() + s1.createSessionTimer("foo", 5*time.Millisecond) err := s1.clearSessionTimer("foo") if err != nil { t.Fatalf("err: %v", err) } - if _, ok := s1.sessionTimers["foo"]; ok { + if s1.sessionTimers.Get("foo") != nil { t.Fatalf("timer should be gone") } } @@ -258,18 +256,17 @@ func TestClearAllSessionTimers(t *testing.T) { defer os.RemoveAll(dir1) defer s1.Shutdown() - s1.sessionTimersLock.Lock() - s1.resetSessionTimerLocked("foo", 10*time.Millisecond) - s1.resetSessionTimerLocked("bar", 10*time.Millisecond) - s1.resetSessionTimerLocked("baz", 10*time.Millisecond) - s1.sessionTimersLock.Unlock() + s1.createSessionTimer("foo", 10*time.Millisecond) + s1.createSessionTimer("bar", 10*time.Millisecond) + s1.createSessionTimer("baz", 10*time.Millisecond) err := s1.clearAllSessionTimers() if err != nil { t.Fatalf("err: %v", err) } - if len(s1.sessionTimers) != 0 { + // sessionTimers is guarded by the lock + if s1.sessionTimers.Len() != 0 { t.Fatalf("timers should be gone") } } @@ -297,7 +294,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) { var leader *Server for _, s := range servers { // Check that s.sessionTimers is empty - if len(s.sessionTimers) != 0 { + if s.sessionTimers.Len() != 0 { t.Fatalf("should have no sessionTimers") } // Find the leader too @@ -338,7 +335,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) { } // Check that sessionTimers has the session ID - if _, ok := leader.sessionTimers[id1]; !ok { + if leader.sessionTimers.Get(id1) == nil { t.Fatalf("missing session timer") } @@ -346,12 +343,11 @@ func TestServer_SessionTTL_Failover(t *testing.T) { leader.Shutdown() // sessionTimers should be cleared on leader shutdown - if len(leader.sessionTimers) != 0 { + if leader.sessionTimers.Len() != 0 { t.Fatalf("session timers should be empty on the shutdown leader") } // Find the new leader retry.Run(t, func(r *retry.R) { - leader = nil for _, s := range servers { if s.IsLeader() { @@ -363,7 +359,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) { } // Ensure session timer is restored - if _, ok := leader.sessionTimers[id1]; !ok { + if leader.sessionTimers.Get(id1) == nil { r.Fatal("missing session timer") } }) diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index d745d158d..eba80bc92 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -211,10 +211,10 @@ func TestSnapshot_LeaderState(t *testing.T) { } // Make sure the leader has timers setup. - if _, ok := s1.sessionTimers[before]; !ok { + if s1.sessionTimers.Get(before) == nil { t.Fatalf("missing session timer") } - if _, ok := s1.sessionTimers[after]; !ok { + if s1.sessionTimers.Get(after) == nil { t.Fatalf("missing session timer") } @@ -229,10 +229,10 @@ func TestSnapshot_LeaderState(t *testing.T) { // Make sure the before time is still there, and that the after timer // got reverted. This proves we fully cycled the leader state. - if _, ok := s1.sessionTimers[before]; !ok { + if s1.sessionTimers.Get(before) == nil { t.Fatalf("missing session timer") } - if _, ok := s1.sessionTimers[after]; ok { + if s1.sessionTimers.Get(after) != nil { t.Fatalf("unexpected session timer") } }