diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 495e54ec2..978edf4ad 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -253,7 +253,12 @@ func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { return fmt.Errorf("failed to retrieve existing index: %s", err) } - if cur, ok := ti.(*IndexEntry); ok && idx > cur.Value { + // Always take the first update, otherwise do the > check. + if ti == nil { + if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { + return fmt.Errorf("failed updating index %s", err) + } + } else if cur, ok := ti.(*IndexEntry); ok && idx > cur.Value { if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { return fmt.Errorf("failed updating index %s", err) } @@ -1752,15 +1757,15 @@ func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWa // SessionRestore is used when restoring from a snapshot. For general inserts, // use SessionCreate. func (s *StateStore) SessionRestore(sess *structs.Session) error { - tx := s.db.Txn(false) + tx := s.db.Txn(true) defer tx.Abort() - // Insert the session + // Insert the session. if err := tx.Insert("sessions", sess); err != nil { return fmt.Errorf("failed inserting session: %s", err) } - // Insert the check mappings + // Insert the check mappings. for _, checkID := range sess.Checks { mapping := &sessionCheck{ Node: sess.Node, @@ -1772,7 +1777,7 @@ func (s *StateStore) SessionRestore(sess *structs.Session) error { } } - // Update the index + // Update the index. if err := indexUpdateMaxTxn(tx, sess.ModifyIndex, "sessions"); err != nil { return fmt.Errorf("failed updating index: %s", err) } diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 36e2d01fd..c0bc34ecb 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -2559,6 +2559,11 @@ func TestStateStore_KVS_Snapshot_Restore(t *testing.T) { t.Fatalf("bad: %#v", entry) } } + + // Check that the index was updated. + if idx := s.maxIndex("kvs"); idx != 7 { + t.Fatalf("bad index: %d", idx) + } }() } @@ -2865,7 +2870,7 @@ func TestStateStore_SessionCreate_GetSession(t *testing.T) { } } -func TestStateStore_SessionList(t *testing.T) { +func TegstStateStore_SessionList(t *testing.T) { s := testStateStore(t) // Listing when no sessions exist returns nil @@ -3021,6 +3026,142 @@ func TestStateStore_SessionDestroy(t *testing.T) { } } +func TestStateStore_Session_Snapshot_Restore(t *testing.T) { + s := testStateStore(t) + + // Register some nodes and checks. + testRegisterNode(t, s, 1, "node1") + testRegisterNode(t, s, 2, "node2") + testRegisterNode(t, s, 3, "node3") + testRegisterCheck(t, s, 4, "node1", "", "check1", structs.HealthPassing) + + // Create some sessions in the state store. + sessions := structs.Sessions{ + &structs.Session{ + ID: "session1", + Node: "node1", + Behavior: structs.SessionKeysDelete, + Checks: []string{"check1"}, + }, + &structs.Session{ + ID: "session2", + Node: "node2", + Behavior: structs.SessionKeysRelease, + LockDelay: 10 * time.Second, + }, + &structs.Session{ + ID: "session3", + Node: "node3", + Behavior: structs.SessionKeysDelete, + TTL: "1.5s", + }, + } + for i, session := range sessions { + if err := s.SessionCreate(uint64(5+i), session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Snapshot the sessions. + snap := s.Snapshot() + defer snap.Close() + + // Verify the snapshot. + if idx := snap.LastIndex(); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + dump, err := snap.SessionDump() + if err != nil { + t.Fatalf("err: %s", err) + } + if !reflect.DeepEqual(dump, sessions) { + t.Fatalf("bad: %#v", dump) + } + + // Restore the sessions into a new state store. + func() { + s := testStateStore(t) + for _, session := range dump { + if err := s.SessionRestore(session); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Read the restored sessions back out and verify that they + // match. + idx, res, err := s.SessionList() + if err != nil { + t.Fatalf("err: %s", err) + } + if idx != 7 { + t.Fatalf("bad index: %d", idx) + } + if !reflect.DeepEqual(res, sessions) { + t.Fatalf("bad: %#v", res) + } + + // Check that the index was updated. + if idx := s.maxIndex("sessions"); idx != 7 { + t.Fatalf("bad index: %d", idx) + } + + // Manually verify that the session check mapping got restored. + tx := s.db.Txn(false) + defer tx.Abort() + + check, err := tx.First("session_checks", "session", "session1") + if err != nil { + t.Fatalf("err: %s", err) + } + if check == nil { + t.Fatalf("missing session check") + } + expectCheck := &sessionCheck{ + Node: "node1", + CheckID: "check1", + Session: "session1", + } + if actual := check.(*sessionCheck); !reflect.DeepEqual(actual, expectCheck) { + t.Fatalf("expected %#v, got: %#v", expectCheck, actual) + } + }() +} + +func TestStateStore_Session_Watches(t *testing.T) { + s := testStateStore(t) + + // Register a test node. + testRegisterNode(t, s, 1, "node1") + + // This just covers the basics. The session invalidation tests above + // cover the more nuanced multiple table watches. + verifyWatch(t, s.GetTableWatch("sessions"), func() { + session := &structs.Session{ + ID: "session1", + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionCreate(2, session); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.GetTableWatch("sessions"), func() { + if err := s.SessionDestroy(3, "session1"); err != nil { + t.Fatalf("err: %s", err) + } + }) + verifyWatch(t, s.GetTableWatch("sessions"), func() { + session := &structs.Session{ + ID: "session1", + Node: "node1", + Behavior: structs.SessionKeysDelete, + } + if err := s.SessionRestore(session); err != nil { + t.Fatalf("err: %s", err) + } + }) +} + func TestStateStore_ACLSet_ACLGet(t *testing.T) { s := testStateStore(t) @@ -3053,7 +3194,7 @@ func TestStateStore_ACLSet_ACLGet(t *testing.T) { // Check that the index was updated if idx := s.maxIndex("acls"); idx != 1 { - t.Fatalf("err: %s", err) + t.Fatalf("bad index: %d", idx) } // Retrieve the ACL again @@ -3265,6 +3406,11 @@ func TestStateStore_ACL_Snapshot_Restore(t *testing.T) { if !reflect.DeepEqual(res, acls) { t.Fatalf("bad: %#v", res) } + + // Check that the index was updated. + if idx := s.maxIndex("acls"); idx != 2 { + t.Fatalf("bad index: %d", idx) + } }() }