From 010e4ee35a372018e56e9572a927cf06394c4291 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 18 Jul 2024 06:37:47 +0000 Subject: [PATCH] de-global services for services Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 9 +- src/admin/mod.rs | 1 + src/api/client/alias.rs | 12 +- src/api/client/membership.rs | 17 +- src/api/client/sync.rs | 6 +- src/api/mod.rs | 2 + src/api/server/backfill.rs | 3 +- src/api/server/event.rs | 3 +- src/api/server/event_auth.rs | 3 +- src/api/server/get_missing_events.rs | 3 +- src/api/server/invite.rs | 6 +- src/api/server/send.rs | 3 +- src/api/server/send_join.rs | 8 +- src/api/server/state.rs | 9 +- src/router/mod.rs | 2 + src/service/account_data/data.rs | 17 +- src/service/account_data/mod.rs | 2 +- src/service/admin/console.rs | 14 +- src/service/admin/grant.rs | 19 +- src/service/admin/mod.rs | 63 +- src/service/appservice/mod.rs | 14 +- src/service/client/mod.rs | 2 +- src/service/emergency/mod.rs | 83 ++ src/service/globals/data.rs | 41 +- src/service/globals/emerg_access.rs | 54 -- src/service/globals/migrations.rs | 274 ++++--- src/service/globals/mod.rs | 15 +- src/service/key_backups/data.rs | 23 +- src/service/key_backups/mod.rs | 2 +- src/service/manager.rs | 20 +- src/service/media/data.rs | 4 +- src/service/media/mod.rs | 22 +- src/service/mod.rs | 4 +- src/service/presence/data.rs | 24 +- src/service/presence/mod.rs | 38 +- src/service/pusher/mod.rs | 46 +- src/service/resolver/actual.rs | 111 ++- src/service/resolver/cache.rs | 3 +- src/service/resolver/mod.rs | 15 +- src/service/rooms/alias/data.rs | 17 +- src/service/rooms/alias/mod.rs | 86 +- src/service/rooms/alias/remote.rs | 114 +-- src/service/rooms/auth_chain/data.rs | 9 +- src/service/rooms/auth_chain/mod.rs | 31 +- src/service/rooms/directory/mod.rs | 4 +- src/service/rooms/event_handler/mod.rs | 229 +++--- .../rooms/event_handler/parse_incoming_pdu.rs | 41 +- .../rooms/event_handler/signing_keys.rs | 57 +- src/service/rooms/lazy_loading/mod.rs | 4 +- src/service/rooms/metadata/data.rs | 17 +- src/service/rooms/metadata/mod.rs | 5 +- src/service/rooms/mod.rs | 4 +- src/service/rooms/pdu_metadata/data.rs | 21 +- src/service/rooms/pdu_metadata/mod.rs | 30 +- src/service/rooms/read_receipt/data.rs | 19 +- src/service/rooms/read_receipt/mod.rs | 14 +- src/service/rooms/search/data.rs | 19 +- src/service/rooms/search/mod.rs | 2 +- src/service/rooms/short/data.rs | 25 +- src/service/rooms/short/mod.rs | 5 +- src/service/rooms/spaces/mod.rs | 381 ++++----- src/service/rooms/state/mod.rs | 110 ++- src/service/rooms/state_accessor/data.rs | 60 +- src/service/rooms/state_accessor/mod.rs | 32 +- src/service/rooms/state_cache/data.rs | 73 +- src/service/rooms/state_cache/mod.rs | 55 +- src/service/rooms/state_compressor/mod.rs | 25 +- src/service/rooms/threads/data.rs | 27 +- src/service/rooms/threads/mod.rs | 39 +- src/service/rooms/timeline/data.rs | 99 +-- src/service/rooms/timeline/mod.rs | 303 +++---- src/service/rooms/typing/mod.rs | 44 +- src/service/rooms/user/data.rs | 27 +- src/service/rooms/user/mod.rs | 5 +- src/service/sending/appservice.rs | 22 +- src/service/sending/data.rs | 17 +- src/service/sending/mod.rs | 65 +- src/service/sending/send.rs | 158 ++-- src/service/sending/sender.rs | 745 +++++++++--------- src/service/service.rs | 66 +- src/service/services.rs | 118 +-- src/service/uiaa/mod.rs | 23 +- src/service/updates/mod.rs | 17 +- src/service/users/data.rs | 83 +- src/service/users/mod.rs | 23 +- 85 files changed, 2480 insertions(+), 1887 deletions(-) create mode 100644 src/service/emergency/mod.rs delete mode 100644 src/service/globals/emerg_access.rs diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index f0aa23cb..cbe52473 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -15,7 +15,7 @@ use ruma::{ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent}; +use service::services; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -189,7 +189,10 @@ pub(super) async fn get_remote_pdu( debug!("Attempting to parse PDU: {:?}", &response.pdu); let parsed_pdu = { - let parsed_result = parse_incoming_pdu(&response.pdu); + let parsed_result = services() + .rooms + .event_handler + .parse_incoming_pdu(&response.pdu); let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -510,7 +513,7 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match parse_incoming_pdu(&pdu) { + events.push(match services().rooms.event_handler.parse_incoming_pdu(&pdu) { Ok(t) => t, Err(e) => { warn!("Could not parse PDU, ignoring: {e}"); diff --git a/src/admin/mod.rs b/src/admin/mod.rs index e020ed43..c57659c1 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "168"] #![allow(clippy::wildcard_imports)] pub(crate) mod appservice; diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 88d1a4e6..11617a0e 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -22,7 +22,11 @@ pub(crate) async fn create_alias_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; + services + .rooms + .alias + .appservice_checks(&body.room_alias, &body.appservice_info) + .await?; // this isn't apart of alias_checks or delete alias route because we should // allow removing forbidden room aliases @@ -61,7 +65,11 @@ pub(crate) async fn delete_alias_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; + services + .rooms + .alias + .appservice_checks(&body.room_alias, &body.appservice_info) + .await?; if services .rooms diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 9fde99a4..e3630c72 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -43,7 +43,6 @@ use crate::{ service::{ pdu::{gen_event_id_canonical_json, PduBuilder}, rooms::state::RoomMutexGuard, - sending::convert_to_outgoing_federation_event, server_is_ours, user_is_local, Services, }, Ruma, @@ -791,7 +790,9 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: convert_to_outgoing_federation_event(join_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -1203,7 +1204,9 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: convert_to_outgoing_federation_event(join_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -1431,7 +1434,9 @@ pub(crate) async fn invite_helper( room_id: room_id.to_owned(), event_id: (*pdu.event_id).to_owned(), room_version: room_version_id.clone(), - event: convert_to_outgoing_federation_event(pdu_json.clone()), + event: services + .sending + .convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, via: services.rooms.state_cache.servers_route_via(room_id).ok(), }, @@ -1763,7 +1768,9 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room federation::membership::create_leave_event::v2::Request { room_id: room_id.to_owned(), event_id, - pdu: convert_to_outgoing_federation_event(leave_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(leave_event.clone()), }, ) .await?; diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 5739052c..6eeb8fff 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -475,8 +475,6 @@ async fn handle_left_room( async fn process_presence_updates( services: &Services, presence_updates: &mut HashMap, since: u64, syncing_user: &UserId, ) -> Result<()> { - use crate::service::presence::Presence; - // Take presence updates for (user_id, _, presence_bytes) in services.presence.presence_since(since) { if !services @@ -487,7 +485,9 @@ async fn process_presence_updates( continue; } - let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; + let presence_event = services + .presence + .from_json_bytes_to_event(&presence_bytes, &user_id)?; match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); diff --git a/src/api/mod.rs b/src/api/mod.rs index 79382934..0d80e581 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "160"] + pub mod client; pub mod router; pub mod server; diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 8dd38cad..1b665c19 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -4,7 +4,6 @@ use ruma::{ api::{client::error::ErrorKind, federation::backfill::get_backfill}, uint, user_id, MilliSecondsSinceUnixEpoch, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -67,7 +66,7 @@ pub(crate) async fn get_backfill_route( }) .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) .filter_map(|r| r.ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(); Ok(get_backfill::v1::Response { diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e8e08c81..e11a01a2 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -4,7 +4,6 @@ use ruma::{ api::{client::error::ErrorKind, federation::event::get_event}, MilliSecondsSinceUnixEpoch, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -50,6 +49,6 @@ pub(crate) async fn get_event_route( Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: convert_to_outgoing_federation_event(event), + pdu: services.sending.convert_to_outgoing_federation_event(event), }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 8d26b73a..4b0f6bc0 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -6,7 +6,6 @@ use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -60,7 +59,7 @@ pub(crate) async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 378cd4fe..e2c3c93c 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -4,7 +4,6 @@ use ruma::{ api::{client::error::ErrorKind, federation::event::get_missing_events}, OwnedEventId, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -82,7 +81,7 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(convert_to_outgoing_federation_event(pdu)); + events.push(services.sending.convert_to_outgoing_federation_event(pdu)); } i = i.saturating_add(1); } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 982a2a01..17e21920 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -7,7 +7,7 @@ use ruma::{ serde::JsonObject, CanonicalJsonValue, EventId, OwnedUserId, }; -use service::{sending::convert_to_outgoing_federation_event, server_is_ours}; +use service::server_is_ours; use crate::Ruma; @@ -174,6 +174,8 @@ pub(crate) async fn create_invite_route( } Ok(create_invite::v2::Response { - event: convert_to_outgoing_federation_event(signed_event), + event: services + .sending + .convert_to_outgoing_federation_event(signed_event), }) } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index f2934480..2f698d33 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -21,7 +21,6 @@ use ruma::{ use tokio::sync::RwLock; use crate::{ - service::rooms::event_handler::parse_incoming_pdu, services::Services, utils::{self}, Error, Result, Ruma, @@ -89,7 +88,7 @@ async fn handle_pdus( ) -> Result { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { - parsed_pdus.push(match parse_incoming_pdu(pdu) { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index b72bfa03..7f79a1d9 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -13,9 +13,7 @@ use ruma::{ CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::{ - pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services, -}; +use service::{pdu::gen_event_id_canonical_json, user_is_local, Services}; use tokio::sync::RwLock; use tracing::warn; @@ -186,12 +184,12 @@ async fn create_join_event( Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), state: state_ids .iter() .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), // Event field is required if the room version supports restricted join rules. event: Some( diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 24a11cca..d215236a 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use axum::extract::State; use conduit::{Error, Result}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -44,7 +43,11 @@ pub(crate) async fn get_room_state_route( .state_full_ids(shortstatehash) .await? .into_values() - .map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())) + .map(|id| { + services + .sending + .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + }) .collect(); let auth_chain_ids = services @@ -61,7 +64,7 @@ pub(crate) async fn get_room_state_route( .timeline .get_pdu_json(&id) .ok()? - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) }) .collect(), pdus, diff --git a/src/router/mod.rs b/src/router/mod.rs index e9bae3c5..03c70f6d 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "160"] + mod layers; mod request; mod router; diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 7b3a3dee..439603be 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use conduit::{utils, warn, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, @@ -9,18 +9,27 @@ use ruma::{ RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { roomuserdataid_accountdata: Arc, roomusertype_roomuserdataid: Arc, + services: Services, +} + +struct Services { + globals: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), + services: Services { + globals: args.depend::("globals"), + }, } } @@ -40,7 +49,7 @@ impl Data { prefix.push(0xFF); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); roomuserdataid.push(0xFF); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 69d2f799..c569889e 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -17,7 +17,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 9c335791..c9a288d9 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -11,10 +11,11 @@ use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use termimad::MadSkin; use tokio::task::JoinHandle; -use crate::services; +use crate::{admin, Dep}; pub struct Console { server: Arc, + admin: Dep, worker_join: Mutex>>, input_abort: Mutex>, command_abort: Mutex>, @@ -29,6 +30,7 @@ impl Console { pub(super) fn new(args: &crate::Args<'_>) -> Arc { Arc::new(Self { server: args.server.clone(), + admin: args.depend::("admin"), worker_join: None.into(), input_abort: None.into(), command_abort: None.into(), @@ -116,7 +118,8 @@ impl Console { let _suppression = log::Suppress::new(&self.server); let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?; - readline.set_tab_completer(Self::tab_complete); + let self_ = Arc::clone(self); + readline.set_tab_completer(move |line| self_.tab_complete(line)); self.set_history(&mut readline); let future = readline.readline(); @@ -154,7 +157,7 @@ impl Console { } async fn process(self: Arc, line: String) { - match services().admin.command_in_place(line, None).await { + match self.admin.command_in_place(line, None).await { Ok(Some(content)) => self.output(content).await, Err(e) => error!("processing command: {e}"), _ => (), @@ -184,9 +187,8 @@ impl Console { history.truncate(HISTORY_LIMIT); } - fn tab_complete(line: &str) -> String { - services() - .admin + fn tab_complete(&self, line: &str) -> String { + self.admin .complete_command(line) .unwrap_or_else(|| line.to_owned()) } diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 213225ac..c35f8c42 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -25,13 +25,14 @@ impl super::Service { return Ok(()); }; - let state_lock = self.state.mutex.lock(&room_id).await; + let state_lock = self.services.state.mutex.lock(&room_id).await; // Use the server user to grant the new admin's power level - let server_user = &self.globals.server_user; + let server_user = &self.services.globals.server_user; // Invite and join the real user - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -55,7 +56,8 @@ impl super::Service { &state_lock, ) .await?; - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -85,7 +87,8 @@ impl super::Service { users.insert(server_user.clone(), 100.into()); users.insert(user_id.to_owned(), 100.into()); - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, @@ -105,12 +108,12 @@ impl super::Service { .await?; // Send welcome message - self.timeline.build_and_append_pdu( + self.services.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.globals.server_name()), - format!("

Thank you for trying out conduwuit!

\n

conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Git and Documentation: https://github.com/girlbossceo/conduwuit
Report issues: https://github.com/girlbossceo/conduwuit/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

conduwuit room (Ask questions and get notified on updates):
/join #conduwuit:puppygock.gay

\n", self.globals.server_name()), + format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.services.globals.server_name()), + format!("

Thank you for trying out conduwuit!

\n

conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Git and Documentation: https://github.com/girlbossceo/conduwuit
Report issues: https://github.com/girlbossceo/conduwuit/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

conduwuit room (Ask questions and get notified on updates):
/join #conduwuit:puppygock.gay

\n", self.services.globals.server_name()), )) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index fcb34212..8b9473a2 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -22,15 +22,10 @@ use ruma::{ use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; -use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local}; +use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local, Dep}; pub struct Service { - server: Arc, - globals: Arc, - alias: Arc, - timeline: Arc, - state: Arc, - state_cache: Arc, + services: Services, sender: Sender, receiver: Mutex>, pub handle: RwLock>, @@ -39,6 +34,15 @@ pub struct Service { pub console: Arc, } +struct Services { + server: Arc, + globals: Dep, + alias: Dep, + timeline: Dep, + state: Dep, + state_cache: Dep, +} + #[derive(Debug)] pub struct Command { pub command: String, @@ -58,12 +62,14 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); Ok(Arc::new(Self { - server: args.server.clone(), - globals: args.require_service::("globals"), - alias: args.require_service::("rooms::alias"), - timeline: args.require_service::("rooms::timeline"), - state: args.require_service::("rooms::state"), - state_cache: args.require_service::("rooms::state_cache"), + services: Services { + server: args.server.clone(), + globals: args.depend::("globals"), + alias: args.depend::("rooms::alias"), + timeline: args.depend::("rooms::timeline"), + state: args.depend::("rooms::state"), + state_cache: args.depend::("rooms::state_cache"), + }, sender, receiver: Mutex::new(receiver), handle: RwLock::new(None), @@ -75,7 +81,7 @@ impl crate::Service for Service { async fn worker(self: Arc) -> Result<()> { let receiver = self.receiver.lock().await; - let mut signals = self.server.signal.subscribe(); + let mut signals = self.services.server.signal.subscribe(); loop { tokio::select! { command = receiver.recv_async() => match command { @@ -116,7 +122,7 @@ impl Service { pub async fn send_message(&self, message_content: RoomMessageEventContent) { if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.globals.server_user; + let user_id = &self.services.globals.server_user; self.respond_to_room(message_content, &room_id, user_id) .await; } @@ -176,7 +182,7 @@ impl Service { /// Checks whether a given user is an admin of this server pub async fn user_is_admin(&self, user_id: &UserId) -> Result { if let Ok(Some(admin_room)) = self.get_admin_room() { - self.state_cache.is_joined(user_id, &admin_room) + self.services.state_cache.is_joined(user_id, &admin_room) } else { Ok(false) } @@ -187,10 +193,15 @@ impl Service { /// Errors are propagated from the database, and will have None if there is /// no admin room pub fn get_admin_room(&self) -> Result> { - if let Some(room_id) = self.alias.resolve_local_alias(&self.globals.admin_alias)? { + if let Some(room_id) = self + .services + .alias + .resolve_local_alias(&self.services.globals.admin_alias)? + { if self + .services .state_cache - .is_joined(&self.globals.server_user, &room_id)? + .is_joined(&self.services.globals.server_user, &room_id)? { return Ok(Some(room_id)); } @@ -207,12 +218,12 @@ impl Service { return; }; - let Ok(Some(pdu)) = self.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { return; }; let response_sender = if self.is_admin_room(&pdu.room_id) { - &self.globals.server_user + &self.services.globals.server_user } else { &pdu.sender }; @@ -229,7 +240,7 @@ impl Service { "sender is not admin" ); - let state_lock = self.state.mutex.lock(room_id).await; + let state_lock = self.services.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -239,6 +250,7 @@ impl Service { }; if let Err(e) = self + .services .timeline .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .await @@ -266,7 +278,8 @@ impl Service { redacts: None, }; - self.timeline + self.services + .timeline .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) .await?; @@ -279,7 +292,7 @@ impl Service { let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); // Admin command with public echo (in admin room) - let server_user = &self.globals.server_user; + let server_user = &self.services.globals.server_user; let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str()); // Expected backward branch @@ -293,7 +306,7 @@ impl Service { } // Check if server-side command-escape is disabled by configuration - if is_public_escape && !self.globals.config.admin_escape_commands { + if is_public_escape && !self.services.globals.config.admin_escape_commands { return false; } @@ -309,7 +322,7 @@ impl Service { // This will evaluate to false if the emergency password is set up so that // the administrator can execute commands as conduit - let emergency_password_set = self.globals.emergency_password().is_some(); + let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; if from_server && self.is_admin_room(&pdu.room_id) { return false; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 24c9b8b0..c0752d56 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -12,7 +12,7 @@ use ruma::{ }; use tokio::sync::RwLock; -use crate::services; +use crate::{sending, Dep}; /// Compiled regular expressions for a namespace #[derive(Clone, Debug)] @@ -118,9 +118,14 @@ impl TryFrom for RegistrationInfo { pub struct Service { pub db: Data, + services: Services, registration_info: RwLock>, } +struct Services { + sending: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let mut registration_info = BTreeMap::new(); @@ -138,6 +143,9 @@ impl crate::Service for Service { Ok(Arc::new(Self { db, + services: Services { + sending: args.depend::("sending"), + }, registration_info: RwLock::new(registration_info), })) } @@ -178,7 +186,9 @@ impl Service { // deletes all active requests for the appservice if there are any so we stop // sending to the URL - services().sending.cleanup_events(service_name.to_owned())?; + self.services + .sending + .cleanup_events(service_name.to_owned())?; Ok(()) } diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index 03b0a142..386bd33c 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -18,7 +18,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; - let resolver = args.require_service::("resolver"); + let resolver = args.require::("resolver"); Ok(Arc::new(Self { default: base(config) diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs new file mode 100644 index 00000000..1bb0843d --- /dev/null +++ b/src/service/emergency/mod.rs @@ -0,0 +1,83 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use conduit::{error, warn, Result}; +use ruma::{ + events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType}, + push::Ruleset, +}; + +use crate::{account_data, globals, users, Dep}; + +pub struct Service { + services: Services, +} + +struct Services { + account_data: Dep, + globals: Dep, + users: Dep, +} + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + services: Services { + account_data: args.depend::("account_data"), + globals: args.depend::("globals"), + users: args.depend::("users"), + }, + })) + } + + async fn worker(self: Arc) -> Result<()> { + self.set_emergency_access() + .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; + + Ok(()) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + /// Sets the emergency password and push rules for the @conduit account in + /// case emergency password is set + fn set_emergency_access(&self) -> Result { + let conduit_user = &self.services.globals.server_user; + + self.services + .users + .set_password(conduit_user, self.services.globals.emergency_password().as_deref())?; + + let (ruleset, pwd_set) = match self.services.globals.emergency_password() { + Some(_) => (Ruleset::server_default(conduit_user), true), + None => (Ruleset::new(), false), + }; + + self.services.account_data.update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + )?; + + if pwd_set { + warn!( + "The server account emergency password is set! Please unset it as soon as you finish admin account \ + recovery! You will be logged out of the server service account when you finish." + ); + } else { + // logs out any users still in the server service account and removes sessions + self.services.users.deactivate_account(conduit_user)?; + } + + Ok(pwd_set) + } +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 281c2a94..5d6240cd 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; -use conduit::{trace, utils, Error, Result}; +use conduit::{trace, utils, Error, Result, Server}; use database::{Database, Map}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -12,7 +12,7 @@ use ruma::{ DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; -use crate::services; +use crate::{rooms, Dep}; pub struct Data { global: Arc, @@ -28,14 +28,23 @@ pub struct Data { server_signingkeys: Arc, readreceiptid_readreceipt: Arc, userid_lastonetimekeyupdate: Arc, - pub(super) db: Arc, counter: RwLock, + pub(super) db: Arc, + services: Services, +} + +struct Services { + server: Arc, + short: Dep, + state_cache: Dep, + typing: Dep, } const COUNTER: &[u8] = b"c"; impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { global: db["global"].clone(), todeviceid_events: db["todeviceid_events"].clone(), @@ -50,8 +59,14 @@ impl Data { server_signingkeys: db["server_signingkeys"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - db: db.clone(), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), + db: args.db.clone(), + services: Services { + server: args.server.clone(), + short: args.depend::("rooms::short"), + state_cache: args.depend::("rooms::state_cache"), + typing: args.depend::("rooms::typing"), + }, } } @@ -118,14 +133,14 @@ impl Data { futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in services() - .rooms + for room_id in self + .services .state_cache .rooms_joined(user_id) .filter_map(Result::ok) { - let short_roomid = services() - .rooms + let short_roomid = self + .services .short .get_shortroomid(&room_id) .ok() @@ -143,7 +158,7 @@ impl Data { // EDUs futures.push(Box::pin(async move { - let _result = services().rooms.typing.wait_for_update(&room_id).await; + let _result = self.services.typing.wait_for_update(&room_id).await; })); futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); @@ -176,12 +191,12 @@ impl Data { futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(Box::pin(async move { - while services().server.running() { - let _result = services().server.signal.subscribe().recv().await; + while self.services.server.running() { + let _result = self.services.server.signal.subscribe().recv().await; } })); - if !services().server.running() { + if !self.services.server.running() { return Ok(()); } diff --git a/src/service/globals/emerg_access.rs b/src/service/globals/emerg_access.rs deleted file mode 100644 index 50c5f8c3..00000000 --- a/src/service/globals/emerg_access.rs +++ /dev/null @@ -1,54 +0,0 @@ -use conduit::Result; -use ruma::{ - events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType}, - push::Ruleset, -}; -use tracing::{error, warn}; - -use crate::services; - -/// Set emergency access for the conduit user -pub(crate) fn init_emergency_access() { - if let Err(e) = set_emergency_access() { - error!("Could not set the configured emergency password for the conduit user: {e}"); - } -} - -/// Sets the emergency password and push rules for the @conduit account in case -/// emergency password is set -fn set_emergency_access() -> Result { - let conduit_user = &services().globals.server_user; - - services() - .users - .set_password(conduit_user, services().globals.emergency_password().as_deref())?; - - let (ruleset, pwd_set) = match services().globals.emergency_password() { - Some(_) => (Ruleset::server_default(conduit_user), true), - None => (Ruleset::new(), false), - }; - - services().account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; - - if pwd_set { - warn!( - "The server account emergency password is set! Please unset it as soon as you finish admin account \ - recovery! You will be logged out of the server service account when you finish." - ); - } else { - // logs out any users still in the server service account and removes sessions - services().users.deactivate_account(conduit_user)?; - } - - Ok(pwd_set) -} diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index d8c5f29b..2fe22b0e 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -10,7 +10,6 @@ use std::{ }; use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Config, Error, Result}; -use database::Database; use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, @@ -18,7 +17,7 @@ use ruma::{ EventId, OwnedRoomId, RoomId, UserId, }; -use crate::services; +use crate::Services; /// The current schema version. /// - If database is opened at greater version we reject with error. The @@ -28,13 +27,13 @@ use crate::services; /// equal or lesser version. These are expected to be backward-compatible. const DATABASE_VERSION: u64 = 13; -pub(crate) async fn migrations(db: &Arc, config: &Config) -> Result<()> { +pub(crate) async fn migrations(services: &Services) -> Result<()> { // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if services().users.count()? > 0 { - let conduit_user = &services().globals.server_user; + if services.users.count()? > 0 { + let conduit_user = &services.globals.server_user; - if !services().users.exists(conduit_user)? { + if !services.users.exists(conduit_user)? { error!("The {} server user does not exist, and the database is not new.", conduit_user); return Err(Error::bad_database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", @@ -42,15 +41,18 @@ pub(crate) async fn migrations(db: &Arc, config: &Config) -> Result<() } } - if services().users.count()? > 0 { - migrate(db, config).await + if services.users.count()? > 0 { + migrate(services).await } else { - fresh(db, config).await + fresh(services).await } } -async fn fresh(db: &Arc, config: &Config) -> Result<()> { - services() +async fn fresh(services: &Services) -> Result<()> { + let db = &services.db; + let config = &services.server.config; + + services .globals .db .bump_database_version(DATABASE_VERSION)?; @@ -70,97 +72,100 @@ async fn fresh(db: &Arc, config: &Config) -> Result<()> { } /// Apply any migrations -async fn migrate(db: &Arc, config: &Config) -> Result<()> { - if services().globals.db.database_version()? < 1 { - db_lt_1(db, config).await?; +async fn migrate(services: &Services) -> Result<()> { + let db = &services.db; + let config = &services.server.config; + + if services.globals.db.database_version()? < 1 { + db_lt_1(services).await?; } - if services().globals.db.database_version()? < 2 { - db_lt_2(db, config).await?; + if services.globals.db.database_version()? < 2 { + db_lt_2(services).await?; } - if services().globals.db.database_version()? < 3 { - db_lt_3(db, config).await?; + if services.globals.db.database_version()? < 3 { + db_lt_3(services).await?; } - if services().globals.db.database_version()? < 4 { - db_lt_4(db, config).await?; + if services.globals.db.database_version()? < 4 { + db_lt_4(services).await?; } - if services().globals.db.database_version()? < 5 { - db_lt_5(db, config).await?; + if services.globals.db.database_version()? < 5 { + db_lt_5(services).await?; } - if services().globals.db.database_version()? < 6 { - db_lt_6(db, config).await?; + if services.globals.db.database_version()? < 6 { + db_lt_6(services).await?; } - if services().globals.db.database_version()? < 7 { - db_lt_7(db, config).await?; + if services.globals.db.database_version()? < 7 { + db_lt_7(services).await?; } - if services().globals.db.database_version()? < 8 { - db_lt_8(db, config).await?; + if services.globals.db.database_version()? < 8 { + db_lt_8(services).await?; } - if services().globals.db.database_version()? < 9 { - db_lt_9(db, config).await?; + if services.globals.db.database_version()? < 9 { + db_lt_9(services).await?; } - if services().globals.db.database_version()? < 10 { - db_lt_10(db, config).await?; + if services.globals.db.database_version()? < 10 { + db_lt_10(services).await?; } - if services().globals.db.database_version()? < 11 { - db_lt_11(db, config).await?; + if services.globals.db.database_version()? < 11 { + db_lt_11(services).await?; } - if services().globals.db.database_version()? < 12 { - db_lt_12(db, config).await?; + if services.globals.db.database_version()? < 12 { + db_lt_12(services).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services().globals.db.database_version()? < 13 { - db_lt_13(db, config).await?; + if services.globals.db.database_version()? < 13 { + db_lt_13(services).await?; } if db["global"].get(b"feat_sha256_media")?.is_none() { - migrate_sha256_media(db, config).await?; + migrate_sha256_media(services).await?; } else if config.media_startup_check { - checkup_sha256_media(db, config).await?; + checkup_sha256_media(services).await?; } if db["global"] .get(b"fix_bad_double_separator_in_state_cache")? .is_none() { - fix_bad_double_separator_in_state_cache(db, config).await?; + fix_bad_double_separator_in_state_cache(services).await?; } if db["global"] .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? .is_none() { - retroactively_fix_bad_data_from_roomuserid_joined(db, config).await?; + retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } assert_eq!( - services().globals.db.database_version().unwrap(), + services.globals.db.database_version().unwrap(), DATABASE_VERSION, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services().globals.db.database_version().unwrap(), + services.globals.db.database_version().unwrap(), DATABASE_VERSION, ); { - let patterns = services().globals.forbidden_usernames(); + let patterns = services.globals.forbidden_usernames(); if !patterns.is_empty() { - for user_id in services() + for user_id in services .users .iter() .filter_map(Result::ok) - .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)) + .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) .filter(|user| user.server_name() == config.server_name) { let matches = patterns.matches(user_id.localpart()); @@ -179,11 +184,11 @@ async fn migrate(db: &Arc, config: &Config) -> Result<()> { } { - let patterns = services().globals.forbidden_alias_names(); + let patterns = services.globals.forbidden_alias_names(); if !patterns.is_empty() { - for address in services().rooms.metadata.iter_ids() { + for address in services.rooms.metadata.iter_ids() { let room_id = address?; - let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id); + let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); for room_alias_result in room_aliases { let room_alias = room_alias_result?; let matches = patterns.matches(room_alias.alias()); @@ -211,7 +216,9 @@ async fn migrate(db: &Arc, config: &Config) -> Result<()> { Ok(()) } -async fn db_lt_1(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_1(services: &Services) -> Result<()> { + let db = &services.db; + let roomserverids = &db["roomserverids"]; let serverroomids = &db["serverroomids"]; for (roomserverid, _) in roomserverids.iter() { @@ -228,12 +235,14 @@ async fn db_lt_1(db: &Arc, _config: &Config) -> Result<()> { serverroomids.insert(&serverroomid, &[])?; } - services().globals.db.bump_database_version(1)?; + services.globals.db.bump_database_version(1)?; info!("Migration: 0 -> 1 finished"); Ok(()) } -async fn db_lt_2(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_2(services: &Services) -> Result<()> { + let db = &services.db; + // We accidentally inserted hashed versions of "" into the db instead of just "" let userid_password = &db["roomserverids"]; for (userid, password) in userid_password.iter() { @@ -245,12 +254,14 @@ async fn db_lt_2(db: &Arc, _config: &Config) -> Result<()> { } } - services().globals.db.bump_database_version(2)?; + services.globals.db.bump_database_version(2)?; info!("Migration: 1 -> 2 finished"); Ok(()) } -async fn db_lt_3(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_3(services: &Services) -> Result<()> { + let db = &services.db; + // Move media to filesystem let mediaid_file = &db["mediaid_file"]; for (key, content) in mediaid_file.iter() { @@ -259,41 +270,45 @@ async fn db_lt_3(db: &Arc, _config: &Config) -> Result<()> { } #[allow(deprecated)] - let path = services().media.get_media_file(&key); + let path = services.media.get_media_file(&key); let mut file = fs::File::create(path)?; file.write_all(&content)?; mediaid_file.insert(&key, &[])?; } - services().globals.db.bump_database_version(3)?; + services.globals.db.bump_database_version(3)?; info!("Migration: 2 -> 3 finished"); Ok(()) } -async fn db_lt_4(_db: &Arc, config: &Config) -> Result<()> { - // Add federated users to services() as deactivated - for our_user in services().users.iter() { +async fn db_lt_4(services: &Services) -> Result<()> { + let config = &services.server.config; + + // Add federated users to services as deactivated + for our_user in services.users.iter() { let our_user = our_user?; - if services().users.is_deactivated(&our_user)? { + if services.users.is_deactivated(&our_user)? { continue; } - for room in services().rooms.state_cache.rooms_joined(&our_user) { - for user in services().rooms.state_cache.room_members(&room?) { + for room in services.rooms.state_cache.rooms_joined(&our_user) { + for user in services.rooms.state_cache.room_members(&room?) { let user = user?; if user.server_name() != config.server_name { info!(?user, "Migration: creating user"); - services().users.create(&user, None)?; + services.users.create(&user, None)?; } } } } - services().globals.db.bump_database_version(4)?; + services.globals.db.bump_database_version(4)?; info!("Migration: 3 -> 4 finished"); Ok(()) } -async fn db_lt_5(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_5(services: &Services) -> Result<()> { + let db = &services.db; + // Upgrade user data store let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; @@ -312,26 +327,30 @@ async fn db_lt_5(db: &Arc, _config: &Config) -> Result<()> { roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; } - services().globals.db.bump_database_version(5)?; + services.globals.db.bump_database_version(5)?; info!("Migration: 4 -> 5 finished"); Ok(()) } -async fn db_lt_6(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_6(services: &Services) -> Result<()> { + let db = &services.db; + // Set room member count let roomid_shortstatehash = &db["roomid_shortstatehash"]; for (roomid, _) in roomid_shortstatehash.iter() { let string = utils::string_from_bytes(&roomid).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services().rooms.state_cache.update_joined_count(room_id)?; + services.rooms.state_cache.update_joined_count(room_id)?; } - services().globals.db.bump_database_version(6)?; + services.globals.db.bump_database_version(6)?; info!("Migration: 5 -> 6 finished"); Ok(()) } -async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_7(services: &Services) -> Result<()> { + let db = &services.db; + // Upgrade state store let mut last_roomstates: HashMap = HashMap::new(); let mut current_sstatehash: Option = None; @@ -347,7 +366,7 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { let states_parents = last_roomsstatehash.map_or_else( || Ok(Vec::new()), |&last_roomsstatehash| { - services() + services .rooms .state_compressor .load_shortstatehash_info(last_roomsstatehash) @@ -371,7 +390,7 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { (current_state, HashSet::new()) }; - services().rooms.state_compressor.save_state_from_diff( + services.rooms.state_compressor.save_state_from_diff( current_sstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), @@ -380,7 +399,7 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { )?; /* - let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; + let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; let state = tmp.pop().unwrap(); println!( "{}\t{}{:?}: {:?} + {:?} - {:?}", @@ -425,12 +444,7 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); let string = utils::string_from_bytes(&event_id).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services() - .rooms - .timeline - .get_pdu(event_id) - .unwrap() - .unwrap(); + let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); if Some(&pdu.room_id) != current_room.as_ref() { current_room = Some(pdu.room_id.clone()); @@ -451,12 +465,14 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(7)?; + services.globals.db.bump_database_version(7)?; info!("Migration: 6 -> 7 finished"); Ok(()) } -async fn db_lt_8(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_8(services: &Services) -> Result<()> { + let db = &services.db; + let roomid_shortstatehash = &db["roomid_shortstatehash"]; let roomid_shortroomid = &db["roomid_shortroomid"]; let pduid_pdu = &db["pduid_pdu"]; @@ -464,7 +480,7 @@ async fn db_lt_8(db: &Arc, _config: &Config) -> Result<()> { // Generate short room ids for all rooms for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); + let shortroomid = services.globals.next_count()?.to_be_bytes(); roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } @@ -517,12 +533,14 @@ async fn db_lt_8(db: &Arc, _config: &Config) -> Result<()> { eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - services().globals.db.bump_database_version(8)?; + services.globals.db.bump_database_version(8)?; info!("Migration: 7 -> 8 finished"); Ok(()) } -async fn db_lt_9(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_9(services: &Services) -> Result<()> { + let db = &services.db; + let tokenids = &db["tokenids"]; let roomid_shortroomid = &db["roomid_shortroomid"]; @@ -574,12 +592,14 @@ async fn db_lt_9(db: &Arc, _config: &Config) -> Result<()> { tokenids.remove(&key)?; } - services().globals.db.bump_database_version(9)?; + services.globals.db.bump_database_version(9)?; info!("Migration: 8 -> 9 finished"); Ok(()) } -async fn db_lt_10(db: &Arc, _config: &Config) -> Result<()> { +async fn db_lt_10(services: &Services) -> Result<()> { + let db = &services.db; + let statekey_shortstatekey = &db["statekey_shortstatekey"]; let shortstatekey_statekey = &db["shortstatekey_statekey"]; @@ -589,28 +609,30 @@ async fn db_lt_10(db: &Arc, _config: &Config) -> Result<()> { } // Force E2EE device list updates so we can send them over federation - for user_id in services().users.iter().filter_map(Result::ok) { - services().users.mark_device_key_update(&user_id)?; + for user_id in services.users.iter().filter_map(Result::ok) { + services.users.mark_device_key_update(&user_id)?; } - services().globals.db.bump_database_version(10)?; + services.globals.db.bump_database_version(10)?; info!("Migration: 9 -> 10 finished"); Ok(()) } #[allow(unreachable_code)] -async fn db_lt_11(_db: &Arc, _config: &Config) -> Result<()> { - todo!("Dropping a column to clear data is not implemented yet."); +async fn db_lt_11(services: &Services) -> Result<()> { + error!("Dropping a column to clear data is not implemented yet."); //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; //userdevicesessionid_uiaarequest.clear()?; - services().globals.db.bump_database_version(11)?; + services.globals.db.bump_database_version(11)?; info!("Migration: 10 -> 11 finished"); Ok(()) } -async fn db_lt_12(_db: &Arc, config: &Config) -> Result<()> { - for username in services().users.list_local_users()? { +async fn db_lt_12(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in services.users.list_local_users()? { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { Ok(u) => u, Err(e) => { @@ -619,7 +641,7 @@ async fn db_lt_12(_db: &Arc, config: &Config) -> Result<()> { }, }; - let raw_rules_list = services() + let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .unwrap() @@ -664,7 +686,7 @@ async fn db_lt_12(_db: &Arc, config: &Config) -> Result<()> { } } - services().account_data.update( + services.account_data.update( None, &user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -672,13 +694,15 @@ async fn db_lt_12(_db: &Arc, config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(12)?; + services.globals.db.bump_database_version(12)?; info!("Migration: 11 -> 12 finished"); Ok(()) } -async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { - for username in services().users.list_local_users()? { +async fn db_lt_13(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in services.users.list_local_users()? { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { Ok(u) => u, Err(e) => { @@ -687,7 +711,7 @@ async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { }, }; - let raw_rules_list = services() + let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .unwrap() @@ -701,7 +725,7 @@ async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { .global .update_with_server_default(user_default_rules); - services().account_data.update( + services.account_data.update( None, &user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -709,7 +733,7 @@ async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(13)?; + services.globals.db.bump_database_version(13)?; info!("Migration: 12 -> 13 finished"); Ok(()) } @@ -717,15 +741,17 @@ async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { /// Migrates a media directory from legacy base64 file names to sha2 file names. /// All errors are fatal. Upon success the database is keyed to not perform this /// again. -async fn migrate_sha256_media(db: &Arc, _config: &Config) -> Result<()> { +async fn migrate_sha256_media(services: &Services) -> Result<()> { + let db = &services.db; + warn!("Migrating legacy base64 file names to sha256 file names"); let mediaid_file = &db["mediaid_file"]; // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); for (key, _) in mediaid_file.iter() { - let old = services().media.get_media_file_b64(&key); - let new = services().media.get_media_file_sha256(&key); + let old = services.media.get_media_file_b64(&key); + let new = services.media.get_media_file_sha256(&key); debug!(?key, ?old, ?new, num = changes.len(), "change"); changes.push((old, new)); } @@ -739,8 +765,8 @@ async fn migrate_sha256_media(db: &Arc, _config: &Config) -> Result<() // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services().globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { - services().globals.db.bump_database_version(13)?; + if services.globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { + services.globals.db.bump_database_version(13)?; } db["global"].insert(b"feat_sha256_media", &[])?; @@ -752,14 +778,16 @@ async fn migrate_sha256_media(db: &Arc, _config: &Config) -> Result<() /// - Going back and forth to non-sha256 legacy binaries (e.g. upstream). /// - Deletion of artifacts in the media directory which will then fall out of /// sync with the database. -async fn checkup_sha256_media(db: &Arc, config: &Config) -> Result<()> { +async fn checkup_sha256_media(services: &Services) -> Result<()> { use crate::media::encode_key; debug!("Checking integrity of media directory"); + let db = &services.db; + let media = &services.media; + let config = &services.server.config; let mediaid_file = &db["mediaid_file"]; let mediaid_user = &db["mediaid_user"]; let dbs = (mediaid_file, mediaid_user); - let media = &services().media; let timer = Instant::now(); let dir = media.get_media_dir(); @@ -791,6 +819,7 @@ async fn handle_media_check( new_path: &OsStr, old_path: &OsStr, ) -> Result<()> { use crate::media::encode_key; + let (mediaid_file, mediaid_user) = dbs; let old_exists = files.contains(old_path); @@ -827,8 +856,10 @@ async fn handle_media_check( Ok(()) } -async fn fix_bad_double_separator_in_state_cache(db: &Arc, _config: &Config) -> Result<()> { +async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> { warn!("Fixing bad double separator in state_cache roomuserid_joined"); + + let db = &services.db; let roomuserid_joined = &db["roomuserid_joined"]; let _cork = db.cork_and_sync(); @@ -864,11 +895,13 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc, _config: &C Ok(()) } -async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _config: &Config) -> Result<()> { +async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> { warn!("Retroactively fixing bad data from broken roomuserid_joined"); + + let db = &services.db; let _cork = db.cork_and_sync(); - let room_ids = services() + let room_ids = services .rooms .metadata .iter_ids() @@ -878,7 +911,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ for room_id in room_ids.clone() { debug_info!("Fixing room {room_id}"); - let users_in_room = services() + let users_in_room = services .rooms .state_cache .room_members(&room_id) @@ -888,7 +921,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ let joined_members = users_in_room .iter() .filter(|user_id| { - services() + services .rooms .state_accessor .get_member(&room_id, user_id) @@ -900,7 +933,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ let non_joined_members = users_in_room .iter() .filter(|user_id| { - services() + services .rooms .state_accessor .get_member(&room_id, user_id) @@ -913,7 +946,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ for user_id in joined_members { debug_info!("User is joined, marking as joined"); - services() + services .rooms .state_cache .mark_as_joined(user_id, &room_id)?; @@ -921,10 +954,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ for user_id in non_joined_members { debug_info!("User is left or banned, marking as left"); - services() - .rooms - .state_cache - .mark_as_left(user_id, &room_id)?; + services.rooms.state_cache.mark_as_left(user_id, &room_id)?; } } @@ -933,7 +963,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc, _ "Updating joined count for room {room_id} to fix servers in room after correcting membership states" ); - services().rooms.state_cache.update_joined_count(&room_id)?; + services.rooms.state_cache.update_joined_count(&room_id)?; } db.db.cleanup()?; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 811aff3a..eab156ee 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,5 +1,4 @@ mod data; -mod emerg_access; pub(super) mod migrations; use std::{ @@ -9,7 +8,6 @@ use std::{ time::Instant, }; -use async_trait::async_trait; use conduit::{error, trace, Config, Result}; use data::Data; use ipaddress::IPAddress; @@ -43,11 +41,10 @@ pub struct Service { type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries -#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { + let db = Data::new(&args); let config = &args.server.config; - let db = Data::new(args.db); let keypair = db.load_keypair(); let keypair = match keypair { @@ -104,19 +101,13 @@ impl crate::Service for Service { .supported_room_versions() .contains(&s.config.default_room_version) { - error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); - s.config.default_room_version = crate::config::default_default_room_version(); + error!(config=?s.config.default_room_version, fallback=?conduit::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); + s.config.default_room_version = conduit::config::default_default_room_version(); }; Ok(Arc::new(s)) } - async fn worker(self: Arc) -> Result<()> { - emerg_access::init_emergency_access(); - - Ok(()) - } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { let bad_event_ratelimiter = self .bad_event_ratelimiter diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index f1794895..30ac593b 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, sync::Arc}; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -11,25 +11,34 @@ use ruma::{ OwnedRoomId, RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { backupid_algorithm: Arc, backupid_etag: Arc, backupkeyid_backup: Arc, + services: Services, +} + +struct Services { + globals: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { backupid_algorithm: db["backupid_algorithm"].clone(), backupid_etag: db["backupid_etag"].clone(), backupkeyid_backup: db["backupkeyid_backup"].clone(), + services: Services { + globals: args.depend::("globals"), + }, } } pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { - let version = services().globals.next_count()?.to_string(); + let version = self.services.globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); @@ -40,7 +49,7 @@ impl Data { &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; Ok(version) } @@ -75,7 +84,7 @@ impl Data { self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; Ok(version.to_owned()) } @@ -152,7 +161,7 @@ impl Data { } self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index d83d4497..65d3c065 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -17,7 +17,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/manager.rs b/src/service/manager.rs index 447cd6fe..087fd3fa 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -8,13 +8,13 @@ use tokio::{ time::sleep, }; -use crate::{service::Service, Services}; +use crate::{service, service::Service, Services}; pub(crate) struct Manager { manager: Mutex>>>, workers: Mutex, server: Arc, - services: &'static Services, + service: Arc, } type Workers = JoinSet; @@ -29,7 +29,7 @@ impl Manager { manager: Mutex::new(None), workers: Mutex::new(JoinSet::new()), server: services.server.clone(), - services: crate::services(), + service: services.service.clone(), }) } @@ -53,9 +53,19 @@ impl Manager { .spawn(async move { self_.worker().await }), ); + // we can't hold the lock during the iteration with start_worker so the values + // are snapshotted here + let services: Vec> = self + .service + .read() + .expect("locked for reading") + .values() + .map(|v| v.0.clone()) + .collect(); + debug!("Starting service workers..."); - for (service, ..) in self.services.service.values() { - self.start_worker(&mut workers, service).await?; + for service in services { + self.start_worker(&mut workers, &service).await?; } Ok(()) diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5856bbf..617ec526 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use conduit::{debug, debug_info, Error, Result}; +use conduit::{debug, debug_info, utils::string_from_bytes, Error, Result}; use database::{Database, Map}; use ruma::api::client::error::ErrorKind; -use crate::{media::UrlPreviewData, utils::string_from_bytes}; +use crate::media::UrlPreviewData; pub(crate) struct Data { mediaid_file: Arc, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 4a7d38a4..d5d518dc 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -15,7 +15,7 @@ use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufReader}, }; -use crate::services; +use crate::{globals, Dep}; #[derive(Debug)] pub struct FileMeta { @@ -41,16 +41,24 @@ pub struct UrlPreviewData { } pub struct Service { - server: Arc, + services: Services, pub(crate) db: Data, pub url_preview_mutex: MutexMap, } +struct Services { + server: Arc, + globals: Dep, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - server: args.server.clone(), + services: Services { + server: args.server.clone(), + globals: args.depend::("globals"), + }, db: Data::new(args.db), url_preview_mutex: MutexMap::new(), })) @@ -164,7 +172,7 @@ impl Service { debug!("Parsed MXC key to URL: {mxc_s}"); let mxc = OwnedMxcUri::from(mxc_s); - if mxc.server_name() == Ok(services().globals.server_name()) { + if mxc.server_name() == Ok(self.services.globals.server_name()) { debug!("Ignoring local media MXC: {mxc}"); // ignore our own MXC URLs as this would be local media. continue; @@ -246,7 +254,7 @@ impl Service { let legacy_rm = fs::remove_file(&legacy); let (file_rm, legacy_rm) = tokio::join!(file_rm, legacy_rm); if let Err(e) = legacy_rm { - if self.server.config.media_compat_file_link { + if self.services.server.config.media_compat_file_link { debug_error!(?key, ?legacy, "Failed to remove legacy media symlink: {e}"); } } @@ -259,7 +267,7 @@ impl Service { debug!(?key, ?path, "Creating media file"); let file = fs::File::create(&path).await?; - if self.server.config.media_compat_file_link { + if self.services.server.config.media_compat_file_link { let legacy = self.get_media_file_b64(key); if let Err(e) = fs::symlink(&path, &legacy).await { debug_error!( @@ -304,7 +312,7 @@ impl Service { #[must_use] pub fn get_media_dir(&self) -> PathBuf { let mut r = PathBuf::new(); - r.push(self.server.config.database_path.clone()); + r.push(self.services.server.config.database_path.clone()); r.push("media"); r } diff --git a/src/service/mod.rs b/src/service/mod.rs index 21d1f594..6e749c99 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "160"] #![allow(refining_impl_trait)] mod manager; @@ -8,6 +9,7 @@ pub mod account_data; pub mod admin; pub mod appservice; pub mod client; +pub mod emergency; pub mod globals; pub mod key_backups; pub mod media; @@ -26,8 +28,8 @@ extern crate conduit_database as database; use std::sync::{Arc, RwLock}; -pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; +use conduit::{Result, Server}; use database::Database; pub(crate) use service::{Args, Dep, Service}; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index dac8becf..53f9d8c7 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,21 +1,32 @@ use std::sync::Arc; use conduit::{debug_warn, utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; -use crate::{presence::Presence, services}; +use crate::{globals, presence::Presence, users, Dep}; pub struct Data { presenceid_presence: Arc, userid_presenceid: Arc, + services: Services, +} + +struct Services { + globals: Dep, + users: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { presenceid_presence: db["presenceid_presence"].clone(), userid_presenceid: db["userid_presenceid"].clone(), + services: Services { + globals: args.depend::("globals"), + users: args.depend::("users"), + }, } } @@ -28,7 +39,10 @@ impl Data { self.presenceid_presence .get(&key)? .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?)) + Ok(( + count, + Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, + )) }) .transpose() } else { @@ -80,7 +94,7 @@ impl Data { last_active_ts, status_msg, ); - let count = services().globals.next_count()?; + let count = self.services.globals.next_count()?; let key = presenceid_key(count, user_id); self.presenceid_presence diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 254304ba..705ac4ff 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -3,8 +3,7 @@ mod data; use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, utils, Error, Result}; -use data::Data; +use conduit::{checked, debug, error, utils, Error, Result, Server}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -14,7 +13,8 @@ use ruma::{ use serde::{Deserialize, Serialize}; use tokio::{sync::Mutex, time::sleep}; -use crate::{services, user_is_local}; +use self::data::Data; +use crate::{user_is_local, users, Dep}; /// Represents data required to be kept in order to implement the presence /// specification. @@ -37,11 +37,6 @@ impl Presence { } } - pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result { - let presence = Self::from_json_bytes(bytes)?; - presence.to_presence_event(user_id) - } - pub fn from_json_bytes(bytes: &[u8]) -> Result { serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) } @@ -51,7 +46,7 @@ impl Presence { } /// Creates a PresenceEvent from available data. - pub fn to_presence_event(&self, user_id: &UserId) -> Result { + pub fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> Result { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -66,14 +61,15 @@ impl Presence { status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: services().users.displayname(user_id)?, - avatar_url: services().users.avatar_url(user_id)?, + displayname: users.displayname(user_id)?, + avatar_url: users.avatar_url(user_id)?, }, }) } } pub struct Service { + services: Services, pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex>, @@ -82,6 +78,11 @@ pub struct Service { offline_timeout: u64, } +struct Services { + server: Arc, + users: Dep, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -90,7 +91,11 @@ impl crate::Service for Service { let offline_timeout_s = config.presence_offline_timeout_s; let (timer_sender, timer_receiver) = loole::unbounded(); Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + server: args.server.clone(), + users: args.depend::("users"), + }, + db: Data::new(&args), timer_sender, timer_receiver: Mutex::new(timer_receiver), timeout_remote_users: config.presence_timeout_remote_users, @@ -182,8 +187,8 @@ impl Service { if self.timeout_remote_users || user_is_local(user_id) { let timeout = match presence_state { - PresenceState::Online => services().globals.config.presence_idle_timeout_s, - _ => services().globals.config.presence_offline_timeout_s, + PresenceState::Online => self.services.server.config.presence_idle_timeout_s, + _ => self.services.server.config.presence_offline_timeout_s, }; self.timer_sender @@ -210,6 +215,11 @@ impl Service { self.db.presence_since(since) } + pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result { + let presence = Presence::from_json_bytes(bytes)?; + presence.to_presence_event(user_id, &self.services.users) + } + fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 38ea5b9a..873f0f49 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -3,8 +3,7 @@ mod data; use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_info, info, trace, warn, Error, Result}; -use data::Data; +use conduit::{debug_info, info, trace, utils::string_from_bytes, warn, Error, PduEvent, Result}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -23,15 +22,32 @@ use ruma::{ uint, RoomId, UInt, UserId, }; -use crate::{services, PduEvent}; +use self::data::Data; +use crate::{client, globals, rooms, users, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + globals: Dep, + client: Dep, + state_accessor: Dep, + state_cache: Dep, + users: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::("globals"), + client: args.depend::("client"), + state_accessor: args.depend::("rooms::state_accessor"), + state_cache: args.depend::("rooms::state_cache"), + users: args.depend::("users"), + }, db: Data::new(args.db), })) } @@ -62,7 +78,7 @@ impl Service { { const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; - let dest = dest.replace(services().globals.notification_push_path(), ""); + let dest = dest.replace(self.services.globals.notification_push_path(), ""); trace!("Push gateway destination: {dest}"); let http_request = request @@ -78,13 +94,13 @@ impl Service { if let Some(url_host) = reqwest_request.url().host_str() { trace!("Checking request URL for IP"); if let Ok(ip) = IPAddress::parse(url_host) { - if !services().globals.valid_cidr_range(&ip) { + if !self.services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } } - let response = services().client.pusher.execute(reqwest_request).await; + let response = self.services.client.pusher.execute(reqwest_request).await; match response { Ok(mut response) => { @@ -93,7 +109,7 @@ impl Service { trace!("Checking response destination's IP"); if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - if !services().globals.valid_cidr_range(&ip) { + if !self.services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } @@ -114,7 +130,7 @@ impl Service { if !status.is_success() { info!("Push gateway {dest} returned unsuccessful HTTP response ({status})"); - debug_info!("Push gateway response body: {:?}", crate::utils::string_from_bytes(&body)); + debug_info!("Push gateway response body: {:?}", string_from_bytes(&body)); return Err(Error::BadServerResponse("Push gateway returned unsuccessful response")); } @@ -143,8 +159,8 @@ impl Service { let mut notify = None; let mut tweaks = Vec::new(); - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -195,15 +211,15 @@ impl Service { let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), member_count: UInt::try_from( - services() - .rooms + self.services .state_cache .room_joined_count(room_id)? .unwrap_or(1), ) .unwrap_or_else(|_| uint!(0)), user_id: user.to_owned(), - user_display_name: services() + user_display_name: self + .services .users .displayname(user)? .unwrap_or_else(|| user.localpart().to_owned()), @@ -263,9 +279,9 @@ impl Service { notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = services().users.displayname(&event.sender)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender)?; - notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; + notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index b2a00023..4d5c132f 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -9,12 +9,9 @@ use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; use ruma::ServerName; -use crate::{ - resolver::{ - cache::{CachedDest, CachedOverride}, - fed::{add_port_to_hostname, get_ip_with_port, FedDest}, - }, - services, +use crate::resolver::{ + cache::{CachedDest, CachedOverride}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest}, }; #[derive(Clone, Debug)] @@ -40,7 +37,7 @@ impl super::Service { result } else { cached = false; - validate_dest(server_name)?; + self.validate_dest(server_name)?; self.resolve_actual_dest(server_name, true).await? }; @@ -188,7 +185,8 @@ impl super::Service { self.query_and_cache_override(dest, dest, 8448).await?; } - let response = services() + let response = self + .services .client .well_known .get(&format!("https://{dest}/.well-known/matrix/server")) @@ -245,19 +243,14 @@ impl super::Service { #[tracing::instrument(skip_all, name = "ip")] async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { - match services() - .resolver - .raw() - .lookup_ip(hostname.to_owned()) - .await - { - Err(e) => handle_resolve_error(&e), + match self.raw().lookup_ip(hostname.to_owned()).await { + Err(e) => Self::handle_resolve_error(&e), Ok(override_ip) => { if hostname != overname { debug_info!("{overname:?} overriden by {hostname:?}"); } - services().resolver.set_cached_override( + self.set_cached_override( overname.to_owned(), CachedOverride { ips: override_ip.iter().collect(), @@ -295,62 +288,62 @@ impl super::Service { for hostname in hostnames { match lookup_srv(self.raw(), &hostname).await { Ok(result) => return Ok(handle_successful_srv(&result)), - Err(e) => handle_resolve_error(&e)?, + Err(e) => Self::handle_resolve_error(&e)?, } } Ok(None) } -} -#[allow(clippy::single_match_else)] -fn handle_resolve_error(e: &ResolveError) -> Result<()> { - use hickory_resolver::error::ResolveErrorKind; + #[allow(clippy::single_match_else)] + fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; - match *e.kind() { - ResolveErrorKind::NoRecordsFound { - .. - } => { - // Raise to debug_warn if we can find out the result wasn't from cache - debug!("{e}"); - Ok(()) - }, - _ => Err!(error!("DNS {e}")), - } -} - -fn validate_dest(dest: &ServerName) -> Result<()> { - if dest == services().globals.server_name() { - return Err!("Won't send federation request to ourselves"); + match *e.kind() { + ResolveErrorKind::NoRecordsFound { + .. + } => { + // Raise to debug_warn if we can find out the result wasn't from cache + debug!("{e}"); + Ok(()) + }, + _ => Err!(error!("DNS {e}")), + } } - if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { - validate_dest_ip_literal(dest)?; + fn validate_dest(&self, dest: &ServerName) -> Result<()> { + if dest == self.services.server.config.server_name { + return Err!("Won't send federation request to ourselves"); + } + + if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { + self.validate_dest_ip_literal(dest)?; + } + + Ok(()) } - Ok(()) -} + fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result<()> { + trace!("Destination is an IP literal, checking against IP range denylist.",); + debug_assert!( + dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), + "Destination is not an IP literal." + ); + let ip = IPAddress::parse(dest.host()).map_err(|e| { + debug_error!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; -fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { - trace!("Destination is an IP literal, checking against IP range denylist.",); - debug_assert!( - dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), - "Destination is not an IP literal." - ); - let ip = IPAddress::parse(dest.host()).map_err(|e| { - debug_error!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; + self.validate_ip(&ip)?; - validate_ip(&ip)?; - - Ok(()) -} - -pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { - if !services().globals.valid_cidr_range(ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + Ok(()) } - Ok(()) + pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> { + if !self.services.globals.valid_cidr_range(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + + Ok(()) + } } diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 0fba2400..465b5985 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -5,11 +5,10 @@ use std::{ time::SystemTime, }; -use conduit::trace; +use conduit::{trace, utils::rand}; use ruma::{OwnedServerName, ServerName}; use super::fed::FedDest; -use crate::utils::rand; pub struct Cache { pub destinations: RwLock, // actual_destination, host diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 48ff8813..457ea9cc 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -6,14 +6,22 @@ mod tests; use std::{fmt::Write, sync::Arc}; -use conduit::Result; +use conduit::{Result, Server}; use hickory_resolver::TokioAsyncResolver; use self::{cache::Cache, dns::Resolver}; +use crate::{client, globals, Dep}; pub struct Service { pub cache: Arc, pub resolver: Arc, + services: Services, +} + +struct Services { + server: Arc, + client: Dep, + globals: Dep, } impl crate::Service for Service { @@ -23,6 +31,11 @@ impl crate::Service for Service { Ok(Arc::new(Self { cache: cache.clone(), resolver: Resolver::build(args.server, cache)?, + services: Services { + server: args.server.clone(), + client: args.depend::("client"), + globals: args.depend::("globals"), + }, })) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 302c21ae..efd2b5b7 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,23 +1,32 @@ use std::sync::Arc; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { alias_userid: Arc, alias_roomid: Arc, aliasid_alias: Arc, + services: Services, +} + +struct Services { + globals: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { alias_userid: db["alias_userid"].clone(), alias_roomid: db["alias_roomid"].clone(), aliasid_alias: db["aliasid_alias"].clone(), + services: Services { + globals: args.depend::("globals"), + }, } } @@ -31,7 +40,7 @@ impl Data { let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xFF); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(()) diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 792f5c98..344ab6d2 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -4,9 +4,8 @@ mod remote; use std::sync::Arc; use conduit::{err, Error, Result}; -use data::Data; use ruma::{ - api::{appservice, client::error::ErrorKind}, + api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, @@ -14,16 +13,33 @@ use ruma::{ OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use crate::{appservice::RegistrationInfo, server_is_ours, services}; +use self::data::Data; +use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, server_is_ours, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Services { + admin: Dep, + appservice: Dep, + globals: Dep, + sending: Dep, + state_accessor: Dep, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), + services: Services { + admin: args.depend::("admin"), + appservice: args.depend::("appservice"), + globals: args.depend::("globals"), + sending: args.depend::("sending"), + state_accessor: args.depend::("rooms::state_accessor"), + }, })) } @@ -33,7 +49,7 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - if alias == services().globals.admin_alias && user_id != services().globals.server_user { + if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", @@ -72,10 +88,10 @@ impl Service { if !server_is_ours(room_alias.server_name()) && (!servers .as_ref() - .is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned())) + .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned())) || servers.as_ref().is_none()) { - return remote::resolve(room_alias, servers).await; + return self.remote_resolve(room_alias, servers).await; } let room_id: Option = match self.resolve_local_alias(room_alias)? { @@ -111,7 +127,7 @@ impl Service { return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); }; - let server_user = &services().globals.server_user; + let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self @@ -119,7 +135,7 @@ impl Service { .who_created_alias(alias)? .is_some_and(|user| user == user_id) // Server admins can remove any local alias - || services().admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await? // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { @@ -127,8 +143,7 @@ impl Service { // Checking whether the user is able to change canonical aliases of the // room } else if let Some(event) = - services() - .rooms + self.services .state_accessor .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? { @@ -140,8 +155,7 @@ impl Service { // If there is no power levels event, only the room creator can change // canonical aliases } else if let Some(event) = - services() - .rooms + self.services .state_accessor .room_state_get(&room_id, &StateEventType::RoomCreate, "")? { @@ -152,14 +166,16 @@ impl Service { } async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result> { - for appservice in services().appservice.read().await.values() { + use ruma::api::appservice::query::query_room_alias; + + for appservice in self.services.appservice.read().await.values() { if appservice.aliases.is_match(room_alias.as_str()) && matches!( - services() + self.services .sending .send_appservice_request( appservice.registration.clone(), - appservice::query::query_room_alias::v1::Request { + query_room_alias::v1::Request { room_alias: room_alias.to_owned(), }, ) @@ -167,10 +183,7 @@ impl Service { Ok(Some(_opt_result)) ) { return Ok(Some( - services() - .rooms - .alias - .resolve_local_alias(room_alias)? + self.resolve_local_alias(room_alias)? .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, )); } @@ -178,20 +191,27 @@ impl Service { Ok(None) } -} -pub async fn appservice_checks(room_alias: &RoomAliasId, appservice_info: &Option) -> Result<()> { - if !server_is_ours(room_alias.server_name()) { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); - } - - if let Some(ref info) = appservice_info { - if !info.aliases.is_match(room_alias.as_str()) { - return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); + pub async fn appservice_checks( + &self, room_alias: &RoomAliasId, appservice_info: &Option, + ) -> Result<()> { + if !server_is_ours(room_alias.server_name()) { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } - } else if services().appservice.is_exclusive_alias(room_alias).await { - return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice.")); - } - Ok(()) + if let Some(ref info) = appservice_info { + if !info.aliases.is_match(room_alias.as_str()) { + return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); + } + } else if self + .services + .appservice + .is_exclusive_alias(room_alias) + .await + { + return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice.")); + } + + Ok(()) + } } diff --git a/src/service/rooms/alias/remote.rs b/src/service/rooms/alias/remote.rs index 7fcd27f5..5d835240 100644 --- a/src/service/rooms/alias/remote.rs +++ b/src/service/rooms/alias/remote.rs @@ -1,71 +1,75 @@ -use conduit::{debug, debug_info, debug_warn, Error, Result}; +use conduit::{debug, debug_warn, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation}, OwnedRoomId, OwnedServerName, RoomAliasId, }; -use crate::services; +impl super::Service { + pub(super) async fn remote_resolve( + &self, room_alias: &RoomAliasId, servers: Option<&Vec>, + ) -> Result<(OwnedRoomId, Option>)> { + debug!(?room_alias, ?servers, "resolve"); -pub(super) async fn resolve( - room_alias: &RoomAliasId, servers: Option<&Vec>, -) -> Result<(OwnedRoomId, Option>)> { - debug!(?room_alias, ?servers, "resolve"); + let mut response = self + .services + .sending + .send_federation_request( + room_alias.server_name(), + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await; - let mut response = services() - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; + debug!("room alias server_name get_alias_helper response: {response:?}"); - debug!("room alias server_name get_alias_helper response: {response:?}"); + if let Err(ref e) = response { + debug_warn!( + "Server {} of the original room alias failed to assist in resolving room alias: {e}", + room_alias.server_name(), + ); + } - if let Err(ref e) = response { - debug_info!( - "Server {} of the original room alias failed to assist in resolving room alias: {e}", - room_alias.server_name() - ); - } + if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { + if let Some(servers) = servers { + for server in servers { + response = self + .services + .sending + .send_federation_request( + server, + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await; + debug!("Got response from server {server} for room aliases: {response:?}"); - if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { - if let Some(servers) = servers { - for server in servers { - response = services() - .sending - .send_federation_request( - server, - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; - debug!("Got response from server {server} for room aliases: {response:?}"); - - if let Ok(ref response) = response { - if !response.servers.is_empty() { - break; + if let Ok(ref response) = response { + if !response.servers.is_empty() { + break; + } + debug_warn!( + "Server {server} responded with room aliases, but was empty? Response: {response:?}" + ); } - debug_warn!("Server {server} responded with room aliases, but was empty? Response: {response:?}"); } } } + + if let Ok(response) = response { + let room_id = response.room_id; + + let mut pre_servers = response.servers; + // since the room alis server responded, insert it into the list + pre_servers.push(room_alias.server_name().into()); + + return Ok((room_id, Some(pre_servers))); + } + + Err(Error::BadRequest( + ErrorKind::NotFound, + "No servers could assist in resolving the room alias", + )) } - - if let Ok(response) = response { - let room_id = response.room_id; - - let mut pre_servers = response.servers; - // since the room alis server responded, insert it into the list - pre_servers.push(room_alias.server_name().into()); - - return Ok((room_id, Some(pre_servers))); - } - - Err(Error::BadRequest( - ErrorKind::NotFound, - "No servers could assist in resolving the room alias", - )) } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 4e468234..6e7c7835 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,8 +3,8 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result, Server}; -use database::{Database, Map}; +use conduit::{utils, utils::math::usize_from_f64, Result}; +use database::Map; use lru_cache::LruCache; pub(super) struct Data { @@ -13,8 +13,9 @@ pub(super) struct Data { } impl Data { - pub(super) fn new(server: &Arc, db: &Arc) -> Self { - let config = &server.config; + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; + let config = &args.server.config; let cache_size = f64::from(config.auth_chain_cache_capacity); let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size"); Self { diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 4e8c7bb2..9a1e7e67 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,19 +6,29 @@ use std::{ }; use conduit::{debug, error, trace, validated, warn, Err, Result}; -use data::Data; use ruma::{EventId, RoomId}; -use crate::services; +use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + short: Dep, + timeline: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.server, args.db), + services: Services { + short: args.depend::("rooms::short"), + timeline: args.depend::("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -27,7 +37,7 @@ impl crate::Service for Service { impl Service { pub async fn event_ids_iter<'a>( - &self, room_id: &RoomId, starting_events_: Vec>, + &'a self, room_id: &RoomId, starting_events_: Vec>, ) -> Result> + 'a> { let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); for starting_event in &starting_events_ { @@ -38,7 +48,7 @@ impl Service { .get_auth_chain(room_id, &starting_events) .await? .into_iter() - .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -48,8 +58,8 @@ impl Service { let started = std::time::Instant::now(); let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, &short) in services() - .rooms + for (i, &short) in self + .services .short .multi_get_or_create_shorteventid(starting_events)? .iter() @@ -140,7 +150,7 @@ impl Service { while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match services().rooms.timeline.get_pdu(&event_id) { + match self.services.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( @@ -150,10 +160,7 @@ impl Service { ))); } for auth_event in &pdu.auth_events { - let sauthevent = services() - .rooms - .short - .get_or_create_shorteventid(auth_event)?; + let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 23ec6b6b..706e6c2e 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -2,10 +2,10 @@ mod data; use std::sync::Arc; -use data::Data; +use conduit::Result; use ruma::{OwnedRoomId, RoomId}; -use crate::Result; +use self::data::Data; pub struct Service { db: Data, diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 6cb23b9f..fd8e2185 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -10,12 +10,11 @@ use std::{ }; use conduit::{ - debug, debug_error, debug_info, err, error, info, trace, + debug, debug_error, debug_info, err, error, info, pdu, trace, utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, Result, + warn, Error, PduEvent, Result, }; use futures_util::Future; -pub use parse_incoming_pdu::parse_incoming_pdu; use ruma::{ api::{ client::error::ErrorKind, @@ -36,13 +35,28 @@ use ruma::{ use tokio::sync::RwLock; use super::state_compressor::CompressedStateEvent; -use crate::{pdu, services, PduEvent}; +use crate::{globals, rooms, sending, Dep}; pub struct Service { + services: Services, pub federation_handletime: StdRwLock, pub mutex_federation: RoomMutexMap, } +struct Services { + globals: Dep, + sending: Dep, + auth_chain: Dep, + metadata: Dep, + outlier: Dep, + pdu_metadata: Dep, + short: Dep, + state: Dep, + state_accessor: Dep, + state_compressor: Dep, + timeline: Dep, +} + type RoomMutexMap = MutexMap; type HandleTimeMap = HashMap; @@ -55,8 +69,21 @@ type AsyncRecursiveCanonicalJsonResult<'a> = AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result> { + fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::("globals"), + sending: args.depend::("sending"), + auth_chain: args.depend::("rooms::auth_chain"), + metadata: args.depend::("rooms::metadata"), + outlier: args.depend::("rooms::outlier"), + pdu_metadata: args.depend::("rooms::pdu_metadata"), + short: args.depend::("rooms::short"), + state: args.depend::("rooms::state"), + state_accessor: args.depend::("rooms::state_accessor"), + state_compressor: args.depend::("rooms::state_compressor"), + timeline: args.depend::("rooms::timeline"), + }, federation_handletime: HandleTimeMap::new().into(), mutex_federation: RoomMutexMap::new(), })) @@ -114,17 +141,17 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> Result>> { // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { + if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { return Ok(Some(pdu_id.to_vec())); } // 1.1 Check the server is in the room - if !services().rooms.metadata.exists(room_id)? { + if !self.services.metadata.exists(room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } // 1.2 Check if the room is disabled - if services().rooms.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "Federation of this room is currently disabled on this server.", @@ -147,8 +174,8 @@ impl Service { self.acl_check(sender.server_name(), room_id)?; // Fetch create event - let create_event = services() - .rooms + let create_event = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; @@ -156,8 +183,8 @@ impl Service { // Procure the room version let room_version_id = Self::get_room_version_id(&create_event)?; - let first_pdu_in_room = services() - .rooms + let first_pdu_in_room = self + .services .timeline .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; @@ -208,7 +235,8 @@ impl Service { Ok(()) => continue, Err(e) => { warn!("Prev event {} failed: {}", prev_id, e); - match services() + match self + .services .globals .bad_event_ratelimiter .write() @@ -258,7 +286,7 @@ impl Service { create_event: &Arc, first_pdu_in_room: &Arc, prev_id: &EventId, ) -> Result<()> { // Check for disabled again because it might have changed - if services().rooms.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id)? { debug!( "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ event ID {event_id}" @@ -269,7 +297,8 @@ impl Service { )); } - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -349,7 +378,7 @@ impl Service { }; // Skip the PDU if it is redacted and we already have it as an outlier event - if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { + if self.services.timeline.get_pdu_json(event_id)?.is_some() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Event was redacted and we already knew about it", @@ -401,7 +430,7 @@ impl Service { // Build map of auth events let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); for id in &incoming_pdu.auth_events { - let Some(auth_event) = services().rooms.timeline.get_pdu(id)? else { + let Some(auth_event) = self.services.timeline.get_pdu(id)? else { warn!("Could not find auth event {}", id); continue; }; @@ -454,8 +483,7 @@ impl Service { trace!("Validation successful."); // 7. Persist the event as an outlier. - services() - .rooms + self.services .outlier .add_pdu_outlier(&incoming_pdu.event_id, &val)?; @@ -470,12 +498,12 @@ impl Service { origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { return Ok(Some(pduid.to_vec())); } - if services() - .rooms + if self + .services .pdu_metadata .is_event_soft_failed(&incoming_pdu.event_id)? { @@ -521,14 +549,13 @@ impl Service { &incoming_pdu, None::, // TODO: third party invite |k, s| { - services() - .rooms + self.services .short .get_shortstatekey(&k.to_string().into(), s) .ok() .flatten() .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) + .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) }, ) .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; @@ -541,7 +568,7 @@ impl Service { } debug!("Gathering auth events"); - let auth_events = services().rooms.state.get_auth_events( + let auth_events = self.services.state.get_auth_events( room_id, &incoming_pdu.kind, &incoming_pdu.sender, @@ -562,7 +589,7 @@ impl Service { && match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &incoming_pdu.redacts { - !services().rooms.state_accessor.user_can_redact( + !self.services.state_accessor.user_can_redact( redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, @@ -577,7 +604,7 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - !services().rooms.state_accessor.user_can_redact( + !self.services.state_accessor.user_can_redact( redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, @@ -594,12 +621,12 @@ impl Service { // We start looking at current room state now, so lets lock the room trace!("Locking the room"); - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = self.services.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) trace!("Calculating extremities"); - let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; + let mut extremities = self.services.state.get_forward_extremities(room_id)?; trace!("Calculated {} extremities", extremities.len()); // Remove any forward extremities that are referenced by this incoming event's @@ -609,22 +636,13 @@ impl Service { } // Only keep those extremities were not referenced yet - extremities.retain(|id| { - !matches!( - services() - .rooms - .pdu_metadata - .is_event_referenced(room_id, id), - Ok(true) - ) - }); + extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); debug!("Retained {} extremities. Compressing state", extremities.len()); let state_ids_compressed = Arc::new( state_at_incoming_event .iter() .map(|(shortstatekey, id)| { - services() - .rooms + self.services .state_compressor .compress_state_event(*shortstatekey, id) }) @@ -637,8 +655,8 @@ impl Service { // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; @@ -651,13 +669,12 @@ impl Service { // Set the new room state to the resolved state debug!("Forcing new room state"); - let (sstatehash, new, removed) = services() - .rooms + let (sstatehash, new, removed) = self + .services .state_compressor .save_state(room_id, new_room_state)?; - services() - .rooms + self.services .state .force_state(room_id, sstatehash, new, removed, &state_lock) .await?; @@ -667,8 +684,7 @@ impl Service { // if not soft fail it if soft_fail { debug!("Soft failing event"); - services() - .rooms + self.services .timeline .append_incoming_pdu( &incoming_pdu, @@ -682,8 +698,7 @@ impl Service { // Soft fail, we keep the event as an outlier but don't add it to the timeline warn!("Event was soft failed: {:?}", incoming_pdu); - services() - .rooms + self.services .pdu_metadata .mark_event_soft_failed(&incoming_pdu.event_id)?; @@ -696,8 +711,8 @@ impl Service { // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately // represent the state for this event. - let pdu_id = services() - .rooms + let pdu_id = self + .services .timeline .append_incoming_pdu( &incoming_pdu, @@ -723,14 +738,14 @@ impl Service { &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, ) -> Result>> { debug!("Loading current room state ids"); - let current_sstatehash = services() - .rooms + let current_sstatehash = self + .services .state .get_room_shortstatehash(room_id)? .expect("every room has state"); - let current_state_ids = services() - .rooms + let current_state_ids = self + .services .state_accessor .state_full_ids(current_sstatehash) .await?; @@ -740,8 +755,7 @@ impl Service { let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { auth_chain_sets.push( - services() - .rooms + self.services .auth_chain .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .await? @@ -755,8 +769,7 @@ impl Service { .map(|map| { map.into_iter() .filter_map(|(k, id)| { - services() - .rooms + self.services .short .get_statekey_from_short(k) .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) @@ -766,11 +779,11 @@ impl Service { }) .collect(); - let lock = services().globals.stateres_mutex.lock(); + let lock = self.services.globals.stateres_mutex.lock(); debug!("Resolving state"); let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); + let res = self.services.timeline.get_pdu(id); if let Err(e) = &res { error!("Failed to fetch event: {}", e); } @@ -793,12 +806,11 @@ impl Service { let new_room_state = state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - services() - .rooms + self.services .state_compressor .compress_state_event(shortstatekey, &event_id) }) @@ -814,15 +826,14 @@ impl Service { &self, incoming_pdu: &Arc, ) -> Result>>> { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = services() - .rooms + let prev_event_sstatehash = self + .services .state_accessor .pdu_shortstatehash(prev_event)?; let state = if let Some(shortstatehash) = prev_event_sstatehash { Some( - services() - .rooms + self.services .state_accessor .state_full_ids(shortstatehash) .await, @@ -833,8 +844,8 @@ impl Service { if let Some(Ok(mut state)) = state { debug!("Using cached state"); - let prev_pdu = services() - .rooms + let prev_pdu = self + .services .timeline .get_pdu(prev_event) .ok() @@ -842,8 +853,8 @@ impl Service { .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; @@ -866,13 +877,13 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = services().rooms.timeline.get_pdu(prev_eventid) else { + let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { okay = false; break; }; - let Ok(Some(sstatehash)) = services() - .rooms + let Ok(Some(sstatehash)) = self + .services .state_accessor .pdu_shortstatehash(prev_eventid) else { @@ -891,15 +902,15 @@ impl Service { let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = services() - .rooms + let mut leaf_state: HashMap<_, _> = self + .services .state_accessor .state_full_ids(sstatehash) .await?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); @@ -910,7 +921,7 @@ impl Service { let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); @@ -921,8 +932,7 @@ impl Service { } auth_chain_sets.push( - services() - .rooms + self.services .auth_chain .event_ids_iter(room_id, starting_events) .await? @@ -932,9 +942,9 @@ impl Service { fork_states.push(state); } - let lock = services().globals.stateres_mutex.lock(); + let lock = self.services.globals.stateres_mutex.lock(); let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); + let res = self.services.timeline.get_pdu(id); if let Err(e) = &res { error!("Failed to fetch event: {}", e); } @@ -947,8 +957,8 @@ impl Service { new_state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; Ok((shortstatekey, event_id)) @@ -974,7 +984,8 @@ impl Service { pub_key_map: &RwLock>>, event_id: &EventId, ) -> Result>>> { debug!("Fetching state ids"); - match services() + match self + .services .sending .send_federation_request( origin, @@ -1004,8 +1015,8 @@ impl Service { .clone() .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; @@ -1022,8 +1033,8 @@ impl Service { } // The original create event must still be in the state - let create_shortstatekey = services() - .rooms + let create_shortstatekey = self + .services .short .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); @@ -1056,7 +1067,8 @@ impl Service { ) -> AsyncRecursiveCanonicalJsonVec<'a> { Box::pin(async move { let back_off = |id| async { - match services() + match self + .services .globals .bad_event_ratelimiter .write() @@ -1075,7 +1087,7 @@ impl Service { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { trace!("Found {} in db", id); events_with_auth_events.push((id, Some(local_pdu), vec![])); continue; @@ -1089,7 +1101,8 @@ impl Service { let mut events_all = HashSet::with_capacity(todo_auth_events.len()); let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -1114,13 +1127,14 @@ impl Service { tokio::task::yield_now().await; } - if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { trace!("Found {} in db", next_id); continue; } debug!("Fetching {} over federation.", next_id); - match services() + match self + .services .sending .send_federation_request( origin, @@ -1195,7 +1209,8 @@ impl Service { pdus.push((local_pdu, None)); } for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -1244,8 +1259,8 @@ impl Service { let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; - let first_pdu_in_room = services() - .rooms + let first_pdu_in_room = self + .services .timeline .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; @@ -1267,19 +1282,18 @@ impl Service { { Self::check_room_id(room_id, &pdu)?; - if amount > services().globals.max_fetch_prev_events() { + if amount > self.services.globals.max_fetch_prev_events() { // Max limit reached debug!( "Max prev event limit reached! Limit: {}", - services().globals.max_fetch_prev_events() + self.services.globals.max_fetch_prev_events() ); graph.insert(prev_event_id.clone(), HashSet::new()); continue; } if let Some(json) = json_opt.or_else(|| { - services() - .rooms + self.services .outlier .get_outlier_pdu_json(&prev_event_id) .ok() @@ -1335,8 +1349,7 @@ impl Service { #[tracing::instrument(skip_all)] pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { let acl_event = if let Some(acl) = - services() - .rooms + self.services .state_accessor .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 8fcd8549..a19862a5 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,29 +1,28 @@ -use conduit::{Err, Error, Result}; +use conduit::{pdu::gen_event_id_canonical_json, warn, Err, Error, Result}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; -use tracing::warn; -use crate::{pdu::gen_event_id_canonical_json, services}; +impl super::Service { + pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {pdu:?}: {e:?}"); + Error::BadServerResponse("Invalid PDU in server response") + })?; -pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {pdu:?}: {e:?}"); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; + let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { + return Err!("Server is not in room {room_id}"); + }; - let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { - return Err!("Server is not in room {room_id}"); - }; + let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + // Event could not be converted to canonical json + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); + }; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { - // Event could not be converted to canonical json - return Err!(Request(InvalidParam("Could not convert event to canonical json."))); - }; - - Ok((event_id, value, room_id)) + Ok((event_id, value, room_id)) + } } diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index 2fa5b0df..1ebcbefb 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -3,7 +3,7 @@ use std::{ time::{Duration, SystemTime}, }; -use conduit::{debug, error, info, trace, warn}; +use conduit::{debug, error, info, trace, warn, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ @@ -21,8 +21,6 @@ use ruma::{ use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard}; -use crate::{services, Error, Result}; - impl super::Service { pub async fn fetch_required_signing_keys<'a, E>( &'a self, events: E, pub_key_map: &RwLock>>, @@ -147,7 +145,8 @@ impl super::Service { debug!("Loading signing keys for {}", origin); - let result: BTreeMap<_, _> = services() + let result: BTreeMap<_, _> = self + .services .globals .signing_keys_for(origin)? .into_iter() @@ -171,9 +170,10 @@ impl super::Service { &self, mut servers: BTreeMap>, pub_key_map: &RwLock>>, ) -> Result<()> { - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking batch signing keys from trusted server {}", server); - match services() + match self + .services .sending .send_federation_request( server, @@ -199,7 +199,8 @@ impl super::Service { // TODO: Check signature from trusted server? servers.remove(&k.server_name); - let result = services() + let result = self + .services .globals .db .add_signing_key(&k.server_name, k.clone())? @@ -234,7 +235,7 @@ impl super::Service { .into_keys() .map(|server| async move { ( - services() + self.services .sending .send_federation_request(&server, get_server_keys::v2::Request::new()) .await, @@ -248,7 +249,8 @@ impl super::Service { if let (Ok(get_keys_response), origin) = result { debug!("Result is from {origin}"); if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = services() + let result: BTreeMap<_, _> = self + .services .globals .db .add_signing_key(&origin, key)? @@ -297,7 +299,7 @@ impl super::Service { return Ok(()); } - if services().globals.query_trusted_key_servers_first() { + if self.services.globals.query_trusted_key_servers_first() { info!( "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ homeserver signing keys." @@ -349,7 +351,8 @@ impl super::Service { ) -> Result> { let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - let mut result: BTreeMap<_, _> = services() + let mut result: BTreeMap<_, _> = self + .services .globals .signing_keys_for(origin)? .into_iter() @@ -362,15 +365,16 @@ impl super::Service { } // i didnt split this out into their own functions because it's relatively small - if services().globals.query_trusted_key_servers_first() { + if self.services.globals.query_trusted_key_servers_first() { info!( "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ keys" ); - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() + if let Some(server_keys) = self + .services .sending .send_federation_request( server, @@ -394,7 +398,10 @@ impl super::Service { }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.db.add_signing_key(origin, k.clone())?; + self.services + .globals + .db + .add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -414,14 +421,15 @@ impl super::Service { } debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() + if let Some(server_key) = self + .services .sending .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services() + self.services .globals .db .add_signing_key(origin, server_key.clone())?; @@ -447,14 +455,15 @@ impl super::Service { info!("query_trusted_key_servers_first is set to false, querying {origin} first"); debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() + if let Some(server_key) = self + .services .sending .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services() + self.services .globals .db .add_signing_key(origin, server_key.clone())?; @@ -477,9 +486,10 @@ impl super::Service { } } - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() + if let Some(server_keys) = self + .services .sending .send_federation_request( server, @@ -503,7 +513,10 @@ impl super::Service { }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.db.add_signing_key(origin, k.clone())?; + self.services + .globals + .db + .add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 96f623f2..64764198 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -6,10 +6,10 @@ use std::{ sync::{Arc, Mutex}, }; -use data::Data; +use conduit::{PduCount, Result}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{PduCount, Result}; +use self::data::Data; pub struct Service { db: Data, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 763dd0e8..efe681b1 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,30 +1,39 @@ use std::sync::Arc; use conduit::{error, utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{OwnedRoomId, RoomId}; -use crate::services; +use crate::{rooms, Dep}; pub(super) struct Data { disabledroomids: Arc, bannedroomids: Arc, roomid_shortroomid: Arc, pduid_pdu: Arc, + services: Services, +} + +struct Services { + short: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { disabledroomids: db["disabledroomids"].clone(), bannedroomids: db["bannedroomids"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(), pduid_pdu: db["pduid_pdu"].clone(), + services: Services { + short: args.depend::("rooms::short"), + }, } } pub(super) fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match services().rooms.short.get_shortroomid(room_id)? { + let prefix = match self.services.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), None => return Ok(false), }; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index ec34a82c..7415c53b 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -3,9 +3,10 @@ mod data; use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{OwnedRoomId, RoomId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index ef50b094..44a83582 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -33,13 +33,13 @@ pub struct Service { pub read_receipt: Arc, pub search: Arc, pub short: Arc, + pub spaces: Arc, pub state: Arc, pub state_accessor: Arc, pub state_cache: Arc, pub state_compressor: Arc, - pub timeline: Arc, pub threads: Arc, + pub timeline: Arc, pub typing: Arc, - pub spaces: Arc, pub user: Arc, } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 24c756fd..d1649da8 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,26 +1,35 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{utils, Error, PduCount, PduEvent, Result}; +use database::Map; use ruma::{EventId, RoomId, UserId}; -use crate::{services, PduCount, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { tofrom_relation: Arc, referencedevents: Arc, softfailedeventids: Arc, + services: Services, +} + +struct Services { + timeline: Dep, } type PdusIterItem = Result<(PduCount, PduEvent)>; type PdusIterator<'a> = Box + 'a>; impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { tofrom_relation: db["tofrom_relation"].clone(), referencedevents: db["referencedevents"].clone(), softfailedeventids: db["softfailedeventids"].clone(), + services: Services { + timeline: args.depend::("rooms::timeline"), + }, } } @@ -57,8 +66,8 @@ impl Data { let mut pduid = shortroomid.to_be_bytes().to_vec(); pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = services() - .rooms + let mut pdu = self + .services .timeline .get_pdu_from_id(&pduid)? .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 05067aa8..7546dcb2 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -2,8 +2,7 @@ mod data; use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{PduCount, PduEvent, Result}; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -11,12 +10,20 @@ use ruma::{ }; use serde::Deserialize; -use crate::{services, PduCount, PduEvent}; +use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + short: Dep, + state_accessor: Dep, + timeline: Dep, +} + #[derive(Clone, Debug, Deserialize)] struct ExtractRelType { rel_type: RelationType, @@ -30,7 +37,12 @@ struct ExtractRelatesToEventId { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + short: args.depend::("rooms::short"), + state_accessor: args.depend::("rooms::state_accessor"), + timeline: args.depend::("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -101,8 +113,7 @@ impl Service { }) .take(limit) .filter(|(_, pdu)| { - services() - .rooms + self.services .state_accessor .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) @@ -147,8 +158,7 @@ impl Service { }) .take(limit) .filter(|(_, pdu)| { - services() - .rooms + self.services .state_accessor .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) @@ -180,10 +190,10 @@ impl Service { pub fn relations_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, ) -> Result> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let room_id = self.services.short.get_or_create_shortroomid(room_id)?; #[allow(unknown_lints)] #[allow(clippy::manual_unwrap_or_default)] - let target = match services().rooms.timeline.get_pdu_count(target)? { + let target = match self.services.timeline.get_pdu_count(target)? { Some(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 06eaf655..0f400ff3 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,14 +1,14 @@ use std::{mem::size_of, sync::Arc}; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; type AnySyncEphemeralRoomEventIter<'a> = Box)>> + 'a>; @@ -16,15 +16,24 @@ type AnySyncEphemeralRoomEventIter<'a> = pub(super) struct Data { roomuserid_privateread: Arc, roomuserid_lastprivatereadupdate: Arc, + services: Services, readreceiptid_readreceipt: Arc, } +struct Services { + globals: Dep, +} + impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { roomuserid_privateread: db["roomuserid_privateread"].clone(), roomuserid_lastprivatereadupdate: db["roomuserid_lastprivatereadupdate"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), + services: Services { + globals: args.depend::("globals"), + }, } } @@ -51,7 +60,7 @@ impl Data { } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); @@ -108,7 +117,7 @@ impl Data { .insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) } pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index 9375276e..d202d893 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -6,16 +6,24 @@ use conduit::Result; use data::Data; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; -use crate::services; +use crate::{sending, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + sending: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + sending: args.depend::("sending"), + }, + db: Data::new(&args), })) } @@ -26,7 +34,7 @@ impl Service { /// Replaces the previous read receipt. pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { self.db.readreceipt_update(user_id, room_id, event)?; - services().sending.flush_room(room_id)?; + self.services.sending.flush_room(room_id)?; Ok(()) } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 79b23cba..a0086095 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,21 +1,30 @@ use std::sync::Arc; use conduit::{utils, Result}; -use database::{Database, Map}; +use database::Map; use ruma::RoomId; -use crate::services; +use crate::{rooms, Dep}; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; pub(super) struct Data { tokenids: Arc, + services: Services, +} + +struct Services { + short: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { tokenids: db["tokenids"].clone(), + services: Services { + short: args.depend::("rooms::short"), + }, } } @@ -51,8 +60,8 @@ impl Data { } pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = services() - .rooms + let prefix = self + .services .short .get_shortroomid(room_id)? .expect("room exists") diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 082dd432..8caa0ce3 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -13,7 +13,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 883c3c1d..963bd927 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{utils, warn, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { eventid_shorteventid: Arc, @@ -13,10 +13,16 @@ pub(super) struct Data { shortstatekey_statekey: Arc, roomid_shortroomid: Arc, statehash_shortstatehash: Arc, + services: Services, +} + +struct Services { + globals: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { eventid_shorteventid: db["eventid_shorteventid"].clone(), shorteventid_eventid: db["shorteventid_eventid"].clone(), @@ -24,6 +30,9 @@ impl Data { shortstatekey_statekey: db["shortstatekey_statekey"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(), statehash_shortstatehash: db["statehash_shortstatehash"].clone(), + services: Services { + globals: args.depend::("globals"), + }, } } @@ -31,7 +40,7 @@ impl Data { let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? } else { - let shorteventid = services().globals.next_count()?; + let shorteventid = self.services.globals.next_count()?; self.eventid_shorteventid .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.shorteventid_eventid @@ -59,7 +68,7 @@ impl Data { utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, ), None => { - let short = services().globals.next_count()?; + let short = self.services.globals.next_count()?; self.eventid_shorteventid .insert(keys[i], &short.to_be_bytes())?; self.shorteventid_eventid @@ -98,7 +107,7 @@ impl Data { let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? } else { - let shortstatekey = services().globals.next_count()?; + let shortstatekey = self.services.globals.next_count()?; self.statekey_shortstatekey .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey @@ -158,7 +167,7 @@ impl Data { true, ) } else { - let shortstatehash = services().globals.next_count()?; + let shortstatehash = self.services.globals.next_count()?; self.statehash_shortstatehash .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) @@ -176,7 +185,7 @@ impl Data { Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? } else { - let short = services().globals.next_count()?; + let short = self.services.globals.next_count()?; self.roomid_shortroomid .insert(room_id.as_bytes(), &short.to_be_bytes())?; short diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 0979fb4f..bfe0e9a0 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -3,9 +3,10 @@ mod data; use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{events::StateEventType, EventId, RoomId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 19a3ebbb..24d612d8 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -28,7 +28,7 @@ use ruma::{ }; use tokio::sync::Mutex; -use crate::services; +use crate::{rooms, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -119,42 +119,18 @@ enum Identifier<'a> { } pub struct Service { + services: Services, pub roomid_spacehierarchy_cache: Mutex>>, } -// Here because cannot implement `From` across ruma-federation-api and -// ruma-client-api types -impl From for SpaceHierarchyRoomsChunk { - fn from(value: CachedSpaceHierarchySummary) -> Self { - let SpaceHierarchyParentSummary { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state, - .. - } = value.summary; - - Self { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state, - } - } +struct Services { + state_accessor: Dep, + state_cache: Dep, + state: Dep, + short: Dep, + event_handler: Dep, + timeline: Dep, + sending: Dep, } impl crate::Service for Service { @@ -163,6 +139,15 @@ impl crate::Service for Service { let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity); let cache_size = cache_size * config.cache_capacity_modifier; Ok(Arc::new(Self { + services: Services { + state_accessor: args.depend::("rooms::state_accessor"), + state_cache: args.depend::("rooms::state_cache"), + state: args.depend::("rooms::state"), + short: args.depend::("rooms::short"), + event_handler: args.depend::("rooms::event_handler"), + timeline: args.depend::("rooms::timeline"), + sending: args.depend::("sending"), + }, roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)), })) } @@ -226,7 +211,7 @@ impl Service { .as_ref() { return Ok(if let Some(cached) = cached { - if is_accessable_child( + if self.is_accessible_child( current_room, &cached.summary.join_rule, &identifier, @@ -242,8 +227,8 @@ impl Service { } Ok( - if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? { - let summary = Self::get_room_summary(current_room, children_pdus, &identifier); + if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { + let summary = self.get_room_summary(current_room, children_pdus, &identifier); if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -269,7 +254,8 @@ impl Service { ) -> Result> { for server in via { debug_info!("Asking {server} for /hierarchy"); - let Ok(response) = services() + let Ok(response) = self + .services .sending .send_federation_request( server, @@ -325,7 +311,10 @@ impl Service { avatar_url, join_rule, room_type, - children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), + children_state: self + .get_stripped_space_child_events(&room_id) + .await? + .unwrap(), allowed_room_ids, } }, @@ -333,7 +322,7 @@ impl Service { ); } } - if is_accessable_child( + if self.is_accessible_child( current_room, &response.room.join_rule, &Identifier::UserId(user_id), @@ -370,12 +359,13 @@ impl Service { } fn get_room_summary( - current_room: &OwnedRoomId, children_state: Vec>, identifier: &Identifier<'_>, + &self, current_room: &OwnedRoomId, children_state: Vec>, + identifier: &Identifier<'_>, ) -> Result { let room_id: &RoomId = current_room; - let join_rule = services() - .rooms + let join_rule = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { @@ -386,12 +376,12 @@ impl Service { .transpose()? .unwrap_or(JoinRule::Invite); - let allowed_room_ids = services() - .rooms + let allowed_room_ids = self + .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { + if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); @@ -400,18 +390,18 @@ impl Service { let join_rule = join_rule.into(); Ok(SpaceHierarchyParentSummary { - canonical_alias: services() - .rooms + canonical_alias: self + .services .state_accessor .get_canonical_alias(room_id) .unwrap_or(None), - name: services() - .rooms + name: self + .services .state_accessor .get_name(room_id) .unwrap_or(None), - num_joined_members: services() - .rooms + num_joined_members: self + .services .state_cache .room_joined_count(room_id) .unwrap_or_default() @@ -422,22 +412,22 @@ impl Service { .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), - topic: services() - .rooms + topic: self + .services .state_accessor .get_room_topic(room_id) .unwrap_or(None), - world_readable: services().rooms.state_accessor.is_world_readable(room_id)?, - guest_can_join: services().rooms.state_accessor.guest_can_join(room_id)?, - avatar_url: services() - .rooms + world_readable: self.services.state_accessor.is_world_readable(room_id)?, + guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + avatar_url: self + .services .state_accessor .get_avatar(room_id)? .into_option() .unwrap_or_default() .url, join_rule, - room_type: services().rooms.state_accessor.get_room_type(room_id)?, + room_type: self.services.state_accessor.get_room_type(room_id)?, children_state, allowed_room_ids, }) @@ -487,7 +477,7 @@ impl Service { .into_iter() .rev() .skip_while(|(room, _)| { - if let Ok(short) = services().rooms.short.get_shortroomid(room) + if let Ok(short) = self.services.short.get_shortroomid(room) { short.as_ref() != short_room_ids.get(parents.len()) } else { @@ -541,7 +531,7 @@ impl Service { let mut short_room_ids = vec![]; for room in parents { - short_room_ids.push(services().rooms.short.get_or_create_shortroomid(&room)?); + short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); } Some( @@ -559,128 +549,152 @@ impl Service { rooms: results, }) } -} -fn next_room_to_traverse( - stack: &mut Vec)>>, parents: &mut VecDeque, -) -> Option<(OwnedRoomId, Vec)> { - while stack.last().map_or(false, Vec::is_empty) { - stack.pop(); - parents.pop_back(); + /// Simply returns the stripped m.space.child events of a room + async fn get_stripped_space_child_events( + &self, room_id: &RoomId, + ) -> Result>>, Error> { + let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + return Ok(None); + }; + + let state = self + .services + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } + + let pdu = self + .services + .timeline + .get_pdu(&id)? + .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + + if serde_json::from_str::(pdu.content.get()) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) + { + continue; + } + + if OwnedRoomId::try_from(state_key).is_ok() { + children_pdus.push(pdu.to_stripped_spacechild_state_event()); + } + } + + Ok(Some(children_pdus)) } - stack.last_mut().and_then(Vec::pop) -} + /// With the given identifier, checks if a room is accessable + fn is_accessible_child( + &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, + allowed_room_ids: &Vec, + ) -> bool { + // Note: unwrap_or_default for bool means false + match identifier { + Identifier::ServerName(server_name) => { + let room_id: &RoomId = current_room; -/// Simply returns the stripped m.space.child events of a room -async fn get_stripped_space_child_events( - room_id: &RoomId, -) -> Result>>, Error> { - let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { - return Ok(None); - }; - - let state = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let mut children_pdus = Vec::new(); - for (key, id) in state { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; - if event_type != StateEventType::SpaceChild { - continue; - } - - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - - if serde_json::from_str::(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; - } - - if OwnedRoomId::try_from(state_key).is_ok() { - children_pdus.push(pdu.to_stripped_spacechild_state_event()); - } - } - - Ok(Some(children_pdus)) -} - -/// With the given identifier, checks if a room is accessable -fn is_accessable_child( - current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, - allowed_room_ids: &Vec, -) -> bool { - // Note: unwrap_or_default for bool means false - match identifier { - Identifier::ServerName(server_name) => { - let room_id: &RoomId = current_room; - - // Checks if ACLs allow for the server to participate - if services() - .rooms - .event_handler - .acl_check(server_name, room_id) - .is_err() - { - return false; - } - }, - Identifier::UserId(user_id) => { - if services() - .rooms - .state_cache - .is_joined(user_id, current_room) - .unwrap_or_default() - || services() - .rooms - .state_cache - .is_invited(user_id, current_room) - .unwrap_or_default() - { - return true; - } - }, - } // Takes care of join rules - match join_rule { - SpaceRoomJoinRule::Restricted => { - for room in allowed_room_ids { - match identifier { - Identifier::UserId(user) => { - if services() - .rooms - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { - return true; - } - }, - Identifier::ServerName(server) => { - if services() - .rooms - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { - return true; - } - }, + // Checks if ACLs allow for the server to participate + if self + .services + .event_handler + .acl_check(server_name, room_id) + .is_err() + { + return false; } - } - false - }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, - // Invite only, Private, or Custom join rule - _ => false, + }, + Identifier::UserId(user_id) => { + if self + .services + .state_cache + .is_joined(user_id, current_room) + .unwrap_or_default() + || self + .services + .state_cache + .is_invited(user_id, current_room) + .unwrap_or_default() + { + return true; + } + }, + } // Takes care of join rules + match join_rule { + SpaceRoomJoinRule::Restricted => { + for room in allowed_room_ids { + match identifier { + Identifier::UserId(user) => { + if self + .services + .state_cache + .is_joined(user, room) + .unwrap_or_default() + { + return true; + } + }, + Identifier::ServerName(server) => { + if self + .services + .state_cache + .server_in_room(server, room) + .unwrap_or_default() + { + return true; + } + }, + } + } + false + }, + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, + // Invite only, Private, or Custom join rule + _ => false, + } + } +} + +// Here because cannot implement `From` across ruma-federation-api and +// ruma-client-api types +impl From for SpaceHierarchyRoomsChunk { + fn from(value: CachedSpaceHierarchySummary) -> Self { + let SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + .. + } = value.summary; + + Self { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + } } } @@ -736,3 +750,14 @@ fn get_parent_children_via( }) .collect() } + +fn next_room_to_traverse( + stack: &mut Vec)>>, parents: &mut VecDeque, +) -> Option<(OwnedRoomId, Vec)> { + while stack.last().map_or(false, Vec::is_empty) { + stack.pop(); + parents.pop_back(); + } + + stack.last_mut().and_then(Vec::pop) +} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index a3a317a5..cb219bc0 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -8,7 +8,7 @@ use std::{ use conduit::{ utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, Result, + warn, Error, PduEvent, Result, }; use data::Data; use ruma::{ @@ -23,19 +23,39 @@ use ruma::{ }; use super::state_compressor::CompressedStateEvent; -use crate::{services, PduEvent}; +use crate::{globals, rooms, Dep}; pub struct Service { + services: Services, db: Data, pub mutex: RoomMutexMap, } +struct Services { + globals: Dep, + short: Dep, + spaces: Dep, + state_cache: Dep, + state_accessor: Dep, + state_compressor: Dep, + timeline: Dep, +} + type RoomMutexMap = MutexMap; pub type RoomMutexGuard = MutexMapGuard; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::("globals"), + short: args.depend::("rooms::short"), + spaces: args.depend::("rooms::spaces"), + state_cache: args.depend::("rooms::state_cache"), + state_accessor: args.depend::("rooms::state_accessor"), + state_compressor: args.depend::("rooms::state_compressor"), + timeline: args.depend::("rooms::timeline"), + }, db: Data::new(args.db), mutex: RoomMutexMap::new(), })) @@ -62,14 +82,13 @@ impl Service { state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else { + let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { continue; }; @@ -94,7 +113,7 @@ impl Service { continue; }; - services().rooms.state_cache.update_membership( + self.services.state_cache.update_membership( room_id, &user_id, membership_event, @@ -105,8 +124,7 @@ impl Service { )?; }, TimelineEventType::SpaceChild => { - services() - .rooms + self.services .spaces .roomid_spacehierarchy_cache .lock() @@ -117,7 +135,7 @@ impl Service { } } - services().rooms.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id)?; self.db .set_room_state(room_id, shortstatehash, state_lock)?; @@ -133,10 +151,7 @@ impl Service { pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(event_id)?; + let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -147,20 +162,15 @@ impl Service { .collect::>(), ); - let (shortstatehash, already_existed) = services() - .rooms + let (shortstatehash, already_existed) = self + .services .short .get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, + |p| self.services.state_compressor.load_shortstatehash_info(p), )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -179,7 +189,7 @@ impl Service { } else { (state_ids_compressed, Arc::new(HashSet::new())) }; - services().rooms.state_compressor.save_state_from_diff( + self.services.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -199,8 +209,8 @@ impl Service { /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { - let shorteventid = services() - .rooms + let shorteventid = self + .services .short .get_or_create_shorteventid(&new_pdu.event_id)?; @@ -214,21 +224,16 @@ impl Service { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), #[inline] - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, + |p| self.services.state_compressor.load_shortstatehash_info(p), )?; - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; - let new = services() - .rooms + let new = self + .services .state_compressor .compress_state_event(shortstatekey, &new_pdu.event_id)?; @@ -246,7 +251,7 @@ impl Service { } // TODO: statehash with deterministic inputs - let shortstatehash = services().globals.next_count()?; + let shortstatehash = self.services.globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); @@ -256,7 +261,7 @@ impl Service { statediffremoved.insert(*replaces); } - services().rooms.state_compressor.save_state_from_diff( + self.services.state_compressor.save_state_from_diff( shortstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), @@ -275,22 +280,20 @@ impl Service { let mut state = Vec::new(); // Add recommended events if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = services().rooms.state_accessor.room_state_get( + if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomCanonicalAlias, "", @@ -298,22 +301,20 @@ impl Service { state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = services().rooms.state_accessor.room_state_get( + if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str(), @@ -339,8 +340,8 @@ impl Service { /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = services() - .rooms + let create_event = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")?; @@ -393,8 +394,7 @@ impl Service { let mut sauthevents = auth_events .into_iter() .filter_map(|(event_type, state_key)| { - services() - .rooms + self.services .short .get_shortstatekey(&event_type.to_string().into(), &state_key) .ok() @@ -403,8 +403,8 @@ impl Service { }) .collect::>(); - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -414,16 +414,14 @@ impl Service { Ok(full_state .iter() .filter_map(|compressed| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(compressed) .ok() }) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) .filter_map(|(k, event_id)| { - services() - .rooms + self.services .timeline .get_pdu(&event_id) .ok() diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 7e9daeda..4c85148d 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,28 +1,43 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{utils, Error, PduEvent, Result}; +use database::Map; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { eventid_shorteventid: Arc, shorteventid_shortstatehash: Arc, + services: Services, +} + +struct Services { + short: Dep, + state: Dep, + state_compressor: Dep, + timeline: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { eventid_shorteventid: db["eventid_shorteventid"].clone(), shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), + services: Services { + short: args.depend::("rooms::short"), + state: args.depend::("rooms::state"), + state_compressor: args.depend::("rooms::state_compressor"), + timeline: args.depend::("rooms::timeline"), + }, } } #[allow(unused_qualifications)] // async traits pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -31,8 +46,8 @@ impl Data { let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { - let parsed = services() - .rooms + let parsed = self + .services .state_compressor .parse_compressed_state_event(compressed)?; result.insert(parsed.0, parsed.1); @@ -49,8 +64,8 @@ impl Data { pub(super) async fn state_full( &self, shortstatehash: u64, ) -> Result>> { - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -60,11 +75,11 @@ impl Data { let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { - let (_, eventid) = services() - .rooms + let (_, eventid) = self + .services .state_compressor .parse_compressed_state_event(compressed)?; - if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { + if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { result.insert( ( pdu.kind.to_string().into(), @@ -92,15 +107,15 @@ impl Data { pub(super) fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { - let Some(shortstatekey) = services() - .rooms + let Some(shortstatekey) = self + .services .short .get_shortstatekey(event_type, state_key)? else { return Ok(None); }; - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -110,8 +125,7 @@ impl Data { .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) .and_then(|compressed| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(compressed) .ok() @@ -125,7 +139,7 @@ impl Data { &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) + .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) } /// Returns the state hash for this pdu. @@ -149,7 +163,7 @@ impl Data { pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_full(current_shortstatehash).await } else { Ok(HashMap::new()) @@ -161,7 +175,7 @@ impl Data { pub(super) fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_get_id(current_shortstatehash, event_type, state_key) } else { Ok(None) @@ -173,7 +187,7 @@ impl Data { pub(super) fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_get(current_shortstatehash, event_type, state_key) } else { Ok(None) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index bd3eb0a1..2526f1bd 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,7 +6,7 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; use data::Data; use lru_cache::LruCache; use ruma::{ @@ -33,14 +33,20 @@ use ruma::{ }; use serde_json::value::to_raw_value; -use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent}; +use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { + services: Services, db: Data, pub server_visibility_cache: Mutex>, pub user_visibility_cache: Mutex>, } +struct Services { + state_cache: Dep, + timeline: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; @@ -50,7 +56,11 @@ impl crate::Service for Service { f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + state_cache: args.depend::("rooms::state_cache"), + timeline: args.depend::("rooms::timeline"), + }, + db: Data::new(&args), server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)), user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)), })) @@ -164,8 +174,8 @@ impl Service { }) .unwrap_or(HistoryVisibility::Shared); - let mut current_server_members = services() - .rooms + let mut current_server_members = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -212,7 +222,7 @@ impl Service { return Ok(*visibility); } - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; let history_visibility = self .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? @@ -258,7 +268,7 @@ impl Service { /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; let history_visibility = self .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? @@ -342,8 +352,8 @@ impl Service { redacts: None, }; - Ok(services() - .rooms + Ok(self + .services .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) .is_ok()) @@ -413,7 +423,7 @@ impl Service { // Falling back on m.room.create to judge power level if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { + || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { pdu.sender == sender } else { false @@ -430,7 +440,7 @@ impl Service { .map(|event: RoomPowerLevels| { event.user_can_redact_event_of_other(sender) || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { + && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { if federation { pdu.sender.server_name() == sender.server_name() } else { diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 2b9fbe94..cbda73cf 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -4,7 +4,7 @@ use std::{ }; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use itertools::Itertools; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, @@ -12,44 +12,55 @@ use ruma::{ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{appservice::RegistrationInfo, services, user_is_local}; +use crate::{appservice::RegistrationInfo, globals, user_is_local, users, Dep}; type StrippedStateEventIter<'a> = Box>)>> + 'a>; type AnySyncStateEventIter<'a> = Box>)>> + 'a>; type AppServiceInRoomCache = RwLock>>; pub(super) struct Data { - userroomid_joined: Arc, - roomuserid_joined: Arc, - userroomid_invitestate: Arc, - roomuserid_invitecount: Arc, - userroomid_leftstate: Arc, - roomuserid_leftcount: Arc, - roomid_inviteviaservers: Arc, - roomuseroncejoinedids: Arc, - roomid_joinedcount: Arc, - roomid_invitedcount: Arc, - roomserverids: Arc, - serverroomids: Arc, pub(super) appservice_in_room_cache: AppServiceInRoomCache, + roomid_invitedcount: Arc, + roomid_inviteviaservers: Arc, + roomid_joinedcount: Arc, + roomserverids: Arc, + roomuserid_invitecount: Arc, + roomuserid_joined: Arc, + roomuserid_leftcount: Arc, + roomuseroncejoinedids: Arc, + serverroomids: Arc, + userroomid_invitestate: Arc, + userroomid_joined: Arc, + userroomid_leftstate: Arc, + services: Services, +} + +struct Services { + globals: Dep, + users: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { - userroomid_joined: db["userroomid_joined"].clone(), - roomuserid_joined: db["roomuserid_joined"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - roomuserid_invitecount: db["roomuserid_invitecount"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - roomuserid_leftcount: db["roomuserid_leftcount"].clone(), - roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(), - roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), - roomid_joinedcount: db["roomid_joinedcount"].clone(), - roomid_invitedcount: db["roomid_invitedcount"].clone(), - roomserverids: db["roomserverids"].clone(), - serverroomids: db["serverroomids"].clone(), appservice_in_room_cache: RwLock::new(HashMap::new()), + roomid_invitedcount: db["roomid_invitedcount"].clone(), + roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(), + roomid_joinedcount: db["roomid_joinedcount"].clone(), + roomserverids: db["roomserverids"].clone(), + roomuserid_invitecount: db["roomuserid_invitecount"].clone(), + roomuserid_joined: db["roomuserid_joined"].clone(), + roomuserid_leftcount: db["roomuserid_leftcount"].clone(), + roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), + serverroomids: db["serverroomids"].clone(), + userroomid_invitestate: db["userroomid_invitestate"].clone(), + userroomid_joined: db["userroomid_joined"].clone(), + userroomid_leftstate: db["userroomid_leftstate"].clone(), + services: Services { + globals: args.depend::("globals"), + users: args.depend::("users"), + }, } } @@ -100,7 +111,7 @@ impl Data { &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), )?; self.roomuserid_invitecount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; @@ -144,7 +155,7 @@ impl Data { &serde_json::to_vec(&Vec::>::new()).unwrap(), )?; // TODO self.roomuserid_leftcount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -228,7 +239,7 @@ impl Data { } else { let bridge_user_id = UserId::parse_with_server_name( appservice.registration.sender_localpart.as_str(), - services().globals.server_name(), + self.services.globals.server_name(), ) .ok(); @@ -356,7 +367,7 @@ impl Data { ) -> Box + 'a> { Box::new( self.local_users_in_room(room_id) - .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), + .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), ) } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 48215817..ac2f688e 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -21,16 +21,28 @@ use ruma::{ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{appservice::RegistrationInfo, services, user_is_local}; +use crate::{account_data, appservice::RegistrationInfo, rooms, user_is_local, users, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + account_data: Dep, + state_accessor: Dep, + users: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + account_data: args.depend::("account_data"), + state_accessor: args.depend::("rooms::state_accessor"), + users: args.depend::("users"), + }, + db: Data::new(&args), })) } @@ -54,18 +66,18 @@ impl Service { // update #[allow(clippy::collapsible_if)] if !user_is_local(user_id) { - if !services().users.exists(user_id)? { - services().users.create(user_id, None)?; + if !self.services.users.exists(user_id)? { + self.services.users.create(user_id, None)?; } /* // Try to update our local copy of the user if ours does not match - if ((services().users.displayname(user_id)? != membership_event.displayname) - || (services().users.avatar_url(user_id)? != membership_event.avatar_url) - || (services().users.blurhash(user_id)? != membership_event.blurhash)) + if ((self.services.users.displayname(user_id)? != membership_event.displayname) + || (self.services.users.avatar_url(user_id)? != membership_event.avatar_url) + || (self.services.users.blurhash(user_id)? != membership_event.blurhash)) && (membership != MembershipState::Leave) { - let response = services() + let response = self.services .sending .send_federation_request( user_id.server_name(), @@ -76,9 +88,9 @@ impl Service { ) .await; - services().users.set_displayname(user_id, response.displayname.clone()).await?; - services().users.set_avatar_url(user_id, response.avatar_url).await?; - services().users.set_blurhash(user_id, response.blurhash).await?; + self.services.users.set_displayname(user_id, response.displayname.clone()).await?; + self.services.users.set_avatar_url(user_id, response.avatar_url).await?; + self.services.users.set_blurhash(user_id, response.blurhash).await?; }; */ } @@ -91,8 +103,8 @@ impl Service { self.db.mark_as_once_joined(user_id, room_id)?; // Check if the room has a predecessor - if let Some(predecessor) = services() - .rooms + if let Some(predecessor) = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .and_then(|create| serde_json::from_str(create.content.get()).ok()) @@ -124,21 +136,23 @@ impl Service { // .ok(); // Copy old tags to new room - if let Some(tag_event) = services() + if let Some(tag_event) = self + .services .account_data .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? .map(|event| { serde_json::from_str(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { - services() + self.services .account_data .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) .ok(); }; // Copy direct chat flag - if let Some(direct_event) = services() + if let Some(direct_event) = self + .services .account_data .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? .map(|event| { @@ -156,7 +170,7 @@ impl Service { } if room_ids_updated { - services().account_data.update( + self.services.account_data.update( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), @@ -171,7 +185,8 @@ impl Service { }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = services() + let is_ignored = self + .services .account_data .get( None, // Ignored users are in global account data @@ -393,8 +408,8 @@ impl Service { /// See #[tracing::instrument(skip(self))] pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { - let most_powerful_user_server = services() - .rooms + let most_powerful_user_server = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .map(|pdu| { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 422c562b..2550774e 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -13,7 +13,7 @@ use lru_cache::LruCache; use ruma::{EventId, RoomId}; use self::data::StateDiff; -use crate::services; +use crate::{rooms, Dep}; type StateInfoLruCache = Mutex< LruCache< @@ -48,16 +48,25 @@ pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Service { db: Data, - + services: Services, pub stateinfo_cache: StateInfoLruCache, } +struct Services { + short: Dep, + state: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { db: Data::new(args.db), + services: Services { + short: args.depend::("rooms::short"), + state: args.depend::("rooms::state"), + }, stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } @@ -124,8 +133,8 @@ impl Service { pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( - &services() - .rooms + &self + .services .short .get_or_create_shorteventid(event_id)? .to_be_bytes(), @@ -138,7 +147,7 @@ impl Service { pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { Ok(( utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), - services().rooms.short.get_eventid_from_short( + self.services.short.get_eventid_from_short( utils::u64_from_bytes(&compressed_event[size_of::()..]).expect("bytes have right length"), )?, )) @@ -282,7 +291,7 @@ impl Service { pub fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc>, ) -> HashSetCompressStateEvent { - let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -291,8 +300,8 @@ impl Service { .collect::>(), ); - let (new_shortstatehash, already_existed) = services() - .rooms + let (new_shortstatehash, already_existed) = self + .services .short .get_or_create_shortstatehash(&state_hash)?; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index c4a1a294..fb279a00 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,29 +1,40 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, Result}; -use database::{Database, Map}; +use conduit::{checked, utils, Error, PduEvent, Result}; +use database::Map; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; type PduEventIterResult<'a> = Result> + 'a>>; pub(super) struct Data { threadid_userids: Arc, + services: Services, +} + +struct Services { + short: Dep, + timeline: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { threadid_userids: db["threadid_userids"].clone(), + services: Services { + short: args.depend::("rooms::short"), + timeline: args.depend::("rooms::timeline"), + }, } } pub(super) fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, ) -> PduEventIterResult<'a> { - let prefix = services() - .rooms + let prefix = self + .services .short .get_shortroomid(room_id)? .expect("room exists") @@ -40,8 +51,8 @@ impl Data { .map(move |(pduid, _users)| { let count = utils::u64_from_bytes(&pduid[(size_of::())..]) .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = services() - .rooms + let mut pdu = self + .services .timeline .get_pdu_from_id(&pduid)? .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index dd2686b0..ae51cd0f 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, Result}; +use conduit::{Error, PduEvent, Result}; use data::Data; use ruma::{ api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, @@ -11,16 +11,24 @@ use ruma::{ }; use serde_json::json; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + timeline: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + timeline: args.depend::("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -35,22 +43,22 @@ impl Service { } pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { - let root_id = &services() - .rooms + let root_id = self + .services .timeline .get_pdu_id(root_event_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; - let root_pdu = services() - .rooms + let root_pdu = self + .services .timeline - .get_pdu_from_id(root_id)? + .get_pdu_from_id(&root_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; - let mut root_pdu_json = services() - .rooms + let mut root_pdu_json = self + .services .timeline - .get_pdu_json_from_id(root_id)? + .get_pdu_json_from_id(&root_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json @@ -93,20 +101,19 @@ impl Service { ); } - services() - .rooms + self.services .timeline - .replace_pdu(root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(root_id)? { + if let Some(userids) = self.db.get_participants(&root_id)? { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); } users.push(pdu.sender.clone()); - self.db.update_participants(root_id, &users) + self.db.update_participants(&root_id, &users) } } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index ec975b99..5917e96b 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,19 +4,25 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{checked, error, utils, Error, Result}; +use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; use database::{Database, Map}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{services, PduCount, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { + eventid_outlierpdu: Arc, eventid_pduid: Arc, pduid_pdu: Arc, - eventid_outlierpdu: Arc, - userroomid_notificationcount: Arc, userroomid_highlightcount: Arc, + userroomid_notificationcount: Arc, pub(super) lasttimelinecount_cache: LastTimelineCountCache, + pub(super) db: Arc, + services: Services, +} + +struct Services { + short: Dep, } type PdusIterItem = Result<(PduCount, PduEvent)>; @@ -24,14 +30,19 @@ type PdusIterator<'a> = Box + 'a>; type LastTimelineCountCache = Mutex>; impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { + eventid_outlierpdu: db["eventid_outlierpdu"].clone(), eventid_pduid: db["eventid_pduid"].clone(), pduid_pdu: db["pduid_pdu"].clone(), - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(), + userroomid_notificationcount: db["userroomid_notificationcount"].clone(), lasttimelinecount_cache: Mutex::new(HashMap::new()), + db: args.db.clone(), + services: Services { + short: args.depend::("rooms::short"), + }, } } @@ -210,7 +221,7 @@ impl Data { /// happened before the event with id `until` in reverse-chronological /// order. pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result> { - let (prefix, current) = count_to_id(room_id, until, 1, true)?; + let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; let user_id = user_id.to_owned(); @@ -232,7 +243,7 @@ impl Data { } pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result> { - let (prefix, current) = count_to_id(room_id, from, 1, false)?; + let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; let user_id = user_id.to_owned(); @@ -277,6 +288,41 @@ impl Data { .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; Ok(()) } + + pub(super) fn count_to_id( + &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, + ) -> Result<(Vec, Vec)> { + let prefix = self + .services + .short + .get_shortroomid(room_id)? + .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .to_be_bytes() + .to_vec(); + let mut pdu_id = prefix.clone(); + // +1 so we don't send the base event + let count_raw = match count { + PduCount::Normal(x) => { + if subtract { + x.saturating_sub(offset) + } else { + x.saturating_add(offset) + } + }, + PduCount::Backfilled(x) => { + pdu_id.extend_from_slice(&0_u64.to_be_bytes()); + let num = u64::MAX.saturating_sub(x); + if subtract { + num.saturating_sub(offset) + } else { + num.saturating_add(offset) + } + }, + }; + pdu_id.extend_from_slice(&count_raw.to_be_bytes()); + + Ok((prefix, pdu_id)) + } } /// Returns the `count` of this pdu's id. @@ -294,38 +340,3 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { Ok(PduCount::Normal(last_u64)) } } - -pub(super) fn count_to_id( - room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, -) -> Result<(Vec, Vec)> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? - .to_be_bytes() - .to_vec(); - let mut pdu_id = prefix.clone(); - // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } - }, - }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - - Ok((prefix, pdu_id)) -} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 4c5e407a..50d29475 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -7,11 +7,12 @@ use std::{ }; use conduit::{ - debug, error, info, utils, + debug, error, info, + pdu::{EventHash, PduBuilder, PduCount, PduEvent}, + utils, utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, + validated, warn, Error, Result, Server, }; -use data::Data; use itertools::Itertools; use ruma::{ api::{client::error::ErrorKind, federation}, @@ -37,11 +38,10 @@ use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; +use self::data::Data; use crate::{ - appservice::NamespaceRegex, - pdu::{EventHash, PduBuilder}, - rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent}, - server_is_ours, services, PduCount, PduEvent, + account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, + rooms::state_compressor::CompressedStateEvent, sending, server_is_ours, Dep, }; // Update Relationships @@ -67,17 +67,61 @@ struct ExtractBody { } pub struct Service { + services: Services, db: Data, pub mutex_insert: RoomMutexMap, } +struct Services { + server: Arc, + account_data: Dep, + appservice: Dep, + admin: Dep, + alias: Dep, + globals: Dep, + short: Dep, + state: Dep, + state_cache: Dep, + state_accessor: Dep, + pdu_metadata: Dep, + read_receipt: Dep, + sending: Dep, + user: Dep, + pusher: Dep, + threads: Dep, + search: Dep, + spaces: Dep, + event_handler: Dep, +} + type RoomMutexMap = MutexMap; pub type RoomMutexGuard = MutexMapGuard; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + server: args.server.clone(), + account_data: args.depend::("account_data"), + appservice: args.depend::("appservice"), + admin: args.depend::("admin"), + alias: args.depend::("rooms::alias"), + globals: args.depend::("globals"), + short: args.depend::("rooms::short"), + state: args.depend::("rooms::state"), + state_cache: args.depend::("rooms::state_cache"), + state_accessor: args.depend::("rooms::state_accessor"), + pdu_metadata: args.depend::("rooms::pdu_metadata"), + read_receipt: args.depend::("rooms::read_receipt"), + sending: args.depend::("sending"), + user: args.depend::("rooms::user"), + pusher: args.depend::("pusher"), + threads: args.depend::("rooms::threads"), + search: args.depend::("rooms::search"), + spaces: args.depend::("rooms::spaces"), + event_handler: args.depend::("rooms::event_handler"), + }, + db: Data::new(&args), mutex_insert: RoomMutexMap::new(), })) } @@ -217,10 +261,10 @@ impl Service { state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { // Coalesce database writes for the remainder of this scope. - let _cork = services().db.cork_and_flush(); + let _cork = self.db.db.cork_and_flush(); - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(&pdu.room_id)? .expect("room exists"); @@ -233,14 +277,14 @@ impl Service { .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = services() - .rooms + if let Some(shortstatehash) = self + .services .state_accessor .pdu_shortstatehash(&pdu.event_id) .unwrap() { - if let Some(prev_state) = services() - .rooms + if let Some(prev_state) = self + .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) .unwrap() @@ -270,30 +314,26 @@ impl Service { } // We must keep track of all events that have been referenced. - services() - .rooms + self.services .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms + self.services .state .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; - let count1 = services().globals.next_count()?; + let count1 = self.services.globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if // appending fails - services() - .rooms + self.services .read_receipt .private_read_set(&pdu.room_id, &pdu.sender, count1)?; - services() - .rooms + self.services .user .reset_notification_counts(&pdu.sender, &pdu.room_id)?; - let count2 = services().globals.next_count()?; + let count2 = self.services.globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); @@ -303,8 +343,8 @@ impl Service { drop(insert_lock); // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -319,8 +359,8 @@ impl Service { let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = services() - .rooms + let mut push_target = self + .services .state_cache .active_local_users_in_room(&pdu.room_id) .collect_vec(); @@ -341,7 +381,8 @@ impl Service { continue; } - let rules_for_user = services() + let rules_for_user = self + .services .account_data .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? .map(|event| { @@ -357,7 +398,7 @@ impl Service { let mut notify = false; for action in - services() + self.services .pusher .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? { @@ -378,8 +419,10 @@ impl Service { highlights.push(user.clone()); } - for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_pdu_push(&pdu_id, user, push_key?)?; + for push_key in self.services.pusher.get_pushkeys(user) { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key?)?; } } @@ -390,11 +433,11 @@ impl Service { TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if services().rooms.state_accessor.user_can_redact( + if self.services.state_accessor.user_can_redact( redact_id, &pdu.sender, &pdu.room_id, @@ -412,7 +455,7 @@ impl Service { })?; if let Some(redact_id) = &content.redacts { - if services().rooms.state_accessor.user_can_redact( + if self.services.state_accessor.user_can_redact( redact_id, &pdu.sender, &pdu.room_id, @@ -433,8 +476,7 @@ impl Service { }, TimelineEventType::SpaceChild => { if let Some(_state_key) = &pdu.state_key { - services() - .rooms + self.services .spaces .roomid_spacehierarchy_cache .lock() @@ -455,7 +497,7 @@ impl Service { let invite_state = match content.membership { MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; + let state = self.services.state.calculate_invite_state(pdu)?; Some(state) }, _ => None, @@ -463,7 +505,7 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - services().rooms.state_cache.update_membership( + self.services.state_cache.update_membership( &pdu.room_id, &target_user_id, content, @@ -479,13 +521,12 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services() - .rooms + self.services .search .index_pdu(shortroomid, &pdu_id, &body)?; - if services().admin.is_admin_command(pdu, &body).await { - services() + if self.services.admin.is_admin_command(pdu, &body).await { + self.services .admin .command(body, Some((*pdu.event_id).into())) .await; @@ -497,8 +538,7 @@ impl Service { if let Ok(content) = serde_json::from_str::(pdu.content.get()) { if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { - services() - .rooms + self.services .pdu_metadata .add_relation(PduCount::Normal(count2), related_pducount)?; } @@ -512,29 +552,25 @@ impl Service { // We need to do it again here, because replies don't have // event_id as a top level field if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { - services() - .rooms + self.services .pdu_metadata .add_relation(PduCount::Normal(count2), related_pducount)?; } }, Relation::Thread(thread) => { - services() - .rooms - .threads - .add_to_thread(&thread.event_id, pdu)?; + self.services.threads.add_to_thread(&thread.event_id, pdu)?; }, _ => {}, // TODO: Aggregate other types } } - for appservice in services().appservice.read().await.values() { - if services() - .rooms + for appservice in self.services.appservice.read().await.values() { + if self + .services .state_cache .appservice_in_room(&pdu.room_id, appservice)? { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; @@ -550,7 +586,7 @@ impl Service { { let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; @@ -567,8 +603,7 @@ impl Service { .map_or(false, |state_key| users.is_match(state_key)) }; let matching_aliases = |aliases: &NamespaceRegex| { - services() - .rooms + self.services .alias .local_aliases_for_room(&pdu.room_id) .filter_map(Result::ok) @@ -579,7 +614,7 @@ impl Service { || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; } @@ -603,8 +638,8 @@ impl Service { redacts, } = pdu_builder; - let prev_events: Vec<_> = services() - .rooms + let prev_events: Vec<_> = self + .services .state .get_forward_extremities(room_id)? .into_iter() @@ -612,28 +647,23 @@ impl Service { .collect(); // If there was no create event yet, assume we are creating a room - let room_version_id = services() - .rooms - .state - .get_room_version(room_id) - .or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content = serde_json::from_str::(content.get()) + .expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); let auth_events = - services() - .rooms + self.services .state .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; @@ -649,8 +679,7 @@ impl Service { if let Some(state_key) = &state_key { if let Some(prev_pdu) = - services() - .rooms + self.services .state_accessor .room_state_get(room_id, &event_type.to_string().into(), state_key)? { @@ -730,12 +759,12 @@ impl Service { // Add origin because synapse likes that (and it's required in the spec) pdu_json.insert( "origin".to_owned(), - to_canonical_value(services().globals.server_name()).expect("server name is a valid CanonicalJsonValue"), + to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"), ); match ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + self.services.globals.server_name().as_str(), + self.services.globals.keypair(), &mut pdu_json, &room_version_id, ) { @@ -763,8 +792,8 @@ impl Service { ); // Generate short event id - let _shorteventid = services() - .rooms + let _shorteventid = self + .services .short .get_or_create_shorteventid(&pdu.event_id)?; @@ -783,7 +812,7 @@ impl Service { state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = services().admin.get_admin_room()? { + if let Some(admin_room) = self.services.admin.get_admin_room()? { if admin_room == room_id { match pdu.event_type() { TimelineEventType::RoomEncryption => { @@ -798,7 +827,7 @@ impl Service { .state_key() .filter(|v| v.starts_with('@')) .unwrap_or(sender.as_str()); - let server_user = &services().globals.server_user.to_string(); + let server_user = &self.services.globals.server_user.to_string(); let content = serde_json::from_str::(pdu.content.get()) .map_err(|_| Error::bad_database("Invalid content in pdu"))?; @@ -812,8 +841,8 @@ impl Service { )); } - let count = services() - .rooms + let count = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -837,8 +866,8 @@ impl Service { )); } - let count = services() - .rooms + let count = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -861,15 +890,14 @@ impl Service { // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match services().rooms.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id)? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if !services().rooms.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { + if !self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + { return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); } }; @@ -879,12 +907,11 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - if !services().rooms.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { + if !self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + { return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); } } @@ -895,7 +922,7 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = services().rooms.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu)?; let pdu_id = self .append_pdu( @@ -910,13 +937,12 @@ impl Service { // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist - services() - .rooms + self.services .state .set_room_state(room_id, statehashid, state_lock)?; - let mut servers: HashSet = services() - .rooms + let mut servers: HashSet = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -936,9 +962,9 @@ impl Service { // Remove our server from the server list since it will be added to it by // room_servers() and/or the if statement above - servers.remove(services().globals.server_name()); + servers.remove(self.services.globals.server_name()); - services() + self.services .sending .send_pdu_servers(servers.into_iter(), &pdu_id)?; @@ -960,18 +986,15 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - services() - .rooms + self.services .state .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; if soft_fail { - services() - .rooms + self.services .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms + self.services .state .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; return Ok(None); @@ -1022,14 +1045,13 @@ impl Service { if let Ok(content) = serde_json::from_str::(pdu.content.get()) { if let Some(body) = content.body { - services() - .rooms + self.services .search .deindex_pdu(shortroomid, &pdu_id, &body)?; } } - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; pdu.redact(room_version_id, reason)?; @@ -1058,8 +1080,8 @@ impl Service { return Ok(()); } - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -1077,8 +1099,8 @@ impl Service { } }); - let room_alias_servers = services() - .rooms + let room_alias_servers = self + .services .alias .local_aliases_for_room(room_id) .filter_map(|alias| { @@ -1090,14 +1112,13 @@ impl Service { let servers = room_mods .chain(room_alias_servers) - .chain(services().globals.config.trusted_servers.clone()) + .chain(self.services.server.config.trusted_servers.clone()) .filter(|server_name| { if server_is_ours(server_name) { return false; } - services() - .rooms + self.services .state_cache .server_in_room(server_name, room_id) .unwrap_or(false) @@ -1105,7 +1126,8 @@ impl Service { for backfill_server in servers { info!("Asking {backfill_server} for backfill"); - let response = services() + let response = self + .services .sending .send_federation_request( &backfill_server, @@ -1141,11 +1163,11 @@ impl Service { &self, origin: &ServerName, pdu: Box, pub_key_map: &RwLock>>, ) -> Result<()> { - let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; // Lock so we cannot backfill the same pdu twice at the same time - let mutex_lock = services() - .rooms + let mutex_lock = self + .services .event_handler .mutex_federation .lock(&room_id) @@ -1158,14 +1180,12 @@ impl Service { return Ok(()); } - services() - .rooms + self.services .event_handler .fetch_required_signing_keys([&value], pub_key_map) .await?; - services() - .rooms + self.services .event_handler .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .await?; @@ -1173,8 +1193,8 @@ impl Service { let value = self.get_pdu_json(&event_id)?.expect("We just created it"); let pdu = self.get_pdu(&event_id)?.expect("We just created it"); - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(&room_id)? .expect("room exists"); @@ -1182,7 +1202,7 @@ impl Service { let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; - let count = services().globals.next_count()?; + let count = self.services.globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); @@ -1197,8 +1217,7 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services() - .rooms + self.services .search .index_pdu(shortroomid, &pdu_id, &body)?; } diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 715e3162..d863f217 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug_info, trace, utils, Result}; +use conduit::{debug_info, trace, utils, Result, Server}; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -8,19 +8,31 @@ use ruma::{ }; use tokio::sync::{broadcast, RwLock}; -use crate::{services, user_is_local}; +use crate::{globals, sending, user_is_local, Dep}; pub struct Service { - pub typing: RwLock>>, // u64 is unix timestamp of timeout - pub last_typing_update: RwLock>, /* timestamp of the last change to - * typing - * users */ + server: Arc, + services: Services, + /// u64 is unix timestamp of timeout + pub typing: RwLock>>, + /// timestamp of the last change to typing users + pub last_typing_update: RwLock>, pub typing_update_sender: broadcast::Sender, } +struct Services { + globals: Dep, + sending: Dep, +} + impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result> { + fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + server: args.server.clone(), + services: Services { + globals: args.depend::("globals"), + sending: args.depend::("sending"), + }, typing: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()), typing_update_sender: broadcast::channel(100).0, @@ -45,14 +57,14 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if user_is_local(user_id) { - Self::federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true)?; } Ok(()) @@ -71,14 +83,14 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if user_is_local(user_id) { - Self::federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false)?; } Ok(()) @@ -126,7 +138,7 @@ impl Service { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } @@ -134,7 +146,7 @@ impl Service { // update federation for user in removable { if user_is_local(&user) { - Self::federation_send(room_id, &user, false)?; + self.federation_send(room_id, &user, false)?; } } } @@ -171,15 +183,15 @@ impl Service { }) } - fn federation_send(room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",); - if !services().globals.config.allow_outgoing_typing { + if !self.server.config.allow_outgoing_typing { return Ok(()); } let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing)); - services() + self.services .sending .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 618caae0..c7131615 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::services; +use crate::{globals, rooms, Dep}; pub(super) struct Data { userroomid_notificationcount: Arc, @@ -12,16 +12,27 @@ pub(super) struct Data { roomuserid_lastnotificationread: Arc, roomsynctoken_shortstatehash: Arc, userroomid_joined: Arc, + services: Services, +} + +struct Services { + globals: Dep, + short: Dep, } impl Data { - pub(super) fn new(db: &Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { userroomid_notificationcount: db["userroomid_notificationcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(), roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), userroomid_joined: db["userroomid_joined"].clone(), + services: Services { + globals: args.depend::("globals"), + short: args.depend::("rooms::short"), + }, } } @@ -39,7 +50,7 @@ impl Data { .insert(&userroom_id, &0_u64.to_be_bytes())?; self.roomuserid_lastnotificationread - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; Ok(()) } @@ -87,8 +98,8 @@ impl Data { pub(super) fn associate_token_shortstatehash( &self, room_id: &RoomId, token: u64, shortstatehash: u64, ) -> Result<()> { - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(room_id)? .expect("room exists"); @@ -101,8 +112,8 @@ impl Data { } pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(room_id)? .expect("room exists"); diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 12124a57..93d38470 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -3,9 +3,10 @@ mod data; use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/sending/appservice.rs b/src/service/sending/appservice.rs index 9e060e81..5ed40ad9 100644 --- a/src/service/sending/appservice.rs +++ b/src/service/sending/appservice.rs @@ -1,16 +1,17 @@ use std::{fmt::Debug, mem}; use bytes::BytesMut; +use conduit::{debug_error, trace, utils, warn, Error, Result}; +use reqwest::Client; use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; -use tracing::{trace, warn}; - -use crate::{debug_error, services, utils, Error, Result}; /// Sends a request to an appservice /// /// Only returns Ok(None) if there is no url specified in the appservice /// registration file -pub(crate) async fn send_request(registration: Registration, request: T) -> Result> +pub(crate) async fn send_request( + client: &Client, registration: Registration, request: T, +) -> Result> where T: OutgoingRequest + Debug + Send, { @@ -48,15 +49,10 @@ where let reqwest_request = reqwest::Request::try_from(http_request)?; - let mut response = services() - .client - .appservice - .execute(reqwest_request) - .await - .map_err(|e| { - warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id); - e - })?; + let mut response = client.execute(reqwest_request).await.map_err(|e| { + warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id); + e + })?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 9cb1c267..6c8e2544 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -5,7 +5,7 @@ use database::{Database, Map}; use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; -use crate::services; +use crate::{globals, Dep}; type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; @@ -15,15 +15,24 @@ pub struct Data { servernameevent_data: Arc, servername_educount: Arc, pub(super) db: Arc, + services: Services, +} + +struct Services { + globals: Dep, } impl Data { - pub(super) fn new(db: Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { servercurrentevent_data: db["servercurrentevent_data"].clone(), servernameevent_data: db["servernameevent_data"].clone(), servername_educount: db["servername_educount"].clone(), - db, + db: args.db.clone(), + services: Services { + globals: args.depend::("globals"), + }, } } @@ -78,7 +87,7 @@ impl Data { if let SendingEvent::Pdu(value) = &event { key.extend_from_slice(value); } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 26f43fd3..6f091b04 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -6,26 +6,39 @@ mod sender; use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, Result, Server}; +use conduit::{err, warn, Result, Server}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -pub use sender::convert_to_outgoing_federation_event; use tokio::sync::Mutex; -use tracing::warn; -use crate::{server_is_ours, services}; +use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_is_ours, users, Dep}; pub struct Service { - pub db: data::Data, server: Arc, - - /// The state for a given state hash. + services: Services, + pub db: data::Data, sender: loole::Sender, receiver: Mutex>, } +struct Services { + client: Dep, + globals: Dep, + resolver: Dep, + state: Dep, + state_cache: Dep, + user: Dep, + users: Dep, + presence: Dep, + read_receipt: Dep, + timeline: Dep, + account_data: Dep, + appservice: Dep, + pusher: Dep, +} + #[derive(Clone, Debug, PartialEq, Eq)] struct Msg { dest: Destination, @@ -53,8 +66,23 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let (sender, receiver) = loole::unbounded(); Ok(Arc::new(Self { - db: data::Data::new(args.db.clone()), server: args.server.clone(), + services: Services { + client: args.depend::("client"), + globals: args.depend::("globals"), + resolver: args.depend::("resolver"), + state: args.depend::("rooms::state"), + state_cache: args.depend::("rooms::state_cache"), + user: args.depend::("rooms::user"), + users: args.depend::("users"), + presence: args.depend::("presence"), + read_receipt: args.depend::("rooms::read_receipt"), + timeline: args.depend::("rooms::timeline"), + account_data: args.depend::("account_data"), + appservice: args.depend::("appservice"), + pusher: args.depend::("pusher"), + }, + db: data::Data::new(&args), sender, receiver: Mutex::new(receiver), })) @@ -103,8 +131,8 @@ impl Service { #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -152,8 +180,8 @@ impl Service { #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -189,8 +217,8 @@ impl Service { #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -213,13 +241,13 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, request), name = "request")] + #[tracing::instrument(skip_all, name = "request")] pub async fn send_federation_request(&self, dest: &ServerName, request: T) -> Result where T: OutgoingRequest + Debug + Send, { - let client = &services().client.federation; - send::send(client, dest, request).await + let client = &self.services.client.federation; + self.send(client, dest, request).await } /// Sends a request to an appservice @@ -232,7 +260,8 @@ impl Service { where T: OutgoingRequest + Debug + Send, { - appservice::send_request(registration, request).await + let client = &self.services.client.appservice; + appservice::send_request(client, registration, request).await } /// Cleanup event data diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 7901de48..b3a84d62 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,6 +1,8 @@ use std::{fmt::Debug, mem}; -use conduit::Err; +use conduit::{ + debug, debug_error, debug_warn, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, Result, +}; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; use reqwest::{Client, Method, Request, Response, Url}; @@ -13,75 +15,91 @@ use ruma::{ server_util::authorization::XMatrix, ServerName, }; -use tracing::{debug, trace}; use crate::{ - debug_error, debug_warn, resolver, + globals, resolver, resolver::{actual::ActualDest, cache::CachedDest}, - services, Error, Result, }; -#[tracing::instrument(skip_all, name = "send")] -pub async fn send(client: &Client, dest: &ServerName, req: T) -> Result -where - T: OutgoingRequest + Debug + Send, -{ - if !services().globals.allow_federation() { - return Err!(Config("allow_federation", "Federation is disabled.")); +impl super::Service { + #[tracing::instrument(skip(self, client, req), name = "send")] + pub async fn send(&self, client: &Client, dest: &ServerName, req: T) -> Result + where + T: OutgoingRequest + Debug + Send, + { + if !self.server.config.allow_federation { + return Err!(Config("allow_federation", "Federation is disabled.")); + } + + let actual = self.services.resolver.get_actual_dest(dest).await?; + let request = self.prepare::(dest, &actual, req).await?; + self.execute::(dest, &actual, request, client).await } - let actual = services().resolver.get_actual_dest(dest).await?; - let request = prepare::(dest, &actual, req).await?; - execute::(client, dest, &actual, request).await -} + async fn execute( + &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client, + ) -> Result + where + T: OutgoingRequest + Debug + Send, + { + let url = request.url().clone(); + let method = request.method().clone(); -async fn execute( - client: &Client, dest: &ServerName, actual: &ActualDest, request: Request, -) -> Result -where - T: OutgoingRequest + Debug + Send, -{ - let method = request.method().clone(); - let url = request.url().clone(); - debug!( - method = ?method, - url = ?url, - "Sending request", - ); - match client.execute(request).await { - Ok(response) => handle_response::(dest, actual, &method, &url, response).await, - Err(e) => handle_error::(dest, actual, &method, &url, e), + debug!(?method, ?url, "Sending request"); + match client.execute(request).await { + Ok(response) => handle_response::(&self.services.resolver, dest, actual, &method, &url, response).await, + Err(error) => handle_error::(dest, actual, &method, &url, error), + } } -} -async fn prepare(dest: &ServerName, actual: &ActualDest, req: T) -> Result -where - T: OutgoingRequest + Debug + Send, -{ - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; + async fn prepare(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result + where + T: OutgoingRequest + Debug + Send, + { + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); - trace!("Preparing request"); + trace!("Preparing request"); + let mut http_request = req + .try_into_http_request::>(&actual.string, SATIR, &VERSIONS) + .map_err(|_| Error::BadServerResponse("Invalid destination"))?; - let mut http_request = req - .try_into_http_request::>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS) - .map_err(|_e| Error::BadServerResponse("Invalid destination"))?; + sign_request::(&self.services.globals, dest, &mut http_request); - sign_request::(dest, &mut http_request); + let request = Request::try_from(http_request)?; + self.validate_url(request.url())?; - let request = Request::try_from(http_request)?; - validate_url(request.url())?; + Ok(request) + } - Ok(request) + fn validate_url(&self, url: &Url) -> Result<()> { + if let Some(url_host) = url.host_str() { + if let Ok(ip) = IPAddress::parse(url_host) { + trace!("Checking request URL IP {ip:?}"); + self.services.resolver.validate_ip(&ip)?; + } + } + + Ok(()) + } } async fn handle_response( - dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, + resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, + mut response: Response, ) -> Result where T: OutgoingRequest + Debug + Send, { - trace!("Received response from {} for {} with {}", actual.string, url, response.url()); let status = response.status(); + trace!( + ?status, ?method, + request_url = ?url, + response_url = ?response.url(), + "Received response from {}", + actual.string, + ); + let mut http_response_builder = http::Response::builder() .status(status) .version(response.version()); @@ -92,11 +110,13 @@ where .expect("http::response::Builder is usable"), ); - trace!("Waiting for response body"); - let body = response.bytes().await.unwrap_or_else(|e| { - debug_error!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout + // TODO: handle timeout + trace!("Waiting for response body..."); + let body = response + .bytes() + .await + .inspect_err(inspect_debug_log) + .unwrap_or_else(|_| Vec::new().into()); let http_response = http_response_builder .body(body) @@ -109,7 +129,7 @@ where let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services().resolver.set_cached_destination( + resolver.set_cached_destination( dest.to_owned(), CachedDest { dest: actual.dest.clone(), @@ -120,7 +140,7 @@ where } match response { - Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")), + Err(_) => Err(Error::BadServerResponse("Server returned bad 200 response.")), Ok(response) => Ok(response), } } @@ -150,7 +170,7 @@ where Err(e.into()) } -fn sign_request(dest: &ServerName, http_request: &mut http::Request>) +fn sign_request(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request>) where T: OutgoingRequest + Debug + Send, { @@ -172,16 +192,12 @@ where .to_string() .into(), ); - req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); + req_map.insert("origin".to_owned(), globals.server_name().as_str().into()); req_map.insert("destination".to_owned(), dest.as_str().into()); let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut req_json, - ) - .expect("our request json is what ruma expects"); + ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json) + .expect("our request json is what ruma expects"); let req_json: serde_json::Map = serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); @@ -207,24 +223,8 @@ where http_request.headers_mut().insert( AUTHORIZATION, - HeaderValue::from(&XMatrix::new( - services().globals.config.server_name.clone(), - dest.to_owned(), - key, - sig, - )), + HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)), ); } } } - -fn validate_url(url: &Url) -> Result<()> { - if let Some(url_host) = url.host_str() { - if let Ok(ip) = IPAddress::parse(url_host) { - trace!("Checking request URL IP {ip:?}"); - resolver::actual::validate_ip(&ip)?; - } - } - - Ok(()) -} diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 774c3d69..0668ce24 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -6,7 +6,11 @@ use std::{ }; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; +use conduit::{ + debug, debug_warn, error, trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs}, + warn, Error, Result, +}; use federation::transactions::send_transaction_message; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -24,8 +28,8 @@ use ruma::{ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, send, Destination, Msg, SendingEvent, Service}; -use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; +use super::{appservice, Destination, Msg, SendingEvent, Service}; +use crate::user_is_local; #[derive(Debug)] enum TransactionStatus { @@ -69,8 +73,8 @@ impl Service { Ok(()) } - fn handle_response( - &self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, + fn handle_response<'a>( + &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { Ok(dest) => self.handle_response_ok(&dest, futures, statuses), @@ -91,8 +95,8 @@ impl Service { }); } - fn handle_response_ok( - &self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, + fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); self.db @@ -113,24 +117,24 @@ impl Service { .mark_as_active(&new_events) .expect("marked as active"); let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(send_events(dest.clone(), new_events_vec))); + futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); } else { statuses.remove(dest); } } - fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let iv = vec![(msg.event, msg.queue_id)]; if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { if !events.is_empty() { - futures.push(Box::pin(send_events(msg.dest, events))); + futures.push(Box::pin(self.send_events(msg.dest, events))); } else { statuses.remove(&msg.dest); } } } - async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let now = Instant::now(); let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS); let deadline = now.checked_add(timeout).unwrap_or(now); @@ -148,7 +152,7 @@ impl Service { debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { @@ -166,12 +170,12 @@ impl Service { for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(send_events(dest.clone(), events))); + futures.push(Box::pin(self.send_events(dest.clone(), events))); } } } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_events( &self, dest: &Destination, @@ -218,7 +222,7 @@ impl Service { Ok(Some(events)) } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> { let (mut allow, mut retry) = (true, false); statuses @@ -244,7 +248,7 @@ impl Service { Ok((allow, retry)) } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu let since = self.db.get_latest_educount(server_name)?; @@ -252,11 +256,11 @@ impl Service { let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - for room_id in services().rooms.state_cache.server_rooms(server_name) { + for room_id in self.services.state_cache.server_rooms(server_name) { let room_id = room_id?; // Look for device list updates in this room device_list_changes.extend( - services() + self.services .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok) @@ -264,7 +268,7 @@ impl Service { ); if self.server.config.allow_outgoing_read_receipts - && !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? + && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? { break; } @@ -287,381 +291,390 @@ impl Service { } if self.server.config.allow_outgoing_presence { - select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; } Ok((events, max_edu_count)) } -} -/// Look for presence -fn select_edus_presence( - server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, -) -> Result { - // Look for presence updates for this server - let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in services().presence.presence_since(since) { - *max_edu_count = cmp::max(count, *max_edu_count); + /// Look for presence + fn select_edus_presence( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec>, + ) -> Result { + // Look for presence updates for this server + let mut presence_updates = Vec::new(); + for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { + *max_edu_count = cmp::max(count, *max_edu_count); - if !user_is_local(&user_id) { - continue; - } + if !user_is_local(&user_id) { + continue; + } - if !services() - .rooms - .state_cache - .server_sees_user(server_name, &user_id)? - { - continue; - } + if !self + .services + .state_cache + .server_sees_user(server_name, &user_id)? + { + continue; + } - let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; - presence_updates.push(PresenceUpdate { - user_id, - presence: presence_event.content.presence, - currently_active: presence_event.content.currently_active.unwrap_or(false), - last_active_ago: presence_event - .content - .last_active_ago - .unwrap_or_else(|| uint!(0)), - status_msg: presence_event.content.status_msg, - }); - - if presence_updates.len() >= SELECT_EDU_LIMIT { - break; - } - } - - if !presence_updates.is_empty() { - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); - } - - Ok(true) -} - -/// Look for read receipts in this room -fn select_edus_receipts( - room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, -) -> Result { - for r in services() - .rooms - .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); - - if !user_is_local(&user_id) { - continue; - } - - let event = serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { - let mut read = BTreeMap::new(); - - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( + let presence_event = self + .services + .presence + .from_json_bytes_to_event(&presence_bytes, &user_id)?; + presence_updates.push(PresenceUpdate { user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); + presence: presence_event.content.presence, + currently_active: presence_event.content.currently_active.unwrap_or(false), + last_active_ago: presence_event + .content + .last_active_ago + .unwrap_or_else(|| uint!(0)), + status_msg: presence_event.content.status_msg, + }); - let receipt_map = ReceiptMap { - read, - }; - - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.to_owned(), receipt_map); - - Edu::Receipt(ReceiptContent { - receipts, - }) - } else { - Error::bad_database("Invalid event type in read_receipts"); - continue; - }; - - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); - - if events.len() >= SELECT_EDU_LIMIT { - return Ok(false); - } - } - - Ok(true) -} - -async fn send_events(dest: Destination, events: Vec) -> SendingResult { - //debug_assert!(!events.is_empty(), "sending empty transaction"); - match dest { - Destination::Normal(ref server) => send_events_dest_normal(&dest, server, events).await, - Destination::Appservice(ref id) => send_events_dest_appservice(&dest, id, events).await, - Destination::Push(ref userid, ref pushkey) => send_events_dest_push(&dest, userid, pushkey, events).await, - } -} - -#[tracing::instrument(skip(dest, events))] -async fn send_events_dest_appservice(dest: &Destination, id: &str, events: Vec) -> SendingResult { - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - services() - .rooms - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); - }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Appservices don't need EDUs (?) and flush only; - // no new content - }, - } - } - - //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); - match appservice::send_request( - services() - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, - ruma::api::appservice::event::push_events::v1::Request { - events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::>(), - ))) - .into(), - }, - ) - .await - { - Ok(_) => Ok(dest.clone()), - Err(e) => Err((dest.clone(), e)), - } -} - -#[tracing::instrument(skip(dest, events))] -async fn send_events_dest_push( - dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec, -) -> SendingResult { - let mut pdus = Vec::new(); - - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => { - pdus.push( - services() - .rooms - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); - }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Push gateways don't need EDUs (?) and flush only; - // no new content - }, - } - } - - for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = serde_json::from_str::(unsigned.get()) { - if unsigned.get("redacted_because").is_some() { - continue; - } + if presence_updates.len() >= SELECT_EDU_LIMIT { + break; } } - let Some(pusher) = services() - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { - continue; - }; + if !presence_updates.is_empty() { + let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); + events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + } - let rules_for_user = services() - .account_data - .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); - - let unread: UInt = services() - .rooms - .user - .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? - .try_into() - .expect("notification count can't go that high"); - - let _response = services() - .pusher - .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) - .await - .map(|_response| dest.clone()) - .map_err(|e| (dest.clone(), e)); + Ok(true) } - Ok(dest.clone()) -} + /// Look for read receipts in this room + fn select_edus_receipts( + &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec>, + ) -> Result { + for r in self + .services + .read_receipt + .readreceipts_since(room_id, since) + { + let (user_id, count, read_receipt) = r?; + *max_edu_count = cmp::max(count, *max_edu_count); -#[tracing::instrument(skip(dest, events), name = "")] -async fn send_events_dest_normal( - dest: &Destination, server: &OwnedServerName, events: Vec, -) -> SendingResult { - let mut pdu_jsons = Vec::with_capacity( - events - .iter() - .filter(|event| matches!(event, SendingEvent::Pdu(_))) - .count(), - ); - let mut edu_jsons = Vec::with_capacity( - events - .iter() - .filter(|event| matches!(event, SendingEvent::Edu(_))) - .count(), - ); + if !user_is_local(&user_id) { + continue; + } - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => pdu_jsons.push(convert_to_outgoing_federation_event( - // TODO: check room version and remove event_id if needed - services() - .rooms - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - )), - SendingEvent::Edu(edu) => { - if let Ok(raw) = serde_json::from_slice(edu) { - edu_jsons.push(raw); - } - }, - SendingEvent::Flush => { - // flush only; no new content + let event = serde_json::from_str(read_receipt.json().get()) + .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { + let mut read = BTreeMap::new(); + + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); + + let receipt_map = ReceiptMap { + read, + }; + + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.to_owned(), receipt_map); + + Edu::Receipt(ReceiptContent { + receipts, + }) + } else { + Error::bad_database("Invalid event type in read_receipts"); + continue; + }; + + events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + + if events.len() >= SELECT_EDU_LIMIT { + return Ok(false); + } + } + + Ok(true) + } + + async fn send_events(&self, dest: Destination, events: Vec) -> SendingResult { + //debug_assert!(!events.is_empty(), "sending empty transaction"); + match dest { + Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await, + Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await, + Destination::Push(ref userid, ref pushkey) => { + self.send_events_dest_push(&dest, userid, pushkey, events) + .await }, } } - //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty - // transaction"); - send::send( - &services().client.sender, - server, - send_transaction_message::v1::Request { - origin: services().server.config.server_name.clone(), + #[tracing::instrument(skip(self, dest, events), name = "appservice")] + async fn send_events_dest_appservice( + &self, dest: &Destination, id: &str, events: Vec, + ) -> SendingResult { + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEvent::Pdu(pdu_id) => { + pdu_jsons.push( + self.services + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), + ) + })? + .to_room_event(), + ); + }, + SendingEvent::Edu(_) | SendingEvent::Flush => { + // Appservices don't need EDUs (?) and flush only; + // no new content + }, + } + } + + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); + let client = &self.services.client.appservice; + match appservice::send_request( + client, + self.services + .appservice + .get_registration(id) + .await + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Appservice] Could not load registration from db."), + ) + })?, + ruma::api::appservice::event::push_events::v1::Request { + events: pdu_jsons, + txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::>(), + ))) + .into(), + }, + ) + .await + { + Ok(_) => Ok(dest.clone()), + Err(e) => Err((dest.clone(), e)), + } + } + + #[tracing::instrument(skip(self, dest, events), name = "push")] + async fn send_events_dest_push( + &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec, + ) -> SendingResult { + let mut pdus = Vec::new(); + + for event in &events { + match event { + SendingEvent::Pdu(pdu_id) => { + pdus.push( + self.services + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Push] Event in servernameevent_data not found in db."), + ) + })?, + ); + }, + SendingEvent::Edu(_) | SendingEvent::Flush => { + // Push gateways don't need EDUs (?) and flush only; + // no new content + }, + } + } + + for pdu in pdus { + // Redacted events are not notification targets (we don't send push for them) + if let Some(unsigned) = &pdu.unsigned { + if let Ok(unsigned) = serde_json::from_str::(unsigned.get()) { + if unsigned.get("redacted_because").is_some() { + continue; + } + } + } + + let Some(pusher) = self + .services + .pusher + .get_pusher(userid, pushkey) + .map_err(|e| (dest.clone(), e))? + else { + continue; + }; + + let rules_for_user = self + .services + .account_data + .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap_or_default() + .and_then(|event| serde_json::from_str::(event.get()).ok()) + .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + + let unread: UInt = self + .services + .user + .notification_count(userid, &pdu.room_id) + .map_err(|e| (dest.clone(), e))? + .try_into() + .expect("notification count can't go that high"); + + let _response = self + .services + .pusher + .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .await + .map(|_response| dest.clone()) + .map_err(|e| (dest.clone(), e)); + } + + Ok(dest.clone()) + } + + #[tracing::instrument(skip(self, dest, events), name = "", level = "debug")] + async fn send_events_dest_normal( + &self, dest: &Destination, server: &OwnedServerName, events: Vec, + ) -> SendingResult { + let mut pdu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Pdu(_))) + .count(), + ); + let mut edu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Edu(_))) + .count(), + ); + + for event in &events { + match event { + // TODO: check room version and remove event_id if needed + SendingEvent::Pdu(pdu_id) => pdu_jsons.push( + self.convert_to_outgoing_federation_event( + self.services + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + error!(?dest, ?server, ?pdu_id, "event not found"); + ( + dest.clone(), + Error::bad_database("[Normal] Event in servernameevent_data not found in db."), + ) + })?, + ), + ), + SendingEvent::Edu(edu) => { + if let Ok(raw) = serde_json::from_slice(edu) { + edu_jsons.push(raw); + } + }, + SendingEvent::Flush => {}, // flush only; no new content + } + } + + //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty + // transaction"); + let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::>(), + )); + + let request = send_transaction_message::v1::Request { + origin: self.server.config.server_name.clone(), pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events + transaction_id: transaction_id.into(), + }; + + let client = &self.services.client.sender; + self.send(client, server, request) + .await + .inspect(|response| { + response + .pdus .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::>(), - ))) - .into(), - }, - ) - .await - .map(|response| { - for pdu in response.pdus { - if pdu.1.is_err() { - warn!("error for {} from remote: {:?}", pdu.0, pdu.1); + .filter(|(_, res)| res.is_err()) + .for_each(|(pdu_id, res)| warn!("error for {pdu_id} from remote: {res:?}")); + }) + .map(|_| dest.clone()) + .map_err(|e| (dest.clone(), e)) + } + + /// This does not return a full `Pdu` it is only to satisfy ruma's types. + pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box { + if let Some(unsigned) = pdu_json + .get_mut("unsigned") + .and_then(|val| val.as_object_mut()) + { + unsigned.remove("transaction_id"); + } + + // room v3 and above removed the "event_id" field from remote PDU format + if let Some(room_id) = pdu_json + .get("room_id") + .and_then(|val| RoomId::parse(val.as_str()?).ok()) + { + match self.services.state.get_room_version(&room_id) { + Ok(room_version_id) => match room_version_id { + RoomVersionId::V1 | RoomVersionId::V2 => {}, + _ => _ = pdu_json.remove("event_id"), + }, + Err(_) => _ = pdu_json.remove("event_id"), } + } else { + pdu_json.remove("event_id"); } - dest.clone() - }) - .map_err(|e| (dest.clone(), e)) -} -/// This does not return a full `Pdu` it is only to satisfy ruma's types. -#[tracing::instrument] -pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { - if let Some(unsigned) = pdu_json - .get_mut("unsigned") - .and_then(|val| val.as_object_mut()) - { - unsigned.remove("transaction_id"); + // TODO: another option would be to convert it to a canonical string to validate + // size and return a Result> + // serde_json::from_str::>( + // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is + // valid serde_json::Value"), ) + // .expect("Raw::from_value always works") + + to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") } - - // room v3 and above removed the "event_id" field from remote PDU format - if let Some(room_id) = pdu_json - .get("room_id") - .and_then(|val| RoomId::parse(val.as_str()?).ok()) - { - match services().rooms.state.get_room_version(&room_id) { - Ok(room_version_id) => match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - _ => _ = pdu_json.remove("event_id"), - }, - Err(_) => _ = pdu_json.remove("event_id"), - } - } else { - pdu_json.remove("event_id"); - } - - // TODO: another option would be to convert it to a canonical string to validate - // size and return a Result> - // serde_json::from_str::>( - // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is - // valid serde_json::Value"), ) - // .expect("Raw::from_value always works") - - to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") } diff --git a/src/service/service.rs b/src/service/service.rs index ce4f15b2..bf3b891b 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,7 +3,7 @@ use std::{ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock}, + sync::{Arc, OnceLock, RwLock}, }; use async_trait::async_trait; @@ -50,20 +50,20 @@ pub(crate) struct Args<'a> { } /// Dep is a reference to a service used within another service. -/// Circular-dependencies between services require this indirection to allow the -/// referenced service construction after the referencing service. +/// Circular-dependencies between services require this indirection. pub(crate) struct Dep { dep: OnceLock>, service: Arc, name: &'static str, } -pub(crate) type Map = BTreeMap; +pub(crate) type Map = RwLock>; pub(crate) type MapVal = (Arc, Arc); -impl Deref for Dep { +impl Deref for Dep { type Target = Arc; + /// Dereference a dependency. The dependency must be ready or panics. fn deref(&self) -> &Self::Target { self.dep .get_or_init(|| require::(&self.service, self.name)) @@ -71,39 +71,61 @@ impl Deref for Dep { } impl Args<'_> { - pub(crate) fn depend_service(&self, name: &'static str) -> Dep { + /// Create a lazy-reference to a service when constructing another Service. + pub(crate) fn depend(&self, name: &'static str) -> Dep { Dep:: { dep: OnceLock::new(), service: self.service.clone(), name, } } + + /// Create a reference immediately to a service when constructing another + /// Service. The other service must be constructed. + pub(crate) fn require(&self, name: &str) -> Arc { require::(self.service, name) } } -pub(crate) fn require(map: &Map, name: &str) -> Arc { +/// Reference a Service by name. Panics if the Service does not exist or was +/// incorrectly cast. +pub(crate) fn require(map: &Map, name: &str) -> Arc { try_get::(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") } -pub(crate) fn try_get(map: &Map, name: &str) -> Result> { - map.get(name).map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { +/// Reference a Service by name. Returns Err if the Service does not exist or +/// was incorrectly cast. +pub(crate) fn try_get(map: &Map, name: &str) -> Result> { + map.read() + .expect("locked for reading") + .get(name) + .map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.clone() + .downcast::() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) +} + +/// Reference a Service by name. Returns None if the Service does not exist, but +/// panics if incorrectly cast. +/// +/// # Panics +/// Incorrect type is not a silent failure (None) as the type never has a reason +/// to be incorrect. +pub(crate) fn get(map: &Map, name: &str) -> Option> { + map.read() + .expect("locked for reading") + .get(name) + .map(|(_, s)| { s.clone() .downcast::() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) -} - -pub(crate) fn get(map: &Map, name: &str) -> Option> { - map.get(name).map(|(_, s)| { - s.clone() - .downcast::() - .expect("Service must be correctly downcast.") - }) + .expect("Service must be correctly downcast.") + }) } +/// Utility for service implementations; see Service::name() in the trait. #[inline] pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index 68205323..59909f8c 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,11 +1,16 @@ -use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{ + any::Any, + collections::BTreeMap, + fmt::Write, + sync::{Arc, RwLock}, +}; use conduit::{debug, debug_info, info, trace, Result, Server}; use database::Database; use tokio::sync::Mutex; use crate::{ - account_data, admin, appservice, client, globals, key_backups, + account_data, admin, appservice, client, emergency, globals, key_backups, manager::Manager, media, presence, pusher, resolver, rooms, sending, service, service::{Args, Map, Service}, @@ -13,22 +18,23 @@ use crate::{ }; pub struct Services { - pub resolver: Arc, - pub client: Arc, - pub globals: Arc, - pub rooms: rooms::Service, - pub appservice: Arc, - pub pusher: Arc, - pub transaction_ids: Arc, - pub uiaa: Arc, - pub users: Arc, pub account_data: Arc, - pub presence: Arc, pub admin: Arc, + pub appservice: Arc, + pub client: Arc, + pub emergency: Arc, + pub globals: Arc, pub key_backups: Arc, pub media: Arc, + pub presence: Arc, + pub pusher: Arc, + pub resolver: Arc, + pub rooms: rooms::Service, pub sending: Arc, + pub transaction_ids: Arc, + pub uiaa: Arc, pub updates: Arc, + pub users: Arc, manager: Mutex>>, pub(crate) service: Arc, @@ -36,37 +42,34 @@ pub struct Services { pub db: Arc, } -macro_rules! build_service { - ($map:ident, $server:ident, $db:ident, $tyname:ty) => {{ - let built = <$tyname>::build(Args { - server: &$server, - db: &$db, - service: &$map, - })?; - - Arc::get_mut(&mut $map) - .expect("must have mutable reference to services collection") - .insert(built.name().to_owned(), (built.clone(), built.clone())); - - trace!("built service #{}: {:?}", $map.len(), built.name()); - built - }}; -} - impl Services { #[allow(clippy::cognitive_complexity)] pub fn build(server: Arc, db: Arc) -> Result { - let mut service: Arc = Arc::new(BTreeMap::new()); + let service: Arc = Arc::new(RwLock::new(BTreeMap::new())); macro_rules! build { - ($srv:ty) => { - build_service!(service, server, db, $srv) - }; + ($tyname:ty) => {{ + let built = <$tyname>::build(Args { + db: &db, + server: &server, + service: &service, + })?; + add_service(&service, built.clone(), built.clone()); + built + }}; } Ok(Self { - globals: build!(globals::Service), + account_data: build!(account_data::Service), + admin: build!(admin::Service), + appservice: build!(appservice::Service), resolver: build!(resolver::Service), client: build!(client::Service), + emergency: build!(emergency::Service), + globals: build!(globals::Service), + key_backups: build!(key_backups::Service), + media: build!(media::Service), + presence: build!(presence::Service), + pusher: build!(pusher::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), auth_chain: build!(rooms::auth_chain::Service), @@ -79,28 +82,22 @@ impl Services { read_receipt: build!(rooms::read_receipt::Service), search: build!(rooms::search::Service), short: build!(rooms::short::Service), + spaces: build!(rooms::spaces::Service), state: build!(rooms::state::Service), state_accessor: build!(rooms::state_accessor::Service), state_cache: build!(rooms::state_cache::Service), state_compressor: build!(rooms::state_compressor::Service), - timeline: build!(rooms::timeline::Service), threads: build!(rooms::threads::Service), + timeline: build!(rooms::timeline::Service), typing: build!(rooms::typing::Service), - spaces: build!(rooms::spaces::Service), user: build!(rooms::user::Service), }, - appservice: build!(appservice::Service), - pusher: build!(pusher::Service), + sending: build!(sending::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), - users: build!(users::Service), - account_data: build!(account_data::Service), - presence: build!(presence::Service), - admin: build!(admin::Service), - key_backups: build!(key_backups::Service), - media: build!(media::Service), - sending: build!(sending::Service), updates: build!(updates::Service), + users: build!(users::Service), + manager: Mutex::new(None), service, server, @@ -111,7 +108,7 @@ impl Services { pub(super) async fn start(&self) -> Result<()> { debug_info!("Starting services..."); - globals::migrations::migrations(&self.db, &self.server.config).await?; + globals::migrations::migrations(self).await?; self.manager .lock() .await @@ -144,7 +141,7 @@ impl Services { } pub async fn clear_cache(&self) { - for (service, ..) in self.service.values() { + for (service, ..) in self.service.read().expect("locked for reading").values() { service.clear_cache(); } @@ -159,7 +156,7 @@ impl Services { pub async fn memory_usage(&self) -> Result { let mut out = String::new(); - for (service, ..) in self.service.values() { + for (service, ..) in self.service.read().expect("locked for reading").values() { service.memory_usage(&mut out)?; } @@ -179,23 +176,26 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.iter() { + for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { trace!("Interrupting {name}"); service.interrupt(); } } - pub fn try_get(&self, name: &str) -> Result> - where - T: Any + Send + Sync, - { + pub fn try_get(&self, name: &str) -> Result> { service::try_get::(&self.service, name) } - pub fn get(&self, name: &str) -> Option> - where - T: Any + Send + Sync, - { - service::get::(&self.service, name) - } + pub fn get(&self, name: &str) -> Option> { service::get::(&self.service, name) } +} + +fn add_service(map: &Arc, s: Arc, a: Arc) { + let name = s.name(); + let len = map.read().expect("locked for reading").len(); + + trace!("built service #{len}: {name:?}"); + + map.write() + .expect("locked for writing") + .insert(name.to_owned(), (s, a)); } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 4b953ffb..6041bbd3 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -2,7 +2,7 @@ mod data; use std::sync::Arc; -use conduit::{utils, utils::hash, Error, Result}; +use conduit::{error, utils, utils::hash, Error, Result, Server}; use data::Data; use ruma::{ api::client::{ @@ -11,19 +11,30 @@ use ruma::{ }, CanonicalJsonValue, DeviceId, UserId, }; -use tracing::error; -use crate::services; +use crate::{globals, users, Dep}; pub const SESSION_ID_LENGTH: usize = 32; pub struct Service { + server: Arc, + services: Services, pub db: Data, } +struct Services { + globals: Dep, + users: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + server: args.server.clone(), + services: Services { + globals: args.depend::("globals"), + users: args.depend::("users"), + }, db: Data::new(args.db), })) } @@ -87,11 +98,11 @@ impl Service { return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); }; - let user_id = UserId::parse_with_server_name(username.clone(), services().globals.server_name()) + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; // Check if password is correct - if let Some(hash) = services().users.password_hash(&user_id)? { + if let Some(hash) = self.services.users.password_hash(&user_id)? { let hash_matches = hash::verify_password(password, &hash).is_ok(); if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { @@ -106,7 +117,7 @@ impl Service { uiaainfo.completed.push(AuthType::Password); }, AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { + if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { uiaainfo.completed.push(AuthType::RegistrationToken); } else { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index db69d9b0..f89471cd 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -7,14 +7,20 @@ use ruma::events::room::message::RoomMessageEventContent; use serde::Deserialize; use tokio::{sync::Notify, time::interval}; -use crate::services; +use crate::{admin, client, Dep}; pub struct Service { + services: Services, db: Arc, interrupt: Notify, interval: Duration, } +struct Services { + admin: Dep, + client: Dep, +} + #[derive(Deserialize)] struct CheckForUpdatesResponse { updates: Vec, @@ -35,6 +41,10 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: Services { + admin: args.depend::("admin"), + client: args.depend::("client"), + }, db: args.db["global"].clone(), interrupt: Notify::new(), interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), @@ -63,7 +73,8 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip_all)] async fn handle_updates(&self) -> Result<()> { - let response = services() + let response = self + .services .client .default .get(CHECK_FOR_UPDATES_URL) @@ -78,7 +89,7 @@ impl Service { last_update_id = last_update_id.max(update.id); if update.id > self.last_check_for_updates_id()? { info!("{:#}", update.message); - services() + self.services .admin .send_message(RoomMessageEventContent::text_markdown(format!( "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 5546adb1..2dcde7ce 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, mem::size_of, sync::Arc}; -use conduit::{debug_info, err, utils, warn, Err, Error, Result}; -use database::{Database, Map}; +use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; +use database::Map; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -11,52 +11,65 @@ use ruma::{ OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use crate::{services, users::clean_signatures}; +use crate::{globals, rooms, users::clean_signatures, Dep}; pub struct Data { - userid_password: Arc, + keychangeid_userid: Arc, + keyid_key: Arc, + onetimekeyid_onetimekeys: Arc, + openidtoken_expiresatuserid: Arc, + todeviceid_events: Arc, token_userdeviceid: Arc, - userid_displayname: Arc, + userdeviceid_metadata: Arc, + userdeviceid_token: Arc, + userfilterid_filter: Arc, userid_avatarurl: Arc, userid_blurhash: Arc, userid_devicelistversion: Arc, - userdeviceid_token: Arc, - userdeviceid_metadata: Arc, - onetimekeyid_onetimekeys: Arc, + userid_displayname: Arc, userid_lastonetimekeyupdate: Arc, - keyid_key: Arc, userid_masterkeyid: Arc, + userid_password: Arc, userid_selfsigningkeyid: Arc, userid_usersigningkeyid: Arc, - openidtoken_expiresatuserid: Arc, - keychangeid_userid: Arc, - todeviceid_events: Arc, - userfilterid_filter: Arc, - _db: Arc, + services: Services, +} + +struct Services { + server: Arc, + globals: Dep, + state_cache: Dep, + state_accessor: Dep, } impl Data { - pub(super) fn new(db: Arc) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { - userid_password: db["userid_password"].clone(), + keychangeid_userid: db["keychangeid_userid"].clone(), + keyid_key: db["keyid_key"].clone(), + onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: db["todeviceid_events"].clone(), token_userdeviceid: db["token_userdeviceid"].clone(), - userid_displayname: db["userid_displayname"].clone(), + userdeviceid_metadata: db["userdeviceid_metadata"].clone(), + userdeviceid_token: db["userdeviceid_token"].clone(), + userfilterid_filter: db["userfilterid_filter"].clone(), userid_avatarurl: db["userid_avatarurl"].clone(), userid_blurhash: db["userid_blurhash"].clone(), userid_devicelistversion: db["userid_devicelistversion"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), + userid_displayname: db["userid_displayname"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - keyid_key: db["keyid_key"].clone(), userid_masterkeyid: db["userid_masterkeyid"].clone(), + userid_password: db["userid_password"].clone(), userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - keychangeid_userid: db["keychangeid_userid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - _db: db, + services: Services { + server: args.server.clone(), + globals: args.depend::("globals"), + state_cache: args.depend::("rooms::state_cache"), + state_accessor: args.depend::("rooms::state_accessor"), + }, } } @@ -377,7 +390,7 @@ impl Data { )?; self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; Ok(()) } @@ -403,7 +416,7 @@ impl Data { prefix.push(b':'); self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; self.onetimekeyid_onetimekeys .scan_prefix(prefix) @@ -631,16 +644,16 @@ impl Data { } pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services() - .rooms + let count = self.services.globals.next_count()?.to_be_bytes(); + for room_id in self + .services .state_cache .rooms_joined(user_id) .filter_map(Result::ok) { // Don't send key updates to unencrypted rooms - if services() - .rooms + if self + .services .state_accessor .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .is_none() @@ -750,7 +763,7 @@ impl Data { key.push(0xFF); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xFF); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -916,7 +929,7 @@ impl Data { pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { use std::num::Saturating as Sat; - let expires_in = services().globals.config.openid_token_ttl; + let expires_in = self.services.server.config.openid_token_ttl; let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); let mut value = expires_at.0.to_be_bytes().to_vec(); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index e0a4dd1c..4c80d0d3 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -7,7 +7,6 @@ use std::{ }; use conduit::{Error, Result}; -use data::Data; use ruma::{ api::client::{ device::Device, @@ -24,7 +23,8 @@ use ruma::{ UInt, UserId, }; -use crate::services; +use self::data::Data; +use crate::{admin, rooms, Dep}; pub struct SlidingSyncCache { lists: BTreeMap, @@ -36,14 +36,24 @@ pub struct SlidingSyncCache { type DbConnections = Mutex>>>; pub struct Service { + services: Services, pub db: Data, pub connections: DbConnections, } +struct Services { + admin: Dep, + state_cache: Dep, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db.clone()), + services: Services { + admin: args.depend::("admin"), + state_cache: args.depend::("rooms::state_cache"), + }, + db: Data::new(&args), connections: StdMutex::new(BTreeMap::new()), })) } @@ -247,11 +257,8 @@ impl Service { /// Check if a user is an admin pub fn is_admin(&self, user_id: &UserId) -> Result { - if let Some(admin_room_id) = services().admin.get_admin_room()? { - services() - .rooms - .state_cache - .is_joined(user_id, &admin_room_id) + if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + self.services.state_cache.is_joined(user_id, &admin_room_id) } else { Ok(false) }