From 7bdd9660aa51b2d3d0d39b35b14e49c3e4d6a23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 20 Feb 2023 22:59:45 +0100 Subject: [PATCH] feat: ask for backfill --- src/api/client_server/context.rs | 33 ++-- src/api/client_server/membership.rs | 6 +- src/api/client_server/message.rs | 52 +++-- src/api/client_server/read_marker.rs | 60 ++++-- src/api/client_server/sync.rs | 34 ++-- src/api/server_server.rs | 64 +++--- src/database/key_value/rooms/timeline.rs | 237 ++++++++++++++--------- src/database/mod.rs | 11 +- src/service/rooms/lazy_loading/mod.rs | 8 +- src/service/rooms/state_accessor/mod.rs | 2 +- src/service/rooms/timeline/data.rs | 36 ++-- src/service/rooms/timeline/mod.rs | 228 +++++++++++++++++++--- 12 files changed, 502 insertions(+), 269 deletions(-) diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index 1e62f910..fa3c7543 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -27,25 +27,24 @@ pub async fn get_context_route( let mut lazy_loaded = HashSet::new(); - let base_pdu_id = services() + let base_token = services() .rooms .timeline - .get_pdu_id(&body.event_id)? + .get_pdu_count(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Base event id not found.", ))?; - let base_token = services().rooms.timeline.pdu_count(&base_pdu_id)?; - - let base_event = services() - .rooms - .timeline - .get_pdu_from_id(&base_pdu_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Base event not found.", - ))?; + let base_event = + services() + .rooms + .timeline + .get_pdu(&body.event_id)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Base event not found.", + ))?; let room_id = base_event.room_id.clone(); @@ -97,10 +96,7 @@ pub async fn get_context_route( } } - let start_token = events_before - .last() - .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) - .map(|count| count.to_string()); + let start_token = events_before.last().map(|(count, _)| count.stringify()); let events_before: Vec<_> = events_before .into_iter() @@ -151,10 +147,7 @@ pub async fn get_context_route( .state_full_ids(shortstatehash) .await?; - let end_token = events_after - .last() - .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) - .map(|count| count.to_string()); + let end_token = events_after.last().map(|(count, _)| count.stringify()); let events_after: Vec<_> = events_after .into_iter() diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 61c67cbc..cd0cc7a7 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -714,8 +714,10 @@ async fn join_room_by_id_helper( .ok()? }, ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? - { + .map_err(|e| { + warn!("Auth check failed: {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })? { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Auth check failed", diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 6ad07517..a0c9571b 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,4 +1,7 @@ -use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma}; +use crate::{ + service::{pdu::PduBuilder, rooms::timeline::PduCount}, + services, utils, Error, Result, Ruma, +}; use ruma::{ api::client::{ error::ErrorKind, @@ -122,17 +125,17 @@ pub async fn get_message_events_route( } let from = match body.from.clone() { - Some(from) => from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from` value."))?, - + Some(from) => PduCount::try_from_string(&from)?, None => match body.dir { - ruma::api::client::Direction::Forward => 0, - ruma::api::client::Direction::Backward => u64::MAX, + ruma::api::client::Direction::Forward => PduCount::min(), + ruma::api::client::Direction::Backward => PduCount::max(), }, }; - let to = body.to.as_ref().map(|t| t.parse()); + let to = body + .to + .as_ref() + .and_then(|t| PduCount::try_from_string(&t).ok()); services().rooms.lazy_loading.lazy_load_confirm_delivery( sender_user, @@ -158,15 +161,7 @@ pub async fn get_message_events_route( .pdus_after(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events - .filter_map(|(pdu_id, pdu)| { - services() - .rooms - .timeline - .pdu_count(&pdu_id) - .map(|pdu_count| (pdu_count, pdu)) - .ok() - }) - .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .collect(); for (_, event) in &events_after { @@ -192,26 +187,23 @@ pub async fn get_message_events_route( .map(|(_, pdu)| pdu.to_room_event()) .collect(); - resp.start = from.to_string(); - resp.end = next_token.map(|count| count.to_string()); + resp.start = from.stringify(); + resp.end = next_token.map(|count| count.stringify()); resp.chunk = events_after; } ruma::api::client::Direction::Backward => { + services() + .rooms + .timeline + .backfill_if_required(&body.room_id, from) + .await?; let events_before: Vec<_> = services() .rooms .timeline .pdus_until(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events - .filter_map(|(pdu_id, pdu)| { - services() - .rooms - .timeline - .pdu_count(&pdu_id) - .map(|pdu_count| (pdu_count, pdu)) - .ok() - }) - .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .collect(); for (_, event) in &events_before { @@ -237,8 +229,8 @@ pub async fn get_message_events_route( .map(|(_, pdu)| pdu.to_room_event()) .collect(); - resp.start = from.to_string(); - resp.end = next_token.map(|count| count.to_string()); + resp.start = from.stringify(); + resp.end = next_token.map(|count| count.stringify()); resp.chunk = events_before; } } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index b12468a7..a5553d25 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,4 +1,4 @@ -use crate::{services, Error, Result, Ruma}; +use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; use ruma::{ api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, events::{ @@ -42,18 +42,28 @@ pub async fn set_read_marker_route( } if let Some(event) = &body.private_read_receipt { - services().rooms.edus.read_receipt.private_read_set( - &body.room_id, - sender_user, - services() - .rooms - .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest( + let count = services() + .rooms + .timeline + .get_pdu_count(event)? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event does not exist.", + ))?; + let count = match count { + PduCount::Backfilled(_) => { + return Err(Error::BadRequest( ErrorKind::InvalidParam, - "Event does not exist.", - ))?, - )?; + "Read receipt is in backfilled timeline", + )) + } + PduCount::Normal(c) => c, + }; + services() + .rooms + .edus + .read_receipt + .private_read_set(&body.room_id, sender_user, count)?; } if let Some(event) = &body.read_receipt { @@ -142,17 +152,27 @@ pub async fn create_receipt_route( )?; } create_receipt::v3::ReceiptType::ReadPrivate => { + let count = services() + .rooms + .timeline + .get_pdu_count(&body.event_id)? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event does not exist.", + ))?; + let count = match count { + PduCount::Backfilled(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Read receipt is in backfilled timeline", + )) + } + PduCount::Normal(c) => c, + }; services().rooms.edus.read_receipt.private_read_set( &body.room_id, sender_user, - services() - .rooms - .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?, + count, )?; } _ => return Err(Error::bad_database("Unsupported receipt type")), diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 568a23ce..834438c9 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,4 +1,4 @@ -use crate::{services, Error, Result, Ruma, RumaResponse}; +use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma, RumaResponse}; use ruma::{ api::client::{ filter::{FilterDefinition, LazyLoadOptions}, @@ -172,6 +172,7 @@ async fn sync_helper( let watcher = services().globals.watch(&sender_user, &sender_device); let next_batch = services().globals.current_count()?; + let next_batchcount = PduCount::Normal(next_batch); let next_batch_string = next_batch.to_string(); // Load filter @@ -197,6 +198,7 @@ async fn sync_helper( .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); + let sincecount = PduCount::Normal(since); let mut presence_updates = HashMap::new(); let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in @@ -241,12 +243,12 @@ async fn sync_helper( .rooms .timeline .last_timeline_count(&sender_user, &room_id)? - > since + > sincecount { let mut non_timeline_pdus = services() .rooms .timeline - .pdus_until(&sender_user, &room_id, u64::MAX)? + .pdus_until(&sender_user, &room_id, PduCount::max())? .filter_map(|r| { // Filter out buggy events if r.is_err() { @@ -254,13 +256,7 @@ async fn sync_helper( } r.ok() }) - .take_while(|(pduid, _)| { - services() - .rooms - .timeline - .pdu_count(pduid) - .map_or(false, |count| count > since) - }); + .take_while(|(pducount, _)| pducount > &sincecount); // Take the last 10 events for the timeline timeline_pdus = non_timeline_pdus @@ -295,7 +291,7 @@ async fn sync_helper( &sender_user, &sender_device, &room_id, - since, + sincecount, )?; // Database queries: @@ -492,7 +488,7 @@ async fn sync_helper( &sender_device, &room_id, lazy_loaded, - next_batch, + next_batchcount, ); ( @@ -582,7 +578,7 @@ async fn sync_helper( &sender_device, &room_id, lazy_loaded, - next_batch, + next_batchcount, ); let encrypted_room = services() @@ -711,10 +707,14 @@ async fn sync_helper( let prev_batch = timeline_pdus .first() - .map_or(Ok::<_, Error>(None), |(pdu_id, _)| { - Ok(Some( - services().rooms.timeline.pdu_count(pdu_id)?.to_string(), - )) + .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + } + PduCount::Normal(c) => c.to_string(), + })) })?; let room_events: Vec<_> = timeline_pdus diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 11a6cbf4..e95a560b 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -629,6 +629,37 @@ pub async fn get_public_rooms_route( }) } +pub fn parse_incoming_pdu( + pdu: &RawJsonValue, +) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid room id in pdu", + ))?; + + let room_version_id = services().rooms.state.get_room_version(&room_id)?; + + let (event_id, value) = match gen_event_id_canonical_json(&pdu, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + Ok((event_id, value, room_id)) +} + /// # `PUT /_matrix/federation/v1/send/{txnId}` /// /// Push EDUs and PDUs to this server. @@ -657,36 +688,7 @@ pub async fn send_transaction_message_route( // let mut auth_cache = EventMap::new(); for pdu in &body.pdus { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let room_id: OwnedRoomId = match value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - { - Some(id) => id, - None => { - // Event is invalid - continue; - } - }; - - let room_version_id = match services().rooms.state.get_room_version(&room_id) { - Ok(v) => v, - Err(_) => { - continue; - } - }; - - let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - continue; - } - }; + let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?; // We do not add the event_id field to the pdu here because of signature and hashes checks services() @@ -1017,7 +1019,7 @@ pub async fn get_backfill_route( Ok(true), ) }) - .map(|(pdu_id, _)| services().rooms.timeline.get_pdu_json_from_id(&pdu_id)) + .map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) .filter_map(|r| r.ok().flatten()) .map(|pdu| PduEvent::convert_to_outgoing_federation_event(pdu)) .collect(); diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 336317da..9f2c6074 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -7,6 +7,8 @@ use tracing::error; use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use service::rooms::timeline::PduCount; + impl service::rooms::timeline::Data for KeyValueDatabase { fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { let prefix = services() @@ -30,7 +32,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .transpose() } - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() @@ -39,20 +41,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase { { hash_map::Entry::Vacant(v) => { if let Some(last_count) = self - .pdus_until(sender_user, room_id, u64::MAX)? - .filter_map(|r| { + .pdus_until(sender_user, room_id, PduCount::max())? + .find_map(|r| { // Filter out buggy events if r.is_err() { error!("Bad pdu in pdus_since: {:?}", r); } r.ok() }) - .map(|(pduid, _)| self.pdu_count(&pduid)) - .next() { - Ok(*v.insert(last_count?)) + Ok(*v.insert(last_count.0)) } else { - Ok(0) + Ok(PduCount::Normal(0)) } } hash_map::Entry::Occupied(o) => Ok(*o.get()), @@ -60,11 +60,23 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid + fn get_pdu_count(&self, event_id: &EventId) -> Result> { + Ok(self + .eventid_pduid .get(event_id.as_bytes())? - .map(|pdu_id| self.pdu_count(&pdu_id)) - .transpose() + .map(|pdu_id| Ok::<_, Error>(PduCount::Normal(pdu_count(&pdu_id)?))) + .transpose()? + .map_or_else( + || { + Ok::<_, Error>( + self.eventid_backfillpduid + .get(event_id.as_bytes())? + .map(|pdu_id| Ok::<_, Error>(PduCount::Backfilled(pdu_count(&pdu_id)?))) + .transpose()?, + ) + }, + |x| Ok(Some(x)), + )?) } /// Returns the json of a pdu. @@ -182,12 +194,6 @@ impl service::rooms::timeline::Data for KeyValueDatabase { }) } - /// Returns the `count` of this pdu's id. - fn pdu_count(&self, pdu_id: &[u8]) -> Result { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) - } - fn append_pdu( &self, pdu_id: &[u8], @@ -203,7 +209,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { self.lasttimelinecount_cache .lock() .unwrap() - .insert(pdu.room_id.clone(), count); + .insert(pdu.room_id.clone(), PduCount::Normal(count)); self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; @@ -211,6 +217,24 @@ impl service::rooms::timeline::Data for KeyValueDatabase { Ok(()) } + fn prepend_backfill_pdu( + &self, + pdu_id: &[u8], + event_id: &EventId, + json: &CanonicalJsonObject, + ) -> Result<()> { + self.pduid_backfillpdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; + + self.eventid_backfillpduid + .insert(event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(event_id.as_bytes())?; + + Ok(()) + } + /// Removes a pdu and creates a new one with the same id. fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(pdu_id)?.is_some() { @@ -227,51 +251,14 @@ impl service::rooms::timeline::Data for KeyValueDatabase { } } - /// Returns an iterator over all events in a room that happened after the event with id `since` - /// in chronological order. - fn pdus_since<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - since: u64, - ) -> Result, PduEvent)>> + 'a>> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Skip the first pdu if it's exactly at since, because we sent that last time - let mut first_pdu_id = prefix.clone(); - first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(&first_pdu_id, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - }), - )) - } - /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, - until: u64, - ) -> Result, PduEvent)>> + 'a>> { + until: PduCount, + ) -> Result> + 'a>> { // Create the first part of the full pdu id let prefix = services() .rooms @@ -281,34 +268,63 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .to_be_bytes() .to_vec(); - let mut current = prefix.clone(); - current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` - - let current: &[u8] = ¤t; + let mut current_backfill = prefix.clone(); + // +1 so we don't send the base event + let backfill_count = match until { + PduCount::Backfilled(x) => x + 1, + PduCount::Normal(_) => 0, + }; + current_backfill.extend_from_slice(&backfill_count.to_be_bytes()); let user_id = user_id.to_owned(); + let user_id2 = user_id.to_owned(); + let prefix2 = prefix.clone(); - Ok(Box::new( - self.pduid_pdu - .iter_from(current, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - }), - )) + let backfill_iter = self + .pduid_backfillpdu + .iter_from(¤t_backfill, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + let count = PduCount::Backfilled(pdu_count(&pdu_id)?); + Ok((count, pdu)) + }); + + match until { + PduCount::Backfilled(_) => Ok(Box::new(backfill_iter)), + PduCount::Normal(x) => { + let mut current_normal = prefix2.clone(); + // -1 so we don't send the base event + current_normal.extend_from_slice(&x.saturating_sub(1).to_be_bytes()); + let normal_iter = self + .pduid_pdu + .iter_from(¤t_normal, true) + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id2 { + pdu.remove_transaction_id()?; + } + let count = PduCount::Normal(pdu_count(&pdu_id)?); + Ok((count, pdu)) + }); + + Ok(Box::new(normal_iter.chain(backfill_iter))) + } + } } fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, - from: u64, - ) -> Result, PduEvent)>> + 'a>> { + from: PduCount, + ) -> Result> + 'a>> { // Create the first part of the full pdu id let prefix = services() .rooms @@ -318,26 +334,55 @@ impl service::rooms::timeline::Data for KeyValueDatabase { .to_be_bytes() .to_vec(); - let mut current = prefix.clone(); - current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event - - let current: &[u8] = ¤t; + let mut current_normal = prefix.clone(); + // +1 so we don't send the base event + let normal_count = match from { + PduCount::Normal(x) => x + 1, + PduCount::Backfilled(_) => 0, + }; + current_normal.extend_from_slice(&normal_count.to_be_bytes()); let user_id = user_id.to_owned(); + let user_id2 = user_id.to_owned(); + let prefix2 = prefix.clone(); - Ok(Box::new( - self.pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - }), - )) + let normal_iter = self + .pduid_pdu + .iter_from(¤t_normal, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + let count = PduCount::Normal(pdu_count(&pdu_id)?); + Ok((count, pdu)) + }); + + match from { + PduCount::Normal(_) => Ok(Box::new(normal_iter)), + PduCount::Backfilled(x) => { + let mut current_backfill = prefix2.clone(); + // -1 so we don't send the base event + current_backfill.extend_from_slice(&x.saturating_sub(1).to_be_bytes()); + let backfill_iter = self + .pduid_backfillpdu + .iter_from(¤t_backfill, true) + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id2 { + pdu.remove_transaction_id()?; + } + let count = PduCount::Backfilled(pdu_count(&pdu_id)?); + Ok((count, pdu)) + }); + + Ok(Box::new(backfill_iter.chain(normal_iter))) + } + } } fn increment_notification_counts( @@ -368,3 +413,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { Ok(()) } } + +/// Returns the `count` of this pdu's id. +fn pdu_count(pdu_id: &[u8]) -> Result { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 46ba5b33..f07ad879 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,7 +1,10 @@ pub mod abstraction; pub mod key_value; -use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES}; +use crate::{ + service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services, + SERVICES, +}; use abstraction::{KeyValueDatabaseEngine, KvTree}; use directories::ProjectDirs; use lru_cache::LruCache; @@ -71,7 +74,9 @@ pub struct KeyValueDatabase { //pub rooms: rooms::Rooms, pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count + pub(super) pduid_backfillpdu: Arc, // PduId = ShortRoomId + Count pub(super) eventid_pduid: Arc, + pub(super) eventid_backfillpduid: Arc, pub(super) roomid_pduleaves: Arc, pub(super) alias_roomid: Arc, pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count @@ -161,7 +166,7 @@ pub struct KeyValueDatabase { pub(super) shortstatekey_cache: Mutex>, pub(super) our_real_users_cache: RwLock>>>, pub(super) appservice_in_room_cache: RwLock>>, - pub(super) lasttimelinecount_cache: Mutex>, + pub(super) lasttimelinecount_cache: Mutex>, } impl KeyValueDatabase { @@ -292,7 +297,9 @@ impl KeyValueDatabase { presenceid_presence: builder.open_tree("presenceid_presence")?, userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, pduid_pdu: builder.open_tree("pduid_pdu")?, + pduid_backfillpdu: builder.open_tree("pduid_backfillpdu")?, eventid_pduid: builder.open_tree("eventid_pduid")?, + eventid_backfillpduid: builder.open_tree("eventid_backfillpduid")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, alias_roomid: builder.open_tree("alias_roomid")?, diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 701a7340..e6e4f896 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -9,11 +9,13 @@ use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; +use super::timeline::PduCount; + pub struct Service { pub db: &'static dyn Data, pub lazy_load_waiting: - Mutex>>, + Mutex>>, } impl Service { @@ -36,7 +38,7 @@ impl Service { device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, - count: u64, + count: PduCount, ) { self.lazy_load_waiting.lock().unwrap().insert( ( @@ -55,7 +57,7 @@ impl Service { user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - since: u64, + since: PduCount, ) -> Result<()> { if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( user_id.to_owned(), diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index e940ffa1..bd9ef889 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -14,7 +14,7 @@ use ruma::{ }, StateEventType, }, - EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + EventId, OwnedServerName, RoomId, ServerName, UserId, }; use tracing::error; diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 9377af07..c8021055 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,12 +4,14 @@ use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; use crate::{PduEvent, Result}; +use super::PduCount; + pub trait Data: Send + Sync { fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>>; - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result>; + fn get_pdu_count(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. fn get_pdu_json(&self, event_id: &EventId) -> Result>; @@ -38,9 +40,6 @@ pub trait Data: Send + Sync { /// Returns the pdu as a `BTreeMap`. fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; - /// Returns the `count` of this pdu's id. - fn pdu_count(&self, pdu_id: &[u8]) -> Result; - /// Adds a new pdu to the timeline fn append_pdu( &self, @@ -50,33 +49,34 @@ pub trait Data: Send + Sync { count: u64, ) -> Result<()>; + // Adds a new pdu to the backfilled timeline + fn prepend_backfill_pdu( + &self, + pdu_id: &[u8], + event_id: &EventId, + json: &CanonicalJsonObject, + ) -> Result<()>; + /// Removes a pdu and creates a new one with the same id. fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; - /// Returns an iterator over all events in a room that happened after the event with id `since` - /// in chronological order. - fn pdus_since<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - since: u64, - ) -> Result, PduEvent)>> + 'a>>; - /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, - until: u64, - ) -> Result, PduEvent)>> + 'a>>; + until: PduCount, + ) -> Result> + 'a>>; + /// Returns an iterator over all events in a room that happened after the event with id `from` + /// in chronological order. fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, - from: u64, - ) -> Result, PduEvent)>> + 'a>>; + from: PduCount, + ) -> Result> + 'a>>; fn increment_notification_counts( &self, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index cc58e6f4..b407dfde 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,7 +1,9 @@ mod data; -use std::collections::HashMap; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashMap}; +use std::sync::RwLock; use std::{ collections::HashSet, sync::{Arc, Mutex}, @@ -9,6 +11,8 @@ use std::{ pub use data::Data; use regex::Regex; +use ruma::api::federation; +use ruma::serde::Base64; use ruma::{ api::client::error::ErrorKind, canonical_json::to_canonical_value, @@ -27,11 +31,13 @@ use ruma::{ uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, UserId, }; +use ruma::{user_id, ServerName}; use serde::Deserialize; -use serde_json::value::to_raw_value; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::MutexGuard; -use tracing::{error, warn}; +use tracing::{error, info, warn}; +use crate::api::server_server; use crate::{ service::pdu::{EventHash, PduBuilder}, services, utils, Error, PduEvent, Result, @@ -39,10 +45,70 @@ use crate::{ use super::state_compressor::CompressedStateEvent; +#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] +pub enum PduCount { + Backfilled(u64), + Normal(u64), +} + +impl PduCount { + pub fn min() -> Self { + Self::Backfilled(u64::MAX) + } + pub fn max() -> Self { + Self::Normal(u64::MAX) + } + + pub fn try_from_string(token: &str) -> Result { + if token.starts_with('-') { + token[1..].parse().map(PduCount::Backfilled) + } else { + token.parse().map(PduCount::Normal) + } + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) + } + + pub fn stringify(&self) -> String { + match self { + PduCount::Backfilled(x) => format!("-{x}"), + PduCount::Normal(x) => x.to_string(), + } + } +} + +impl PartialOrd for PduCount { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PduCount { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), + (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), + (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, + (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, + } + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn comparisons() { + assert!(PduCount::Normal(1) < PduCount::Normal(2)); + assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); + assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); + assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); + } +} + pub struct Service { pub db: &'static dyn Data, - pub lasttimelinecount_cache: Mutex>, + pub lasttimelinecount_cache: Mutex>, } impl Service { @@ -52,10 +118,15 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { self.db.last_timeline_count(sender_user, room_id) } + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.db.get_pdu_count(event_id) + } + // TODO Is this the same as the function above? /* #[tracing::instrument(skip(self))] @@ -79,11 +150,6 @@ impl Service { } */ - /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_count(event_id) - } - /// Returns the json of a pdu. pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_pdu_json(event_id) @@ -128,11 +194,6 @@ impl Service { self.db.get_pdu_json_from_id(pdu_id) } - /// Returns the `count` of this pdu's id. - pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { - self.db.pdu_count(pdu_id) - } - /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { @@ -863,19 +924,8 @@ impl Service { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result, PduEvent)>> + 'a> { - self.pdus_since(user_id, room_id, 0) - } - - /// Returns an iterator over all events in a room that happened after the event with id `since` - /// in chronological order. - pub fn pdus_since<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - since: u64, - ) -> Result, PduEvent)>> + 'a> { - self.db.pdus_since(user_id, room_id, since) + ) -> Result> + 'a> { + self.pdus_after(user_id, room_id, PduCount::min()) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -885,8 +935,8 @@ impl Service { &'a self, user_id: &UserId, room_id: &RoomId, - until: u64, - ) -> Result, PduEvent)>> + 'a> { + until: PduCount, + ) -> Result> + 'a> { self.db.pdus_until(user_id, room_id, until) } @@ -897,8 +947,8 @@ impl Service { &'a self, user_id: &UserId, room_id: &RoomId, - from: u64, - ) -> Result, PduEvent)>> + 'a> { + from: PduCount, + ) -> Result> + 'a> { self.db.pdus_after(user_id, room_id, from) } @@ -915,4 +965,118 @@ impl Service { // If event does not exist, just noop Ok(()) } + + #[tracing::instrument(skip(self, room_id))] + pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { + let first_pdu = self + .all_pdus(&user_id!("@doesntmatter:conduit.rs"), &room_id)? + .next() + .expect("Room is not empty")?; + + if first_pdu.0 < from { + // No backfill required, there are still events between them + return Ok(()); + } + + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + let mut admin_servers = power_levels + .users + .iter() + .filter(|(_, level)| **level > power_levels.users_default) + .map(|(user_id, _)| user_id.server_name()) + .collect::>(); + admin_servers.remove(services().globals.server_name()); + + // Request backfill + for backfill_server in admin_servers { + info!("Asking {backfill_server} for backfill"); + let response = services() + .sending + .send_federation_request( + backfill_server, + federation::backfill::get_backfill::v1::Request { + room_id: room_id.to_owned(), + v: vec![first_pdu.1.event_id.as_ref().to_owned()], + limit: uint!(100), + }, + ) + .await; + match response { + Ok(response) => { + let mut pub_key_map = RwLock::new(BTreeMap::new()); + for pdu in response.pdus { + if let Err(e) = self + .backfill_pdu(backfill_server, pdu, &mut pub_key_map) + .await + { + warn!("Failed to add backfilled pdu: {e}"); + } + } + return Ok(()); + } + Err(e) => { + warn!("{backfill_server} could not provide backfill: {e}"); + } + } + } + + info!("No servers could backfill"); + Ok(()) + } + + #[tracing::instrument(skip(self, pdu))] + pub async fn backfill_pdu( + &self, + origin: &ServerName, + pdu: Box, + pub_key_map: &RwLock>>, + ) -> Result<()> { + let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; + + services() + .rooms + .event_handler + .handle_incoming_pdu(origin, &event_id, &room_id, value, false, &pub_key_map) + .await?; + + let value = self.get_pdu_json(&event_id)?.expect("We just created it"); + + let shortroomid = services() + .rooms + .short + .get_shortroomid(&room_id)? + .expect("room exists"); + + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + + let count = services().globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&count.to_be_bytes()); + + // Insert pdu + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + + drop(insert_lock); + + info!("Appended incoming pdu"); + Ok(()) + } }