diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 0ffd9659..41ab79f1 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -169,7 +169,7 @@ impl Service { .await?; // Procure the room version - let room_version_id = Self::get_room_version_id(&create_event)?; + let room_version_id = get_room_version_id(&create_event)?; let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; @@ -178,7 +178,7 @@ impl Service { .boxed() .await?; - Self::check_room_id(room_id, &incoming_pdu)?; + check_room_id(room_id, &incoming_pdu)?; // 8. if not timeline event: stop if !is_timeline_event { @@ -341,7 +341,7 @@ impl Service { // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; + let room_version_id = get_room_version_id(create_event)?; let mut val = match self .services .server_keys @@ -378,7 +378,7 @@ impl Service { ) .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - Self::check_room_id(room_id, &incoming_pdu)?; + check_room_id(room_id, &incoming_pdu)?; if !auth_events_known { // 4. fetch any missing auth events doing all checks listed here starting at 1. @@ -414,7 +414,7 @@ impl Service { continue; }; - Self::check_room_id(room_id, &auth_event)?; + check_room_id(room_id, &auth_event)?; match auth_events.entry(( auth_event.kind.to_string().into(), @@ -454,7 +454,7 @@ impl Service { }; let auth_check = state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), + &to_room_version(&room_version_id), &incoming_pdu, None, // TODO: third party invite state_fetch, @@ -502,8 +502,8 @@ impl Service { } debug!("Upgrading to timeline pdu"); - let timer = tokio::time::Instant::now(); - let room_version_id = Self::get_room_version_id(create_event)?; + let timer = Instant::now(); + let room_version_id = get_room_version_id(create_event)?; // 10. Fetch missing state and auth chain events by calling /state_ids at // backwards extremities doing all the checks in this list starting at 1. @@ -524,7 +524,7 @@ impl Service { } let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); - let room_version = Self::to_room_version(&room_version_id); + let room_version = to_room_version(&room_version_id); debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event @@ -1278,7 +1278,7 @@ impl Service { .await .pop() { - Self::check_room_id(room_id, &pdu)?; + check_room_id(room_id, &pdu)?; let limit = self.services.globals.max_fetch_prev_events(); if amount > limit { @@ -1370,31 +1370,34 @@ impl Service { } } - fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { - if pdu.room_id != room_id { - return Err!(Request(InvalidParam( - warn!(pdu_event_id = ?pdu.event_id, pdu_room_id = ?pdu.room_id, ?room_id, "Found event from room in room") - ))); - } - - Ok(()) - } - - fn get_room_version_id(create_event: &PduEvent) -> Result { - let content: RoomCreateEventContent = create_event.get_content()?; - let room_version = content.room_version; - - Ok(room_version) - } - - #[inline] - fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { - RoomVersion::new(room_version_id).expect("room version is supported") - } - async fn event_exists(&self, event_id: Arc) -> bool { self.services.timeline.pdu_exists(&event_id).await } async fn event_fetch(&self, event_id: Arc) -> Option> { self.services.timeline.get_pdu(&event_id).await.ok() } } + +fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result { + if pdu.room_id != room_id { + return Err!(Request(InvalidParam(error!( + pdu_event_id = ?pdu.event_id, + pdu_room_id = ?pdu.room_id, + ?room_id, + "Found event from room in room", + )))); + } + + Ok(()) +} + +fn get_room_version_id(create_event: &PduEvent) -> Result { + let content: RoomCreateEventContent = create_event.get_content()?; + let room_version = content.room_version; + + Ok(room_version) +} + +#[inline] +fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { + RoomVersion::new(room_version_id).expect("room version is supported") +} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index c51b7856..5428a3b9 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -8,11 +8,11 @@ use conduit::{ err, expected, result::{LogErr, NotFound}, utils, - utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + utils::{future::TryExtExt, stream::TryIgnore, u64_from_u8, ReadyExt}, Err, PduCount, PduEvent, Result, }; use database::{Database, Deserialized, Json, KeyVal, Map}; -use futures::{FutureExt, Stream, StreamExt}; +use futures::{Stream, StreamExt}; use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; @@ -115,12 +115,10 @@ impl Data { /// Like get_non_outlier_pdu(), but without the expense of fetching and /// parsing the PduEvent - pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result { let pduid = self.get_pdu_id(event_id).await?; - self.pduid_pdu.get(&pduid).await?; - - Ok(()) + self.pduid_pdu.get(&pduid).await.map(|_| ()) } /// Returns the pdu. @@ -140,16 +138,14 @@ impl Data { /// Like get_non_outlier_pdu(), but without the expense of fetching and /// parsing the PduEvent - pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> { - self.eventid_outlierpdu.get(event_id).await?; - - Ok(()) + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result { + self.eventid_outlierpdu.get(event_id).await.map(|_| ()) } /// Like get_pdu(), but without the expense of fetching and parsing the data pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { - let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok()); - let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok()); + let non_outlier = self.non_outlier_pdu_exists(event_id).is_ok(); + let outlier = self.outlier_pdu_exists(event_id).is_ok(); //TODO: parallelize non_outlier.await || outlier.await @@ -169,7 +165,6 @@ impl Data { pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { self.pduid_pdu.raw_put(pdu_id, Json(json)); - self.lasttimelinecount_cache .lock() .await @@ -181,21 +176,17 @@ impl Data { pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.raw_put(pdu_id, Json(json)); - self.eventid_pduid.insert(event_id, pdu_id); self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. - pub(super) async fn replace_pdu( - &self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, - ) -> Result<()> { + pub(super) async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result { if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); } - let pdu = serde_json::to_vec(pdu_json)?; - self.pduid_pdu.insert(pdu_id, &pdu); + self.pduid_pdu.raw_put(pdu_id, Json(pdu_json)); Ok(()) }