Adds fine-grained watches to session endpoints.

This commit is contained in:
James Phillips 2017-01-24 10:08:14 -08:00
parent 997934c94f
commit e59f398d80
No known key found for this signature in database
GPG Key ID: 77183E682AC5FC11
7 changed files with 167 additions and 107 deletions

View File

@ -500,7 +500,7 @@ func TestFSM_SnapshotRestore(t *testing.T) {
}
// Verify session is restored
idx, s, err := fsm2.state.SessionGet(session.ID)
idx, s, err := fsm2.state.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -875,7 +875,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) {
// Get the session
id := resp.(string)
_, session, err := fsm.state.SessionGet(id)
_, session, err := fsm.state.SessionGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -911,7 +911,7 @@ func TestFSM_SessionCreate_Destroy(t *testing.T) {
t.Fatalf("resp: %v", resp)
}
_, session, err = fsm.state.SessionGet(id)
_, session, err = fsm.state.SessionGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -6,6 +6,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-uuid"
)
@ -39,7 +40,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
switch args.Op {
case structs.SessionDestroy:
state := s.srv.fsm.State()
_, existing, err := state.SessionGet(args.Session.ID)
_, existing, err := state.SessionGet(nil, args.Session.ID)
if err != nil {
return fmt.Errorf("Unknown session %q", args.Session.ID)
}
@ -94,7 +95,7 @@ func (s *Session) Apply(args *structs.SessionRequest, reply *string) error {
s.srv.logger.Printf("[ERR] consul.session: UUID generation failed: %v", err)
return err
}
_, sess, err := state.SessionGet(args.Session.ID)
_, sess, err := state.SessionGet(nil, args.Session.ID)
if err != nil {
s.srv.logger.Printf("[ERR] consul.session: Session lookup failed: %v", err)
return err
@ -141,12 +142,11 @@ func (s *Session) Get(args *structs.SessionSpecificRequest,
// Get the local state
state := s.srv.fsm.State()
return s.srv.blockingRPC(
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
state.GetQueryWatch("SessionGet"),
func() error {
index, session, err := state.SessionGet(args.Session)
func(ws memdb.WatchSet) error {
index, session, err := state.SessionGet(ws, args.Session)
if err != nil {
return err
}
@ -173,12 +173,11 @@ func (s *Session) List(args *structs.DCSpecificRequest,
// Get the local state
state := s.srv.fsm.State()
return s.srv.blockingRPC(
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
state.GetQueryWatch("SessionList"),
func() error {
index, sessions, err := state.SessionList()
func(ws memdb.WatchSet) error {
index, sessions, err := state.SessionList(ws)
if err != nil {
return err
}
@ -200,12 +199,11 @@ func (s *Session) NodeSessions(args *structs.NodeSpecificRequest,
// Get the local state
state := s.srv.fsm.State()
return s.srv.blockingRPC(
return s.srv.blockingQuery(
&args.QueryOptions,
&reply.QueryMeta,
state.GetQueryWatch("NodeSessions"),
func() error {
index, sessions, err := state.NodeSessions(args.Node)
func(ws memdb.WatchSet) error {
index, sessions, err := state.NodeSessions(ws, args.Node)
if err != nil {
return err
}
@ -228,7 +226,7 @@ func (s *Session) Renew(args *structs.SessionSpecificRequest,
// Get the session, from local state.
state := s.srv.fsm.State()
index, session, err := state.SessionGet(args.Session)
index, session, err := state.SessionGet(nil, args.Session)
if err != nil {
return err
}

View File

@ -40,7 +40,7 @@ func TestSession_Apply(t *testing.T) {
// Verify
state := s1.fsm.State()
_, s, err := state.SessionGet(out)
_, s, err := state.SessionGet(nil, out)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -62,7 +62,7 @@ func TestSession_Apply(t *testing.T) {
}
// Verify
_, s, err = state.SessionGet(id)
_, s, err = state.SessionGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -100,7 +100,7 @@ func TestSession_DeleteApply(t *testing.T) {
// Verify
state := s1.fsm.State()
_, s, err := state.SessionGet(out)
_, s, err := state.SessionGet(nil, out)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -125,7 +125,7 @@ func TestSession_DeleteApply(t *testing.T) {
}
// Verify
_, s, err = state.SessionGet(id)
_, s, err = state.SessionGet(nil, id)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -22,7 +22,7 @@ const (
func (s *Server) initializeSessionTimers() error {
// Scan all sessions and reset their timer
state := s.fsm.State()
_, sessions, err := state.SessionList()
_, sessions, err := state.SessionList(nil)
if err != nil {
return err
}
@ -41,7 +41,7 @@ func (s *Server) resetSessionTimer(id string, session *structs.Session) error {
// Fault the session in if not given
if session == nil {
state := s.fsm.State()
_, s, err := state.SessionGet(id)
_, s, err := state.SessionGet(nil, id)
if err != nil {
return err
}

View File

@ -225,7 +225,7 @@ func TestInvalidateSession(t *testing.T) {
s1.invalidateSession(session.ID)
// Check it is gone
_, sess, err := state.SessionGet(session.ID)
_, sess, err := state.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -145,18 +145,19 @@ func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.S
}
// SessionGet is used to retrieve an active session from the state store.
func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, error) {
func (s *StateStore) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("SessionGet")...)
idx := maxIndexTxn(tx, "sessions")
// Look up the session by its ID
session, err := tx.First("sessions", "id", sessionID)
watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID)
if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
}
ws.Add(watchCh)
if session != nil {
return idx, session.(*structs.Session), nil
}
@ -164,18 +165,19 @@ func (s *StateStore) SessionGet(sessionID string) (uint64, *structs.Session, err
}
// SessionList returns a slice containing all of the active sessions.
func (s *StateStore) SessionList() (uint64, structs.Sessions, error) {
func (s *StateStore) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("SessionList")...)
idx := maxIndexTxn(tx, "sessions")
// Query all of the active sessions.
sessions, err := tx.Get("sessions", "id")
if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
}
ws.Add(sessions.WatchCh())
// Go over the sessions and create a slice of them.
var result structs.Sessions
@ -188,18 +190,19 @@ func (s *StateStore) SessionList() (uint64, structs.Sessions, error) {
// NodeSessions returns a set of active sessions associated
// with the given node ID. The returned index is the highest
// index seen from the result set.
func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) {
func (s *StateStore) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxn(tx, s.getWatchTables("NodeSessions")...)
idx := maxIndexTxn(tx, "sessions")
// Get all of the sessions which belong to the node
sessions, err := tx.Get("sessions", "node", nodeID)
if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
}
ws.Add(sessions.WatchCh())
// Go over all of the sessions and return them as a slice
var result structs.Sessions

View File

@ -9,13 +9,15 @@ import (
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/types"
"github.com/hashicorp/go-memdb"
)
func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
s := testStateStore(t)
// SessionGet returns nil if the session doesn't exist
idx, session, err := s.SessionGet(testUUID())
ws := memdb.NewWatchSet()
idx, session, err := s.SessionGet(ws, testUUID())
if session != nil || err != nil {
t.Fatalf("expected (nil, nil), got: (%#v, %#v)", session, err)
}
@ -49,6 +51,9 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if idx := s.maxIndex("sessions"); idx != 0 {
t.Fatalf("bad index: %d", idx)
}
if watchFired(ws) {
t.Fatalf("bad")
}
// Valid session is able to register
testRegisterNode(t, s, 1, "node1")
@ -62,9 +67,13 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if idx := s.maxIndex("sessions"); idx != 2 {
t.Fatalf("bad index: %s", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Retrieve the session again
idx, session, err = s.SessionGet(sess.ID)
ws = memdb.NewWatchSet()
idx, session, err = s.SessionGet(ws, sess.ID)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -104,12 +113,19 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), structs.HealthCritical) {
t.Fatalf("expected critical state error, got: %#v", err)
}
if watchFired(ws) {
t.Fatalf("bad")
}
// Registering with a healthy check succeeds
// Registering with a healthy check succeeds (doesn't hit the watch since
// we are looking at the old session).
testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing)
if err := s.SessionCreate(5, sess); err != nil {
t.Fatalf("err: %s", err)
}
if watchFired(ws) {
t.Fatalf("bad")
}
// Register a session against two checks.
testRegisterCheck(t, s, 5, "node1", "", "check2", structs.HealthPassing)
@ -159,7 +175,7 @@ func TestStateStore_SessionCreate_SessionGet(t *testing.T) {
}
// Pulling a nonexistent session gives the table index.
idx, session, err = s.SessionGet(testUUID())
idx, session, err = s.SessionGet(nil, testUUID())
if err != nil {
t.Fatalf("err: %s", err)
}
@ -175,7 +191,8 @@ func TegstStateStore_SessionList(t *testing.T) {
s := testStateStore(t)
// Listing when no sessions exist returns nil
idx, res, err := s.SessionList()
ws := memdb.NewWatchSet()
idx, res, err := s.SessionList(ws)
if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
}
@ -208,9 +225,12 @@ func TegstStateStore_SessionList(t *testing.T) {
t.Fatalf("err: %s", err)
}
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// List out all of the sessions
idx, sessionList, err := s.SessionList()
idx, sessionList, err := s.SessionList(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -226,7 +246,8 @@ func TestStateStore_NodeSessions(t *testing.T) {
s := testStateStore(t)
// Listing sessions with no results returns nil
idx, res, err := s.NodeSessions("node1")
ws := memdb.NewWatchSet()
idx, res, err := s.NodeSessions(ws, "node1")
if idx != 0 || res != nil || err != nil {
t.Fatalf("expected (0, nil, nil), got: (%d, %#v, %#v)", idx, res, err)
}
@ -261,10 +282,14 @@ func TestStateStore_NodeSessions(t *testing.T) {
t.Fatalf("err: %s", err)
}
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Query all of the sessions associated with a specific
// node in the state store.
idx, res, err = s.NodeSessions("node1")
ws1 := memdb.NewWatchSet()
idx, res, err = s.NodeSessions(ws1, "node1")
if err != nil {
t.Fatalf("err: %s", err)
}
@ -275,7 +300,8 @@ func TestStateStore_NodeSessions(t *testing.T) {
t.Fatalf("bad index: %d", idx)
}
idx, res, err = s.NodeSessions("node2")
ws2 := memdb.NewWatchSet()
idx, res, err = s.NodeSessions(ws2, "node2")
if err != nil {
t.Fatalf("err: %s", err)
}
@ -285,6 +311,17 @@ func TestStateStore_NodeSessions(t *testing.T) {
if idx != 6 {
t.Fatalf("bad index: %d", idx)
}
// Destroying a session on node1 should not affect node2's watch.
if err := s.SessionDestroy(100, sessions1[0].ID); err != nil {
t.Fatalf("err: %s", err)
}
if !watchFired(ws1) {
t.Fatalf("bad")
}
if watchFired(ws2) {
t.Fatalf("bad")
}
}
func TestStateStore_SessionDestroy(t *testing.T) {
@ -418,7 +455,7 @@ func TestStateStore_Session_Snapshot_Restore(t *testing.T) {
// Read the restored sessions back out and verify that they
// match.
idx, res, err := s.SessionList()
idx, res, err := s.SessionList(nil)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -520,17 +557,21 @@ func TestStateStore_Session_Invalidate_DeleteNode(t *testing.T) {
t.Fatalf("err: %v", err)
}
// Delete the node and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("nodes"), func() {
if err := s.DeleteNode(15, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
})
})
// Delete the node and make sure the watch fires.
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.DeleteNode(15, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -571,19 +612,21 @@ func TestStateStore_Session_Invalidate_DeleteService(t *testing.T) {
t.Fatalf("err: %v", err)
}
// Delete the service and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("services"), func() {
verifyWatch(t, s.getTableWatch("checks"), func() {
if err := s.DeleteService(15, "foo", "api"); err != nil {
t.Fatalf("err: %v", err)
}
})
})
})
// Delete the service and make sure the watch fires.
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.DeleteService(15, "foo", "api"); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -620,17 +663,21 @@ func TestStateStore_Session_Invalidate_Critical_Check(t *testing.T) {
}
// Invalidate the check and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("checks"), func() {
check.Status = structs.HealthCritical
if err := s.EnsureCheck(15, check); err != nil {
t.Fatalf("err: %v", err)
}
})
})
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
check.Status = structs.HealthCritical
if err := s.EnsureCheck(15, check); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -667,16 +714,20 @@ func TestStateStore_Session_Invalidate_DeleteCheck(t *testing.T) {
}
// Delete the check and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("checks"), func() {
if err := s.DeleteCheck(15, "foo", "bar"); err != nil {
t.Fatalf("err: %v", err)
}
})
})
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.DeleteCheck(15, "foo", "bar"); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -731,18 +782,20 @@ func TestStateStore_Session_Invalidate_Key_Unlock_Behavior(t *testing.T) {
}
// Delete the node and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("nodes"), func() {
verifyWatch(t, s.GetKVSWatch("/f"), func() {
if err := s.DeleteNode(6, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
})
})
})
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.DeleteNode(6, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -811,18 +864,20 @@ func TestStateStore_Session_Invalidate_Key_Delete_Behavior(t *testing.T) {
}
// Delete the node and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("nodes"), func() {
verifyWatch(t, s.GetKVSWatch("/b"), func() {
if err := s.DeleteNode(6, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
})
})
})
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.DeleteNode(6, "foo"); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Lookup by ID, should be nil.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -877,16 +932,20 @@ func TestStateStore_Session_Invalidate_PreparedQuery_Delete(t *testing.T) {
}
// Invalidate the session and make sure the watches fire.
verifyWatch(t, s.getTableWatch("sessions"), func() {
verifyWatch(t, s.getTableWatch("prepared-queries"), func() {
if err := s.SessionDestroy(5, session.ID); err != nil {
t.Fatalf("err: %v", err)
}
})
})
ws := memdb.NewWatchSet()
idx, s2, err := s.SessionGet(ws, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.SessionDestroy(5, session.ID); err != nil {
t.Fatalf("err: %v", err)
}
if !watchFired(ws) {
t.Fatalf("bad")
}
// Make sure the session is gone.
idx, s2, err := s.SessionGet(session.ID)
idx, s2, err = s.SessionGet(nil, session.ID)
if err != nil {
t.Fatalf("err: %v", err)
}