finish implementing better state store

This commit is contained in:
Timo Kösters 2021-08-12 23:04:00 +02:00
parent 31f60ad6fd
commit 3eabaa2a95
No known key found for this signature in database
GPG key ID: 356E705610F626D5
10 changed files with 645 additions and 526 deletions

View file

@ -249,6 +249,8 @@ pub async fn register_route(
let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone(
db.globals
.roomid_mutex_state

View file

@ -44,7 +44,7 @@ pub async fn get_context_route(
let events_before = db
.rooms
.pdus_until(&sender_user, &body.room_id, base_token)
.pdus_until(&sender_user, &body.room_id, base_token)?
.take(
u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")
@ -66,7 +66,7 @@ pub async fn get_context_route(
let events_after = db
.rooms
.pdus_after(&sender_user, &body.room_id, base_token)
.pdus_after(&sender_user, &body.room_id, base_token)?
.take(
u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")

View file

@ -609,6 +609,8 @@ async fn join_room_by_id_helper(
)
.await?;
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let pdu = PduEvent::from_id_val(&event_id, join_event.clone())
.map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?;

View file

@ -128,7 +128,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Forward => {
let events_after = db
.rooms
.pdus_after(&sender_user, &body.room_id, from)
.pdus_after(&sender_user, &body.room_id, from)?
.take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| {
@ -158,7 +158,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Backward => {
let events_before = db
.rooms
.pdus_until(&sender_user, &body.room_id, from)
.pdus_until(&sender_user, &body.room_id, from)?
.take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| {

View file

@ -33,6 +33,8 @@ pub async fn create_room_route(
let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone(
db.globals
.roomid_mutex_state
@ -173,7 +175,6 @@ pub async fn create_room_route(
)?;
// 4. Canonical room alias
if let Some(room_alias_id) = &alias {
db.rooms.build_and_append_pdu(
PduBuilder {
@ -193,7 +194,7 @@ pub async fn create_room_route(
&room_id,
&db,
&state_lock,
);
)?;
}
// 5. Events set by preset

View file

@ -205,7 +205,7 @@ async fn sync_helper(
let mut non_timeline_pdus = db
.rooms
.pdus_until(&sender_user, &room_id, u64::MAX)
.pdus_until(&sender_user, &room_id, u64::MAX)?
.filter_map(|r| {
// Filter out buggy events
if r.is_err() {
@ -248,13 +248,13 @@ async fn sync_helper(
let first_pdu_before_since = db
.rooms
.pdus_until(&sender_user, &room_id, since)
.pdus_until(&sender_user, &room_id, since)?
.next()
.transpose()?;
let pdus_after_since = db
.rooms
.pdus_after(&sender_user, &room_id, since)
.pdus_after(&sender_user, &room_id, since)?
.next()
.is_some();
@ -286,7 +286,7 @@ async fn sync_helper(
for hero in db
.rooms
.all_pdus(&sender_user, &room_id)
.all_pdus(&sender_user, &room_id)?
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
.filter(|(_, pdu)| pdu.kind == EventType::RoomMember)
.map(|(_, pdu)| {
@ -328,11 +328,11 @@ async fn sync_helper(
}
}
(
Ok::<_, Error>((
Some(joined_member_count),
Some(invited_member_count),
heroes,
)
))
};
let (
@ -343,7 +343,7 @@ async fn sync_helper(
state_events,
) = if since_shortstatehash.is_none() {
// Probably since = 0, we will do an initial sync
let (joined_member_count, invited_member_count, heroes) = calculate_counts();
let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?;
let state_events = current_state_ids
@ -510,7 +510,7 @@ async fn sync_helper(
}
let (joined_member_count, invited_member_count, heroes) = if send_member_count {
calculate_counts()
calculate_counts()?
} else {
(None, None, Vec::new())
};

View file

@ -28,7 +28,7 @@ use ruma::{DeviceId, EventId, RoomId, ServerName, UserId};
use serde::{de::IgnoredAny, Deserialize};
use std::{
collections::{BTreeMap, HashMap, HashSet},
convert::TryFrom,
convert::{TryFrom, TryInto},
fs::{self, remove_dir_all},
io::Write,
mem::size_of,
@ -266,7 +266,6 @@ impl Database {
shortroomid_roomid: builder.open_tree("shortroomid_roomid")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
stateid_shorteventid: builder.open_tree("stateid_shorteventid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
@ -431,7 +430,6 @@ impl Database {
}
if db.globals.database_version()? < 6 {
// TODO update to 6
// Set room member count
for (roomid, _) in db.rooms.roomid_shortstatehash.iter() {
let room_id =
@ -445,263 +443,98 @@ impl Database {
println!("Migration: 5 -> 6 finished");
}
fn load_shortstatehash_info(
shortstatehash: &[u8],
db: &Database,
lru: &mut LruCache<
Vec<u8>,
Vec<(
Vec<u8>,
HashSet<Vec<u8>>,
HashSet<Vec<u8>>,
HashSet<Vec<u8>>,
)>,
>,
) -> Result<
Vec<(
Vec<u8>, // sstatehash
HashSet<Vec<u8>>, // full state
HashSet<Vec<u8>>, // added
HashSet<Vec<u8>>, // removed
)>,
> {
if let Some(result) = lru.get_mut(shortstatehash) {
return Ok(result.clone());
}
let value = db
.rooms
.shortstatehash_statediff
.get(shortstatehash)?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = value[0..size_of::<u64>()].to_vec();
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.to_vec());
} else {
removed.insert(v.to_vec());
}
i += 2 * size_of::<u64>();
}
if parent != 0_u64.to_be_bytes() {
let mut response = load_shortstatehash_info(&parent, db, lru)?;
let mut state = response.last().unwrap().1.clone();
state.extend(added.iter().cloned());
for r in &removed {
state.remove(r);
}
response.push((shortstatehash.to_vec(), state, added, removed));
lru.insert(shortstatehash.to_vec(), response.clone());
Ok(response)
} else {
let mut response = Vec::new();
response.push((shortstatehash.to_vec(), added.clone(), added, removed));
lru.insert(shortstatehash.to_vec(), response.clone());
Ok(response)
}
}
fn update_shortstatehash_level(
current_shortstatehash: &[u8],
statediffnew: HashSet<Vec<u8>>,
statediffremoved: HashSet<Vec<u8>>,
diff_to_sibling: usize,
mut parent_states: Vec<(
Vec<u8>, // sstatehash
HashSet<Vec<u8>>, // full state
HashSet<Vec<u8>>, // added
HashSet<Vec<u8>>, // removed
)>,
db: &Database,
) -> Result<()> {
let diffsum = statediffnew.len() + statediffremoved.len();
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().unwrap();
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
update_shortstatehash_level(
current_shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
db,
)?;
return Ok(());
}
if parent_states.len() == 0 {
// There is no parent layer, create a new state
let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
for new in &statediffnew {
value.extend_from_slice(&new);
}
if !statediffremoved.is_empty() {
warn!("Tried to create new state with removals");
}
db.rooms
.shortstatehash_statediff
.insert(&current_shortstatehash, &value)?;
return Ok(());
};
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let parent = parent_states.pop().unwrap();
let parent_diff = parent.2.len() + parent.3.len();
if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff {
// Diff too big, we replace above layer(s)
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
update_shortstatehash_level(
current_shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
db,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
let mut value = parent.0.clone();
for new in &statediffnew {
value.extend_from_slice(&new);
}
if !statediffremoved.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in &statediffremoved {
value.extend_from_slice(&removed);
}
}
db.rooms
.shortstatehash_statediff
.insert(&current_shortstatehash, &value)?;
}
Ok(())
}
if db.globals.database_version()? < 7 {
// Upgrade state store
let mut lru = LruCache::new(1000);
let mut last_roomstates: HashMap<RoomId, Vec<u8>> = HashMap::new();
let mut current_sstatehash: Vec<u8> = Vec::new();
let mut last_roomstates: HashMap<RoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let mut counter = 0;
let mut handle_state =
|current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
counter += 1;
println!("counter: {}", counter);
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash))
},
)?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.cloned()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
db.rooms.save_state_from_diff(
dbg!(current_sstatehash),
statediffnew,
statediffremoved,
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = db.rooms.load_shortstatehash_info(&current_sstatehash, &db)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = k[0..size_of::<u64>()].to_vec();
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if sstatehash != current_sstatehash {
if !current_sstatehash.is_empty() {
counter += 1;
println!("counter: {}", counter);
let current_room = current_room.as_ref().unwrap();
let last_roomsstatehash = last_roomstates.get(&current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|last_roomsstatehash| {
load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru)
},
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_ref().unwrap(),
current_state,
&mut last_roomstates,
)?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.cloned()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
update_shortstatehash_level(
&current_sstatehash,
statediffnew,
statediffremoved,
2, // every state change is 2 event changes on average
states_parents,
&db,
)?;
/*
let mut tmp = load_shortstatehash_info(&current_sstatehash, &db)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
last_roomstates.insert(current_room.clone(), current_sstatehash);
last_roomstates
.insert(current_room.clone().unwrap(), current_sstatehash);
}
current_state = HashSet::new();
current_sstatehash = sstatehash;
current_sstatehash = Some(sstatehash);
let event_id = db
.rooms
@ -721,7 +554,16 @@ impl Database {
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state.insert(val);
current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_ref().unwrap(),
current_state,
&mut last_roomstates,
)?;
}
db.globals.bump_database_version(7)?;
@ -761,11 +603,28 @@ impl Database {
db.rooms.pduid_pdu.insert_batch(&mut batch)?;
for (key, _) in db.rooms.pduid_pdu.iter() {
if key.starts_with(b"!") {
db.rooms.pduid_pdu.remove(&key);
let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| {
if !value.starts_with(b"!") {
return None;
}
}
let mut parts = value.splitn(2, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(&room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.rooms.eventid_pduid.insert_batch(&mut batch2)?;
db.globals.bump_database_version(8)?;
@ -803,7 +662,7 @@ impl Database {
for (key, _) in db.rooms.tokenids.iter() {
if key.starts_with(b"!") {
db.rooms.pduid_pdu.remove(&key)?;
db.rooms.tokenids.remove(&key)?;
}
}
@ -811,8 +670,6 @@ impl Database {
println!("Migration: 8 -> 9 finished");
}
panic!();
}
let guard = db.read().await;

View file

@ -9,13 +9,13 @@ use std::{
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::oneshot::Sender;
use tracing::debug;
thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
}
struct PreparedStatementIterator<'a> {
@ -77,6 +77,21 @@ impl Engine {
})
}
fn read_lock_iterator(&self) -> &'static Connection {
READ_CONNECTION_ITERATOR.with(|cell| {
let connection = &mut cell.borrow_mut();
if (*connection).is_none() {
let c = Box::leak(Box::new(
Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(),
));
**connection = Some(c);
}
connection.unwrap()
})
}
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock()
.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?;
@ -151,6 +166,34 @@ impl SqliteTable {
)?;
Ok(())
}
pub fn iter_with_guard<'a>(
&'a self,
guard: &'a Connection,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
}
}
impl Tree for SqliteTable {
@ -219,30 +262,9 @@ impl Tree for SqliteTable {
#[tracing::instrument(skip(self))]
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock();
let guard = self.engine.read_lock_iterator();
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
self.iter_with_guard(&guard)
}
#[tracing::instrument(skip(self, from, backwards))]
@ -251,7 +273,7 @@ impl Tree for SqliteTable {
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock();
let guard = self.engine.read_lock_iterator();
let from = from.to_vec(); // TODO change interface?
if backwards {

View file

@ -24,7 +24,7 @@ use ruma::{
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
convert::{TryFrom, TryInto},
mem,
mem::size_of,
sync::{Arc, Mutex},
};
use tokio::sync::MutexGuard;
@ -37,10 +37,11 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher};
/// This is created when a state group is added to the database by
/// hashing the entire state.
pub type StateHashId = Vec<u8>;
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Rooms {
pub edus: edus::RoomEdus,
pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = RoomId + Count
pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = ShortRoomId + Count
pub(super) eventid_pduid: Arc<dyn Tree>,
pub(super) roomid_pduleaves: Arc<dyn Tree>,
pub(super) alias_roomid: Arc<dyn Tree>,
@ -79,9 +80,6 @@ pub struct Rooms {
pub(super) eventid_shorteventid: Arc<dyn Tree>,
pub(super) statehash_shortstatehash: Arc<dyn Tree>,
/// ShortStateHash = Count
/// StateId = ShortStateHash
pub(super) stateid_shorteventid: Arc<dyn Tree>,
pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
/// RoomId + EventId -> outlier PDU.
@ -100,29 +98,30 @@ impl Rooms {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
Ok(self
.stateid_shorteventid
.scan_prefix(shortstatehash.to_be_bytes().to_vec())
.map(|(_, bytes)| {
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap())
.ok()
})
.flatten()
.collect())
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
full_state
.into_iter()
.map(|compressed| self.parse_compressed_state_event(compressed))
.collect()
}
pub fn state_full(
&self,
shortstatehash: u64,
) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> {
let state = self
.stateid_shorteventid
.scan_prefix(shortstatehash.to_be_bytes().to_vec())
.map(|(_, bytes)| {
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap())
.ok()
})
.flatten()
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.map(|compressed| self.parse_compressed_state_event(compressed))
.filter_map(|r| r.ok())
.map(|eventid| self.get_pdu(&eventid))
.filter_map(|r| r.ok().flatten())
.map(|pdu| {
@ -138,9 +137,7 @@ impl Rooms {
))
})
.filter_map(|r| r.ok())
.collect();
Ok(state)
.collect())
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -151,27 +148,19 @@ impl Rooms {
event_type: &EventType,
state_key: &str,
) -> Result<Option<EventId>> {
let mut key = event_type.as_ref().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&state_key.as_bytes());
let shortstatekey = self.statekey_shortstatekey.get(&key)?;
if let Some(shortstatekey) = shortstatekey {
let mut stateid = shortstatehash.to_be_bytes().to_vec();
stateid.extend_from_slice(&shortstatekey);
Ok(self
.stateid_shorteventid
.get(&stateid)?
.map(|bytes| {
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap())
.ok()
})
.flatten())
} else {
Ok(None)
}
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? {
Some(s) => s,
None => return Ok(None),
};
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| self.parse_compressed_state_event(compressed).ok()))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -260,8 +249,7 @@ impl Rooms {
/// Checks if a room exists.
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
// Look for PDUs in that room.
Ok(self
@ -274,8 +262,7 @@ impl Rooms {
/// Checks if a room exists.
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
// Look for PDUs in that room.
self.pduid_pdu
@ -292,74 +279,78 @@ impl Rooms {
/// Force the creation of a new StateHash and insert it into the db.
///
/// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot.
/// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
pub fn force_state(
&self,
room_id: &RoomId,
state: HashMap<(EventType, String), EventId>,
new_state: HashMap<(EventType, String), EventId>,
db: &Database,
) -> Result<()> {
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let new_state_ids_compressed = new_state
.iter()
.filter_map(|((event_type, state_key), event_id)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, &db.globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, event_id, &db.globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&state
&new_state
.values()
.map(|event_id| event_id.as_bytes())
.collect::<Vec<_>>(),
);
let (shortstatehash, already_existed) =
let (new_shortstatehash, already_existed) =
self.get_or_create_shortstatehash(&state_hash, &db.globals)?;
let new_state = if !already_existed {
let mut new_state = HashSet::new();
if Some(new_shortstatehash) == previous_shortstatehash {
return Ok(());
}
let batch = state
.iter()
.filter_map(|((event_type, state_key), eventid)| {
new_state.insert(eventid.clone());
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last()
{
let statediffnew = new_state_ids_compressed
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? {
Some(shortstatekey) => shortstatekey.to_vec(),
None => {
let shortstatekey = db.globals.next_count().ok()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())
.ok()?;
shortstatekey.to_be_bytes().to_vec()
}
};
let statediffremoved = parent_stateinfo
.1
.difference(&new_state_ids_compressed)
.cloned()
.collect::<HashSet<_>>();
let shorteventid = self
.get_or_create_shorteventid(&eventid, &db.globals)
.ok()?;
let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey);
Some((state_id, shorteventid.to_be_bytes().to_vec()))
})
.collect::<Vec<_>>();
self.stateid_shorteventid
.insert_batch(&mut batch.into_iter())?;
new_state
(statediffnew, statediffremoved)
} else {
self.state_full_ids(shortstatehash)?.into_iter().collect()
(new_state_ids_compressed, HashSet::new())
};
let old_state = self
.current_shortstatehash(&room_id)?
.map(|s| self.state_full_ids(s))
.transpose()?
.map(|vec| vec.into_iter().collect::<HashSet<_>>())
.unwrap_or_default();
if !already_existed {
self.save_state_from_diff(
new_shortstatehash,
statediffnew.clone(),
statediffremoved.clone(),
2, // every state change is 2 event changes on average
states_parents,
)?;
};
for event_id in new_state.difference(&old_state) {
if let Some(pdu) = self.get_pdu_json(event_id)? {
for event_id in statediffnew
.into_iter()
.filter_map(|new| self.parse_compressed_state_event(new).ok())
{
if let Some(pdu) = self.get_pdu_json(&event_id)? {
if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") {
if let Ok(pdu) = serde_json::from_value::<PduEvent>(
serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"),
@ -392,7 +383,206 @@ impl Rooms {
}
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?;
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
/// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
pub fn load_shortstatehash_info(
&self,
shortstatehash: u64,
) -> Result<
Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent =
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
if parent != 0_u64 {
let mut response = self.load_shortstatehash_info(parent)?;
let mut state = response.last().unwrap().1.clone();
state.extend(added.iter().cloned());
for r in &removed {
state.remove(r);
}
response.push((shortstatehash, state, added, removed));
Ok(response)
} else {
let mut response = Vec::new();
response.push((shortstatehash, added.clone(), added, removed));
Ok(response)
}
}
pub fn compress_state_event(
&self,
shortstatekey: u64,
event_id: &EventId,
globals: &super::globals::Globals,
) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&self
.get_or_create_shorteventid(event_id, globals)?
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above"))
}
pub fn parse_compressed_state_event(
&self,
compressed_event: CompressedStateEvent,
) -> Result<EventId> {
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
.expect("bytes have right length"),
)
}
/// Creates a new shortstatehash that often is just a diff to an already existing
/// shortstatehash and therefore very efficient.
///
/// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer
/// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0
/// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's
/// based on layer n-2. If that layer is also too big, it will recursively fix above layers too.
///
/// * `shortstatehash` - Shortstatehash of this state
/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
/// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid
/// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer
/// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer
pub fn save_state_from_diff(
&self,
shortstatehash: u64,
statediffnew: HashSet<CompressedStateEvent>,
statediffremoved: HashSet<CompressedStateEvent>,
diff_to_sibling: usize,
mut parent_states: Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
) -> Result<()> {
let diffsum = statediffnew.len() + statediffremoved.len();
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().unwrap();
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
self.save_state_from_diff(
shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
)?;
return Ok(());
}
if parent_states.len() == 0 {
// There is no parent layer, create a new state
let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
for new in &statediffnew {
value.extend_from_slice(&new[..]);
}
if !statediffremoved.is_empty() {
warn!("Tried to create new state with removals");
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)?;
return Ok(());
};
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let parent = parent_states.pop().unwrap();
let parent_diff = parent.2.len() + parent.3.len();
if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff {
// Diff too big, we replace above layer(s)
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
self.save_state_from_diff(
shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
let mut value = parent.0.to_be_bytes().to_vec();
for new in &statediffnew {
value.extend_from_slice(&new[..]);
}
if !statediffremoved.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in &statediffremoved {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)?;
}
Ok(())
}
@ -418,7 +608,6 @@ impl Rooms {
})
}
/// Returns (shortstatehash, already_existed)
pub fn get_or_create_shorteventid(
&self,
event_id: &EventId,
@ -438,6 +627,71 @@ impl Rooms {
})
}
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
let bytes = self
.roomid_shortroomid
.get(&room_id.as_bytes())?
.expect("every room has a shortroomid");
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))
}
pub fn get_shortstatekey(
&self,
event_type: &EventType,
state_key: &str,
) -> Result<Option<u64>> {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
self.statekey_shortstatekey
.get(&statekey)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()
}
pub fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
globals: &super::globals::Globals,
) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
None => {
let short = globals.next_count()?;
self.roomid_shortroomid
.insert(&room_id.as_bytes(), &short.to_be_bytes())?;
short
}
})
}
pub fn get_or_create_shortstatekey(
&self,
event_type: &EventType,
state_key: &str,
globals: &super::globals::Globals,
) -> Result<u64> {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
Ok(match self.statekey_shortstatekey.get(&statekey)? {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => {
let shortstatekey = globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
shortstatekey
}
})
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> {
if let Some(id) = self
.shorteventid_cache
@ -514,7 +768,7 @@ impl Rooms {
#[tracing::instrument(skip(self))]
pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
Ok(
utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()])
utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?,
)
}
@ -527,8 +781,7 @@ impl Rooms {
}
pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -758,6 +1011,8 @@ impl Rooms {
///
/// By this point the incoming event should be fully authenticated, no auth happens
/// in `append_pdu`.
///
/// Returns pdu id
#[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))]
pub fn append_pdu(
&self,
@ -766,7 +1021,8 @@ impl Rooms {
leaves: &[EventId],
db: &Database,
) -> Result<Vec<u8>> {
// returns pdu id
let shortroomid = self.get_shortroomid(&pdu.room_id)?;
// Make unsigned fields correct. This is not properly documented in the spec, but state
// events need to have previous content in the unsigned field, so clients can easily
// interpret things like membership changes
@ -821,8 +1077,7 @@ impl Rooms {
self.reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = db.globals.next_count()?;
let mut pdu_id = pdu.room_id.as_bytes().to_vec();
pdu_id.push(0xff);
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes());
// There's a brief moment of time here where the count is updated but the pdu does not
@ -968,8 +1223,7 @@ impl Rooms {
.filter(|word| word.len() <= 50)
.map(str::to_lowercase)
.map(|word| {
let mut key = pdu.room_id.as_bytes().to_vec();
key.push(0xff);
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xff);
key.extend_from_slice(&pdu_id);
@ -1152,11 +1406,27 @@ impl Rooms {
pub fn set_event_state(
&self,
event_id: &EventId,
room_id: &RoomId,
state: &StateMap<Arc<PduEvent>>,
globals: &super::globals::Globals,
) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let state_ids_compressed = state
.iter()
.filter_map(|((event_type, state_key), pdu)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, &pdu.event_id, globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&state
.values()
@ -1168,37 +1438,33 @@ impl Rooms {
self.get_or_create_shortstatehash(&state_hash, globals)?;
if !already_existed {
let batch = state
.iter()
.filter_map(|((event_type, state_key), pdu)| {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? {
Some(shortstatekey) => shortstatekey.to_vec(),
None => {
let shortstatekey = globals.next_count().ok()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())
.ok()?;
shortstatekey.to_be_bytes().to_vec()
}
};
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = state_ids_compressed
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let shorteventid = self
.get_or_create_shorteventid(&pdu.event_id, globals)
.ok()?;
let statediffremoved = parent_stateinfo
.1
.difference(&state_ids_compressed)
.cloned()
.collect::<HashSet<_>>();
let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey);
Some((state_id, shorteventid.to_be_bytes().to_vec()))
})
.collect::<Vec<_>>();
self.stateid_shorteventid
.insert_batch(&mut batch.into_iter())?;
(statediffnew, statediffremoved)
} else {
(state_ids_compressed, HashSet::new())
};
self.save_state_from_diff(
shortstatehash,
statediffnew.clone(),
statediffremoved.clone(),
1_000_000, // high number because no state will be based on this one
states_parents,
)?;
}
self.shorteventid_shortstatehash
@ -1219,82 +1485,52 @@ impl Rooms {
) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?;
let old_state = if let Some(old_shortstatehash) =
self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())?
{
// Store state for event. The state does not include the event itself.
// Instead it's the state before the pdu, so the room's old state.
let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?;
if let Some(p) = previous_shortstatehash {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?;
if new_pdu.state_key.is_none() {
return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.")
});
}
self.stateid_shorteventid
.scan_prefix(old_shortstatehash.clone())
// Chop the old_shortstatehash out leaving behind the short state key
.map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v))
.collect::<HashMap<Vec<u8>, Vec<u8>>>()
} else {
HashMap::new()
};
.insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?;
}
if let Some(state_key) = &new_pdu.state_key {
let mut new_state: HashMap<Vec<u8>, Vec<u8>> = old_state;
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec();
new_state_key.push(0xff);
new_state_key.extend_from_slice(state_key.as_bytes());
let shortstatekey =
self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?;
let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? {
Some(shortstatekey) => shortstatekey.to_vec(),
None => {
let shortstatekey = globals.next_count()?;
self.statekey_shortstatekey
.insert(&new_state_key, &shortstatekey.to_be_bytes())?;
shortstatekey.to_be_bytes().to_vec()
}
};
let replaces = states_parents
.last()
.map(|info| {
info.1
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
})
.unwrap_or_default();
new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec());
// TODO: statehash with deterministic inputs
let shortstatehash = globals.next_count()?;
let new_state_hash = self.calculate_hash(
&new_state
.values()
.map(|event_id| &**event_id)
.collect::<Vec<_>>(),
);
let mut statediffnew = HashSet::new();
let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?;
statediffnew.insert(new);
let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? {
Some(shortstatehash) => {
warn!("state hash already existed?!");
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?
}
None => {
let shortstatehash = globals.next_count()?;
self.statehash_shortstatehash
.insert(&new_state_hash, &shortstatehash.to_be_bytes())?;
shortstatehash
}
};
let mut statediffremoved = HashSet::new();
if let Some(replaces) = replaces {
statediffremoved.insert(replaces.clone());
}
let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| {
let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey);
(state_id, shorteventid)
});
self.stateid_shorteventid.insert_batch(&mut batch)?;
self.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
2,
states_parents,
)?;
Ok(shortstatehash)
} else {
Err(Error::bad_database(
"Tried to insert non-state event into room without a state.",
))
Ok(previous_shortstatehash.expect("first event in room must be a state event"))
}
}
@ -1597,7 +1833,7 @@ impl Rooms {
&'a self,
user_id: &UserId,
room_id: &RoomId,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a {
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
self.pdus_since(user_id, room_id, 0)
}
@ -1609,16 +1845,17 @@ impl Rooms {
user_id: &UserId,
room_id: &RoomId,
since: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
// Skip the first pdu if it's exactly at since, because we sent that last time
let mut first_pdu_id = prefix.clone();
first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes());
let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(&first_pdu_id, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
@ -1628,7 +1865,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id");
}
Ok((pdu_id, pdu))
})
}))
}
/// Returns an iterator over all events and their tokens in a room that happened before the
@ -1639,10 +1876,9 @@ impl Rooms {
user_id: &UserId,
room_id: &RoomId,
until: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a {
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
// Create the first part of the full pdu id
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until`
@ -1650,7 +1886,9 @@ impl Rooms {
let current: &[u8] = &current;
let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
@ -1660,7 +1898,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id");
}
Ok((pdu_id, pdu))
})
}))
}
/// Returns an iterator over all events and their token in a room that happened after the event
@ -1671,10 +1909,9 @@ impl Rooms {
user_id: &UserId,
room_id: &RoomId,
from: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a {
) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
// Create the first part of the full pdu id
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
@ -1682,7 +1919,9 @@ impl Rooms {
let current: &[u8] = &current;
let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
@ -1692,7 +1931,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id");
}
Ok((pdu_id, pdu))
})
}))
}
/// Replace a PDU with the redacted form.
@ -2223,8 +2462,8 @@ impl Rooms {
room_id: &RoomId,
search_string: &str,
) -> Result<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
let prefix_clone = prefix.clone();
let words = search_string
.split_terminator(|c: char| !c.is_alphanumeric())
@ -2243,16 +2482,7 @@ impl Rooms {
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(|(key, _)| {
let pduid_index = key
.iter()
.enumerate()
.filter(|(_, &b)| b == 0xff)
.nth(1)
.ok_or_else(|| Error::bad_database("Invalid tokenid in db."))?
.0
+ 1; // +1 because the pdu id starts AFTER the separator
let pdu_id = key[pduid_index..].to_vec();
let pdu_id = key[key.len() - size_of::<u64>()..].to_vec();
Ok::<_, Error>(pdu_id)
})
@ -2264,7 +2494,12 @@ impl Rooms {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
})
.unwrap(),
.unwrap()
.map(move |id| {
let mut pduid = prefix_clone.clone();
pduid.extend_from_slice(&id);
pduid
}),
words,
))
}

View file

@ -1704,7 +1704,7 @@ fn append_incoming_pdu(
// We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail.
db.rooms
.set_event_state(&pdu.event_id, state, &db.globals)?;
.set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?;
let pdu_id = db.rooms.append_pdu(
pdu,