diff --git a/agent/consul/session_endpoint.go b/agent/consul/session_endpoint.go index e15b05227..ae39a6fc5 100644 --- a/agent/consul/session_endpoint.go +++ b/agent/consul/session_endpoint.go @@ -151,7 +151,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error { if args.Op == structs.SessionCreate && args.Session.TTL != "" { // If we created a session with a TTL, reset the expiration timer - s.srv.resetSessionTimer(args.Session.ID, &args.Session) + s.srv.resetSessionTimer(&args.Session) } else if args.Op == structs.SessionDestroy { // If we destroyed a session, it might potentially have a TTL, // and we need to clear the timer @@ -308,7 +308,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest, // Reset the session TTL timer. reply.Sessions = structs.Sessions{session} - if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil { + if err := s.srv.resetSessionTimer(session); err != nil { s.logger.Error("Session renew failed", "error", err) return err } diff --git a/agent/consul/session_ttl.go b/agent/consul/session_ttl.go index 426179d96..0bb1cb3f1 100644 --- a/agent/consul/session_ttl.go +++ b/agent/consul/session_ttl.go @@ -47,13 +47,12 @@ func (s *Server) initializeSessionTimers() error { // Scan all sessions and reset their timer state := s.fsm.State() - // TODO(partitions): track all session timers in all partitions - _, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMetaInDefaultPartition()) + _, sessions, err := state.SessionListAll(nil) if err != nil { return err } for _, session := range sessions { - if err := s.resetSessionTimer(session.ID, session); err != nil { + if err := s.resetSessionTimer(session); err != nil { return err } } @@ -63,20 +62,7 @@ func (s *Server) initializeSessionTimers() error { // resetSessionTimer is used to renew the TTL of a session. // This can be used for new sessions and existing ones. A session // will be faulted in if not given. -func (s *Server) resetSessionTimer(id string, session *structs.Session) error { - // Fault the session in if not given - if session == nil { - state := s.fsm.State() - _, s, err := state.SessionGet(nil, id, nil) - if err != nil { - return err - } - if s == nil { - return fmt.Errorf("Session '%s' not found", id) - } - session = s - } - +func (s *Server) resetSessionTimer(session *structs.Session) error { // Bail if the session has no TTL, fast-path some common inputs switch session.TTL { case "", "0", "0s", "0m", "0h": diff --git a/agent/consul/session_ttl_test.go b/agent/consul/session_ttl_test.go index 160a5b69e..5fc4b09f3 100644 --- a/agent/consul/session_ttl_test.go +++ b/agent/consul/session_ttl_test.go @@ -11,7 +11,7 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/go-uuid" - "github.com/hashicorp/net-rpc-msgpackrpc" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" ) func generateUUID() (ret string) { @@ -59,50 +59,6 @@ func TestInitializeSessionTimers(t *testing.T) { } } -func TestResetSessionTimer_Fault(t *testing.T) { - if testing.Short() { - t.Skip("too slow for testing.Short") - } - - t.Parallel() - dir1, s1 := testServer(t) - defer os.RemoveAll(dir1) - defer s1.Shutdown() - - testrpc.WaitForLeader(t, s1.RPC, "dc1") - - // Should not exist - err := s1.resetSessionTimer(generateUUID(), nil) - if err == nil || !strings.Contains(err.Error(), "not found") { - t.Fatalf("err: %v", err) - } - - // Create a session - state := s1.fsm.State() - if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil { - t.Fatalf("err: %s", err) - } - session := &structs.Session{ - ID: generateUUID(), - Node: "foo", - TTL: "10s", - } - if err := state.SessionCreate(100, session); err != nil { - t.Fatalf("err: %v", err) - } - - // Reset the session timer - err = s1.resetSessionTimer(session.ID, nil) - if err != nil { - t.Fatalf("err: %v", err) - } - - // Check that we have a timer - if s1.sessionTimers.Get(session.ID) == nil { - t.Fatalf("missing session timer") - } -} - func TestResetSessionTimer_NoTTL(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -130,7 +86,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) { } // Reset the session timer - err := s1.resetSessionTimer(session.ID, session) + err := s1.resetSessionTimer(session) if err != nil { t.Fatalf("err: %v", err) } @@ -155,7 +111,7 @@ func TestResetSessionTimer_InvalidTTL(t *testing.T) { } // Reset the session timer - err := s1.resetSessionTimer(session.ID, session) + err := s1.resetSessionTimer(session) if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") { t.Fatalf("err: %v", err) } diff --git a/agent/consul/state/session_oss.go b/agent/consul/state/session_oss.go index a706f2c14..d313fb5f9 100644 --- a/agent/consul/state/session_oss.go +++ b/agent/consul/state/session_oss.go @@ -187,3 +187,7 @@ func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta) func maxIndexTxnSessions(tx *memdb.Txn, _ *structs.EnterpriseMeta) uint64 { return maxIndexTxn(tx, tableSessions) } + +func (s *Store) SessionListAll(ws memdb.WatchSet) (uint64, structs.Sessions, error) { + return s.SessionList(ws, nil) +}