port oss changes (#11736)

This commit is contained in:
Dhia Ayachi 2021-12-03 17:23:55 -05:00 committed by GitHub
parent 3791d6d7da
commit e38ccf0a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 66 deletions

View File

@ -151,7 +151,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
if args.Op == structs.SessionCreate && args.Session.TTL != "" {
// If we created a session with a TTL, reset the expiration timer
s.srv.resetSessionTimer(args.Session.ID, &args.Session)
s.srv.resetSessionTimer(&args.Session)
} else if args.Op == structs.SessionDestroy {
// If we destroyed a session, it might potentially have a TTL,
// and we need to clear the timer
@ -308,7 +308,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest,
// Reset the session TTL timer.
reply.Sessions = structs.Sessions{session}
if err := s.srv.resetSessionTimer(args.SessionID, session); err != nil {
if err := s.srv.resetSessionTimer(session); err != nil {
s.logger.Error("Session renew failed", "error", err)
return err
}

View File

@ -47,13 +47,12 @@ func (s *Server) initializeSessionTimers() error {
// Scan all sessions and reset their timer
state := s.fsm.State()
// TODO(partitions): track all session timers in all partitions
_, sessions, err := state.SessionList(nil, structs.WildcardEnterpriseMetaInDefaultPartition())
_, sessions, err := state.SessionListAll(nil)
if err != nil {
return err
}
for _, session := range sessions {
if err := s.resetSessionTimer(session.ID, session); err != nil {
if err := s.resetSessionTimer(session); err != nil {
return err
}
}
@ -63,20 +62,7 @@ func (s *Server) initializeSessionTimers() error {
// resetSessionTimer is used to renew the TTL of a session.
// This can be used for new sessions and existing ones. A session
// will be faulted in if not given.
func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
// Fault the session in if not given
if session == nil {
state := s.fsm.State()
_, s, err := state.SessionGet(nil, id, nil)
if err != nil {
return err
}
if s == nil {
return fmt.Errorf("Session '%s' not found", id)
}
session = s
}
func (s *Server) resetSessionTimer(session *structs.Session) error {
// Bail if the session has no TTL, fast-path some common inputs
switch session.TTL {
case "", "0", "0s", "0m", "0h":

View File

@ -11,7 +11,7 @@ import (
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/net-rpc-msgpackrpc"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
)
func generateUUID() (ret string) {
@ -59,50 +59,6 @@ func TestInitializeSessionTimers(t *testing.T) {
}
}
func TestResetSessionTimer_Fault(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
dir1, s1 := testServer(t)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
// Should not exist
err := s1.resetSessionTimer(generateUUID(), nil)
if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("err: %v", err)
}
// Create a session
state := s1.fsm.State()
if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
t.Fatalf("err: %s", err)
}
session := &structs.Session{
ID: generateUUID(),
Node: "foo",
TTL: "10s",
}
if err := state.SessionCreate(100, session); err != nil {
t.Fatalf("err: %v", err)
}
// Reset the session timer
err = s1.resetSessionTimer(session.ID, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
// Check that we have a timer
if s1.sessionTimers.Get(session.ID) == nil {
t.Fatalf("missing session timer")
}
}
func TestResetSessionTimer_NoTTL(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
@ -130,7 +86,7 @@ func TestResetSessionTimer_NoTTL(t *testing.T) {
}
// Reset the session timer
err := s1.resetSessionTimer(session.ID, session)
err := s1.resetSessionTimer(session)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -155,7 +111,7 @@ func TestResetSessionTimer_InvalidTTL(t *testing.T) {
}
// Reset the session timer
err := s1.resetSessionTimer(session.ID, session)
err := s1.resetSessionTimer(session)
if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
t.Fatalf("err: %v", err)
}

View File

@ -187,3 +187,7 @@ func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta)
func maxIndexTxnSessions(tx *memdb.Txn, _ *structs.EnterpriseMeta) uint64 {
return maxIndexTxn(tx, tableSessions)
}
func (s *Store) SessionListAll(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
return s.SessionList(ws, nil)
}