rpc: refactor sessionTimers and fix racy tests

The sessionTimers map was secured by a lock which wasn't used
properly in the tests. This lead to data races and failing tests
when accessing the length or the members of the map.

This patch adds a separate SessionTimers struct which is safe
for concurrent use and which ecapsulates the behavior of the
sessionTimers map.
This commit is contained in:
Frank Schroeder 2017-06-27 15:25:25 +02:00 committed by Frank Schröder
parent 06ad8e96be
commit b3189a566a
7 changed files with 253 additions and 106 deletions

View File

@ -172,8 +172,7 @@ type Server struct {
// sessionTimers track the expiration time of each Session that has // sessionTimers track the expiration time of each Session that has
// a TTL. On expiration, a SessionDestroy event will occur, and // a TTL. On expiration, a SessionDestroy event will occur, and
// destroy the session via standard session destroy processing // destroy the session via standard session destroy processing
sessionTimers map[string]*time.Timer sessionTimers *SessionTimers
sessionTimersLock sync.Mutex
// statsFetcher is used by autopilot to check the status of the other // statsFetcher is used by autopilot to check the status of the other
// Consul servers. // Consul servers.
@ -296,6 +295,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) {
rpcServer: rpc.NewServer(), rpcServer: rpc.NewServer(),
rpcTLS: incomingTLS, rpcTLS: incomingTLS,
reassertLeaderCh: make(chan chan error), reassertLeaderCh: make(chan chan error),
sessionTimers: NewSessionTimers(),
tombstoneGC: gc, tombstoneGC: gc,
shutdownCh: shutdownCh, shutdownCh: shutdownCh,
} }

View File

@ -514,7 +514,7 @@ func TestSession_ApplyTimers(t *testing.T) {
} }
// Check the session map // Check the session map
if _, ok := s1.sessionTimers[out]; !ok { if s1.sessionTimers.Get(out) == nil {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
@ -526,7 +526,7 @@ func TestSession_ApplyTimers(t *testing.T) {
} }
// Check the session map // Check the session map
if _, ok := s1.sessionTimers[out]; ok { if s1.sessionTimers.Get(out) != nil {
t.Fatalf("session timer exists") t.Fatalf("session timer exists")
} }
} }
@ -564,7 +564,7 @@ func TestSession_Renew(t *testing.T) {
} }
// Verify the timer map is setup // Verify the timer map is setup
if len(s1.sessionTimers) != 5 { if s1.sessionTimers.Len() != 5 {
t.Fatalf("missing session timers") t.Fatalf("missing session timers")
} }

View File

@ -0,0 +1,82 @@
package consul
import (
"sync"
"time"
)
// SessionTimers provides a map of named timers which
// is safe for concurrent use.
type SessionTimers struct {
sync.RWMutex
m map[string]*time.Timer
}
func NewSessionTimers() *SessionTimers {
return &SessionTimers{m: make(map[string]*time.Timer)}
}
// Get returns the timer with the given id or nil.
func (t *SessionTimers) Get(id string) *time.Timer {
t.RLock()
defer t.RUnlock()
return t.m[id]
}
// Set stores the timer under given id. If tm is nil the timer
// witht the given id is removed.
func (t *SessionTimers) Set(id string, tm *time.Timer) {
t.Lock()
defer t.Unlock()
if tm == nil {
// todo(fs): shouldn't we call Stop() here?
delete(t.m, id)
} else {
t.m[id] = tm
}
}
// Del removes the timer with the given id.
func (t *SessionTimers) Del(id string) {
t.Set(id, nil)
}
// Len returns the number of registered timers.
func (t *SessionTimers) Len() int {
t.RLock()
defer t.RUnlock()
return len(t.m)
}
// ResetOrCreate sets the ttl of the timer with the given id or creates a new
// one if it does not exist.
func (t *SessionTimers) ResetOrCreate(id string, ttl time.Duration, afterFunc func()) {
t.Lock()
defer t.Unlock()
if tm := t.m[id]; tm != nil {
tm.Reset(ttl)
return
}
t.m[id] = time.AfterFunc(ttl, afterFunc)
}
// Stop stops the timer with the given id and removes it.
func (t *SessionTimers) Stop(id string) {
t.Lock()
defer t.Unlock()
if tm := t.m[id]; tm != nil {
tm.Stop()
delete(t.m, id)
}
}
// StopAll stops and removes all registered timers.
func (t *SessionTimers) StopAll() {
t.Lock()
defer t.Unlock()
for _, tm := range t.m {
tm.Stop()
}
t.m = make(map[string]*time.Timer)
}

