From 679e4e6e702102c87f302c18e92ef0f90bc644d8 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 15 May 2014 11:51:31 -0700 Subject: [PATCH] consul: Adding session invalidation --- consul/state_store.go | 102 +++++++++++++++++----- consul/state_store_test.go | 171 ++++++++++++++++++++++++++++++++++++- 2 files changed, 250 insertions(+), 23 deletions(-) diff --git a/consul/state_store.go b/consul/state_store.go index 02567b686..730ff3e87 100644 --- a/consul/state_store.go +++ b/consul/state_store.go @@ -366,8 +366,7 @@ func (s *StateStore) Nodes() (uint64, structs.Nodes) { // EnsureService is used to ensure a given node exposes a service func (s *StateStore) EnsureService(index uint64, node string, ns *structs.NodeService) error { - tables := MDBTables{s.nodeTable, s.serviceTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } @@ -461,8 +460,7 @@ func (s *StateStore) parseNodeServices(tables MDBTables, tx *MDBTxn, name string // DeleteNodeService is used to delete a node service func (s *StateStore) DeleteNodeService(index uint64, node, id string) error { - tables := MDBTables{s.serviceTable, s.checkTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } @@ -476,6 +474,19 @@ func (s *StateStore) DeleteNodeService(index uint64, node, id string) error { } defer s.watch[s.serviceTable].Notify() } + + // Invalidate any sessions using these checks + checks, err := s.checkTable.GetTxn(tx, "node", node, id) + if err != nil { + return err + } + for _, c := range checks { + check := c.(*structs.HealthCheck) + if err := s.invalidateCheck(index, tx, node, check.CheckID); err != nil { + return err + } + } + if n, err := s.checkTable.DeleteTxn(tx, "node", node, id); err != nil { return err } else if n > 0 { @@ -489,13 +500,17 @@ func (s *StateStore) DeleteNodeService(index uint64, node, id string) error { // DeleteNode is used to delete a node and all it's services func (s *StateStore) DeleteNode(index uint64, node string) error { - tables := MDBTables{s.nodeTable, s.serviceTable, s.checkTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } defer tx.Abort() + // Invalidate any sessions held by the node + if err := s.invalidateNode(index, tx, node); err != nil { + return err + } + if n, err := s.serviceTable.DeleteTxn(tx, "id", node); err != nil { return err } else if n > 0 { @@ -633,8 +648,7 @@ func (s *StateStore) EnsureCheck(index uint64, check *structs.HealthCheck) error } // Start the txn - tables := MDBTables{s.nodeTable, s.serviceTable, s.checkTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } @@ -663,6 +677,14 @@ func (s *StateStore) EnsureCheck(index uint64, check *structs.HealthCheck) error check.ServiceName = srv.ServiceName } + // Invalidate any sessions if status is critical + if check.Status == structs.HealthCritical { + err := s.invalidateCheck(index, tx, check.Node, check.CheckID) + if err != nil { + return err + } + } + // Ensure the check is set if err := s.checkTable.InsertTxn(tx, check); err != nil { return err @@ -676,12 +698,17 @@ func (s *StateStore) EnsureCheck(index uint64, check *structs.HealthCheck) error // DeleteNodeCheck is used to delete a node health check func (s *StateStore) DeleteNodeCheck(index uint64, node, id string) error { - tx, err := s.checkTable.StartTxn(false, nil) + tx, err := s.tables.StartTxn(false) if err != nil { return err } defer tx.Abort() + // Invalidate any sessions held by this check + if err := s.invalidateCheck(index, tx, node, id); err != nil { + return err + } + if n, err := s.checkTable.DeleteTxn(tx, "id", node, id); err != nil { return err } else if n > 0 { @@ -1101,9 +1128,7 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error session.CreateIndex = index // Start the transaction - tables := MDBTables{s.nodeTable, s.checkTable, - s.sessionTable, s.sessionCheckTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } @@ -1172,9 +1197,7 @@ func (s *StateStore) SessionCreate(index uint64, session *structs.Session) error // doing a restore, otherwise SessionCreate should be used. func (s *StateStore) SessionRestore(session *structs.Session) error { // Start the transaction - tables := MDBTables{s.nodeTable, s.checkTable, - s.sessionTable, s.sessionCheckTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } @@ -1235,14 +1258,53 @@ func (s *StateStore) NodeSessions(node string) (uint64, []*structs.Session, erro // SessionDelete is used to destroy a session. func (s *StateStore) SessionDestroy(index uint64, id string) error { - // Start the transaction - tables := MDBTables{s.sessionTable, s.sessionCheckTable} - tx, err := tables.StartTxn(false) + tx, err := s.tables.StartTxn(false) if err != nil { panic(fmt.Errorf("Failed to start txn: %v", err)) } defer tx.Abort() + if err := s.invalidateSession(index, tx, id); err != nil { + return err + } + return tx.Commit() +} + +// invalideNode is used to invalide all sessions belonging to a node +// All tables should be locked in the tx. +func (s *StateStore) invalidateNode(index uint64, tx *MDBTxn, node string) error { + sessions, err := s.sessionTable.GetTxn(tx, "node", node) + if err != nil { + return err + } + for _, sess := range sessions { + session := sess.(*structs.Session).ID + if err := s.invalidateSession(index, tx, session); err != nil { + return err + } + } + return nil +} + +// invalidateCheck is used to invalide all sessions belonging to a check +// All tables should be locked in the tx. +func (s *StateStore) invalidateCheck(index uint64, tx *MDBTxn, node, check string) error { + sessionChecks, err := s.sessionCheckTable.GetTxn(tx, "id", node, check) + if err != nil { + return err + } + for _, sc := range sessionChecks { + session := sc.(*sessionCheck).Session + if err := s.invalidateSession(index, tx, session); err != nil { + return err + } + } + return nil +} + +// invalidateSession is used to invalide a session within a given txn +// All tables should be locked in the tx. +func (s *StateStore) invalidateSession(index uint64, tx *MDBTxn, id string) error { // Get the session res, err := s.sessionTable.GetTxn(tx, "id", id) if err != nil { @@ -1272,8 +1334,8 @@ func (s *StateStore) SessionDestroy(index uint64, id string) error { if err := s.sessionTable.SetLastIndexTxn(tx, index); err != nil { return err } - defer s.watch[s.sessionTable].Notify() - return tx.Commit() + tx.Defer(func() { s.watch[s.sessionTable].Notify() }) + return nil } // Snapshot is used to create a point in time snapshot diff --git a/consul/state_store_test.go b/consul/state_store_test.go index 41f9375fd..d048b2fc2 100644 --- a/consul/state_store_test.go +++ b/consul/state_store_test.go @@ -717,11 +717,11 @@ func TestStoreSnapshot(t *testing.T) { ServiceID: "db", } if err := store.EnsureCheck(17, checkAfter); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.KVSDelete(18, "/web/a"); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } // Check snapshot has old values @@ -1630,7 +1630,7 @@ func TestSessionCreate_Invalid(t *testing.T) { Status: structs.HealthCritical, } if err := store.EnsureCheck(13, check); err != nil { - t.Fatalf("err: %v") + t.Fatalf("err: %v", err) } if err := store.SessionCreate(1000, session); err.Error() != "Check 'bar' is in critical state" { t.Fatalf("err: %v", err) @@ -1719,3 +1719,168 @@ func TestSession_Lookups(t *testing.T) { t.Fatalf("bad: %v %v", ids, out) } } + +func TestSessionInvalidate_CriticalHealthCheck(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { + t.Fatalf("err: %v") + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := store.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v") + } + + session := &structs.Session{ + Node: "foo", + Checks: []string{"bar"}, + } + if err := store.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Invalidate the check + check.Status = structs.HealthCritical + if err := store.EnsureCheck(15, check); err != nil { + t.Fatalf("err: %v", err) + } + + // Lookup by ID, should be nil + _, s2, err := store.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } +} + +func TestSessionInvalidate_DeleteHealthCheck(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { + t.Fatalf("err: %v") + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "bar", + Status: structs.HealthPassing, + } + if err := store.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v") + } + + session := &structs.Session{ + Node: "foo", + Checks: []string{"bar"}, + } + if err := store.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the check + if err := store.DeleteNodeCheck(15, "foo", "bar"); err != nil { + t.Fatalf("err: %v", err) + } + + // Lookup by ID, should be nil + _, s2, err := store.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } +} + +func TestSessionInvalidate_DeleteNode(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode(3, structs.Node{"foo", "127.0.0.1"}); err != nil { + t.Fatalf("err: %v") + } + + session := &structs.Session{ + Node: "foo", + } + if err := store.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Delete the node + if err := store.DeleteNode(15, "foo"); err != nil { + t.Fatalf("err: %v") + } + + // Lookup by ID, should be nil + _, s2, err := store.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } +} + +func TestSessionInvalidate_DeleteNodeService(t *testing.T) { + store, err := testStateStore() + if err != nil { + t.Fatalf("err: %v", err) + } + defer store.Close() + + if err := store.EnsureNode(11, structs.Node{"foo", "127.0.0.1"}); err != nil { + t.Fatalf("err: %v", err) + } + if err := store.EnsureService(12, "foo", &structs.NodeService{"api", "api", nil, 5000}); err != nil { + t.Fatalf("err: %v", err) + } + check := &structs.HealthCheck{ + Node: "foo", + CheckID: "api", + Name: "Can connect", + Status: structs.HealthPassing, + ServiceID: "api", + } + if err := store.EnsureCheck(13, check); err != nil { + t.Fatalf("err: %v") + } + + session := &structs.Session{ + Node: "foo", + Checks: []string{"api"}, + } + if err := store.SessionCreate(14, session); err != nil { + t.Fatalf("err: %v", err) + } + + // Should invalidate the session + if err := store.DeleteNodeService(15, "foo", "api"); err != nil { + t.Fatalf("err: %v", err) + } + + // Lookup by ID, should be nil + _, s2, err := store.SessionGet(session.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if s2 != nil { + t.Fatalf("session should be invalidated") + } +}