From 4d0903f781bb30b05991fc9af089e11b91966d9c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Fri, 12 Dec 2014 21:42:59 -0800 Subject: [PATCH] consul: Adding more tests for session TTLs --- consul/session_ttl.go | 11 +- consul/session_ttl_test.go | 362 +++++++++++++++++++++++++++++-------- consul/state_store_test.go | 2 +- 3 files changed, 287 insertions(+), 88 deletions(-) diff --git a/consul/session_ttl.go b/consul/session_ttl.go index 5fad0f9eb..1d9bad93c 100644 --- a/consul/session_ttl.go +++ b/consul/session_ttl.go @@ -11,9 +11,6 @@ import ( // a new map to track session expiration and to reset all the timers from // the previously known set of timers. func (s *Server) initializeSessionTimers() error { - s.sessionTimersLock.Lock() - defer s.sessionTimersLock.Unlock() - // Scan all sessions and reset their timer state := s.fsm.State() _, sessions, err := state.SessionList() @@ -45,9 +42,9 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { session = s } - // Bail if the session has no TTL + // Bail if the session has no TTL, fast-path some common inputs switch session.TTL { - case "", "0s", "0m", "0h": + case "", "0", "0s", "0m", "0h": return nil } @@ -62,8 +59,8 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error { // Reset the session timer s.sessionTimersLock.Lock() + defer s.sessionTimersLock.Unlock() s.resetSessionTimerLocked(id, ttl) - s.sessionTimersLock.Unlock() return nil } @@ -111,7 +108,7 @@ func (s *Server) invalidateSession(id string) { ID: id, }, } - s.logger.Printf("[DEBUG] consul.state: Invalidating session %s due to TTL timeout", id) + s.logger.Printf("[DEBUG] consul.state: Session %s TTL expired", id) // Apply the update to destroy the session if _, err := s.raftApply(structs.SessionRequestType, args); err != nil { diff --git a/consul/session_ttl_test.go b/consul/session_ttl_test.go index d26264f03..8ae75dddc 100644 --- a/consul/session_ttl_test.go +++ b/consul/session_ttl_test.go @@ -1,9 +1,9 @@ package consul import ( - "errors" "fmt" "os" + "strings" "testing" "time" @@ -11,7 +11,259 @@ import ( "github.com/hashicorp/consul/testutil" ) -func TestServer_sessionTTL(t *testing.T) { +func TestInitializeSessionTimers(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + state := s1.fsm.State() + state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + TTL: "10s", + } + if err := state.SessionCreate(100, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Reset the session timers + err := s1.initializeSessionTimers() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Check that we have a timer + _, ok := s1.sessionTimers[session.ID] + if !ok { + t.Fatalf("missing session timer") + } +} + +func TestResetSessionTimer_Fault(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + // Should not exist + err := s1.resetSessionTimer("nope", nil) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err: %v", err) + } + + // Create a session + state := s1.fsm.State() + state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + TTL: "10s", + } + if err := state.SessionCreate(100, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Reset the session timer + err = s1.resetSessionTimer(session.ID, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Check that we have a timer + _, ok := s1.sessionTimers[session.ID] + if !ok { + t.Fatalf("missing session timer") + } +} + +func TestResetSessionTimer_NoTTL(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + // Create a session + state := s1.fsm.State() + state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + TTL: "0000s", + } + if err := state.SessionCreate(100, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Reset the session timer + err := s1.resetSessionTimer(session.ID, session) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Check that we have a timer + _, ok := s1.sessionTimers[session.ID] + if ok { + t.Fatalf("should not have session timer") + } +} + +func TestResetSessionTimer_InvalidTTL(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + // Create a session + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + TTL: "foo", + } + + // Reset the session timer + err := s1.resetSessionTimer(session.ID, session) + if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") { + t.Fatalf("err: %v", err) + } +} + +func TestResetSessionTimerLocked(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + s1.sessionTimersLock.Lock() + s1.resetSessionTimerLocked("foo", 5*time.Millisecond) + s1.sessionTimersLock.Unlock() + + if _, ok := s1.sessionTimers["foo"]; !ok { + t.Fatalf("missing timer") + } + + time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier) + + if _, ok := s1.sessionTimers["foo"]; ok { + t.Fatalf("timer should be gone") + } +} + +func TestResetSessionTimerLocked_Renew(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + s1.sessionTimersLock.Lock() + s1.resetSessionTimerLocked("foo", 5*time.Millisecond) + s1.sessionTimersLock.Unlock() + + if _, ok := s1.sessionTimers["foo"]; !ok { + t.Fatalf("missing timer") + } + + time.Sleep(5 * time.Millisecond) + + // Renew the session + s1.sessionTimersLock.Lock() + renew := time.Now() + s1.resetSessionTimerLocked("foo", 5*time.Millisecond) + s1.sessionTimersLock.Unlock() + + // Watch for invalidation + for time.Now().Sub(renew) < 20*time.Millisecond { + s1.sessionTimersLock.Lock() + _, ok := s1.sessionTimers["foo"] + s1.sessionTimersLock.Unlock() + if !ok { + end := time.Now() + if end.Sub(renew) < 5*time.Millisecond { + t.Fatalf("early invalidate") + } + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("should have expired") +} + +func TestInvalidateSession(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + // Create a session + state := s1.fsm.State() + state.EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + session := &structs.Session{ + ID: generateUUID(), + Node: "foo", + TTL: "10s", + } + if err := state.SessionCreate(100, session); err != nil { + t.Fatalf("err: %v", err) + } + + // This should cause a destroy + s1.invalidateSession(session.ID) + + // Check it is gone + _, sess, err := state.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if sess != nil { + t.Fatalf("should destroy session") + } +} + +func TestClearSessionTimer(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + s1.sessionTimersLock.Lock() + s1.resetSessionTimerLocked("foo", 5*time.Millisecond) + s1.sessionTimersLock.Unlock() + + err := s1.clearSessionTimer("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + + if _, ok := s1.sessionTimers["foo"]; ok { + t.Fatalf("timer should be gone") + } +} + +func TestClearAllSessionTimers(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + s1.sessionTimersLock.Lock() + s1.resetSessionTimerLocked("foo", 10*time.Millisecond) + s1.resetSessionTimerLocked("bar", 10*time.Millisecond) + s1.resetSessionTimerLocked("baz", 10*time.Millisecond) + s1.sessionTimersLock.Unlock() + + err := s1.clearAllSessionTimers() + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(s1.sessionTimers) != 0 { + t.Fatalf("timers should be gone") + } +} + +func TestServer_SessionTTL_Failover(t *testing.T) { dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -35,28 +287,25 @@ func TestServer_sessionTTL(t *testing.T) { t.Fatalf("err: %v", err) } - for _, s := range servers { - testutil.WaitForResult(func() (bool, error) { - peers, _ := s.raftPeers.Peers() - return len(peers) == 3, nil - }, func(err error) { - t.Fatalf("should have 3 peers") - }) - } + testutil.WaitForResult(func() (bool, error) { + peers, _ := s1.raftPeers.Peers() + return len(peers) == 3, nil + }, func(err error) { + t.Fatalf("should have 3 peers") + }) // Find the leader var leader *Server for _, s := range servers { - // check that s.sessionTimers is empty + // Check that s.sessionTimers is empty if len(s.sessionTimers) != 0 { t.Fatalf("should have no sessionTimers") } - // find the leader too + // Find the leader too if s.IsLeader() { leader = s } } - if leader == nil { t.Fatalf("Should have a leader") } @@ -64,9 +313,18 @@ func TestServer_sessionTTL(t *testing.T) { client := rpcClient(t, leader) defer client.Close() - leader.fsm.State().EnsureNode(1, structs.Node{"foo", "127.0.0.1"}) + // Register a node + node := structs.RegisterRequest{ + Datacenter: s1.config.Datacenter, + Node: "foo", + Address: "127.0.0.1", + } + var out struct{} + if err := s1.RPC("Catalog.Register", &node, &out); err != nil { + t.Fatalf("err: %v", err) + } - // create a TTL session + // Create a TTL session arg := structs.SessionRequest{ Datacenter: "dc1", Op: structs.SessionCreate, @@ -80,89 +338,33 @@ func TestServer_sessionTTL(t *testing.T) { t.Fatalf("err: %v", err) } - // check that leader.sessionTimers has the session id in it - // means initializeSessionTimers was called and resetSessionTimer was called - if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil { - t.Fatalf("sessionTimers not initialized and does not contain session timer for session") + // Check that sessionTimers has the session ID + if _, ok := leader.sessionTimers[id1]; !ok { + t.Fatalf("missing session timer") } - time.Sleep(100 * time.Millisecond) - leader.Leave() + // Shutdown the leader! leader.Shutdown() - // leader.sessionTimers should be empty due to clearAllSessionTimers getting called + // sessionTimers should be cleared on leader shutdown if len(leader.sessionTimers) != 0 { t.Fatalf("session timers should be empty on the shutdown leader") } - time.Sleep(100 * time.Millisecond) - - var remain *Server - for _, s := range servers { - if s == leader { - continue - } - remain = s - testutil.WaitForResult(func() (bool, error) { - peers, _ := s.raftPeers.Peers() - return len(peers) == 2, errors.New(fmt.Sprintf("%v", peers)) - }, func(err error) { - t.Fatalf("should have 2 peers: %v", err) - }) - } - - // Verify the old leader is deregistered - state := remain.fsm.State() - testutil.WaitForResult(func() (bool, error) { - _, found, _ := state.GetNode(leader.config.NodeName) - return !found, nil - }, func(err error) { - t.Fatalf("leader should be deregistered") - }) - // Find the new leader + time.Sleep(200 * time.Millisecond) leader = nil for _, s := range servers { - // find the leader too if s.IsLeader() { leader = s } } - if leader == nil { t.Fatalf("Should have a new leader") } - // check that new leader.sessionTimers has the session id in it - if len(leader.sessionTimers) == 0 || leader.sessionTimers[id1] == nil { - t.Fatalf("sessionTimers not initialized and does not contain session timer for session") - } - - // create another TTL session with the same parameters - var id2 string - if err := client.Call("Session.Apply", &arg, &id2); err != nil { - t.Fatalf("err: %v", err) - } - - if len(leader.sessionTimers) != 2 { - t.Fatalf("sessionTimes length should be 2") - } - - // destroy the via invalidateSession as if on TTL expiry - leader.invalidateSession(id2) - - if len(leader.sessionTimers) != 1 { - t.Fatalf("sessionTimers length should 1") - } - - // destroy the id2 session (test clearSessionTimer) - arg.Op = structs.SessionDestroy - arg.Session.ID = id2 - if err := client.Call("Session.Apply", &arg, &id2); err != nil { - t.Fatalf("err: %v", err) - } - - if len(leader.sessionTimers) != 0 { - t.Fatalf("sessionTimers length should be 0") + // Ensure session timer is restored + if _, ok := leader.sessionTimers[id1]; !ok { + t.Fatalf("missing session timer") } } diff --git a/consul/state_store_test.go b/consul/state_store_test.go index c933dbdb0..888b4a343 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -703,7 +703,7 @@ func TestStoreSnapshot(t *testing.T) { if ok, err := store.KVSLock(18, d); err != nil || !ok { t.Fatalf("err: %v", err) } - session = &structs.Session{ID: generateUUID(), Node: "baz", TTL: "60s"} + session = &structs.Session{ID: generateUUID(), Node: "bar", TTL: "60s"} if err := store.SessionCreate(19, session); err != nil { t.Fatalf("err: %v", err) }