View File

@ -0,0 +1,105 @@
package consul
import (
"testing"
"time"
)
func TestSessionTimers(t *testing.T) {
m := NewSessionTimers()
ch := make(chan int)
newTm := func(d time.Duration) *time.Timer {
return time.AfterFunc(d, func() { ch <- 1 })
}
waitForTimer := func() {
select {
case <-ch:
return
case <-time.After(100 * time.Millisecond):
t.Fatal("timer did not fire")
}
}
// check that non-existent id returns nil
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
t.Fatalf("got %v want %v", got, want)
}
// add a timer and look it up and delete via Set(id, nil)
tm := newTm(time.Millisecond)
m.Set("foo", tm)
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
if got, want := m.Get("foo"), tm; got != want {
t.Fatalf("got %v want %v", got, want)
}
m.Set("foo", nil)
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
t.Fatalf("got %v want %v", got, want)
}
waitForTimer()
// same thing via Del(id)
tm = newTm(time.Millisecond)
m.Set("foo", tm)
if got, want := m.Get("foo"), tm; got != want {
t.Fatalf("got %v want %v", got, want)
}
m.Del("foo")
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
waitForTimer()
// create timer via ResetOrCreate
m.ResetOrCreate("foo", time.Millisecond, func() { ch <- 1 })
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
waitForTimer()
// timer is still there
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
// reset the timer and check that it fires again
m.ResetOrCreate("foo", time.Millisecond, nil)
waitForTimer()
// reset the timer with a long ttl and then stop it
m.ResetOrCreate("foo", 20*time.Millisecond, func() { ch <- 1 })
m.Stop("foo")
select {
case <-ch:
t.Fatal("timer fired although it shouldn't")
case <-time.After(100 * time.Millisecond):
// want
}
// stopping a stopped timer should not break
m.Stop("foo")
// stop should also remove the timer
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
// create two timers and stop and then stop all
m.ResetOrCreate("foo1", 20*time.Millisecond, func() { ch <- 1 })
m.ResetOrCreate("foo2", 30*time.Millisecond, func() { ch <- 2 })
m.StopAll()
select {
case x := <-ch:
t.Fatalf("timer %d fired although it shouldn't", x)
case <-time.After(100 * time.Millisecond):
// want
}
// stopall should remove all timers
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
}

View File

