rpc: refactor sessionTimers and fix racy tests

The sessionTimers map was secured by a lock which wasn't used
properly in the tests. This lead to data races and failing tests
when accessing the length or the members of the map.

This patch adds a separate SessionTimers struct which is safe
for concurrent use and which ecapsulates the behavior of the
sessionTimers map.
This commit is contained in:
Frank Schroeder 2017-06-27 15:25:25 +02:00 committed by Frank Schröder
parent 06ad8e96be
commit b3189a566a
7 changed files with 253 additions and 106 deletions

View File

@ -172,8 +172,7 @@ type Server struct {
// sessionTimers track the expiration time of each Session that has
// a TTL. On expiration, a SessionDestroy event will occur, and
// destroy the session via standard session destroy processing
sessionTimers map[string]*time.Timer
sessionTimersLock sync.Mutex
sessionTimers *SessionTimers
// statsFetcher is used by autopilot to check the status of the other
// Consul servers.
@ -296,6 +295,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) {
rpcServer: rpc.NewServer(),
rpcTLS: incomingTLS,
reassertLeaderCh: make(chan chan error),
sessionTimers: NewSessionTimers(),
tombstoneGC: gc,
shutdownCh: shutdownCh,
}

View File

@ -514,7 +514,7 @@ func TestSession_ApplyTimers(t *testing.T) {
}
// Check the session map
if _, ok := s1.sessionTimers[out]; !ok {
if s1.sessionTimers.Get(out) == nil {
t.Fatalf("missing session timer")
}
@ -526,7 +526,7 @@ func TestSession_ApplyTimers(t *testing.T) {
}
// Check the session map
if _, ok := s1.sessionTimers[out]; ok {
if s1.sessionTimers.Get(out) != nil {
t.Fatalf("session timer exists")
}
}
@ -564,7 +564,7 @@ func TestSession_Renew(t *testing.T) {
}
// Verify the timer map is setup
if len(s1.sessionTimers) != 5 {
if s1.sessionTimers.Len() != 5 {
t.Fatalf("missing session timers")
}

View File

@ -0,0 +1,82 @@
package consul
import (
"sync"
"time"
)
// SessionTimers provides a map of named timers which
// is safe for concurrent use.
type SessionTimers struct {
sync.RWMutex
m map[string]*time.Timer
}
func NewSessionTimers() *SessionTimers {
return &SessionTimers{m: make(map[string]*time.Timer)}
}
// Get returns the timer with the given id or nil.
func (t *SessionTimers) Get(id string) *time.Timer {
t.RLock()
defer t.RUnlock()
return t.m[id]
}
// Set stores the timer under given id. If tm is nil the timer
// witht the given id is removed.
func (t *SessionTimers) Set(id string, tm *time.Timer) {
t.Lock()
defer t.Unlock()
if tm == nil {
// todo(fs): shouldn't we call Stop() here?
delete(t.m, id)
} else {
t.m[id] = tm
}
}
// Del removes the timer with the given id.
func (t *SessionTimers) Del(id string) {
t.Set(id, nil)
}
// Len returns the number of registered timers.
func (t *SessionTimers) Len() int {
t.RLock()
defer t.RUnlock()
return len(t.m)
}
// ResetOrCreate sets the ttl of the timer with the given id or creates a new
// one if it does not exist.
func (t *SessionTimers) ResetOrCreate(id string, ttl time.Duration, afterFunc func()) {
t.Lock()
defer t.Unlock()
if tm := t.m[id]; tm != nil {
tm.Reset(ttl)
return
}
t.m[id] = time.AfterFunc(ttl, afterFunc)
}
// Stop stops the timer with the given id and removes it.
func (t *SessionTimers) Stop(id string) {
t.Lock()
defer t.Unlock()
if tm := t.m[id]; tm != nil {
tm.Stop()
delete(t.m, id)
}
}
// StopAll stops and removes all registered timers.
func (t *SessionTimers) StopAll() {
t.Lock()
defer t.Unlock()
for _, tm := range t.m {
tm.Stop()
}
t.m = make(map[string]*time.Timer)
}

View File

@ -0,0 +1,105 @@
package consul
import (
"testing"
"time"
)
func TestSessionTimers(t *testing.T) {
m := NewSessionTimers()
ch := make(chan int)
newTm := func(d time.Duration) *time.Timer {
return time.AfterFunc(d, func() { ch <- 1 })
}
waitForTimer := func() {
select {
case <-ch:
return
case <-time.After(100 * time.Millisecond):
t.Fatal("timer did not fire")
}
}
// check that non-existent id returns nil
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
t.Fatalf("got %v want %v", got, want)
}
// add a timer and look it up and delete via Set(id, nil)
tm := newTm(time.Millisecond)
m.Set("foo", tm)
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
if got, want := m.Get("foo"), tm; got != want {
t.Fatalf("got %v want %v", got, want)
}
m.Set("foo", nil)
if got, want := m.Get("foo"), (*time.Timer)(nil); got != want {
t.Fatalf("got %v want %v", got, want)
}
waitForTimer()
// same thing via Del(id)
tm = newTm(time.Millisecond)
m.Set("foo", tm)
if got, want := m.Get("foo"), tm; got != want {
t.Fatalf("got %v want %v", got, want)
}
m.Del("foo")
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
waitForTimer()
// create timer via ResetOrCreate
m.ResetOrCreate("foo", time.Millisecond, func() { ch <- 1 })
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
waitForTimer()
// timer is still there
if got, want := m.Len(), 1; got != want {
t.Fatalf("got len %d want %d", got, want)
}
// reset the timer and check that it fires again
m.ResetOrCreate("foo", time.Millisecond, nil)
waitForTimer()
// reset the timer with a long ttl and then stop it
m.ResetOrCreate("foo", 20*time.Millisecond, func() { ch <- 1 })
m.Stop("foo")
select {
case <-ch:
t.Fatal("timer fired although it shouldn't")
case <-time.After(100 * time.Millisecond):
// want
}
// stopping a stopped timer should not break
m.Stop("foo")
// stop should also remove the timer
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
// create two timers and stop and then stop all
m.ResetOrCreate("foo1", 20*time.Millisecond, func() { ch <- 1 })
m.ResetOrCreate("foo2", 30*time.Millisecond, func() { ch <- 2 })
m.StopAll()
select {
case x := <-ch:
t.Fatalf("timer %d fired although it shouldn't", x)
case <-time.After(100 * time.Millisecond):
// want
}
// stopall should remove all timers
if got, want := m.Len(), 0; got != want {
t.Fatalf("got len %d want %d", got, want)
}
}

View File

@ -66,49 +66,28 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
return nil
}
// Reset the session timer
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
s.resetSessionTimerLocked(id, ttl)
s.createSessionTimer(session.ID, ttl)
return nil
}
// resetSessionTimerLocked is used to reset a session timer
// assuming the sessionTimerLock is already held
func (s *Server) resetSessionTimerLocked(id string, ttl time.Duration) {
// Ensure a timer map exists
if s.sessionTimers == nil {
s.sessionTimers = make(map[string]*time.Timer)
}
func (s *Server) createSessionTimer(id string, ttl time.Duration) {
// Reset the session timer
// Adjust the given TTL by the TTL multiplier. This is done
// to give a client a grace period and to compensate for network
// and processing delays. The contract is that a session is not expired
// before the TTL, but there is no explicit promise about the upper
// bound so this is allowable.
ttl = ttl * structs.SessionTTLMultiplier
// Renew the session timer if it exists
if timer, ok := s.sessionTimers[id]; ok {
timer.Reset(ttl)
return
}
// Create a new timer to track expiration of this ssession
timer := time.AfterFunc(ttl, func() {
s.invalidateSession(id)
})
s.sessionTimers[id] = timer
s.sessionTimers.ResetOrCreate(id, ttl, func() { s.invalidateSession(id) })
}
// invalidateSession is invoked when a session TTL is reached and we
// need to invalidate the session.
func (s *Server) invalidateSession(id string) {
defer metrics.MeasureSince([]string{"consul", "session_ttl", "invalidate"}, time.Now())
// Clear the session timer
s.sessionTimersLock.Lock()
delete(s.sessionTimers, id)
s.sessionTimersLock.Unlock()
s.sessionTimers.Del(id)
// Create a session destroy request
args := structs.SessionRequest{
@ -137,26 +116,14 @@ func (s *Server) invalidateSession(id string) {
// a single session. This is used when a session is destroyed
// explicitly and no longer needed.
func (s *Server) clearSessionTimer(id string) error {
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
if timer, ok := s.sessionTimers[id]; ok {
timer.Stop()
delete(s.sessionTimers, id)
}
s.sessionTimers.Stop(id)
return nil
}
// clearAllSessionTimers is used when a leader is stepping
// down and we no longer need to track any session timers.
func (s *Server) clearAllSessionTimers() error {
s.sessionTimersLock.Lock()
defer s.sessionTimersLock.Unlock()
for _, t := range s.sessionTimers {
t.Stop()
}
s.sessionTimers = nil
s.sessionTimers.StopAll()
return nil
}
@ -166,10 +133,7 @@ func (s *Server) sessionStats() {
for {
select {
case <-time.After(5 * time.Second):
s.sessionTimersLock.Lock()
num := len(s.sessionTimers)
s.sessionTimersLock.Unlock()
metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(num))
metrics.SetGauge([]string{"consul", "session_ttl", "active"}, float32(s.sessionTimers.Len()))
case <-s.shutdownCh:
return

View File

@ -39,8 +39,7 @@ func TestInitializeSessionTimers(t *testing.T) {
}
// Check that we have a timer
_, ok := s1.sessionTimers[session.ID]
if !ok {
if s1.sessionTimers.Get(session.ID) == nil {
t.Fatalf("missing session timer")
}
}
@ -79,8 +78,7 @@ func TestResetSessionTimer_Fault(t *testing.T) {
}
// Check that we have a timer
_, ok := s1.sessionTimers[session.ID]
if !ok {
if s1.sessionTimers.Get(session.ID) == nil {
t.Fatalf("missing session timer")
}
}
@ -113,8 +111,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
}
// Check that we have a timer
_, ok := s1.sessionTimers[session.ID]
if ok {
if s1.sessionTimers.Get(session.ID) != nil {
t.Fatalf("should not have session timer")
}
}
@ -145,17 +142,13 @@ func TestResetSessionTimerLocked(t *testing.T) {
testrpc.WaitForLeader(t, s1.RPC, "dc1")
s1.sessionTimersLock.Lock()
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
if _, ok := s1.sessionTimers["foo"]; !ok {
s1.createSessionTimer("foo", 5*time.Millisecond)
if s1.sessionTimers.Get("foo") == nil {
t.Fatalf("missing timer")
}
time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier)
if _, ok := s1.sessionTimers["foo"]; ok {
if s1.sessionTimers.Get("foo") != nil {
t.Fatalf("timer should be gone")
}
}
@ -165,39 +158,46 @@ func TestResetSessionTimerLocked_Renew(t *testing.T) {
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
ttl := 100 * time.Millisecond
s1.sessionTimersLock.Lock()
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
if _, ok := s1.sessionTimers["foo"]; !ok {
// create the timer
s1.createSessionTimer("foo", ttl)
if s1.sessionTimers.Get("foo") == nil {
t.Fatalf("missing timer")
}
time.Sleep(5 * time.Millisecond)
// wait until it is "expired" but at this point
// the session still exists.
time.Sleep(ttl)
if s1.sessionTimers.Get("foo") == nil {
t.Fatal("missing timer")
}
// Renew the session
s1.sessionTimersLock.Lock()
renew := time.Now()
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
// renew the session which will reset the TTL to 2*ttl
// since that is the current SessionTTLMultiplier
s1.createSessionTimer("foo", ttl)
// Watch for invalidation
for time.Now().Sub(renew) < 20*time.Millisecond {
s1.sessionTimersLock.Lock()
_, ok := s1.sessionTimers["foo"]
s1.sessionTimersLock.Unlock()
if !ok {
end := time.Now()
if end.Sub(renew) < 5*time.Millisecond {
t.Fatalf("early invalidate")
}
return
renew := time.Now()
deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl)
for {
now := time.Now()
if now.After(deadline) {
t.Fatal("should have expired by now")
}
time.Sleep(time.Millisecond)
// timer still exists
if s1.sessionTimers.Get("foo") != nil {
time.Sleep(time.Millisecond)
continue
}
// timer gone
if now.Sub(renew) < ttl {
t.Fatalf("early invalidate")
}
break
}
t.Fatalf("should have expired")
}
func TestInvalidateSession(t *testing.T) {
@ -239,16 +239,14 @@ func TestClearSessionTimer(t *testing.T) {
defer os.RemoveAll(dir1)
defer s1.Shutdown()
s1.sessionTimersLock.Lock()
s1.resetSessionTimerLocked("foo", 5*time.Millisecond)
s1.sessionTimersLock.Unlock()
s1.createSessionTimer("foo", 5*time.Millisecond)
err := s1.clearSessionTimer("foo")
if err != nil {
t.Fatalf("err: %v", err)
}
if _, ok := s1.sessionTimers["foo"]; ok {
if s1.sessionTimers.Get("foo") != nil {
t.Fatalf("timer should be gone")
}
}
@ -258,18 +256,17 @@ func TestClearAllSessionTimers(t *testing.T) {
defer os.RemoveAll(dir1)
defer s1.Shutdown()
s1.sessionTimersLock.Lock()
s1.resetSessionTimerLocked("foo", 10*time.Millisecond)
s1.resetSessionTimerLocked("bar", 10*time.Millisecond)
s1.resetSessionTimerLocked("baz", 10*time.Millisecond)
s1.sessionTimersLock.Unlock()
s1.createSessionTimer("foo", 10*time.Millisecond)
s1.createSessionTimer("bar", 10*time.Millisecond)
s1.createSessionTimer("baz", 10*time.Millisecond)
err := s1.clearAllSessionTimers()
if err != nil {
t.Fatalf("err: %v", err)
}
if len(s1.sessionTimers) != 0 {
// sessionTimers is guarded by the lock
if s1.sessionTimers.Len() != 0 {
t.Fatalf("timers should be gone")
}
}
@ -297,7 +294,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
var leader *Server
for _, s := range servers {
// Check that s.sessionTimers is empty
if len(s.sessionTimers) != 0 {
if s.sessionTimers.Len() != 0 {
t.Fatalf("should have no sessionTimers")
}
// Find the leader too
@ -338,7 +335,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
}
// Check that sessionTimers has the session ID
if _, ok := leader.sessionTimers[id1]; !ok {
if leader.sessionTimers.Get(id1) == nil {
t.Fatalf("missing session timer")
}
@ -346,12 +343,11 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
leader.Shutdown()
// sessionTimers should be cleared on leader shutdown
if len(leader.sessionTimers) != 0 {
if leader.sessionTimers.Len() != 0 {
t.Fatalf("session timers should be empty on the shutdown leader")
}
// Find the new leader
retry.Run(t, func(r *retry.R) {
leader = nil
for _, s := range servers {
if s.IsLeader() {
@ -363,7 +359,7 @@ func TestServer_SessionTTL_Failover(t *testing.T) {
}
// Ensure session timer is restored
if _, ok := leader.sessionTimers[id1]; !ok {
if leader.sessionTimers.Get(id1) == nil {
r.Fatal("missing session timer")
}
})

View File

@ -211,10 +211,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
}
// Make sure the leader has timers setup.
if _, ok := s1.sessionTimers[before]; !ok {
if s1.sessionTimers.Get(before) == nil {
t.Fatalf("missing session timer")
}
if _, ok := s1.sessionTimers[after]; !ok {
if s1.sessionTimers.Get(after) == nil {
t.Fatalf("missing session timer")
}
@ -229,10 +229,10 @@ func TestSnapshot_LeaderState(t *testing.T) {
// Make sure the before time is still there, and that the after timer
// got reverted. This proves we fully cycled the leader state.
if _, ok := s1.sessionTimers[before]; !ok {
if s1.sessionTimers.Get(before) == nil {
t.Fatalf("missing session timer")
}
if _, ok := s1.sessionTimers[after]; ok {
if s1.sessionTimers.Get(after) != nil {
t.Fatalf("unexpected session timer")
}
}