@ -66,49 +66,28 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
return nil return nil
} }
// Reset the session timer s.createSessionTimer(session.ID, ttl)
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
s.resetSessionTimerLocked(id, ttl)
return nil return nil
} }
// resetSessionTimerLocked is used to reset a session timer func (s *Server) createSessionTimer(id string, ttl time.Duration) {
// assuming the sessionTimerLock is already held // Reset the session timer
func (s *Server) resetSessionTimerLocked(id string, ttl time.Duration) {
// Ensure a timer map exists
if s.sessionTimers == nil {
s.sessionTimers = make(map[string]*time.Timer)
}
// Adjust the given TTL by the TTL multiplier. This is done // Adjust the given TTL by the TTL multiplier. This is done
// to give a client a grace period and to compensate for network // to give a client a grace period and to compensate for network
// and processing delays. The contract is that a session is not expired // and processing delays. The contract is that a session is not expired
// before the TTL, but there is no explicit promise about the upper // before the TTL, but there is no explicit promise about the upper
// bound so this is allowable. // bound so this is allowable.
ttl = ttl * structs.SessionTTLMultiplier ttl = ttl * structs.SessionTTLMultiplier
s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) })
// Renew the session timer if it exists
if timer, ok := s.sessionTimers[id]; ok {
timer.Reset(ttl)
return
}
// Create a new timer to track expiration of this ssession
timer := time.AfterFunc(ttl, func() {
s.invalidateSession(id)
})
s.sessionTimers[id] = timer
} }
// invalidateSession is invoked when a session TTL is reached and we // invalidateSession is invoked when a session TTL is reached and we
// need to invalidate the session. // need to invalidate the session.
func (s *Server) invalidateSession(id string) { func (s *Server) invalidateSession(id string) {
defer metrics.MeasureSince([]string{"consul", "session_ttl", "invalidate"}, time.Now()) defer metrics.MeasureSince([]string{"consul", "session_ttl", "invalidate"}, time.Now())
// Clear the session timer // Clear the session timer
s.sessionTimersLock.Lock() s.sessionTimers.Del(id)
delete(s.sessionTimers, id)
s.sessionTimersLock.Unlock()
// Create a session destroy request // Create a session destroy request
args := structs.SessionRequest{ args := structs.SessionRequest{
@ -137,26 +116,14 @@ func (s *Server) invalidateSession(id string) {
// a single session. This is used when a session is destroyed // a single session. This is used when a session is destroyed
// explicitly and no longer needed. // explicitly and no longer needed.
func (s *Server) clearSessionTimer(id string) error { func (s *Server) clearSessionTimer(id string) error {
s.sessionTimersLock.Lock() s.sessionTimers.Stop(id)
defer s.sessionTimersLock.Unlock()
if timer, ok := s.sessionTimers[id]; ok {
timer.Stop()
delete(s.sessionTimers, id)
}
return nil return nil
} }
// clearAllSessionTimers is used when a leader is stepping // clearAllSessionTimers is used when a leader is stepping
// down and we no longer need to track any session timers. // down and we no longer need to track any session timers.
func (s *Server) clearAllSessionTimers() error { func (s *Server) clearAllSessionTimers() error {
s.sessionTimersLock.Lock() s.sessionTimers.StopAll()
defer s.sessionTimersLock.Unlock()
for _, t := range s.sessionTimers {
t.Stop()
}
s.sessionTimers = nil
return nil return nil
} }
@ -166,10 +133,7 @@ func (s *Server) sessionStats() {
for { for {
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
s.sessionTimersLock.Lock() metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(s.sessionTimers.Len()))
num := len(s.sessionTimers)
s.sessionTimersLock.Unlock()
metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(num))
case <-s.shutdownCh: case <-s.shutdownCh:
return return

View File

@ -39,8 +39,7 @@ func TestInitializeSessionTimers(t *testing.T) {
} }
// Check that we have a timer // Check that we have a timer
_, ok := s1.sessionTimers[session.ID] if s1.sessionTimers.Get(session.ID) == nil {
if !ok {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
} }
@ -79,8 +78,7 @@ func TestResetSessionTimer_Fault(t *testing.T) {
} }
// Check that we have a timer // Check that we have a timer
_, ok := s1.sessionTimers[session.ID] if s1.sessionTimers.Get(session.ID) == nil {
if !ok {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
} }
@ -113,8 +111,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
} }
// Check that we have a timer // Check that we have a timer
_, ok := s1.sessionTimers[session.ID] if s1.sessionTimers.Get(session.ID) != nil {
if ok {
t.Fatalf("should not have session timer") t.Fatalf("should not have session timer")
} }
} }
@ -145,17 +142,13 @@ func TestResetSessionTimerLocked(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")
s1.sessionTimersLock.Lock() s1.createSessionTimer("foo", 5*time.Millisecond)
s1.resetSessionTimerLocked("foo", 5*time.Millisecond) if s1.sessionTimers.Get("foo") == nil {
s1.sessionTimersLock.Unlock()
if _, ok := s1.sessionTimers["foo"]; !ok {
t.Fatalf("missing timer") t.Fatalf("missing timer")
} }
time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier) time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier)
if s1.sessionTimers.Get("foo") != nil {
if _, ok := s1.sessionTimers["foo"]; ok {
t.Fatalf("timer should be gone") t.Fatalf("timer should be gone")
} }
} }
@ -165,39 +158,46 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) {
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1") ttl := 100 * time.Millisecond
s1.sessionTimersLock.Lock() // create the timer
s1.resetSessionTimerLocked("foo", 5*time.Millisecond) s1.createSessionTimer("foo", ttl)
s1.sessionTimersLock.Unlock() if s1.sessionTimers.Get("foo") == nil {
if _, ok := s1.sessionTimers["foo"]; !ok {
t.Fatalf("missing timer") t.Fatalf("missing timer")
} }
time.Sleep(5 * time.Millisecond) // wait until it is "expired" but at this point
// the session still exists.
time.Sleep(ttl)
if s1.sessionTimers.Get("foo") == nil {
t.Fatal("missing timer")
}
// Renew the session // renew the session which will reset the TTL to 2*ttl
s1.sessionTimersLock.Lock() // since that is the current SessionTTLMultiplier
renew := time.Now() s1.createSessionTimer("foo", ttl)
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
// Watch for invalidation // Watch for invalidation
for time.Now().Sub(renew) < 20*time.Millisecond { renew := time.Now()
s1.sessionTimersLock.Lock() deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl)
_, ok := s1.sessionTimers["foo"] for {
s1.sessionTimersLock.Unlock() now := time.Now()
if !ok { if now.After(deadline) {
end := time.Now() t.Fatal("should have expired by now")
if end.Sub(renew) < 5*time.Millisecond { }
// timer still exists
if s1.sessionTimers.Get("foo") != nil {
time.Sleep(time.Millisecond)
continue
}
// timer gone
if now.Sub(renew) < ttl {
t.Fatalf("early invalidate") t.Fatalf("early invalidate")
} }
return break
} }
time.Sleep(time.Millisecond)
}
t.Fatalf("should have expired")
} }
func TestInvalidateSession(t *testing.T) { func TestInvalidateSession(t *testing.T) {
@ -239,16 +239,14 @@ func TestClearSessionTimer(t *testing.T) {
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
s1.sessionTimersLock.Lock() s1.createSessionTimer("foo", 5*time.Millisecond)
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
err := s1.clearSessionTimer("foo") err := s1.clearSessionTimer("foo")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if _, ok := s1.sessionTimers["foo"]; ok { if s1.sessionTimers.Get("foo") != nil {
t.Fatalf("timer should be gone") t.Fatalf("timer should be gone")
} }
} }
@ -258,18 +256,17 @@ func TestClearAllSessionTimers(t *testing.T) {
defer os.RemoveAll(dir1) defer os.RemoveAll(dir1)
defer s1.Shutdown() defer s1.Shutdown()
s1.sessionTimersLock.Lock() s1.createSessionTimer("foo", 10*time.Millisecond)
s1.resetSessionTimerLocked("foo", 10*time.Millisecond) s1.createSessionTimer("bar", 10*time.Millisecond)
s1.resetSessionTimerLocked("bar", 10*time.Millisecond) s1.createSessionTimer("baz", 10*time.Millisecond)
s1.resetSessionTimerLocked("baz", 10*time.Millisecond)
s1.sessionTimersLock.Unlock()
err := s1.clearAllSessionTimers() err := s1.clearAllSessionTimers()
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(s1.sessionTimers) != 0 { // sessionTimers is guarded by the lock
if s1.sessionTimers.Len() != 0 {
t.Fatalf("timers should be gone") t.Fatalf("timers should be gone")
} }
} }
@ -297,7 +294,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
var leader *Server var leader *Server
for _, s := range servers { for _, s := range servers {
// Check that s.sessionTimers is empty // Check that s.sessionTimers is empty
if len(s.sessionTimers) != 0 { if s.sessionTimers.Len() != 0 {
t.Fatalf("should have no sessionTimers") t.Fatalf("should have no sessionTimers")
} }
// Find the leader too // Find the leader too
@ -338,7 +335,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
} }
// Check that sessionTimers has the session ID // Check that sessionTimers has the session ID
if _, ok := leader.sessionTimers[id1]; !ok { if leader.sessionTimers.Get(id1) == nil {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
@ -346,12 +343,11 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
leader.Shutdown() leader.Shutdown()
// sessionTimers should be cleared on leader shutdown // sessionTimers should be cleared on leader shutdown
if len(leader.sessionTimers) != 0 { if leader.sessionTimers.Len() != 0 {
t.Fatalf("session timers should be empty on the shutdown leader") t.Fatalf("session timers should be empty on the shutdown leader")
} }
// Find the new leader // Find the new leader
retry.Run(t, func(r *retry.R) { retry.Run(t, func(r *retry.R) {
leader = nil leader = nil
for _, s := range servers { for _, s := range servers {
if s.IsLeader() { if s.IsLeader() {
@ -363,7 +359,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
} }
// Ensure session timer is restored // Ensure session timer is restored
if _, ok := leader.sessionTimers[id1]; !ok { if leader.sessionTimers.Get(id1) == nil {
r.Fatal("missing session timer") r.Fatal("missing session timer")
} }
}) })

View File

@ -211,10 +211,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
} }
// Make sure the leader has timers setup. // Make sure the leader has timers setup.
if _, ok := s1.sessionTimers[before]; !ok { if s1.sessionTimers.Get(before) == nil {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
if _, ok := s1.sessionTimers[after]; !ok { if s1.sessionTimers.Get(after) == nil {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
@ -229,10 +229,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
// Make sure the before time is still there, and that the after timer // Make sure the before time is still there, and that the after timer
// got reverted. This proves we fully cycled the leader state. // got reverted. This proves we fully cycled the leader state.
if _, ok := s1.sessionTimers[before]; !ok { if s1.sessionTimers.Get(before) == nil {
t.Fatalf("missing session timer") t.Fatalf("missing session timer")
} }
if _, ok := s1.sessionTimers[after]; ok { if s1.sessionTimers.Get(after) != nil {
t.Fatalf("unexpected session timer") t.Fatalf("unexpected session timer")
} }
} }