de-global services for services

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-18 06:37:47 +00:00
parent 992c0a1e58
commit 010e4ee35a
85 changed files with 2480 additions and 1887 deletions

View File

@ -15,7 +15,7 @@ use ruma::{
events::room::message::RoomMessageEventContent, events::room::message::RoomMessageEventContent,
CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName,
}; };
use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent}; use service::services;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@ -189,7 +189,10 @@ pub(super) async fn get_remote_pdu(
debug!("Attempting to parse PDU: {:?}", &response.pdu); debug!("Attempting to parse PDU: {:?}", &response.pdu);
let parsed_pdu = { let parsed_pdu = {
let parsed_result = parse_incoming_pdu(&response.pdu); let parsed_result = services()
.rooms
.event_handler
.parse_incoming_pdu(&response.pdu);
let (event_id, value, room_id) = match parsed_result { let (event_id, value, room_id) = match parsed_result {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
@ -510,7 +513,7 @@ pub(super) async fn force_set_room_state_from_server(
let mut events = Vec::with_capacity(remote_state_response.pdus.len()); let mut events = Vec::with_capacity(remote_state_response.pdus.len());
for pdu in remote_state_response.pdus.clone() { for pdu in remote_state_response.pdus.clone() {
events.push(match parse_incoming_pdu(&pdu) { events.push(match services().rooms.event_handler.parse_incoming_pdu(&pdu) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
warn!("Could not parse PDU, ignoring: {e}"); warn!("Could not parse PDU, ignoring: {e}");

View File

@ -1,3 +1,4 @@
#![recursion_limit = "168"]
#![allow(clippy::wildcard_imports)] #![allow(clippy::wildcard_imports)]
pub(crate) mod appservice; pub(crate) mod appservice;

View File

@ -22,7 +22,11 @@ pub(crate) async fn create_alias_route(
) -> Result<create_alias::v3::Response> { ) -> Result<create_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; services
.rooms
.alias
.appservice_checks(&body.room_alias, &body.appservice_info)
.await?;
// this isn't apart of alias_checks or delete alias route because we should // this isn't apart of alias_checks or delete alias route because we should
// allow removing forbidden room aliases // allow removing forbidden room aliases
@ -61,7 +65,11 @@ pub(crate) async fn delete_alias_route(
) -> Result<delete_alias::v3::Response> { ) -> Result<delete_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; services
.rooms
.alias
.appservice_checks(&body.room_alias, &body.appservice_info)
.await?;
if services if services
.rooms .rooms

View File

@ -43,7 +43,6 @@ use crate::{
service::{ service::{
pdu::{gen_event_id_canonical_json, PduBuilder}, pdu::{gen_event_id_canonical_json, PduBuilder},
rooms::state::RoomMutexGuard, rooms::state::RoomMutexGuard,
sending::convert_to_outgoing_federation_event,
server_is_ours, user_is_local, Services, server_is_ours, user_is_local, Services,
}, },
Ruma, Ruma,
@ -791,7 +790,9 @@ async fn join_room_by_id_helper_remote(
federation::membership::create_join_event::v2::Request { federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id: event_id.to_owned(), event_id: event_id.to_owned(),
pdu: convert_to_outgoing_federation_event(join_event.clone()), pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false, omit_members: false,
}, },
) )
@ -1203,7 +1204,9 @@ async fn join_room_by_id_helper_local(
federation::membership::create_join_event::v2::Request { federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id: event_id.to_owned(), event_id: event_id.to_owned(),
pdu: convert_to_outgoing_federation_event(join_event.clone()), pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false, omit_members: false,
}, },
) )
@ -1431,7 +1434,9 @@ pub(crate) async fn invite_helper(
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id: (*pdu.event_id).to_owned(), event_id: (*pdu.event_id).to_owned(),
room_version: room_version_id.clone(), room_version: room_version_id.clone(),
event: convert_to_outgoing_federation_event(pdu_json.clone()), event: services
.sending
.convert_to_outgoing_federation_event(pdu_json.clone()),
invite_room_state, invite_room_state,
via: services.rooms.state_cache.servers_route_via(room_id).ok(), via: services.rooms.state_cache.servers_route_via(room_id).ok(),
}, },
@ -1763,7 +1768,9 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
federation::membership::create_leave_event::v2::Request { federation::membership::create_leave_event::v2::Request {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id, event_id,
pdu: convert_to_outgoing_federation_event(leave_event.clone()), pdu: services
.sending
.convert_to_outgoing_federation_event(leave_event.clone()),
}, },
) )
.await?; .await?;

View File

@ -475,8 +475,6 @@ async fn handle_left_room(
async fn process_presence_updates( async fn process_presence_updates(
services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId,
) -> Result<()> { ) -> Result<()> {
use crate::service::presence::Presence;
// Take presence updates // Take presence updates
for (user_id, _, presence_bytes) in services.presence.presence_since(since) { for (user_id, _, presence_bytes) in services.presence.presence_since(since) {
if !services if !services
@ -487,7 +485,9 @@ async fn process_presence_updates(
continue; continue;
} }
let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; let presence_event = services
.presence
.from_json_bytes_to_event(&presence_bytes, &user_id)?;
match presence_updates.entry(user_id) { match presence_updates.entry(user_id) {
Entry::Vacant(slot) => { Entry::Vacant(slot) => {
slot.insert(presence_event); slot.insert(presence_event);

View File

@ -1,3 +1,5 @@
#![recursion_limit = "160"]
pub mod client; pub mod client;
pub mod router; pub mod router;
pub mod server; pub mod server;

View File

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::backfill::get_backfill}, api::{client::error::ErrorKind, federation::backfill::get_backfill},
uint, user_id, MilliSecondsSinceUnixEpoch, uint, user_id, MilliSecondsSinceUnixEpoch,
}; };
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -67,7 +66,7 @@ pub(crate) async fn get_backfill_route(
}) })
.map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id))
.filter_map(|r| r.ok().flatten()) .filter_map(|r| r.ok().flatten())
.map(convert_to_outgoing_federation_event) .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(); .collect();
Ok(get_backfill::v1::Response { Ok(get_backfill::v1::Response {

View File

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::event::get_event}, api::{client::error::ErrorKind, federation::event::get_event},
MilliSecondsSinceUnixEpoch, RoomId, MilliSecondsSinceUnixEpoch, RoomId,
}; };
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -50,6 +49,6 @@ pub(crate) async fn get_event_route(
Ok(get_event::v1::Response { Ok(get_event::v1::Response {
origin: services.globals.server_name().to_owned(), origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdu: convert_to_outgoing_federation_event(event), pdu: services.sending.convert_to_outgoing_federation_event(event),
}) })
} }

View File

@ -6,7 +6,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
RoomId, RoomId,
}; };
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -60,7 +59,7 @@ pub(crate) async fn get_event_authorization_route(
Ok(get_event_authorization::v1::Response { Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?)
.map(convert_to_outgoing_federation_event) .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(), .collect(),
}) })
} }

View File

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::event::get_missing_events}, api::{client::error::ErrorKind, federation::event::get_missing_events},
OwnedEventId, RoomId, OwnedEventId, RoomId,
}; };
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -82,7 +81,7 @@ pub(crate) async fn get_missing_events_route(
) )
.map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?,
); );
events.push(convert_to_outgoing_federation_event(pdu)); events.push(services.sending.convert_to_outgoing_federation_event(pdu));
} }
i = i.saturating_add(1); i = i.saturating_add(1);
} }

View File

@ -7,7 +7,7 @@ use ruma::{
serde::JsonObject, serde::JsonObject,
CanonicalJsonValue, EventId, OwnedUserId, CanonicalJsonValue, EventId, OwnedUserId,
}; };
use service::{sending::convert_to_outgoing_federation_event, server_is_ours}; use service::server_is_ours;
use crate::Ruma; use crate::Ruma;
@ -174,6 +174,8 @@ pub(crate) async fn create_invite_route(
} }
Ok(create_invite::v2::Response { Ok(create_invite::v2::Response {
event: convert_to_outgoing_federation_event(signed_event), event: services
.sending
.convert_to_outgoing_federation_event(signed_event),
}) })
} }

View File

@ -21,7 +21,6 @@ use ruma::{
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{ use crate::{
service::rooms::event_handler::parse_incoming_pdu,
services::Services, services::Services,
utils::{self}, utils::{self},
Error, Result, Ruma, Error, Result, Ruma,
@ -89,7 +88,7 @@ async fn handle_pdus(
) -> Result<ResolvedMap> { ) -> Result<ResolvedMap> {
let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); let mut parsed_pdus = Vec::with_capacity(body.pdus.len());
for pdu in &body.pdus { for pdu in &body.pdus {
parsed_pdus.push(match parse_incoming_pdu(pdu) { parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
debug_warn!("Could not parse PDU: {e}"); debug_warn!("Could not parse PDU: {e}");

View File

@ -13,9 +13,7 @@ use ruma::{
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::{ use service::{pdu::gen_event_id_canonical_json, user_is_local, Services};
pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services,
};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::warn; use tracing::warn;
@ -186,12 +184,12 @@ async fn create_join_event(
Ok(create_join_event::v1::RoomState { Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten())
.map(convert_to_outgoing_federation_event) .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(), .collect(),
state: state_ids state: state_ids
.iter() .iter()
.filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten())
.map(convert_to_outgoing_federation_event) .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(), .collect(),
// Event field is required if the room version supports restricted join rules. // Event field is required if the room version supports restricted join rules.
event: Some( event: Some(

View File

@ -3,7 +3,6 @@ use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -44,7 +43,11 @@ pub(crate) async fn get_room_state_route(
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await? .await?
.into_values() .into_values()
.map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())) .map(|id| {
services
.sending
.convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())
})
.collect(); .collect();
let auth_chain_ids = services let auth_chain_ids = services
@ -61,7 +64,7 @@ pub(crate) async fn get_room_state_route(
.timeline .timeline
.get_pdu_json(&id) .get_pdu_json(&id)
.ok()? .ok()?
.map(convert_to_outgoing_federation_event) .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
}) })
.collect(), .collect(),
pdus, pdus,

View File

@ -1,3 +1,5 @@
#![recursion_limit = "160"]
mod layers; mod layers;
mod request; mod request;
mod router; mod router;

View File

@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use conduit::{utils, warn, Error, Result}; use conduit::{utils, warn, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
@ -9,18 +9,27 @@ use ruma::{
RoomId, UserId, RoomId, UserId,
}; };
use crate::services; use crate::{globals, Dep};
pub(super) struct Data { pub(super) struct Data {
roomuserdataid_accountdata: Arc<Map>, roomuserdataid_accountdata: Arc<Map>,
roomusertype_roomuserdataid: Arc<Map>, roomusertype_roomuserdataid: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(),
roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
@ -40,7 +49,7 @@ impl Data {
prefix.push(0xFF); prefix.push(0xFF);
let mut roomuserdataid = prefix.clone(); let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF); roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());

View File

@ -17,7 +17,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -11,10 +11,11 @@ use rustyline_async::{Readline, ReadlineError, ReadlineEvent};
use termimad::MadSkin; use termimad::MadSkin;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use crate::services; use crate::{admin, Dep};
pub struct Console { pub struct Console {
server: Arc<Server>, server: Arc<Server>,
admin: Dep<admin::Service>,
worker_join: Mutex<Option<JoinHandle<()>>>, worker_join: Mutex<Option<JoinHandle<()>>>,
input_abort: Mutex<Option<AbortHandle>>, input_abort: Mutex<Option<AbortHandle>>,
command_abort: Mutex<Option<AbortHandle>>, command_abort: Mutex<Option<AbortHandle>>,
@ -29,6 +30,7 @@ impl Console {
pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> { pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
server: args.server.clone(), server: args.server.clone(),
admin: args.depend::<admin::Service>("admin"),
worker_join: None.into(), worker_join: None.into(),
input_abort: None.into(), input_abort: None.into(),
command_abort: None.into(), command_abort: None.into(),
@ -116,7 +118,8 @@ impl Console {
let _suppression = log::Suppress::new(&self.server); let _suppression = log::Suppress::new(&self.server);
let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?; let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?;
readline.set_tab_completer(Self::tab_complete); let self_ = Arc::clone(self);
readline.set_tab_completer(move |line| self_.tab_complete(line));
self.set_history(&mut readline); self.set_history(&mut readline);
let future = readline.readline(); let future = readline.readline();
@ -154,7 +157,7 @@ impl Console {
} }
async fn process(self: Arc<Self>, line: String) { async fn process(self: Arc<Self>, line: String) {
match services().admin.command_in_place(line, None).await { match self.admin.command_in_place(line, None).await {
Ok(Some(content)) => self.output(content).await, Ok(Some(content)) => self.output(content).await,
Err(e) => error!("processing command: {e}"), Err(e) => error!("processing command: {e}"),
_ => (), _ => (),
@ -184,9 +187,8 @@ impl Console {
history.truncate(HISTORY_LIMIT); history.truncate(HISTORY_LIMIT);
} }
fn tab_complete(line: &str) -> String { fn tab_complete(&self, line: &str) -> String {
services() self.admin
.admin
.complete_command(line) .complete_command(line)
.unwrap_or_else(|| line.to_owned()) .unwrap_or_else(|| line.to_owned())
} }

View File

@ -25,13 +25,14 @@ impl super::Service {
return Ok(()); return Ok(());
}; };
let state_lock = self.state.mutex.lock(&room_id).await; let state_lock = self.services.state.mutex.lock(&room_id).await;
// Use the server user to grant the new admin's power level // Use the server user to grant the new admin's power level
let server_user = &self.globals.server_user; let server_user = &self.services.globals.server_user;
// Invite and join the real user // Invite and join the real user
self.timeline self.services
.timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
@ -55,7 +56,8 @@ impl super::Service {
&state_lock, &state_lock,
) )
.await?; .await?;
self.timeline self.services
.timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
@ -85,7 +87,8 @@ impl super::Service {
users.insert(server_user.clone(), 100.into()); users.insert(server_user.clone(), 100.into());
users.insert(user_id.to_owned(), 100.into()); users.insert(user_id.to_owned(), 100.into());
self.timeline self.services
.timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomPowerLevels, event_type: TimelineEventType::RoomPowerLevels,
@ -105,12 +108,12 @@ impl super::Service {
.await?; .await?;
// Send welcome message // Send welcome message
self.timeline.build_and_append_pdu( self.services.timeline.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMessage, event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&RoomMessageEventContent::text_html( content: to_raw_value(&RoomMessageEventContent::text_html(
format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.globals.server_name()), format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.services.globals.server_name()),
format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.globals.server_name()), format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.services.globals.server_name()),
)) ))
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,

View File

@ -22,15 +22,10 @@ use ruma::{
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local}; use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local, Dep};
pub struct Service { pub struct Service {
server: Arc<Server>, services: Services,
globals: Arc<globals::Service>,
alias: Arc<rooms::alias::Service>,
timeline: Arc<rooms::timeline::Service>,
state: Arc<rooms::state::Service>,
state_cache: Arc<rooms::state_cache::Service>,
sender: Sender<Command>, sender: Sender<Command>,
receiver: Mutex<Receiver<Command>>, receiver: Mutex<Receiver<Command>>,
pub handle: RwLock<Option<Handler>>, pub handle: RwLock<Option<Handler>>,
@ -39,6 +34,15 @@ pub struct Service {
pub console: Arc<console::Console>, pub console: Arc<console::Console>,
} }
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
alias: Dep<rooms::alias::Service>,
timeline: Dep<rooms::timeline::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
#[derive(Debug)] #[derive(Debug)]
pub struct Command { pub struct Command {
pub command: String, pub command: String,
@ -58,12 +62,14 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT);
Ok(Arc::new(Self { Ok(Arc::new(Self {
server: args.server.clone(), services: Services {
globals: args.require_service::<globals::Service>("globals"), server: args.server.clone(),
alias: args.require_service::<rooms::alias::Service>("rooms::alias"), globals: args.depend::<globals::Service>("globals"),
timeline: args.require_service::<rooms::timeline::Service>("rooms::timeline"), alias: args.depend::<rooms::alias::Service>("rooms::alias"),
state: args.require_service::<rooms::state::Service>("rooms::state"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
state_cache: args.require_service::<rooms::state_cache::Service>("rooms::state_cache"), state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
sender, sender,
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
handle: RwLock::new(None), handle: RwLock::new(None),
@ -75,7 +81,7 @@ impl crate::Service for Service {
async fn worker(self: Arc<Self>) -> Result<()> { async fn worker(self: Arc<Self>) -> Result<()> {
let receiver = self.receiver.lock().await; let receiver = self.receiver.lock().await;
let mut signals = self.server.signal.subscribe(); let mut signals = self.services.server.signal.subscribe();
loop { loop {
tokio::select! { tokio::select! {
command = receiver.recv_async() => match command { command = receiver.recv_async() => match command {
@ -116,7 +122,7 @@ impl Service {
pub async fn send_message(&self, message_content: RoomMessageEventContent) { pub async fn send_message(&self, message_content: RoomMessageEventContent) {
if let Ok(Some(room_id)) = self.get_admin_room() { if let Ok(Some(room_id)) = self.get_admin_room() {
let user_id = &self.globals.server_user; let user_id = &self.services.globals.server_user;
self.respond_to_room(message_content, &room_id, user_id) self.respond_to_room(message_content, &room_id, user_id)
.await; .await;
} }
@ -176,7 +182,7 @@ impl Service {
/// Checks whether a given user is an admin of this server /// Checks whether a given user is an admin of this server
pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> {
if let Ok(Some(admin_room)) = self.get_admin_room() { if let Ok(Some(admin_room)) = self.get_admin_room() {
self.state_cache.is_joined(user_id, &admin_room) self.services.state_cache.is_joined(user_id, &admin_room)
} else { } else {
Ok(false) Ok(false)
} }
@ -187,10 +193,15 @@ impl Service {
/// Errors are propagated from the database, and will have None if there is /// Errors are propagated from the database, and will have None if there is
/// no admin room /// no admin room
pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> { pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> {
if let Some(room_id) = self.alias.resolve_local_alias(&self.globals.admin_alias)? { if let Some(room_id) = self
.services
.alias
.resolve_local_alias(&self.services.globals.admin_alias)?
{
if self if self
.services
.state_cache .state_cache
.is_joined(&self.globals.server_user, &room_id)? .is_joined(&self.services.globals.server_user, &room_id)?
{ {
return Ok(Some(room_id)); return Ok(Some(room_id));
} }
@ -207,12 +218,12 @@ impl Service {
return; return;
}; };
let Ok(Some(pdu)) = self.timeline.get_pdu(&in_reply_to.event_id) else { let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else {
return; return;
}; };
let response_sender = if self.is_admin_room(&pdu.room_id) { let response_sender = if self.is_admin_room(&pdu.room_id) {
&self.globals.server_user &self.services.globals.server_user
} else { } else {
&pdu.sender &pdu.sender
}; };
@ -229,7 +240,7 @@ impl Service {
"sender is not admin" "sender is not admin"
); );
let state_lock = self.state.mutex.lock(room_id).await; let state_lock = self.services.state.mutex.lock(room_id).await;
let response_pdu = PduBuilder { let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage, event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&content).expect("event is valid, we just created it"), content: to_raw_value(&content).expect("event is valid, we just created it"),
@ -239,6 +250,7 @@ impl Service {
}; };
if let Err(e) = self if let Err(e) = self
.services
.timeline .timeline
.build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock)
.await .await
@ -266,7 +278,8 @@ impl Service {
redacts: None, redacts: None,
}; };
self.timeline self.services
.timeline
.build_and_append_pdu(response_pdu, user_id, room_id, state_lock) .build_and_append_pdu(response_pdu, user_id, room_id, state_lock)
.await?; .await?;
@ -279,7 +292,7 @@ impl Service {
let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin");
// Admin command with public echo (in admin room) // Admin command with public echo (in admin room)
let server_user = &self.globals.server_user; let server_user = &self.services.globals.server_user;
let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str()); let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str());
// Expected backward branch // Expected backward branch
@ -293,7 +306,7 @@ impl Service {
} }
// Check if server-side command-escape is disabled by configuration // Check if server-side command-escape is disabled by configuration
if is_public_escape && !self.globals.config.admin_escape_commands { if is_public_escape && !self.services.globals.config.admin_escape_commands {
return false; return false;
} }
@ -309,7 +322,7 @@ impl Service {
// This will evaluate to false if the emergency password is set up so that // This will evaluate to false if the emergency password is set up so that
// the administrator can execute commands as conduit // the administrator can execute commands as conduit
let emergency_password_set = self.globals.emergency_password().is_some(); let emergency_password_set = self.services.globals.emergency_password().is_some();
let from_server = pdu.sender == *server_user && !emergency_password_set; let from_server = pdu.sender == *server_user && !emergency_password_set;
if from_server && self.is_admin_room(&pdu.room_id) { if from_server && self.is_admin_room(&pdu.room_id) {
return false; return false;

View File

@ -12,7 +12,7 @@ use ruma::{
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::services; use crate::{sending, Dep};
/// Compiled regular expressions for a namespace /// Compiled regular expressions for a namespace
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -118,9 +118,14 @@ impl TryFrom<Registration> for RegistrationInfo {
pub struct Service { pub struct Service {
pub db: Data, pub db: Data,
services: Services,
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>, registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
} }
struct Services {
sending: Dep<sending::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let mut registration_info = BTreeMap::new(); let mut registration_info = BTreeMap::new();
@ -138,6 +143,9 @@ impl crate::Service for Service {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db, db,
services: Services {
sending: args.depend::<sending::Service>("sending"),
},
registration_info: RwLock::new(registration_info), registration_info: RwLock::new(registration_info),
})) }))
} }
@ -178,7 +186,9 @@ impl Service {
// deletes all active requests for the appservice if there are any so we stop // deletes all active requests for the appservice if there are any so we stop
// sending to the URL // sending to the URL
services().sending.cleanup_events(service_name.to_owned())?; self.services
.sending
.cleanup_events(service_name.to_owned())?;
Ok(()) Ok(())
} }

View File

@ -18,7 +18,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config; let config = &args.server.config;
let resolver = args.require_service::<resolver::Service>("resolver"); let resolver = args.require::<resolver::Service>("resolver");
Ok(Arc::new(Self { Ok(Arc::new(Self {
default: base(config) default: base(config)

View File

@ -0,0 +1,83 @@
use std::sync::Arc;
use async_trait::async_trait;
use conduit::{error, warn, Result};
use ruma::{
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
push::Ruleset,
};
use crate::{account_data, globals, users, Dep};
pub struct Service {
services: Services,
}
struct Services {
account_data: Dep<account_data::Service>,
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}))
}
async fn worker(self: Arc<Self>) -> Result<()> {
self.set_emergency_access()
.inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?;
Ok(())
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Sets the emergency password and push rules for the @conduit account in
/// case emergency password is set
fn set_emergency_access(&self) -> Result<bool> {
let conduit_user = &self.services.globals.server_user;
self.services
.users
.set_password(conduit_user, self.services.globals.emergency_password().as_deref())?;
let (ruleset, pwd_set) = match self.services.globals.emergency_password() {
Some(_) => (Ruleset::server_default(conduit_user), true),
None => (Ruleset::new(), false),
};
self.services.account_data.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
if pwd_set {
warn!(
"The server account emergency password is set! Please unset it as soon as you finish admin account \
recovery! You will be logged out of the server service account when you finish."
);
} else {
// logs out any users still in the server service account and removes sessions
self.services.users.deactivate_account(conduit_user)?;
}
Ok(pwd_set)
}
}

View File

@ -3,7 +3,7 @@ use std::{
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use conduit::{trace, utils, Error, Result}; use conduit::{trace, utils, Error, Result, Server};
use database::{Database, Map}; use database::{Database, Map};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
@ -12,7 +12,7 @@ use ruma::{
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
}; };
use crate::services; use crate::{rooms, Dep};
pub struct Data { pub struct Data {
global: Arc<Map>, global: Arc<Map>,
@ -28,14 +28,23 @@ pub struct Data {
server_signingkeys: Arc<Map>, server_signingkeys: Arc<Map>,
readreceiptid_readreceipt: Arc<Map>, readreceiptid_readreceipt: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>, userid_lastonetimekeyupdate: Arc<Map>,
pub(super) db: Arc<Database>,
counter: RwLock<u64>, counter: RwLock<u64>,
pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
server: Arc<Server>,
short: Dep<rooms::short::Service>,
state_cache: Dep<rooms::state_cache::Service>,
typing: Dep<rooms::typing::Service>,
} }
const COUNTER: &[u8] = b"c"; const COUNTER: &[u8] = b"c";
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
global: db["global"].clone(), global: db["global"].clone(),
todeviceid_events: db["todeviceid_events"].clone(), todeviceid_events: db["todeviceid_events"].clone(),
@ -50,8 +59,14 @@ impl Data {
server_signingkeys: db["server_signingkeys"].clone(), server_signingkeys: db["server_signingkeys"].clone(),
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
db: db.clone(),
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
db: args.db.clone(),
services: Services {
server: args.server.clone(),
short: args.depend::<rooms::short::Service>("rooms::short"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
typing: args.depend::<rooms::typing::Service>("rooms::typing"),
},
} }
} }
@ -118,14 +133,14 @@ impl Data {
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in // Events for rooms we are in
for room_id in services() for room_id in self
.rooms .services
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
{ {
let short_roomid = services() let short_roomid = self
.rooms .services
.short .short
.get_shortroomid(&room_id) .get_shortroomid(&room_id)
.ok() .ok()
@ -143,7 +158,7 @@ impl Data {
// EDUs // EDUs
futures.push(Box::pin(async move { futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await; let _result = self.services.typing.wait_for_update(&room_id).await;
})); }));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
@ -176,12 +191,12 @@ impl Data {
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(async move { futures.push(Box::pin(async move {
while services().server.running() { while self.services.server.running() {
let _result = services().server.signal.subscribe().recv().await; let _result = self.services.server.signal.subscribe().recv().await;
} }
})); }));
if !services().server.running() { if !self.services.server.running() {
return Ok(()); return Ok(());
} }

View File

@ -1,54 +0,0 @@
use conduit::Result;
use ruma::{
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
push::Ruleset,
};
use tracing::{error, warn};
use crate::services;
/// Set emergency access for the conduit user
pub(crate) fn init_emergency_access() {
if let Err(e) = set_emergency_access() {
error!("Could not set the configured emergency password for the conduit user: {e}");
}
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = &services().globals.server_user;
services()
.users
.set_password(conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, pwd_set) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(conduit_user), true),
None => (Ruleset::new(), false),
};
services().account_data.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
if pwd_set {
warn!(
"The server account emergency password is set! Please unset it as soon as you finish admin account \
recovery! You will be logged out of the server service account when you finish."
);
} else {
// logs out any users still in the server service account and removes sessions
services().users.deactivate_account(conduit_user)?;
}
Ok(pwd_set)
}

View File

@ -10,7 +10,6 @@ use std::{
}; };
use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Config, Error, Result}; use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Config, Error, Result};
use database::Database;
use itertools::Itertools; use itertools::Itertools;
use ruma::{ use ruma::{
events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType},
@ -18,7 +17,7 @@ use ruma::{
EventId, OwnedRoomId, RoomId, UserId, EventId, OwnedRoomId, RoomId, UserId,
}; };
use crate::services; use crate::Services;
/// The current schema version. /// The current schema version.
/// - If database is opened at greater version we reject with error. The /// - If database is opened at greater version we reject with error. The
@ -28,13 +27,13 @@ use crate::services;
/// equal or lesser version. These are expected to be backward-compatible. /// equal or lesser version. These are expected to be backward-compatible.
const DATABASE_VERSION: u64 = 13; const DATABASE_VERSION: u64 = 13;
pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<()> { pub(crate) async fn migrations(services: &Services) -> Result<()> {
// Matrix resource ownership is based on the server name; changing it // Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch. // requires recreating the database from scratch.
if services().users.count()? > 0 { if services.users.count()? > 0 {
let conduit_user = &services().globals.server_user; let conduit_user = &services.globals.server_user;
if !services().users.exists(conduit_user)? { if !services.users.exists(conduit_user)? {
error!("The {} server user does not exist, and the database is not new.", conduit_user); error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database( return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.", "Cannot reuse an existing database after changing the server name, please delete the old one first.",
@ -42,15 +41,18 @@ pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<()
} }
} }
if services().users.count()? > 0 { if services.users.count()? > 0 {
migrate(db, config).await migrate(services).await
} else { } else {
fresh(db, config).await fresh(services).await
} }
} }
async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> { async fn fresh(services: &Services) -> Result<()> {
services() let db = &services.db;
let config = &services.server.config;
services
.globals .globals
.db .db
.bump_database_version(DATABASE_VERSION)?; .bump_database_version(DATABASE_VERSION)?;
@ -70,97 +72,100 @@ async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> {
} }
/// Apply any migrations /// Apply any migrations
async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> { async fn migrate(services: &Services) -> Result<()> {
if services().globals.db.database_version()? < 1 { let db = &services.db;
db_lt_1(db, config).await?; let config = &services.server.config;
if services.globals.db.database_version()? < 1 {
db_lt_1(services).await?;
} }
if services().globals.db.database_version()? < 2 { if services.globals.db.database_version()? < 2 {
db_lt_2(db, config).await?; db_lt_2(services).await?;
} }
if services().globals.db.database_version()? < 3 { if services.globals.db.database_version()? < 3 {
db_lt_3(db, config).await?; db_lt_3(services).await?;
} }
if services().globals.db.database_version()? < 4 { if services.globals.db.database_version()? < 4 {
db_lt_4(db, config).await?; db_lt_4(services).await?;
} }
if services().globals.db.database_version()? < 5 { if services.globals.db.database_version()? < 5 {
db_lt_5(db, config).await?; db_lt_5(services).await?;
} }
if services().globals.db.database_version()? < 6 { if services.globals.db.database_version()? < 6 {
db_lt_6(db, config).await?; db_lt_6(services).await?;
} }
if services().globals.db.database_version()? < 7 { if services.globals.db.database_version()? < 7 {
db_lt_7(db, config).await?; db_lt_7(services).await?;
} }
if services().globals.db.database_version()? < 8 { if services.globals.db.database_version()? < 8 {
db_lt_8(db, config).await?; db_lt_8(services).await?;
} }
if services().globals.db.database_version()? < 9 { if services.globals.db.database_version()? < 9 {
db_lt_9(db, config).await?; db_lt_9(services).await?;
} }
if services().globals.db.database_version()? < 10 { if services.globals.db.database_version()? < 10 {
db_lt_10(db, config).await?; db_lt_10(services).await?;
} }
if services().globals.db.database_version()? < 11 { if services.globals.db.database_version()? < 11 {
db_lt_11(db, config).await?; db_lt_11(services).await?;
} }
if services().globals.db.database_version()? < 12 { if services.globals.db.database_version()? < 12 {
db_lt_12(db, config).await?; db_lt_12(services).await?;
} }
// This migration can be reused as-is anytime the server-default rules are // This migration can be reused as-is anytime the server-default rules are
// updated. // updated.
if services().globals.db.database_version()? < 13 { if services.globals.db.database_version()? < 13 {
db_lt_13(db, config).await?; db_lt_13(services).await?;
} }
if db["global"].get(b"feat_sha256_media")?.is_none() { if db["global"].get(b"feat_sha256_media")?.is_none() {
migrate_sha256_media(db, config).await?; migrate_sha256_media(services).await?;
} else if config.media_startup_check { } else if config.media_startup_check {
checkup_sha256_media(db, config).await?; checkup_sha256_media(services).await?;
} }
if db["global"] if db["global"]
.get(b"fix_bad_double_separator_in_state_cache")? .get(b"fix_bad_double_separator_in_state_cache")?
.is_none() .is_none()
{ {
fix_bad_double_separator_in_state_cache(db, config).await?; fix_bad_double_separator_in_state_cache(services).await?;
} }
if db["global"] if db["global"]
.get(b"retroactively_fix_bad_data_from_roomuserid_joined")? .get(b"retroactively_fix_bad_data_from_roomuserid_joined")?
.is_none() .is_none()
{ {
retroactively_fix_bad_data_from_roomuserid_joined(db, config).await?; retroactively_fix_bad_data_from_roomuserid_joined(services).await?;
} }
assert_eq!( assert_eq!(
services().globals.db.database_version().unwrap(), services.globals.db.database_version().unwrap(),
DATABASE_VERSION, DATABASE_VERSION,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}", "Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services().globals.db.database_version().unwrap(), services.globals.db.database_version().unwrap(),
DATABASE_VERSION, DATABASE_VERSION,
); );
{ {
let patterns = services().globals.forbidden_usernames(); let patterns = services.globals.forbidden_usernames();
if !patterns.is_empty() { if !patterns.is_empty() {
for user_id in services() for user_id in services
.users .users
.iter() .iter()
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true)) .filter(|user| !services.users.is_deactivated(user).unwrap_or(true))
.filter(|user| user.server_name() == config.server_name) .filter(|user| user.server_name() == config.server_name)
{ {
let matches = patterns.matches(user_id.localpart()); let matches = patterns.matches(user_id.localpart());
@ -179,11 +184,11 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> {
} }
{ {
let patterns = services().globals.forbidden_alias_names(); let patterns = services.globals.forbidden_alias_names();
if !patterns.is_empty() { if !patterns.is_empty() {
for address in services().rooms.metadata.iter_ids() { for address in services.rooms.metadata.iter_ids() {
let room_id = address?; let room_id = address?;
let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id); let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases { for room_alias_result in room_aliases {
let room_alias = room_alias_result?; let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias()); let matches = patterns.matches(room_alias.alias());
@ -211,7 +216,9 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> {
Ok(()) Ok(())
} }
async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_1(services: &Services) -> Result<()> {
let db = &services.db;
let roomserverids = &db["roomserverids"]; let roomserverids = &db["roomserverids"];
let serverroomids = &db["serverroomids"]; let serverroomids = &db["serverroomids"];
for (roomserverid, _) in roomserverids.iter() { for (roomserverid, _) in roomserverids.iter() {
@ -228,12 +235,14 @@ async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> {
serverroomids.insert(&serverroomid, &[])?; serverroomids.insert(&serverroomid, &[])?;
} }
services().globals.db.bump_database_version(1)?; services.globals.db.bump_database_version(1)?;
info!("Migration: 0 -> 1 finished"); info!("Migration: 0 -> 1 finished");
Ok(()) Ok(())
} }
async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_2(services: &Services) -> Result<()> {
let db = &services.db;
// We accidentally inserted hashed versions of "" into the db instead of just "" // We accidentally inserted hashed versions of "" into the db instead of just ""
let userid_password = &db["roomserverids"]; let userid_password = &db["roomserverids"];
for (userid, password) in userid_password.iter() { for (userid, password) in userid_password.iter() {
@ -245,12 +254,14 @@ async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> {
} }
} }
services().globals.db.bump_database_version(2)?; services.globals.db.bump_database_version(2)?;
info!("Migration: 1 -> 2 finished"); info!("Migration: 1 -> 2 finished");
Ok(()) Ok(())
} }
async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_3(services: &Services) -> Result<()> {
let db = &services.db;
// Move media to filesystem // Move media to filesystem
let mediaid_file = &db["mediaid_file"]; let mediaid_file = &db["mediaid_file"];
for (key, content) in mediaid_file.iter() { for (key, content) in mediaid_file.iter() {
@ -259,41 +270,45 @@ async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> {
} }
#[allow(deprecated)] #[allow(deprecated)]
let path = services().media.get_media_file(&key); let path = services.media.get_media_file(&key);
let mut file = fs::File::create(path)?; let mut file = fs::File::create(path)?;
file.write_all(&content)?; file.write_all(&content)?;
mediaid_file.insert(&key, &[])?; mediaid_file.insert(&key, &[])?;
} }
services().globals.db.bump_database_version(3)?; services.globals.db.bump_database_version(3)?;
info!("Migration: 2 -> 3 finished"); info!("Migration: 2 -> 3 finished");
Ok(()) Ok(())
} }
async fn db_lt_4(_db: &Arc<Database>, config: &Config) -> Result<()> { async fn db_lt_4(services: &Services) -> Result<()> {
// Add federated users to services() as deactivated let config = &services.server.config;
for our_user in services().users.iter() {
// Add federated users to services as deactivated
for our_user in services.users.iter() {
let our_user = our_user?; let our_user = our_user?;
if services().users.is_deactivated(&our_user)? { if services.users.is_deactivated(&our_user)? {
continue; continue;
} }
for room in services().rooms.state_cache.rooms_joined(&our_user) { for room in services.rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) { for user in services.rooms.state_cache.room_members(&room?) {
let user = user?; let user = user?;
if user.server_name() != config.server_name { if user.server_name() != config.server_name {
info!(?user, "Migration: creating user"); info!(?user, "Migration: creating user");
services().users.create(&user, None)?; services.users.create(&user, None)?;
} }
} }
} }
} }
services().globals.db.bump_database_version(4)?; services.globals.db.bump_database_version(4)?;
info!("Migration: 3 -> 4 finished"); info!("Migration: 3 -> 4 finished");
Ok(()) Ok(())
} }
async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_5(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade user data store // Upgrade user data store
let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"];
let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"];
@ -312,26 +327,30 @@ async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> {
roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
} }
services().globals.db.bump_database_version(5)?; services.globals.db.bump_database_version(5)?;
info!("Migration: 4 -> 5 finished"); info!("Migration: 4 -> 5 finished");
Ok(()) Ok(())
} }
async fn db_lt_6(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_6(services: &Services) -> Result<()> {
let db = &services.db;
// Set room member count // Set room member count
let roomid_shortstatehash = &db["roomid_shortstatehash"]; let roomid_shortstatehash = &db["roomid_shortstatehash"];
for (roomid, _) in roomid_shortstatehash.iter() { for (roomid, _) in roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap(); let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?; services.rooms.state_cache.update_joined_count(room_id)?;
} }
services().globals.db.bump_database_version(6)?; services.globals.db.bump_database_version(6)?;
info!("Migration: 5 -> 6 finished"); info!("Migration: 5 -> 6 finished");
Ok(()) Ok(())
} }
async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_7(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade state store // Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None; let mut current_sstatehash: Option<u64> = None;
@ -347,7 +366,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
let states_parents = last_roomsstatehash.map_or_else( let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
|&last_roomsstatehash| { |&last_roomsstatehash| {
services() services
.rooms .rooms
.state_compressor .state_compressor
.load_shortstatehash_info(last_roomsstatehash) .load_shortstatehash_info(last_roomsstatehash)
@ -371,7 +390,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
(current_state, HashSet::new()) (current_state, HashSet::new())
}; };
services().rooms.state_compressor.save_state_from_diff( services.rooms.state_compressor.save_state_from_diff(
current_sstatehash, current_sstatehash,
Arc::new(statediffnew), Arc::new(statediffnew),
Arc::new(statediffremoved), Arc::new(statediffremoved),
@ -380,7 +399,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
)?; )?;
/* /*
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?; let mut tmp = services.rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap(); let state = tmp.pop().unwrap();
println!( println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}", "{}\t{}{:?}: {:?} + {:?} - {:?}",
@ -425,12 +444,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap(); let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services() let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap();
.rooms
.timeline
.get_pdu(event_id)
.unwrap()
.unwrap();
if Some(&pdu.room_id) != current_room.as_ref() { if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone()); current_room = Some(pdu.room_id.clone());
@ -451,12 +465,14 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
)?; )?;
} }
services().globals.db.bump_database_version(7)?; services.globals.db.bump_database_version(7)?;
info!("Migration: 6 -> 7 finished"); info!("Migration: 6 -> 7 finished");
Ok(()) Ok(())
} }
async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_8(services: &Services) -> Result<()> {
let db = &services.db;
let roomid_shortstatehash = &db["roomid_shortstatehash"]; let roomid_shortstatehash = &db["roomid_shortstatehash"];
let roomid_shortroomid = &db["roomid_shortroomid"]; let roomid_shortroomid = &db["roomid_shortroomid"];
let pduid_pdu = &db["pduid_pdu"]; let pduid_pdu = &db["pduid_pdu"];
@ -464,7 +480,7 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> {
// Generate short room ids for all rooms // Generate short room ids for all rooms
for (room_id, _) in roomid_shortstatehash.iter() { for (room_id, _) in roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes(); let shortroomid = services.globals.next_count()?.to_be_bytes();
roomid_shortroomid.insert(&room_id, &shortroomid)?; roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8"); info!("Migration: 8");
} }
@ -517,12 +533,14 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> {
eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?;
services().globals.db.bump_database_version(8)?; services.globals.db.bump_database_version(8)?;
info!("Migration: 7 -> 8 finished"); info!("Migration: 7 -> 8 finished");
Ok(()) Ok(())
} }
async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_9(services: &Services) -> Result<()> {
let db = &services.db;
let tokenids = &db["tokenids"]; let tokenids = &db["tokenids"];
let roomid_shortroomid = &db["roomid_shortroomid"]; let roomid_shortroomid = &db["roomid_shortroomid"];
@ -574,12 +592,14 @@ async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> {
tokenids.remove(&key)?; tokenids.remove(&key)?;
} }
services().globals.db.bump_database_version(9)?; services.globals.db.bump_database_version(9)?;
info!("Migration: 8 -> 9 finished"); info!("Migration: 8 -> 9 finished");
Ok(()) Ok(())
} }
async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_10(services: &Services) -> Result<()> {
let db = &services.db;
let statekey_shortstatekey = &db["statekey_shortstatekey"]; let statekey_shortstatekey = &db["statekey_shortstatekey"];
let shortstatekey_statekey = &db["shortstatekey_statekey"]; let shortstatekey_statekey = &db["shortstatekey_statekey"];
@ -589,28 +609,30 @@ async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> {
} }
// Force E2EE device list updates so we can send them over federation // Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(Result::ok) { for user_id in services.users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?; services.users.mark_device_key_update(&user_id)?;
} }
services().globals.db.bump_database_version(10)?; services.globals.db.bump_database_version(10)?;
info!("Migration: 9 -> 10 finished"); info!("Migration: 9 -> 10 finished");
Ok(()) Ok(())
} }
#[allow(unreachable_code)] #[allow(unreachable_code)]
async fn db_lt_11(_db: &Arc<Database>, _config: &Config) -> Result<()> { async fn db_lt_11(services: &Services) -> Result<()> {
todo!("Dropping a column to clear data is not implemented yet."); error!("Dropping a column to clear data is not implemented yet.");
//let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"];
//userdevicesessionid_uiaarequest.clear()?; //userdevicesessionid_uiaarequest.clear()?;
services().globals.db.bump_database_version(11)?; services.globals.db.bump_database_version(11)?;
info!("Migration: 10 -> 11 finished"); info!("Migration: 10 -> 11 finished");
Ok(()) Ok(())
} }
async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> { async fn db_lt_12(services: &Services) -> Result<()> {
for username in services().users.list_local_users()? { let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
@ -619,7 +641,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
}, },
}; };
let raw_rules_list = services() let raw_rules_list = services
.account_data .account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap() .unwrap()
@ -664,7 +686,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
} }
} }
services().account_data.update( services.account_data.update(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -672,13 +694,15 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
)?; )?;
} }
services().globals.db.bump_database_version(12)?; services.globals.db.bump_database_version(12)?;
info!("Migration: 11 -> 12 finished"); info!("Migration: 11 -> 12 finished");
Ok(()) Ok(())
} }
async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { async fn db_lt_13(services: &Services) -> Result<()> {
for username in services().users.list_local_users()? { let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
@ -687,7 +711,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
}, },
}; };
let raw_rules_list = services() let raw_rules_list = services
.account_data .account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap() .unwrap()
@ -701,7 +725,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
.global .global
.update_with_server_default(user_default_rules); .update_with_server_default(user_default_rules);
services().account_data.update( services.account_data.update(
None, None,
&user, &user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -709,7 +733,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
)?; )?;
} }
services().globals.db.bump_database_version(13)?; services.globals.db.bump_database_version(13)?;
info!("Migration: 12 -> 13 finished"); info!("Migration: 12 -> 13 finished");
Ok(()) Ok(())
} }
@ -717,15 +741,17 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
/// Migrates a media directory from legacy base64 file names to sha2 file names. /// Migrates a media directory from legacy base64 file names to sha2 file names.
/// All errors are fatal. Upon success the database is keyed to not perform this /// All errors are fatal. Upon success the database is keyed to not perform this
/// again. /// again.
async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn migrate_sha256_media(services: &Services) -> Result<()> {
let db = &services.db;
warn!("Migrating legacy base64 file names to sha256 file names"); warn!("Migrating legacy base64 file names to sha256 file names");
let mediaid_file = &db["mediaid_file"]; let mediaid_file = &db["mediaid_file"];
// Move old media files to new names // Move old media files to new names
let mut changes = Vec::<(PathBuf, PathBuf)>::new(); let mut changes = Vec::<(PathBuf, PathBuf)>::new();
for (key, _) in mediaid_file.iter() { for (key, _) in mediaid_file.iter() {
let old = services().media.get_media_file_b64(&key); let old = services.media.get_media_file_b64(&key);
let new = services().media.get_media_file_sha256(&key); let new = services.media.get_media_file_sha256(&key);
debug!(?key, ?old, ?new, num = changes.len(), "change"); debug!(?key, ?old, ?new, num = changes.len(), "change");
changes.push((old, new)); changes.push((old, new));
} }
@ -739,8 +765,8 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()
// Apply fix from when sha256_media was backward-incompat and bumped the schema // Apply fix from when sha256_media was backward-incompat and bumped the schema
// version from 13 to 14. For users satisfying these conditions we can go back. // version from 13 to 14. For users satisfying these conditions we can go back.
if services().globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { if services.globals.db.database_version()? == 14 && DATABASE_VERSION == 13 {
services().globals.db.bump_database_version(13)?; services.globals.db.bump_database_version(13)?;
} }
db["global"].insert(b"feat_sha256_media", &[])?; db["global"].insert(b"feat_sha256_media", &[])?;
@ -752,14 +778,16 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()
/// - Going back and forth to non-sha256 legacy binaries (e.g. upstream). /// - Going back and forth to non-sha256 legacy binaries (e.g. upstream).
/// - Deletion of artifacts in the media directory which will then fall out of /// - Deletion of artifacts in the media directory which will then fall out of
/// sync with the database. /// sync with the database.
async fn checkup_sha256_media(db: &Arc<Database>, config: &Config) -> Result<()> { async fn checkup_sha256_media(services: &Services) -> Result<()> {
use crate::media::encode_key; use crate::media::encode_key;
debug!("Checking integrity of media directory"); debug!("Checking integrity of media directory");
let db = &services.db;
let media = &services.media;
let config = &services.server.config;
let mediaid_file = &db["mediaid_file"]; let mediaid_file = &db["mediaid_file"];
let mediaid_user = &db["mediaid_user"]; let mediaid_user = &db["mediaid_user"];
let dbs = (mediaid_file, mediaid_user); let dbs = (mediaid_file, mediaid_user);
let media = &services().media;
let timer = Instant::now(); let timer = Instant::now();
let dir = media.get_media_dir(); let dir = media.get_media_dir();
@ -791,6 +819,7 @@ async fn handle_media_check(
new_path: &OsStr, old_path: &OsStr, new_path: &OsStr, old_path: &OsStr,
) -> Result<()> { ) -> Result<()> {
use crate::media::encode_key; use crate::media::encode_key;
let (mediaid_file, mediaid_user) = dbs; let (mediaid_file, mediaid_user) = dbs;
let old_exists = files.contains(old_path); let old_exists = files.contains(old_path);
@ -827,8 +856,10 @@ async fn handle_media_check(
Ok(()) Ok(())
} }
async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> {
warn!("Fixing bad double separator in state_cache roomuserid_joined"); warn!("Fixing bad double separator in state_cache roomuserid_joined");
let db = &services.db;
let roomuserid_joined = &db["roomuserid_joined"]; let roomuserid_joined = &db["roomuserid_joined"];
let _cork = db.cork_and_sync(); let _cork = db.cork_and_sync();
@ -864,11 +895,13 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &C
Ok(()) Ok(())
} }
async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _config: &Config) -> Result<()> { async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> {
warn!("Retroactively fixing bad data from broken roomuserid_joined"); warn!("Retroactively fixing bad data from broken roomuserid_joined");
let db = &services.db;
let _cork = db.cork_and_sync(); let _cork = db.cork_and_sync();
let room_ids = services() let room_ids = services
.rooms .rooms
.metadata .metadata
.iter_ids() .iter_ids()
@ -878,7 +911,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for room_id in room_ids.clone() { for room_id in room_ids.clone() {
debug_info!("Fixing room {room_id}"); debug_info!("Fixing room {room_id}");
let users_in_room = services() let users_in_room = services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
@ -888,7 +921,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
let joined_members = users_in_room let joined_members = users_in_room
.iter() .iter()
.filter(|user_id| { .filter(|user_id| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.get_member(&room_id, user_id) .get_member(&room_id, user_id)
@ -900,7 +933,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
let non_joined_members = users_in_room let non_joined_members = users_in_room
.iter() .iter()
.filter(|user_id| { .filter(|user_id| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.get_member(&room_id, user_id) .get_member(&room_id, user_id)
@ -913,7 +946,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for user_id in joined_members { for user_id in joined_members {
debug_info!("User is joined, marking as joined"); debug_info!("User is joined, marking as joined");
services() services
.rooms .rooms
.state_cache .state_cache
.mark_as_joined(user_id, &room_id)?; .mark_as_joined(user_id, &room_id)?;
@ -921,10 +954,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for user_id in non_joined_members { for user_id in non_joined_members {
debug_info!("User is left or banned, marking as left"); debug_info!("User is left or banned, marking as left");
services() services.rooms.state_cache.mark_as_left(user_id, &room_id)?;
.rooms
.state_cache
.mark_as_left(user_id, &room_id)?;
} }
} }
@ -933,7 +963,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
"Updating joined count for room {room_id} to fix servers in room after correcting membership states" "Updating joined count for room {room_id} to fix servers in room after correcting membership states"
); );
services().rooms.state_cache.update_joined_count(&room_id)?; services.rooms.state_cache.update_joined_count(&room_id)?;
} }
db.db.cleanup()?; db.db.cleanup()?;

View File

@ -1,5 +1,4 @@
mod data; mod data;
mod emerg_access;
pub(super) mod migrations; pub(super) mod migrations;
use std::{ use std::{
@ -9,7 +8,6 @@ use std::{
time::Instant, time::Instant,
}; };
use async_trait::async_trait;
use conduit::{error, trace, Config, Result}; use conduit::{error, trace, Config, Result};
use data::Data; use data::Data;
use ipaddress::IPAddress; use ipaddress::IPAddress;
@ -43,11 +41,10 @@ pub struct Service {
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
#[async_trait]
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let db = Data::new(&args);
let config = &args.server.config; let config = &args.server.config;
let db = Data::new(args.db);
let keypair = db.load_keypair(); let keypair = db.load_keypair();
let keypair = match keypair { let keypair = match keypair {
@ -104,19 +101,13 @@ impl crate::Service for Service {
.supported_room_versions() .supported_room_versions()
.contains(&s.config.default_room_version) .contains(&s.config.default_room_version)
{ {
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); error!(config=?s.config.default_room_version, fallback=?conduit::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = crate::config::default_default_room_version(); s.config.default_room_version = conduit::config::default_default_room_version();
}; };
Ok(Arc::new(s)) Ok(Arc::new(s))
} }
async fn worker(self: Arc<Self>) -> Result<()> {
emerg_access::init_emergency_access();
Ok(())
}
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
let bad_event_ratelimiter = self let bad_event_ratelimiter = self
.bad_event_ratelimiter .bad_event_ratelimiter

View File

@ -1,7 +1,7 @@
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use conduit::{utils, Error, Result}; use conduit::{utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{ use ruma::{
api::client::{ api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
@ -11,25 +11,34 @@ use ruma::{
OwnedRoomId, RoomId, UserId, OwnedRoomId, RoomId, UserId,
}; };
use crate::services; use crate::{globals, Dep};
pub(super) struct Data { pub(super) struct Data {
backupid_algorithm: Arc<Map>, backupid_algorithm: Arc<Map>,
backupid_etag: Arc<Map>, backupid_etag: Arc<Map>,
backupkeyid_backup: Arc<Map>, backupkeyid_backup: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
backupid_algorithm: db["backupid_algorithm"].clone(), backupid_algorithm: db["backupid_algorithm"].clone(),
backupid_etag: db["backupid_etag"].clone(), backupid_etag: db["backupid_etag"].clone(),
backupkeyid_backup: db["backupkeyid_backup"].clone(), backupkeyid_backup: db["backupkeyid_backup"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string(); let version = self.services.globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
@ -40,7 +49,7 @@ impl Data {
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version) Ok(version)
} }
@ -75,7 +84,7 @@ impl Data {
self.backupid_algorithm self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?; .insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned()) Ok(version.to_owned())
} }
@ -152,7 +161,7 @@ impl Data {
} }
self.backupid_etag self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());

View File

@ -17,7 +17,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -8,13 +8,13 @@ use tokio::{
time::sleep, time::sleep,
}; };
use crate::{service::Service, Services}; use crate::{service, service::Service, Services};
pub(crate) struct Manager { pub(crate) struct Manager {
manager: Mutex<Option<JoinHandle<Result<()>>>>, manager: Mutex<Option<JoinHandle<Result<()>>>>,
workers: Mutex<Workers>, workers: Mutex<Workers>,
server: Arc<Server>, server: Arc<Server>,
services: &'static Services, service: Arc<service::Map>,
} }
type Workers = JoinSet<WorkerResult>; type Workers = JoinSet<WorkerResult>;
@ -29,7 +29,7 @@ impl Manager {
manager: Mutex::new(None), manager: Mutex::new(None),
workers: Mutex::new(JoinSet::new()), workers: Mutex::new(JoinSet::new()),
server: services.server.clone(), server: services.server.clone(),
services: crate::services(), service: services.service.clone(),
}) })
} }
@ -53,9 +53,19 @@ impl Manager {
.spawn(async move { self_.worker().await }), .spawn(async move { self_.worker().await }),
); );
// we can't hold the lock during the iteration with start_worker so the values
// are snapshotted here
let services: Vec<Arc<dyn Service>> = self
.service
.read()
.expect("locked for reading")
.values()
.map(|v| v.0.clone())
.collect();
debug!("Starting service workers..."); debug!("Starting service workers...");
for (service, ..) in self.services.service.values() { for service in services {
self.start_worker(&mut workers, service).await?; self.start_worker(&mut workers, &service).await?;
} }
Ok(()) Ok(())

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{debug, debug_info, Error, Result}; use conduit::{debug, debug_info, utils::string_from_bytes, Error, Result};
use database::{Database, Map}; use database::{Database, Map};
use ruma::api::client::error::ErrorKind; use ruma::api::client::error::ErrorKind;
use crate::{media::UrlPreviewData, utils::string_from_bytes}; use crate::media::UrlPreviewData;
pub(crate) struct Data { pub(crate) struct Data {
mediaid_file: Arc<Map>, mediaid_file: Arc<Map>,

View File

@ -15,7 +15,7 @@ use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufReader}, io::{AsyncReadExt, AsyncWriteExt, BufReader},
}; };
use crate::services; use crate::{globals, Dep};
#[derive(Debug)] #[derive(Debug)]
pub struct FileMeta { pub struct FileMeta {
@ -41,16 +41,24 @@ pub struct UrlPreviewData {
} }
pub struct Service { pub struct Service {
server: Arc<Server>, services: Services,
pub(crate) db: Data, pub(crate) db: Data,
pub url_preview_mutex: MutexMap<String, ()>, pub url_preview_mutex: MutexMap<String, ()>,
} }
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
}
#[async_trait] #[async_trait]
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
server: args.server.clone(), services: Services {
server: args.server.clone(),
globals: args.depend::<globals::Service>("globals"),
},
db: Data::new(args.db), db: Data::new(args.db),
url_preview_mutex: MutexMap::new(), url_preview_mutex: MutexMap::new(),
})) }))
@ -164,7 +172,7 @@ impl Service {
debug!("Parsed MXC key to URL: {mxc_s}"); debug!("Parsed MXC key to URL: {mxc_s}");
let mxc = OwnedMxcUri::from(mxc_s); let mxc = OwnedMxcUri::from(mxc_s);
if mxc.server_name() == Ok(services().globals.server_name()) { if mxc.server_name() == Ok(self.services.globals.server_name()) {
debug!("Ignoring local media MXC: {mxc}"); debug!("Ignoring local media MXC: {mxc}");
// ignore our own MXC URLs as this would be local media. // ignore our own MXC URLs as this would be local media.
continue; continue;
@ -246,7 +254,7 @@ impl Service {
let legacy_rm = fs::remove_file(&legacy); let legacy_rm = fs::remove_file(&legacy);
let (file_rm, legacy_rm) = tokio::join!(file_rm, legacy_rm); let (file_rm, legacy_rm) = tokio::join!(file_rm, legacy_rm);
if let Err(e) = legacy_rm { if let Err(e) = legacy_rm {
if self.server.config.media_compat_file_link { if self.services.server.config.media_compat_file_link {
debug_error!(?key, ?legacy, "Failed to remove legacy media symlink: {e}"); debug_error!(?key, ?legacy, "Failed to remove legacy media symlink: {e}");
} }
} }
@ -259,7 +267,7 @@ impl Service {
debug!(?key, ?path, "Creating media file"); debug!(?key, ?path, "Creating media file");
let file = fs::File::create(&path).await?; let file = fs::File::create(&path).await?;
if self.server.config.media_compat_file_link { if self.services.server.config.media_compat_file_link {
let legacy = self.get_media_file_b64(key); let legacy = self.get_media_file_b64(key);
if let Err(e) = fs::symlink(&path, &legacy).await { if let Err(e) = fs::symlink(&path, &legacy).await {
debug_error!( debug_error!(
@ -304,7 +312,7 @@ impl Service {
#[must_use] #[must_use]
pub fn get_media_dir(&self) -> PathBuf { pub fn get_media_dir(&self) -> PathBuf {
let mut r = PathBuf::new(); let mut r = PathBuf::new();
r.push(self.server.config.database_path.clone()); r.push(self.services.server.config.database_path.clone());
r.push("media"); r.push("media");
r r
} }

View File

@ -1,3 +1,4 @@
#![recursion_limit = "160"]
#![allow(refining_impl_trait)] #![allow(refining_impl_trait)]
mod manager; mod manager;
@ -8,6 +9,7 @@ pub mod account_data;
pub mod admin; pub mod admin;
pub mod appservice; pub mod appservice;
pub mod client; pub mod client;
pub mod emergency;
pub mod globals; pub mod globals;
pub mod key_backups; pub mod key_backups;
pub mod media; pub mod media;
@ -26,8 +28,8 @@ extern crate conduit_database as database;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server};
pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent};
use conduit::{Result, Server};
use database::Database; use database::Database;
pub(crate) use service::{Args, Dep, Service}; pub(crate) use service::{Args, Dep, Service};

View File

@ -1,21 +1,32 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{debug_warn, utils, Error, Result}; use conduit::{debug_warn, utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{presence::Presence, services}; use crate::{globals, presence::Presence, users, Dep};
pub struct Data { pub struct Data {
presenceid_presence: Arc<Map>, presenceid_presence: Arc<Map>,
userid_presenceid: Arc<Map>, userid_presenceid: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
presenceid_presence: db["presenceid_presence"].clone(), presenceid_presence: db["presenceid_presence"].clone(),
userid_presenceid: db["userid_presenceid"].clone(), userid_presenceid: db["userid_presenceid"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
} }
} }
@ -28,7 +39,10 @@ impl Data {
self.presenceid_presence self.presenceid_presence
.get(&key)? .get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> { .map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?)) Ok((
count,
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?,
))
}) })
.transpose() .transpose()
} else { } else {
@ -80,7 +94,7 @@ impl Data {
last_active_ts, last_active_ts,
status_msg, status_msg,
); );
let count = services().globals.next_count()?; let count = self.services.globals.next_count()?;
let key = presenceid_key(count, user_id); let key = presenceid_key(count, user_id);
self.presenceid_presence self.presenceid_presence

View File

@ -3,8 +3,7 @@ mod data;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use async_trait::async_trait; use async_trait::async_trait;
use conduit::{checked, debug, error, utils, Error, Result}; use conduit::{checked, debug, error, utils, Error, Result, Server};
use data::Data;
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
events::presence::{PresenceEvent, PresenceEventContent}, events::presence::{PresenceEvent, PresenceEventContent},
@ -14,7 +13,8 @@ use ruma::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{sync::Mutex, time::sleep}; use tokio::{sync::Mutex, time::sleep};
use crate::{services, user_is_local}; use self::data::Data;
use crate::{user_is_local, users, Dep};
/// Represents data required to be kept in order to implement the presence /// Represents data required to be kept in order to implement the presence
/// specification. /// specification.
@ -37,11 +37,6 @@ impl Presence {
} }
} }
pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Self::from_json_bytes(bytes)?;
presence.to_presence_event(user_id)
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> { pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
} }
@ -51,7 +46,7 @@ impl Presence {
} }
/// Creates a PresenceEvent from available data. /// Creates a PresenceEvent from available data.
pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> { pub fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> Result<PresenceEvent> {
let now = utils::millis_since_unix_epoch(); let now = utils::millis_since_unix_epoch();
let last_active_ago = if self.currently_active { let last_active_ago = if self.currently_active {
None None
@ -66,14 +61,15 @@ impl Presence {
status_msg: self.status_msg.clone(), status_msg: self.status_msg.clone(),
currently_active: Some(self.currently_active), currently_active: Some(self.currently_active),
last_active_ago, last_active_ago,
displayname: services().users.displayname(user_id)?, displayname: users.displayname(user_id)?,
avatar_url: services().users.avatar_url(user_id)?, avatar_url: users.avatar_url(user_id)?,
}, },
}) })
} }
} }
pub struct Service { pub struct Service {
services: Services,
pub db: Data, pub db: Data,
pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>,
timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>, timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>,
@ -82,6 +78,11 @@ pub struct Service {
offline_timeout: u64, offline_timeout: u64,
} }
struct Services {
server: Arc<Server>,
users: Dep<users::Service>,
}
#[async_trait] #[async_trait]
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -90,7 +91,11 @@ impl crate::Service for Service {
let offline_timeout_s = config.presence_offline_timeout_s; let offline_timeout_s = config.presence_offline_timeout_s;
let (timer_sender, timer_receiver) = loole::unbounded(); let (timer_sender, timer_receiver) = loole::unbounded();
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
server: args.server.clone(),
users: args.depend::<users::Service>("users"),
},
db: Data::new(&args),
timer_sender, timer_sender,
timer_receiver: Mutex::new(timer_receiver), timer_receiver: Mutex::new(timer_receiver),
timeout_remote_users: config.presence_timeout_remote_users, timeout_remote_users: config.presence_timeout_remote_users,
@ -182,8 +187,8 @@ impl Service {
if self.timeout_remote_users || user_is_local(user_id) { if self.timeout_remote_users || user_is_local(user_id) {
let timeout = match presence_state { let timeout = match presence_state {
PresenceState::Online => services().globals.config.presence_idle_timeout_s, PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
_ => services().globals.config.presence_offline_timeout_s, _ => self.services.server.config.presence_offline_timeout_s,
}; };
self.timer_sender self.timer_sender
@ -210,6 +215,11 @@ impl Service {
self.db.presence_since(since) self.db.presence_since(since)
} }
pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Presence::from_json_bytes(bytes)?;
presence.to_presence_event(user_id, &self.services.users)
}
fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> {
let mut presence_state = PresenceState::Offline; let mut presence_state = PresenceState::Offline;
let mut last_active_ago = None; let mut last_active_ago = None;

View File

@ -3,8 +3,7 @@ mod data;
use std::{fmt::Debug, mem, sync::Arc}; use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut; use bytes::BytesMut;
use conduit::{debug_info, info, trace, warn, Error, Result}; use conduit::{debug_info, info, trace, utils::string_from_bytes, warn, Error, PduEvent, Result};
use data::Data;
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::{ use ruma::{
api::{ api::{
@ -23,15 +22,32 @@ use ruma::{
uint, RoomId, UInt, UserId, uint, RoomId, UInt, UserId,
}; };
use crate::{services, PduEvent}; use self::data::Data;
use crate::{client, globals, rooms, users, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
globals: Dep<globals::Service>,
client: Dep<client::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db), db: Data::new(args.db),
})) }))
} }
@ -62,7 +78,7 @@ impl Service {
{ {
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0];
let dest = dest.replace(services().globals.notification_push_path(), ""); let dest = dest.replace(self.services.globals.notification_push_path(), "");
trace!("Push gateway destination: {dest}"); trace!("Push gateway destination: {dest}");
let http_request = request let http_request = request
@ -78,13 +94,13 @@ impl Service {
if let Some(url_host) = reqwest_request.url().host_str() { if let Some(url_host) = reqwest_request.url().host_str() {
trace!("Checking request URL for IP"); trace!("Checking request URL for IP");
if let Ok(ip) = IPAddress::parse(url_host) { if let Ok(ip) = IPAddress::parse(url_host) {
if !services().globals.valid_cidr_range(&ip) { if !self.services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
} }
} }
} }
let response = services().client.pusher.execute(reqwest_request).await; let response = self.services.client.pusher.execute(reqwest_request).await;
match response { match response {
Ok(mut response) => { Ok(mut response) => {
@ -93,7 +109,7 @@ impl Service {
trace!("Checking response destination's IP"); trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() { if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !services().globals.valid_cidr_range(&ip) { if !self.services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
} }
} }
@ -114,7 +130,7 @@ impl Service {
if !status.is_success() { if !status.is_success() {
info!("Push gateway {dest} returned unsuccessful HTTP response ({status})"); info!("Push gateway {dest} returned unsuccessful HTTP response ({status})");
debug_info!("Push gateway response body: {:?}", crate::utils::string_from_bytes(&body)); debug_info!("Push gateway response body: {:?}", string_from_bytes(&body));
return Err(Error::BadServerResponse("Push gateway returned unsuccessful response")); return Err(Error::BadServerResponse("Push gateway returned unsuccessful response"));
} }
@ -143,8 +159,8 @@ impl Service {
let mut notify = None; let mut notify = None;
let mut tweaks = Vec::new(); let mut tweaks = Vec::new();
let power_levels: RoomPowerLevelsEventContent = services() let power_levels: RoomPowerLevelsEventContent = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
@ -195,15 +211,15 @@ impl Service {
let ctx = PushConditionRoomCtx { let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
member_count: UInt::try_from( member_count: UInt::try_from(
services() self.services
.rooms
.state_cache .state_cache
.room_joined_count(room_id)? .room_joined_count(room_id)?
.unwrap_or(1), .unwrap_or(1),
) )
.unwrap_or_else(|_| uint!(0)), .unwrap_or_else(|_| uint!(0)),
user_id: user.to_owned(), user_id: user.to_owned(),
user_display_name: services() user_display_name: self
.services
.users .users
.displayname(user)? .displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()), .unwrap_or_else(|| user.localpart().to_owned()),
@ -263,9 +279,9 @@ impl Service {
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
} }
notifi.sender_display_name = services().users.displayname(&event.sender)?; notifi.sender_display_name = self.services.users.displayname(&event.sender)?;
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?; .await?;

View File

@ -9,12 +9,9 @@ use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::ServerName; use ruma::ServerName;
use crate::{ use crate::resolver::{
resolver::{ cache::{CachedDest, CachedOverride},
cache::{CachedDest, CachedOverride}, fed::{add_port_to_hostname, get_ip_with_port, FedDest},
fed::{add_port_to_hostname, get_ip_with_port, FedDest},
},
services,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -40,7 +37,7 @@ impl super::Service {
result result
} else { } else {
cached = false; cached = false;
validate_dest(server_name)?; self.validate_dest(server_name)?;
self.resolve_actual_dest(server_name, true).await? self.resolve_actual_dest(server_name, true).await?
}; };
@ -188,7 +185,8 @@ impl super::Service {
self.query_and_cache_override(dest, dest, 8448).await?; self.query_and_cache_override(dest, dest, 8448).await?;
} }
let response = services() let response = self
.services
.client .client
.well_known .well_known
.get(&format!("https://{dest}/.well-known/matrix/server")) .get(&format!("https://{dest}/.well-known/matrix/server"))
@ -245,19 +243,14 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "ip")] #[tracing::instrument(skip_all, name = "ip")]
async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
match services() match self.raw().lookup_ip(hostname.to_owned()).await {
.resolver Err(e) => Self::handle_resolve_error(&e),
.raw()
.lookup_ip(hostname.to_owned())
.await
{
Err(e) => handle_resolve_error(&e),
Ok(override_ip) => { Ok(override_ip) => {
if hostname != overname { if hostname != overname {
debug_info!("{overname:?} overriden by {hostname:?}"); debug_info!("{overname:?} overriden by {hostname:?}");
} }
services().resolver.set_cached_override( self.set_cached_override(
overname.to_owned(), overname.to_owned(),
CachedOverride { CachedOverride {
ips: override_ip.iter().collect(), ips: override_ip.iter().collect(),
@ -295,62 +288,62 @@ impl super::Service {
for hostname in hostnames { for hostname in hostnames {
match lookup_srv(self.raw(), &hostname).await { match lookup_srv(self.raw(), &hostname).await {
Ok(result) => return Ok(handle_successful_srv(&result)), Ok(result) => return Ok(handle_successful_srv(&result)),
Err(e) => handle_resolve_error(&e)?, Err(e) => Self::handle_resolve_error(&e)?,
} }
} }
Ok(None) Ok(None)
} }
}
#[allow(clippy::single_match_else)] #[allow(clippy::single_match_else)]
fn handle_resolve_error(e: &ResolveError) -> Result<()> { fn handle_resolve_error(e: &ResolveError) -> Result<()> {
use hickory_resolver::error::ResolveErrorKind; use hickory_resolver::error::ResolveErrorKind;
match *e.kind() { match *e.kind() {
ResolveErrorKind::NoRecordsFound { ResolveErrorKind::NoRecordsFound {
.. ..
} => { } => {
// Raise to debug_warn if we can find out the result wasn't from cache // Raise to debug_warn if we can find out the result wasn't from cache
debug!("{e}"); debug!("{e}");
Ok(()) Ok(())
}, },
_ => Err!(error!("DNS {e}")), _ => Err!(error!("DNS {e}")),
} }
}
fn validate_dest(dest: &ServerName) -> Result<()> {
if dest == services().globals.server_name() {
return Err!("Won't send federation request to ourselves");
} }
if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { fn validate_dest(&self, dest: &ServerName) -> Result<()> {
validate_dest_ip_literal(dest)?; if dest == self.services.server.config.server_name {
return Err!("Won't send federation request to ourselves");
}
if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) {
self.validate_dest_ip_literal(dest)?;
}
Ok(())
} }
Ok(()) fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result<()> {
} trace!("Destination is an IP literal, checking against IP range denylist.",);
debug_assert!(
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
"Destination is not an IP literal."
);
let ip = IPAddress::parse(dest.host()).map_err(|e| {
debug_error!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { self.validate_ip(&ip)?;
trace!("Destination is an IP literal, checking against IP range denylist.",);
debug_assert!(
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
"Destination is not an IP literal."
);
let ip = IPAddress::parse(dest.host()).map_err(|e| {
debug_error!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
validate_ip(&ip)?; Ok(())
Ok(())
}
pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> {
if !services().globals.valid_cidr_range(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
} }
Ok(()) pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> {
if !self.services.globals.valid_cidr_range(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
Ok(())
}
} }

View File

@ -5,11 +5,10 @@ use std::{
time::SystemTime, time::SystemTime,
}; };
use conduit::trace; use conduit::{trace, utils::rand};
use ruma::{OwnedServerName, ServerName}; use ruma::{OwnedServerName, ServerName};
use super::fed::FedDest; use super::fed::FedDest;
use crate::utils::rand;
pub struct Cache { pub struct Cache {
pub destinations: RwLock<WellKnownMap>, // actual_destination, host pub destinations: RwLock<WellKnownMap>, // actual_destination, host

View File

@ -6,14 +6,22 @@ mod tests;
use std::{fmt::Write, sync::Arc}; use std::{fmt::Write, sync::Arc};
use conduit::Result; use conduit::{Result, Server};
use hickory_resolver::TokioAsyncResolver; use hickory_resolver::TokioAsyncResolver;
use self::{cache::Cache, dns::Resolver}; use self::{cache::Cache, dns::Resolver};
use crate::{client, globals, Dep};
pub struct Service { pub struct Service {
pub cache: Arc<Cache>, pub cache: Arc<Cache>,
pub resolver: Arc<Resolver>, pub resolver: Arc<Resolver>,
services: Services,
}
struct Services {
server: Arc<Server>,
client: Dep<client::Service>,
globals: Dep<globals::Service>,
} }
impl crate::Service for Service { impl crate::Service for Service {
@ -23,6 +31,11 @@ impl crate::Service for Service {
Ok(Arc::new(Self { Ok(Arc::new(Self {
cache: cache.clone(), cache: cache.clone(),
resolver: Resolver::build(args.server, cache)?, resolver: Resolver::build(args.server, cache)?,
services: Services {
server: args.server.clone(),
client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
},
})) }))
} }

View File

@ -1,23 +1,32 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{utils, Error, Result}; use conduit::{utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
use crate::services; use crate::{globals, Dep};
pub(super) struct Data { pub(super) struct Data {
alias_userid: Arc<Map>, alias_userid: Arc<Map>,
alias_roomid: Arc<Map>, alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>, aliasid_alias: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
alias_userid: db["alias_userid"].clone(), alias_userid: db["alias_userid"].clone(),
alias_roomid: db["alias_roomid"].clone(), alias_roomid: db["alias_roomid"].clone(),
aliasid_alias: db["aliasid_alias"].clone(), aliasid_alias: db["aliasid_alias"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
@ -31,7 +40,7 @@ impl Data {
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF); aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(()) Ok(())

View File

@ -4,9 +4,8 @@ mod remote;
use std::sync::Arc; use std::sync::Arc;
use conduit::{err, Error, Result}; use conduit::{err, Error, Result};
use data::Data;
use ruma::{ use ruma::{
api::{appservice, client::error::ErrorKind}, api::client::error::ErrorKind,
events::{ events::{
room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
StateEventType, StateEventType,
@ -14,16 +13,33 @@ use ruma::{
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId,
}; };
use crate::{appservice::RegistrationInfo, server_is_ours, services}; use self::data::Data;
use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, server_is_ours, Dep};
pub struct Service { pub struct Service {
db: Data, db: Data,
services: Services,
}
struct Services {
admin: Dep<admin::Service>,
appservice: Dep<appservice::Service>,
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
} }
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
services: Services {
admin: args.depend::<admin::Service>("admin"),
appservice: args.depend::<appservice::Service>("appservice"),
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
})) }))
} }
@ -33,7 +49,7 @@ impl crate::Service for Service {
impl Service { impl Service {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == services().globals.admin_alias && user_id != services().globals.server_user { if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
Err(Error::BadRequest( Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"Only the server user can set this alias", "Only the server user can set this alias",
@ -72,10 +88,10 @@ impl Service {
if !server_is_ours(room_alias.server_name()) if !server_is_ours(room_alias.server_name())
&& (!servers && (!servers
.as_ref() .as_ref()
.is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned())) .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned()))
|| servers.as_ref().is_none()) || servers.as_ref().is_none())
{ {
return remote::resolve(room_alias, servers).await; return self.remote_resolve(room_alias, servers).await;
} }
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? { let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? {
@ -111,7 +127,7 @@ impl Service {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found."));
}; };
let server_user = &services().globals.server_user; let server_user = &self.services.globals.server_user;
// The creator of an alias can remove it // The creator of an alias can remove it
if self if self
@ -119,7 +135,7 @@ impl Service {
.who_created_alias(alias)? .who_created_alias(alias)?
.is_some_and(|user| user == user_id) .is_some_and(|user| user == user_id)
// Server admins can remove any local alias // Server admins can remove any local alias
|| services().admin.user_is_admin(user_id).await? || self.services.admin.user_is_admin(user_id).await?
// Always allow the server service account to remove the alias, since there may not be an admin room // Always allow the server service account to remove the alias, since there may not be an admin room
|| server_user == user_id || server_user == user_id
{ {
@ -127,8 +143,7 @@ impl Service {
// Checking whether the user is able to change canonical aliases of the // Checking whether the user is able to change canonical aliases of the
// room // room
} else if let Some(event) = } else if let Some(event) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
{ {
@ -140,8 +155,7 @@ impl Service {
// If there is no power levels event, only the room creator can change // If there is no power levels event, only the room creator can change
// canonical aliases // canonical aliases
} else if let Some(event) = } else if let Some(event) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")? .room_state_get(&room_id, &StateEventType::RoomCreate, "")?
{ {
@ -152,14 +166,16 @@ impl Service {
} }
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
for appservice in services().appservice.read().await.values() { use ruma::api::appservice::query::query_room_alias;
for appservice in self.services.appservice.read().await.values() {
if appservice.aliases.is_match(room_alias.as_str()) if appservice.aliases.is_match(room_alias.as_str())
&& matches!( && matches!(
services() self.services
.sending .sending
.send_appservice_request( .send_appservice_request(
appservice.registration.clone(), appservice.registration.clone(),
appservice::query::query_room_alias::v1::Request { query_room_alias::v1::Request {
room_alias: room_alias.to_owned(), room_alias: room_alias.to_owned(),
}, },
) )
@ -167,10 +183,7 @@ impl Service {
Ok(Some(_opt_result)) Ok(Some(_opt_result))
) { ) {
return Ok(Some( return Ok(Some(
services() self.resolve_local_alias(room_alias)?
.rooms
.alias
.resolve_local_alias(room_alias)?
.ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?,
)); ));
} }
@ -178,20 +191,27 @@ impl Service {
Ok(None) Ok(None)
} }
}
pub async fn appservice_checks(room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>) -> Result<()> { pub async fn appservice_checks(
if !server_is_ours(room_alias.server_name()) { &self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>,
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); ) -> Result<()> {
} if !server_is_ours(room_alias.server_name()) {
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
} }
} else if services().appservice.is_exclusive_alias(room_alias).await {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
}
Ok(()) if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
}
} else if self
.services
.appservice
.is_exclusive_alias(room_alias)
.await
{
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
}
Ok(())
}
} }

View File

@ -1,71 +1,75 @@
use conduit::{debug, debug_info, debug_warn, Error, Result}; use conduit::{debug, debug_warn, Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation}, api::{client::error::ErrorKind, federation},
OwnedRoomId, OwnedServerName, RoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId,
}; };
use crate::services; impl super::Service {
pub(super) async fn remote_resolve(
&self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> {
debug!(?room_alias, ?servers, "resolve");
pub(super) async fn resolve( let mut response = self
room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, .services
) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { .sending
debug!(?room_alias, ?servers, "resolve"); .send_federation_request(
room_alias.server_name(),
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
let mut response = services() debug!("room alias server_name get_alias_helper response: {response:?}");
.sending
.send_federation_request(
room_alias.server_name(),
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("room alias server_name get_alias_helper response: {response:?}"); if let Err(ref e) = response {
debug_warn!(
"Server {} of the original room alias failed to assist in resolving room alias: {e}",
room_alias.server_name(),
);
}
if let Err(ref e) = response { if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() {
debug_info!( if let Some(servers) = servers {
"Server {} of the original room alias failed to assist in resolving room alias: {e}", for server in servers {
room_alias.server_name() response = self
); .services
} .sending
.send_federation_request(
server,
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("Got response from server {server} for room aliases: {response:?}");
if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { if let Ok(ref response) = response {
if let Some(servers) = servers { if !response.servers.is_empty() {
for server in servers { break;
response = services() }
.sending debug_warn!(
.send_federation_request( "Server {server} responded with room aliases, but was empty? Response: {response:?}"
server, );
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("Got response from server {server} for room aliases: {response:?}");
if let Ok(ref response) = response {
if !response.servers.is_empty() {
break;
} }
debug_warn!("Server {server} responded with room aliases, but was empty? Response: {response:?}");
} }
} }
} }
if let Ok(response) = response {
let room_id = response.room_id;
let mut pre_servers = response.servers;
// since the room alis server responded, insert it into the list
pre_servers.push(room_alias.server_name().into());
return Ok((room_id, Some(pre_servers)));
}
Err(Error::BadRequest(
ErrorKind::NotFound,
"No servers could assist in resolving the room alias",
))
} }
if let Ok(response) = response {
let room_id = response.room_id;
let mut pre_servers = response.servers;
// since the room alis server responded, insert it into the list
pre_servers.push(room_alias.server_name().into());
return Ok((room_id, Some(pre_servers)));
}
Err(Error::BadRequest(
ErrorKind::NotFound,
"No servers could assist in resolving the room alias",
))
} }

View File

@ -3,8 +3,8 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use conduit::{utils, utils::math::usize_from_f64, Result, Server}; use conduit::{utils, utils::math::usize_from_f64, Result};
use database::{Database, Map}; use database::Map;
use lru_cache::LruCache; use lru_cache::LruCache;
pub(super) struct Data { pub(super) struct Data {
@ -13,8 +13,9 @@ pub(super) struct Data {
} }
impl Data { impl Data {
pub(super) fn new(server: &Arc<Server>, db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let config = &server.config; let db = &args.db;
let config = &args.server.config;
let cache_size = f64::from(config.auth_chain_cache_capacity); let cache_size = f64::from(config.auth_chain_cache_capacity);
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size"); let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size");
Self { Self {

View File

@ -6,19 +6,29 @@ use std::{
}; };
use conduit::{debug, error, trace, validated, warn, Err, Result}; use conduit::{debug, error, trace, validated, warn, Err, Result};
use data::Data;
use ruma::{EventId, RoomId}; use ruma::{EventId, RoomId};
use crate::services; use self::data::Data;
use crate::{rooms, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
short: Dep<rooms::short::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.server, args.db), services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
})) }))
} }
@ -27,7 +37,7 @@ impl crate::Service for Service {
impl Service { impl Service {
pub async fn event_ids_iter<'a>( pub async fn event_ids_iter<'a>(
&self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, &'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
for starting_event in &starting_events_ { for starting_event in &starting_events_ {
@ -38,7 +48,7 @@ impl Service {
.get_auth_chain(room_id, &starting_events) .get_auth_chain(room_id, &starting_events)
.await? .await?
.into_iter() .into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok()))
} }
#[tracing::instrument(skip_all, name = "auth_chain")] #[tracing::instrument(skip_all, name = "auth_chain")]
@ -48,8 +58,8 @@ impl Service {
let started = std::time::Instant::now(); let started = std::time::Instant::now();
let mut buckets = [BUCKET; NUM_BUCKETS]; let mut buckets = [BUCKET; NUM_BUCKETS];
for (i, &short) in services() for (i, &short) in self
.rooms .services
.short .short
.multi_get_or_create_shorteventid(starting_events)? .multi_get_or_create_shorteventid(starting_events)?
.iter() .iter()
@ -140,7 +150,7 @@ impl Service {
while let Some(event_id) = todo.pop() { while let Some(event_id) = todo.pop() {
trace!(?event_id, "processing auth event"); trace!(?event_id, "processing auth event");
match services().rooms.timeline.get_pdu(&event_id) { match self.services.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => { Ok(Some(pdu)) => {
if pdu.room_id != room_id { if pdu.room_id != room_id {
return Err!(Request(Forbidden( return Err!(Request(Forbidden(
@ -150,10 +160,7 @@ impl Service {
))); )));
} }
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
let sauthevent = services() let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?;
.rooms
.short
.get_or_create_shorteventid(auth_event)?;
if found.insert(sauthevent) { if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue"); trace!(?event_id, ?auth_event, "adding auth event to processing queue");

View File

@ -2,10 +2,10 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use data::Data; use conduit::Result;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::Result; use self::data::Data;
pub struct Service { pub struct Service {
db: Data, db: Data,

View File

@ -10,12 +10,11 @@ use std::{
}; };
use conduit::{ use conduit::{
debug, debug_error, debug_info, err, error, info, trace, debug, debug_error, debug_info, err, error, info, pdu, trace,
utils::{math::continue_exponential_backoff_secs, MutexMap}, utils::{math::continue_exponential_backoff_secs, MutexMap},
warn, Error, Result, warn, Error, PduEvent, Result,
}; };
use futures_util::Future; use futures_util::Future;
pub use parse_incoming_pdu::parse_incoming_pdu;
use ruma::{ use ruma::{
api::{ api::{
client::error::ErrorKind, client::error::ErrorKind,
@ -36,13 +35,28 @@ use ruma::{
use tokio::sync::RwLock; use tokio::sync::RwLock;
use super::state_compressor::CompressedStateEvent; use super::state_compressor::CompressedStateEvent;
use crate::{pdu, services, PduEvent}; use crate::{globals, rooms, sending, Dep};
pub struct Service { pub struct Service {
services: Services,
pub federation_handletime: StdRwLock<HandleTimeMap>, pub federation_handletime: StdRwLock<HandleTimeMap>,
pub mutex_federation: RoomMutexMap, pub mutex_federation: RoomMutexMap,
} }
struct Services {
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
auth_chain: Dep<rooms::auth_chain::Service>,
metadata: Dep<rooms::metadata::Service>,
outlier: Dep<rooms::outlier::Service>,
pdu_metadata: Dep<rooms::pdu_metadata::Service>,
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>; type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>; type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>;
@ -55,8 +69,21 @@ type AsyncRecursiveCanonicalJsonResult<'a> =
AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>; AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>;
impl crate::Service for Service { impl crate::Service for Service {
fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
auth_chain: args.depend::<rooms::auth_chain::Service>("rooms::auth_chain"),
metadata: args.depend::<rooms::metadata::Service>("rooms::metadata"),
outlier: args.depend::<rooms::outlier::Service>("rooms::outlier"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
federation_handletime: HandleTimeMap::new().into(), federation_handletime: HandleTimeMap::new().into(),
mutex_federation: RoomMutexMap::new(), mutex_federation: RoomMutexMap::new(),
})) }))
@ -114,17 +141,17 @@ impl Service {
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
// 1. Skip the PDU if we already have it as a timeline event // 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? {
return Ok(Some(pdu_id.to_vec())); return Ok(Some(pdu_id.to_vec()));
} }
// 1.1 Check the server is in the room // 1.1 Check the server is in the room
if !services().rooms.metadata.exists(room_id)? { if !self.services.metadata.exists(room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"));
} }
// 1.2 Check if the room is disabled // 1.2 Check if the room is disabled
if services().rooms.metadata.is_disabled(room_id)? { if self.services.metadata.is_disabled(room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"Federation of this room is currently disabled on this server.", "Federation of this room is currently disabled on this server.",
@ -147,8 +174,8 @@ impl Service {
self.acl_check(sender.server_name(), room_id)?; self.acl_check(sender.server_name(), room_id)?;
// Fetch create event // Fetch create event
let create_event = services() let create_event = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?;
@ -156,8 +183,8 @@ impl Service {
// Procure the room version // Procure the room version
let room_version_id = Self::get_room_version_id(&create_event)?; let room_version_id = Self::get_room_version_id(&create_event)?;
let first_pdu_in_room = services() let first_pdu_in_room = self
.rooms .services
.timeline .timeline
.first_pdu_in_room(room_id)? .first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
@ -208,7 +235,8 @@ impl Service {
Ok(()) => continue, Ok(()) => continue,
Err(e) => { Err(e) => {
warn!("Prev event {} failed: {}", prev_id, e); warn!("Prev event {} failed: {}", prev_id, e);
match services() match self
.services
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.write() .write()
@ -258,7 +286,7 @@ impl Service {
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId, create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
) -> Result<()> { ) -> Result<()> {
// Check for disabled again because it might have changed // Check for disabled again because it might have changed
if services().rooms.metadata.is_disabled(room_id)? { if self.services.metadata.is_disabled(room_id)? {
debug!( debug!(
"Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \
event ID {event_id}" event ID {event_id}"
@ -269,7 +297,8 @@ impl Service {
)); ));
} }
if let Some((time, tries)) = services() if let Some((time, tries)) = self
.services
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
@ -349,7 +378,7 @@ impl Service {
}; };
// Skip the PDU if it is redacted and we already have it as an outlier event // Skip the PDU if it is redacted and we already have it as an outlier event
if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { if self.services.timeline.get_pdu_json(event_id)?.is_some() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Event was redacted and we already knew about it", "Event was redacted and we already knew about it",
@ -401,7 +430,7 @@ impl Service {
// Build map of auth events // Build map of auth events
let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len());
for id in &incoming_pdu.auth_events { for id in &incoming_pdu.auth_events {
let Some(auth_event) = services().rooms.timeline.get_pdu(id)? else { let Some(auth_event) = self.services.timeline.get_pdu(id)? else {
warn!("Could not find auth event {}", id); warn!("Could not find auth event {}", id);
continue; continue;
}; };
@ -454,8 +483,7 @@ impl Service {
trace!("Validation successful."); trace!("Validation successful.");
// 7. Persist the event as an outlier. // 7. Persist the event as an outlier.
services() self.services
.rooms
.outlier .outlier
.add_pdu_outlier(&incoming_pdu.event_id, &val)?; .add_pdu_outlier(&incoming_pdu.event_id, &val)?;
@ -470,12 +498,12 @@ impl Service {
origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
// Skip the PDU if we already have it as a timeline event // Skip the PDU if we already have it as a timeline event
if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) {
return Ok(Some(pduid.to_vec())); return Ok(Some(pduid.to_vec()));
} }
if services() if self
.rooms .services
.pdu_metadata .pdu_metadata
.is_event_soft_failed(&incoming_pdu.event_id)? .is_event_soft_failed(&incoming_pdu.event_id)?
{ {
@ -521,14 +549,13 @@ impl Service {
&incoming_pdu, &incoming_pdu,
None::<PduEvent>, // TODO: third party invite None::<PduEvent>, // TODO: third party invite
|k, s| { |k, s| {
services() self.services
.rooms
.short .short
.get_shortstatekey(&k.to_string().into(), s) .get_shortstatekey(&k.to_string().into(), s)
.ok() .ok()
.flatten() .flatten()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
.and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten())
}, },
) )
.map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?;
@ -541,7 +568,7 @@ impl Service {
} }
debug!("Gathering auth events"); debug!("Gathering auth events");
let auth_events = services().rooms.state.get_auth_events( let auth_events = self.services.state.get_auth_events(
room_id, room_id,
&incoming_pdu.kind, &incoming_pdu.kind,
&incoming_pdu.sender, &incoming_pdu.sender,
@ -562,7 +589,7 @@ impl Service {
&& match room_version_id { && match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &incoming_pdu.redacts { if let Some(redact_id) = &incoming_pdu.redacts {
!services().rooms.state_accessor.user_can_redact( !self.services.state_accessor.user_can_redact(
redact_id, redact_id,
&incoming_pdu.sender, &incoming_pdu.sender,
&incoming_pdu.room_id, &incoming_pdu.room_id,
@ -577,7 +604,7 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?;
if let Some(redact_id) = &content.redacts { if let Some(redact_id) = &content.redacts {
!services().rooms.state_accessor.user_can_redact( !self.services.state_accessor.user_can_redact(
redact_id, redact_id,
&incoming_pdu.sender, &incoming_pdu.sender,
&incoming_pdu.room_id, &incoming_pdu.room_id,
@ -594,12 +621,12 @@ impl Service {
// We start looking at current room state now, so lets lock the room // We start looking at current room state now, so lets lock the room
trace!("Locking the room"); trace!("Locking the room");
let state_lock = services().rooms.state.mutex.lock(room_id).await; let state_lock = self.services.state.mutex.lock(room_id).await;
// Now we calculate the set of extremities this room has after the incoming // Now we calculate the set of extremities this room has after the incoming
// event has been applied. We start with the previous extremities (aka leaves) // event has been applied. We start with the previous extremities (aka leaves)
trace!("Calculating extremities"); trace!("Calculating extremities");
let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; let mut extremities = self.services.state.get_forward_extremities(room_id)?;
trace!("Calculated {} extremities", extremities.len()); trace!("Calculated {} extremities", extremities.len());
// Remove any forward extremities that are referenced by this incoming event's // Remove any forward extremities that are referenced by this incoming event's
@ -609,22 +636,13 @@ impl Service {
} }
// Only keep those extremities were not referenced yet // Only keep those extremities were not referenced yet
extremities.retain(|id| { extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true)));
!matches!(
services()
.rooms
.pdu_metadata
.is_event_referenced(room_id, id),
Ok(true)
)
});
debug!("Retained {} extremities. Compressing state", extremities.len()); debug!("Retained {} extremities. Compressing state", extremities.len());
let state_ids_compressed = Arc::new( let state_ids_compressed = Arc::new(
state_at_incoming_event state_at_incoming_event
.iter() .iter()
.map(|(shortstatekey, id)| { .map(|(shortstatekey, id)| {
services() self.services
.rooms
.state_compressor .state_compressor
.compress_state_event(*shortstatekey, id) .compress_state_event(*shortstatekey, id)
}) })
@ -637,8 +655,8 @@ impl Service {
// We also add state after incoming event to the fork states // We also add state after incoming event to the fork states
let mut state_after = state_at_incoming_event.clone(); let mut state_after = state_at_incoming_event.clone();
if let Some(state_key) = &incoming_pdu.state_key { if let Some(state_key) = &incoming_pdu.state_key {
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?;
@ -651,13 +669,12 @@ impl Service {
// Set the new room state to the resolved state // Set the new room state to the resolved state
debug!("Forcing new room state"); debug!("Forcing new room state");
let (sstatehash, new, removed) = services() let (sstatehash, new, removed) = self
.rooms .services
.state_compressor .state_compressor
.save_state(room_id, new_room_state)?; .save_state(room_id, new_room_state)?;
services() self.services
.rooms
.state .state
.force_state(room_id, sstatehash, new, removed, &state_lock) .force_state(room_id, sstatehash, new, removed, &state_lock)
.await?; .await?;
@ -667,8 +684,7 @@ impl Service {
// if not soft fail it // if not soft fail it
if soft_fail { if soft_fail {
debug!("Soft failing event"); debug!("Soft failing event");
services() self.services
.rooms
.timeline .timeline
.append_incoming_pdu( .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
@ -682,8 +698,7 @@ impl Service {
// Soft fail, we keep the event as an outlier but don't add it to the timeline // Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu); warn!("Event was soft failed: {:?}", incoming_pdu);
services() self.services
.rooms
.pdu_metadata .pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id)?; .mark_event_soft_failed(&incoming_pdu.event_id)?;
@ -696,8 +711,8 @@ impl Service {
// Now that the event has passed all auth it is added into the timeline. // Now that the event has passed all auth it is added into the timeline.
// We use the `state_at_event` instead of `state_after` so we accurately // We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event. // represent the state for this event.
let pdu_id = services() let pdu_id = self
.rooms .services
.timeline .timeline
.append_incoming_pdu( .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
@ -723,14 +738,14 @@ impl Service {
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> { ) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids"); debug!("Loading current room state ids");
let current_sstatehash = services() let current_sstatehash = self
.rooms .services
.state .state
.get_room_shortstatehash(room_id)? .get_room_shortstatehash(room_id)?
.expect("every room has state"); .expect("every room has state");
let current_state_ids = services() let current_state_ids = self
.rooms .services
.state_accessor .state_accessor
.state_full_ids(current_sstatehash) .state_full_ids(current_sstatehash)
.await?; .await?;
@ -740,8 +755,7 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); let mut auth_chain_sets = Vec::with_capacity(fork_states.len());
for state in &fork_states { for state in &fork_states {
auth_chain_sets.push( auth_chain_sets.push(
services() self.services
.rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect())
.await? .await?
@ -755,8 +769,7 @@ impl Service {
.map(|map| { .map(|map| {
map.into_iter() map.into_iter()
.filter_map(|(k, id)| { .filter_map(|(k, id)| {
services() self.services
.rooms
.short .short
.get_statekey_from_short(k) .get_statekey_from_short(k)
.map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
@ -766,11 +779,11 @@ impl Service {
}) })
.collect(); .collect();
let lock = services().globals.stateres_mutex.lock(); let lock = self.services.globals.stateres_mutex.lock();
debug!("Resolving state"); debug!("Resolving state");
let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.timeline.get_pdu(id); let res = self.services.timeline.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
error!("Failed to fetch event: {}", e); error!("Failed to fetch event: {}", e);
} }
@ -793,12 +806,11 @@ impl Service {
let new_room_state = state let new_room_state = state
.into_iter() .into_iter()
.map(|((event_type, state_key), event_id)| { .map(|((event_type, state_key), event_id)| {
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
services() self.services
.rooms
.state_compressor .state_compressor
.compress_state_event(shortstatekey, &event_id) .compress_state_event(shortstatekey, &event_id)
}) })
@ -814,15 +826,14 @@ impl Service {
&self, incoming_pdu: &Arc<PduEvent>, &self, incoming_pdu: &Arc<PduEvent>,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> { ) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
let prev_event = &*incoming_pdu.prev_events[0]; let prev_event = &*incoming_pdu.prev_events[0];
let prev_event_sstatehash = services() let prev_event_sstatehash = self
.rooms .services
.state_accessor .state_accessor
.pdu_shortstatehash(prev_event)?; .pdu_shortstatehash(prev_event)?;
let state = if let Some(shortstatehash) = prev_event_sstatehash { let state = if let Some(shortstatehash) = prev_event_sstatehash {
Some( Some(
services() self.services
.rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await, .await,
@ -833,8 +844,8 @@ impl Service {
if let Some(Ok(mut state)) = state { if let Some(Ok(mut state)) = state {
debug!("Using cached state"); debug!("Using cached state");
let prev_pdu = services() let prev_pdu = self
.rooms .services
.timeline .timeline
.get_pdu(prev_event) .get_pdu(prev_event)
.ok() .ok()
@ -842,8 +853,8 @@ impl Service {
.ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?;
if let Some(state_key) = &prev_pdu.state_key { if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?;
@ -866,13 +877,13 @@ impl Service {
let mut okay = true; let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events { for prev_eventid in &incoming_pdu.prev_events {
let Ok(Some(prev_event)) = services().rooms.timeline.get_pdu(prev_eventid) else { let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else {
okay = false; okay = false;
break; break;
}; };
let Ok(Some(sstatehash)) = services() let Ok(Some(sstatehash)) = self
.rooms .services
.state_accessor .state_accessor
.pdu_shortstatehash(prev_eventid) .pdu_shortstatehash(prev_eventid)
else { else {
@ -891,15 +902,15 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes { for (sstatehash, prev_event) in extremity_sstatehashes {
let mut leaf_state: HashMap<_, _> = services() let mut leaf_state: HashMap<_, _> = self
.rooms .services
.state_accessor .state_accessor
.state_full_ids(sstatehash) .state_full_ids(sstatehash)
.await?; .await?;
if let Some(state_key) = &prev_event.state_key { if let Some(state_key) = &prev_event.state_key {
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?;
leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id));
@ -910,7 +921,7 @@ impl Service {
let mut starting_events = Vec::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state { for (k, id) in leaf_state {
if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) { if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) {
// FIXME: Undo .to_string().into() when StateMap // FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType // is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone()); state.insert((ty.to_string().into(), st_key), id.clone());
@ -921,8 +932,7 @@ impl Service {
} }
auth_chain_sets.push( auth_chain_sets.push(
services() self.services
.rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, starting_events) .event_ids_iter(room_id, starting_events)
.await? .await?
@ -932,9 +942,9 @@ impl Service {
fork_states.push(state); fork_states.push(state);
} }
let lock = services().globals.stateres_mutex.lock(); let lock = self.services.globals.stateres_mutex.lock();
let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.timeline.get_pdu(id); let res = self.services.timeline.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
error!("Failed to fetch event: {}", e); error!("Failed to fetch event: {}", e);
} }
@ -947,8 +957,8 @@ impl Service {
new_state new_state
.into_iter() .into_iter()
.map(|((event_type, state_key), event_id)| { .map(|((event_type, state_key), event_id)| {
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
Ok((shortstatekey, event_id)) Ok((shortstatekey, event_id))
@ -974,7 +984,8 @@ impl Service {
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> { ) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids"); debug!("Fetching state ids");
match services() match self
.services
.sending .sending
.send_federation_request( .send_federation_request(
origin, origin,
@ -1004,8 +1015,8 @@ impl Service {
.clone() .clone()
.ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?;
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?;
@ -1022,8 +1033,8 @@ impl Service {
} }
// The original create event must still be in the state // The original create event must still be in the state
let create_shortstatekey = services() let create_shortstatekey = self
.rooms .services
.short .short
.get_shortstatekey(&StateEventType::RoomCreate, "")? .get_shortstatekey(&StateEventType::RoomCreate, "")?
.expect("Room exists"); .expect("Room exists");
@ -1056,7 +1067,8 @@ impl Service {
) -> AsyncRecursiveCanonicalJsonVec<'a> { ) -> AsyncRecursiveCanonicalJsonVec<'a> {
Box::pin(async move { Box::pin(async move {
let back_off = |id| async { let back_off = |id| async {
match services() match self
.services
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.write() .write()
@ -1075,7 +1087,7 @@ impl Service {
// a. Look in the main timeline (pduid_pdu tree) // a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree // b. Look at outlier pdu tree
// (get_pdu_json checks both) // (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) {
trace!("Found {} in db", id); trace!("Found {} in db", id);
events_with_auth_events.push((id, Some(local_pdu), vec![])); events_with_auth_events.push((id, Some(local_pdu), vec![]));
continue; continue;
@ -1089,7 +1101,8 @@ impl Service {
let mut events_all = HashSet::with_capacity(todo_auth_events.len()); let mut events_all = HashSet::with_capacity(todo_auth_events.len());
let mut i: u64 = 0; let mut i: u64 = 0;
while let Some(next_id) = todo_auth_events.pop() { while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services() if let Some((time, tries)) = self
.services
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
@ -1114,13 +1127,14 @@ impl Service {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id); trace!("Found {} in db", next_id);
continue; continue;
} }
debug!("Fetching {} over federation.", next_id); debug!("Fetching {} over federation.", next_id);
match services() match self
.services
.sending .sending
.send_federation_request( .send_federation_request(
origin, origin,
@ -1195,7 +1209,8 @@ impl Service {
pdus.push((local_pdu, None)); pdus.push((local_pdu, None));
} }
for (next_id, value) in events_in_reverse_order.iter().rev() { for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services() if let Some((time, tries)) = self
.services
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
@ -1244,8 +1259,8 @@ impl Service {
let mut eventid_info = HashMap::new(); let mut eventid_info = HashMap::new();
let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set;
let first_pdu_in_room = services() let first_pdu_in_room = self
.rooms .services
.timeline .timeline
.first_pdu_in_room(room_id)? .first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
@ -1267,19 +1282,18 @@ impl Service {
{ {
Self::check_room_id(room_id, &pdu)?; Self::check_room_id(room_id, &pdu)?;
if amount > services().globals.max_fetch_prev_events() { if amount > self.services.globals.max_fetch_prev_events() {
// Max limit reached // Max limit reached
debug!( debug!(
"Max prev event limit reached! Limit: {}", "Max prev event limit reached! Limit: {}",
services().globals.max_fetch_prev_events() self.services.globals.max_fetch_prev_events()
); );
graph.insert(prev_event_id.clone(), HashSet::new()); graph.insert(prev_event_id.clone(), HashSet::new());
continue; continue;
} }
if let Some(json) = json_opt.or_else(|| { if let Some(json) = json_opt.or_else(|| {
services() self.services
.rooms
.outlier .outlier
.get_outlier_pdu_json(&prev_event_id) .get_outlier_pdu_json(&prev_event_id)
.ok() .ok()
@ -1335,8 +1349,7 @@ impl Service {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
let acl_event = if let Some(acl) = let acl_event = if let Some(acl) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomServerAcl, "")? .room_state_get(room_id, &StateEventType::RoomServerAcl, "")?
{ {

View File

@ -1,29 +1,28 @@
use conduit::{Err, Error, Result}; use conduit::{pdu::gen_event_id_canonical_json, warn, Err, Error, Result};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId};
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
use tracing::warn;
use crate::{pdu::gen_event_id_canonical_json, services}; impl super::Service {
pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {pdu:?}: {e:?}");
Error::BadServerResponse("Invalid PDU in server response")
})?;
pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { let room_id: OwnedRoomId = value
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { .get("room_id")
warn!("Error parsing incoming event {pdu:?}: {e:?}"); .and_then(|id| RoomId::parse(id.as_str()?).ok())
Error::BadServerResponse("Invalid PDU in server response") .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
})?;
let room_id: OwnedRoomId = value let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else {
.get("room_id") return Err!("Server is not in room {room_id}");
.and_then(|id| RoomId::parse(id.as_str()?).ok()) };
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
return Err!("Server is not in room {room_id}"); // Event could not be converted to canonical json
}; return Err!(Request(InvalidParam("Could not convert event to canonical json.")));
};
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { Ok((event_id, value, room_id))
// Event could not be converted to canonical json }
return Err!(Request(InvalidParam("Could not convert event to canonical json.")));
};
Ok((event_id, value, room_id))
} }

View File

@ -3,7 +3,7 @@ use std::{
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use conduit::{debug, error, info, trace, warn}; use conduit::{debug, error, info, trace, warn, Error, Result};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::federation::{ api::federation::{
@ -21,8 +21,6 @@ use ruma::{
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard}; use tokio::sync::{RwLock, RwLockWriteGuard};
use crate::{services, Error, Result};
impl super::Service { impl super::Service {
pub async fn fetch_required_signing_keys<'a, E>( pub async fn fetch_required_signing_keys<'a, E>(
&'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, &'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
@ -147,7 +145,8 @@ impl super::Service {
debug!("Loading signing keys for {}", origin); debug!("Loading signing keys for {}", origin);
let result: BTreeMap<_, _> = services() let result: BTreeMap<_, _> = self
.services
.globals .globals
.signing_keys_for(origin)? .signing_keys_for(origin)?
.into_iter() .into_iter()
@ -171,9 +170,10 @@ impl super::Service {
&self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, &self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> { ) -> Result<()> {
for server in services().globals.trusted_servers() { for server in self.services.globals.trusted_servers() {
debug!("Asking batch signing keys from trusted server {}", server); debug!("Asking batch signing keys from trusted server {}", server);
match services() match self
.services
.sending .sending
.send_federation_request( .send_federation_request(
server, server,
@ -199,7 +199,8 @@ impl super::Service {
// TODO: Check signature from trusted server? // TODO: Check signature from trusted server?
servers.remove(&k.server_name); servers.remove(&k.server_name);
let result = services() let result = self
.services
.globals .globals
.db .db
.add_signing_key(&k.server_name, k.clone())? .add_signing_key(&k.server_name, k.clone())?
@ -234,7 +235,7 @@ impl super::Service {
.into_keys() .into_keys()
.map(|server| async move { .map(|server| async move {
( (
services() self.services
.sending .sending
.send_federation_request(&server, get_server_keys::v2::Request::new()) .send_federation_request(&server, get_server_keys::v2::Request::new())
.await, .await,
@ -248,7 +249,8 @@ impl super::Service {
if let (Ok(get_keys_response), origin) = result { if let (Ok(get_keys_response), origin) = result {
debug!("Result is from {origin}"); debug!("Result is from {origin}");
if let Ok(key) = get_keys_response.server_key.deserialize() { if let Ok(key) = get_keys_response.server_key.deserialize() {
let result: BTreeMap<_, _> = services() let result: BTreeMap<_, _> = self
.services
.globals .globals
.db .db
.add_signing_key(&origin, key)? .add_signing_key(&origin, key)?
@ -297,7 +299,7 @@ impl super::Service {
return Ok(()); return Ok(());
} }
if services().globals.query_trusted_key_servers_first() { if self.services.globals.query_trusted_key_servers_first() {
info!( info!(
"query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \
homeserver signing keys." homeserver signing keys."
@ -349,7 +351,8 @@ impl super::Service {
) -> Result<BTreeMap<String, Base64>> { ) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let mut result: BTreeMap<_, _> = services() let mut result: BTreeMap<_, _> = self
.services
.globals .globals
.signing_keys_for(origin)? .signing_keys_for(origin)?
.into_iter() .into_iter()
@ -362,15 +365,16 @@ impl super::Service {
} }
// i didnt split this out into their own functions because it's relatively small // i didnt split this out into their own functions because it's relatively small
if services().globals.query_trusted_key_servers_first() { if self.services.globals.query_trusted_key_servers_first() {
info!( info!(
"query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \
keys" keys"
); );
for server in services().globals.trusted_servers() { for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key"); debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = services() if let Some(server_keys) = self
.services
.sending .sending
.send_federation_request( .send_federation_request(
server, server,
@ -394,7 +398,10 @@ impl super::Service {
}) { }) {
debug!("Got signing keys: {:?}", server_keys); debug!("Got signing keys: {:?}", server_keys);
for k in server_keys { for k in server_keys {
services().globals.db.add_signing_key(origin, k.clone())?; self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
result.extend( result.extend(
k.verify_keys k.verify_keys
.into_iter() .into_iter()
@ -414,14 +421,15 @@ impl super::Service {
} }
debug!("Asking {origin} for their signing keys over federation"); debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = services() if let Some(server_key) = self
.services
.sending .sending
.send_federation_request(origin, get_server_keys::v2::Request::new()) .send_federation_request(origin, get_server_keys::v2::Request::new())
.await .await
.ok() .ok()
.and_then(|resp| resp.server_key.deserialize().ok()) .and_then(|resp| resp.server_key.deserialize().ok())
{ {
services() self.services
.globals .globals
.db .db
.add_signing_key(origin, server_key.clone())?; .add_signing_key(origin, server_key.clone())?;
@ -447,14 +455,15 @@ impl super::Service {
info!("query_trusted_key_servers_first is set to false, querying {origin} first"); info!("query_trusted_key_servers_first is set to false, querying {origin} first");
debug!("Asking {origin} for their signing keys over federation"); debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = services() if let Some(server_key) = self
.services
.sending .sending
.send_federation_request(origin, get_server_keys::v2::Request::new()) .send_federation_request(origin, get_server_keys::v2::Request::new())
.await .await
.ok() .ok()
.and_then(|resp| resp.server_key.deserialize().ok()) .and_then(|resp| resp.server_key.deserialize().ok())
{ {
services() self.services
.globals .globals
.db .db
.add_signing_key(origin, server_key.clone())?; .add_signing_key(origin, server_key.clone())?;
@ -477,9 +486,10 @@ impl super::Service {
} }
} }
for server in services().globals.trusted_servers() { for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key"); debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = services() if let Some(server_keys) = self
.services
.sending .sending
.send_federation_request( .send_federation_request(
server, server,
@ -503,7 +513,10 @@ impl super::Service {
}) { }) {
debug!("Got signing keys: {:?}", server_keys); debug!("Got signing keys: {:?}", server_keys);
for k in server_keys { for k in server_keys {
services().globals.db.add_signing_key(origin, k.clone())?; self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
result.extend( result.extend(
k.verify_keys k.verify_keys
.into_iter() .into_iter()

View File

@ -6,10 +6,10 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use data::Data; use conduit::{PduCount, Result};
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{PduCount, Result}; use self::data::Data;
pub struct Service { pub struct Service {
db: Data, db: Data,

View File

@ -1,30 +1,39 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{error, utils, Error, Result}; use conduit::{error, utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use crate::services; use crate::{rooms, Dep};
pub(super) struct Data { pub(super) struct Data {
disabledroomids: Arc<Map>, disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>, bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>, roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>, pduid_pdu: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
disabledroomids: db["disabledroomids"].clone(), disabledroomids: db["disabledroomids"].clone(),
bannedroomids: db["bannedroomids"].clone(), bannedroomids: db["bannedroomids"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(), pduid_pdu: db["pduid_pdu"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
} }
} }
pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> { pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? { let prefix = match self.services.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(), Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false), None => return Ok(false),
}; };

View File

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use conduit::Result; use conduit::Result;
use data::Data;
use ruma::{OwnedRoomId, RoomId}; use ruma::{OwnedRoomId, RoomId};
use self::data::Data;
pub struct Service { pub struct Service {
db: Data, db: Data,
} }
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -33,13 +33,13 @@ pub struct Service {
pub read_receipt: Arc<read_receipt::Service>, pub read_receipt: Arc<read_receipt::Service>,
pub search: Arc<search::Service>, pub search: Arc<search::Service>,
pub short: Arc<short::Service>, pub short: Arc<short::Service>,
pub spaces: Arc<spaces::Service>,
pub state: Arc<state::Service>, pub state: Arc<state::Service>,
pub state_accessor: Arc<state_accessor::Service>, pub state_accessor: Arc<state_accessor::Service>,
pub state_cache: Arc<state_cache::Service>, pub state_cache: Arc<state_cache::Service>,
pub state_compressor: Arc<state_compressor::Service>, pub state_compressor: Arc<state_compressor::Service>,
pub timeline: Arc<timeline::Service>,
pub threads: Arc<threads::Service>, pub threads: Arc<threads::Service>,
pub timeline: Arc<timeline::Service>,
pub typing: Arc<typing::Service>, pub typing: Arc<typing::Service>,
pub spaces: Arc<spaces::Service>,
pub user: Arc<user::Service>, pub user: Arc<user::Service>,
} }

View File

@ -1,26 +1,35 @@
use std::{mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result}; use conduit::{utils, Error, PduCount, PduEvent, Result};
use database::{Database, Map}; use database::Map;
use ruma::{EventId, RoomId, UserId}; use ruma::{EventId, RoomId, UserId};
use crate::{services, PduCount, PduEvent}; use crate::{rooms, Dep};
pub(super) struct Data { pub(super) struct Data {
tofrom_relation: Arc<Map>, tofrom_relation: Arc<Map>,
referencedevents: Arc<Map>, referencedevents: Arc<Map>,
softfailedeventids: Arc<Map>, softfailedeventids: Arc<Map>,
services: Services,
}
struct Services {
timeline: Dep<rooms::timeline::Service>,
} }
type PdusIterItem = Result<(PduCount, PduEvent)>; type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>; type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
tofrom_relation: db["tofrom_relation"].clone(), tofrom_relation: db["tofrom_relation"].clone(),
referencedevents: db["referencedevents"].clone(), referencedevents: db["referencedevents"].clone(),
softfailedeventids: db["softfailedeventids"].clone(), softfailedeventids: db["softfailedeventids"].clone(),
services: Services {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
} }
} }
@ -57,8 +66,8 @@ impl Data {
let mut pduid = shortroomid.to_be_bytes().to_vec(); let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes()); pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services() let mut pdu = self
.rooms .services
.timeline .timeline
.get_pdu_from_id(&pduid)? .get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;

View File

@ -2,8 +2,7 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use conduit::Result; use conduit::{PduCount, PduEvent, Result};
use data::Data;
use ruma::{ use ruma::{
api::{client::relations::get_relating_events, Direction}, api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType}, events::{relation::RelationType, TimelineEventType},
@ -11,12 +10,20 @@ use ruma::{
}; };
use serde::Deserialize; use serde::Deserialize;
use crate::{services, PduCount, PduEvent}; use self::data::Data;
use crate::{rooms, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
short: Dep<rooms::short::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
struct ExtractRelType { struct ExtractRelType {
rel_type: RelationType, rel_type: RelationType,
@ -30,7 +37,12 @@ struct ExtractRelatesToEventId {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
})) }))
} }
@ -101,8 +113,7 @@ impl Service {
}) })
.take(limit) .take(limit)
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() self.services
.rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
@ -147,8 +158,7 @@ impl Service {
}) })
.take(limit) .take(limit)
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() self.services
.rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
@ -180,10 +190,10 @@ impl Service {
pub fn relations_until<'a>( pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> { ) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; let room_id = self.services.short.get_or_create_shortroomid(room_id)?;
#[allow(unknown_lints)] #[allow(unknown_lints)]
#[allow(clippy::manual_unwrap_or_default)] #[allow(clippy::manual_unwrap_or_default)]
let target = match services().rooms.timeline.get_pdu_count(target)? { let target = match self.services.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c, Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations // TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator _ => 0, // This will result in an empty iterator

View File

@ -1,14 +1,14 @@
use std::{mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result}; use conduit::{utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{ use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw, serde::Raw,
CanonicalJsonObject, OwnedUserId, RoomId, UserId, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
}; };
use crate::services; use crate::{globals, Dep};
type AnySyncEphemeralRoomEventIter<'a> = type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>; Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
@ -16,15 +16,24 @@ type AnySyncEphemeralRoomEventIter<'a> =
pub(super) struct Data { pub(super) struct Data {
roomuserid_privateread: Arc<Map>, roomuserid_privateread: Arc<Map>,
roomuserid_lastprivatereadupdate: Arc<Map>, roomuserid_lastprivatereadupdate: Arc<Map>,
services: Services,
readreceiptid_readreceipt: Arc<Map>, readreceiptid_readreceipt: Arc<Map>,
} }
struct Services {
globals: Dep<globals::Service>,
}
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
roomuserid_privateread: db["roomuserid_privateread"].clone(), roomuserid_privateread: db["roomuserid_privateread"].clone(),
roomuserid_lastprivatereadupdate: db["roomuserid_lastprivatereadupdate"].clone(), roomuserid_lastprivatereadupdate: db["roomuserid_lastprivatereadupdate"].clone(),
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
@ -51,7 +60,7 @@ impl Data {
} }
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF); room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes()); room_latest_id.extend_from_slice(user_id.as_bytes());
@ -108,7 +117,7 @@ impl Data {
.insert(&key, &count.to_be_bytes())?; .insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes()) .insert(&key, &self.services.globals.next_count()?.to_be_bytes())
} }
pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {

View File

@ -6,16 +6,24 @@ use conduit::Result;
use data::Data; use data::Data;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
use crate::services; use crate::{sending, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
sending: Dep<sending::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
sending: args.depend::<sending::Service>("sending"),
},
db: Data::new(&args),
})) }))
} }
@ -26,7 +34,7 @@ impl Service {
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event)?; self.db.readreceipt_update(user_id, room_id, event)?;
services().sending.flush_room(room_id)?; self.services.sending.flush_room(room_id)?;
Ok(()) Ok(())
} }

View File

@ -1,21 +1,30 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{utils, Result}; use conduit::{utils, Result};
use database::{Database, Map}; use database::Map;
use ruma::RoomId; use ruma::RoomId;
use crate::services; use crate::{rooms, Dep};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
pub(super) struct Data { pub(super) struct Data {
tokenids: Arc<Map>, tokenids: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
tokenids: db["tokenids"].clone(), tokenids: db["tokenids"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
} }
} }
@ -51,8 +60,8 @@ impl Data {
} }
pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services() let prefix = self
.rooms .services
.short .short
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
.expect("room exists") .expect("room exists")

View File

@ -13,7 +13,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{utils, warn, Error, Result}; use conduit::{utils, warn, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::services; use crate::{globals, Dep};
pub(super) struct Data { pub(super) struct Data {
eventid_shorteventid: Arc<Map>, eventid_shorteventid: Arc<Map>,
@ -13,10 +13,16 @@ pub(super) struct Data {
shortstatekey_statekey: Arc<Map>, shortstatekey_statekey: Arc<Map>,
roomid_shortroomid: Arc<Map>, roomid_shortroomid: Arc<Map>,
statehash_shortstatehash: Arc<Map>, statehash_shortstatehash: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
eventid_shorteventid: db["eventid_shorteventid"].clone(), eventid_shorteventid: db["eventid_shorteventid"].clone(),
shorteventid_eventid: db["shorteventid_eventid"].clone(), shorteventid_eventid: db["shorteventid_eventid"].clone(),
@ -24,6 +30,9 @@ impl Data {
shortstatekey_statekey: db["shortstatekey_statekey"].clone(), shortstatekey_statekey: db["shortstatekey_statekey"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(),
statehash_shortstatehash: db["statehash_shortstatehash"].clone(), statehash_shortstatehash: db["statehash_shortstatehash"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
@ -31,7 +40,7 @@ impl Data {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else { } else {
let shorteventid = services().globals.next_count()?; let shorteventid = self.services.globals.next_count()?;
self.eventid_shorteventid self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid self.shorteventid_eventid
@ -59,7 +68,7 @@ impl Data {
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
), ),
None => { None => {
let short = services().globals.next_count()?; let short = self.services.globals.next_count()?;
self.eventid_shorteventid self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?; .insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid self.shorteventid_eventid
@ -98,7 +107,7 @@ impl Data {
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else { } else {
let shortstatekey = services().globals.next_count()?; let shortstatekey = self.services.globals.next_count()?;
self.statekey_shortstatekey self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?; .insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey self.shortstatekey_statekey
@ -158,7 +167,7 @@ impl Data {
true, true,
) )
} else { } else {
let shortstatehash = services().globals.next_count()?; let shortstatehash = self.services.globals.next_count()?;
self.statehash_shortstatehash self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?; .insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false) (shortstatehash, false)
@ -176,7 +185,7 @@ impl Data {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else { } else {
let short = services().globals.next_count()?; let short = self.services.globals.next_count()?;
self.roomid_shortroomid self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?; .insert(room_id.as_bytes(), &short.to_be_bytes())?;
short short

View File

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use conduit::Result; use conduit::Result;
use data::Data;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use self::data::Data;
pub struct Service { pub struct Service {
db: Data, db: Data,
} }
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -28,7 +28,7 @@ use ruma::{
}; };
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::services; use crate::{rooms, sending, Dep};
pub struct CachedSpaceHierarchySummary { pub struct CachedSpaceHierarchySummary {
summary: SpaceHierarchyParentSummary, summary: SpaceHierarchyParentSummary,
@ -119,42 +119,18 @@ enum Identifier<'a> {
} }
pub struct Service { pub struct Service {
services: Services,
pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>, pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
} }
// Here because cannot implement `From` across ruma-federation-api and struct Services {
// ruma-client-api types state_accessor: Dep<rooms::state_accessor::Service>,
impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk { state_cache: Dep<rooms::state_cache::Service>,
fn from(value: CachedSpaceHierarchySummary) -> Self { state: Dep<rooms::state::Service>,
let SpaceHierarchyParentSummary { short: Dep<rooms::short::Service>,
canonical_alias, event_handler: Dep<rooms::event_handler::Service>,
name, timeline: Dep<rooms::timeline::Service>,
num_joined_members, sending: Dep<sending::Service>,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
..
} = value.summary;
Self {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
}
}
} }
impl crate::Service for Service { impl crate::Service for Service {
@ -163,6 +139,15 @@ impl crate::Service for Service {
let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity); let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity);
let cache_size = cache_size * config.cache_capacity_modifier; let cache_size = cache_size * config.cache_capacity_modifier;
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: Services {
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state: args.depend::<rooms::state::Service>("rooms::state"),
short: args.depend::<rooms::short::Service>("rooms::short"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
sending: args.depend::<sending::Service>("sending"),
},
roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)), roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)),
})) }))
} }
@ -226,7 +211,7 @@ impl Service {
.as_ref() .as_ref()
{ {
return Ok(if let Some(cached) = cached { return Ok(if let Some(cached) = cached {
if is_accessable_child( if self.is_accessible_child(
current_room, current_room,
&cached.summary.join_rule, &cached.summary.join_rule,
&identifier, &identifier,
@ -242,8 +227,8 @@ impl Service {
} }
Ok( Ok(
if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? { if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? {
let summary = Self::get_room_summary(current_room, children_pdus, &identifier); let summary = self.get_room_summary(current_room, children_pdus, &identifier);
if let Ok(summary) = summary { if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert( self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(), current_room.clone(),
@ -269,7 +254,8 @@ impl Service {
) -> Result<Option<SummaryAccessibility>> { ) -> Result<Option<SummaryAccessibility>> {
for server in via { for server in via {
debug_info!("Asking {server} for /hierarchy"); debug_info!("Asking {server} for /hierarchy");
let Ok(response) = services() let Ok(response) = self
.services
.sending .sending
.send_federation_request( .send_federation_request(
server, server,
@ -325,7 +311,10 @@ impl Service {
avatar_url, avatar_url,
join_rule, join_rule,
room_type, room_type,
children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), children_state: self
.get_stripped_space_child_events(&room_id)
.await?
.unwrap(),
allowed_room_ids, allowed_room_ids,
} }
}, },
@ -333,7 +322,7 @@ impl Service {
); );
} }
} }
if is_accessable_child( if self.is_accessible_child(
current_room, current_room,
&response.room.join_rule, &response.room.join_rule,
&Identifier::UserId(user_id), &Identifier::UserId(user_id),
@ -370,12 +359,13 @@ impl Service {
} }
fn get_room_summary( fn get_room_summary(
current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>, identifier: &Identifier<'_>, &self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> { ) -> Result<SpaceHierarchyParentSummary, Error> {
let room_id: &RoomId = current_room; let room_id: &RoomId = current_room;
let join_rule = services() let join_rule = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| { .map(|s| {
@ -386,12 +376,12 @@ impl Service {
.transpose()? .transpose()?
.unwrap_or(JoinRule::Invite); .unwrap_or(JoinRule::Invite);
let allowed_room_ids = services() let allowed_room_ids = self
.rooms .services
.state_accessor .state_accessor
.allowed_room_ids(join_rule.clone()); .allowed_room_ids(join_rule.clone());
if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) {
debug!("User is not allowed to see room {room_id}"); debug!("User is not allowed to see room {room_id}");
// This error will be caught later // This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
@ -400,18 +390,18 @@ impl Service {
let join_rule = join_rule.into(); let join_rule = join_rule.into();
Ok(SpaceHierarchyParentSummary { Ok(SpaceHierarchyParentSummary {
canonical_alias: services() canonical_alias: self
.rooms .services
.state_accessor .state_accessor
.get_canonical_alias(room_id) .get_canonical_alias(room_id)
.unwrap_or(None), .unwrap_or(None),
name: services() name: self
.rooms .services
.state_accessor .state_accessor
.get_name(room_id) .get_name(room_id)
.unwrap_or(None), .unwrap_or(None),
num_joined_members: services() num_joined_members: self
.rooms .services
.state_cache .state_cache
.room_joined_count(room_id) .room_joined_count(room_id)
.unwrap_or_default() .unwrap_or_default()
@ -422,22 +412,22 @@ impl Service {
.try_into() .try_into()
.expect("user count should not be that big"), .expect("user count should not be that big"),
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
topic: services() topic: self
.rooms .services
.state_accessor .state_accessor
.get_room_topic(room_id) .get_room_topic(room_id)
.unwrap_or(None), .unwrap_or(None),
world_readable: services().rooms.state_accessor.is_world_readable(room_id)?, world_readable: self.services.state_accessor.is_world_readable(room_id)?,
guest_can_join: services().rooms.state_accessor.guest_can_join(room_id)?, guest_can_join: self.services.state_accessor.guest_can_join(room_id)?,
avatar_url: services() avatar_url: self
.rooms .services
.state_accessor .state_accessor
.get_avatar(room_id)? .get_avatar(room_id)?
.into_option() .into_option()
.unwrap_or_default() .unwrap_or_default()
.url, .url,
join_rule, join_rule,
room_type: services().rooms.state_accessor.get_room_type(room_id)?, room_type: self.services.state_accessor.get_room_type(room_id)?,
children_state, children_state,
allowed_room_ids, allowed_room_ids,
}) })
@ -487,7 +477,7 @@ impl Service {
.into_iter() .into_iter()
.rev() .rev()
.skip_while(|(room, _)| { .skip_while(|(room, _)| {
if let Ok(short) = services().rooms.short.get_shortroomid(room) if let Ok(short) = self.services.short.get_shortroomid(room)
{ {
short.as_ref() != short_room_ids.get(parents.len()) short.as_ref() != short_room_ids.get(parents.len())
} else { } else {
@ -541,7 +531,7 @@ impl Service {
let mut short_room_ids = vec![]; let mut short_room_ids = vec![];
for room in parents { for room in parents {
short_room_ids.push(services().rooms.short.get_or_create_shortroomid(&room)?); short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?);
} }
Some( Some(
@ -559,128 +549,152 @@ impl Service {
rooms: results, rooms: results,
}) })
} }
}
fn next_room_to_traverse( /// Simply returns the stripped m.space.child events of a room
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>, async fn get_stripped_space_child_events(
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> { &self, room_id: &RoomId,
while stack.last().map_or(false, Vec::is_empty) { ) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
stack.pop(); let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else {
parents.pop_back(); return Ok(None);
};
let state = self
.services
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = self
.services
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
.map(|c| c.via)
.map_or(true, |v| v.is_empty())
{
continue;
}
if OwnedRoomId::try_from(state_key).is_ok() {
children_pdus.push(pdu.to_stripped_spacechild_state_event());
}
}
Ok(Some(children_pdus))
} }
stack.last_mut().and_then(Vec::pop) /// With the given identifier, checks if a room is accessable
} fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
// Note: unwrap_or_default for bool means false
match identifier {
Identifier::ServerName(server_name) => {
let room_id: &RoomId = current_room;
/// Simply returns the stripped m.space.child events of a room // Checks if ACLs allow for the server to participate
async fn get_stripped_space_child_events( if self
room_id: &RoomId, .services
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { .event_handler
let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { .acl_check(server_name, room_id)
return Ok(None); .is_err()
}; {
return false;
let state = services()
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = services()
.rooms
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
.map(|c| c.via)
.map_or(true, |v| v.is_empty())
{
continue;
}
if OwnedRoomId::try_from(state_key).is_ok() {
children_pdus.push(pdu.to_stripped_spacechild_state_event());
}
}
Ok(Some(children_pdus))
}
/// With the given identifier, checks if a room is accessable
fn is_accessable_child(
current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
// Note: unwrap_or_default for bool means false
match identifier {
Identifier::ServerName(server_name) => {
let room_id: &RoomId = current_room;
// Checks if ACLs allow for the server to participate
if services()
.rooms
.event_handler
.acl_check(server_name, room_id)
.is_err()
{
return false;
}
},
Identifier::UserId(user_id) => {
if services()
.rooms
.state_cache
.is_joined(user_id, current_room)
.unwrap_or_default()
|| services()
.rooms
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
{
return true;
}
},
} // Takes care of join rules
match join_rule {
SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if services()
.rooms
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
return true;
}
},
Identifier::ServerName(server) => {
if services()
.rooms
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
return true;
}
},
} }
} },
false Identifier::UserId(user_id) => {
}, if self
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, .services
// Invite only, Private, or Custom join rule .state_cache
_ => false, .is_joined(user_id, current_room)
.unwrap_or_default()
|| self
.services
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
{
return true;
}
},
} // Takes care of join rules
match join_rule {
SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if self
.services
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
return true;
}
},
Identifier::ServerName(server) => {
if self
.services
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
return true;
}
},
}
}
false
},
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
// Invite only, Private, or Custom join rule
_ => false,
}
}
}
// Here because cannot implement `From` across ruma-federation-api and
// ruma-client-api types
impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
fn from(value: CachedSpaceHierarchySummary) -> Self {
let SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
..
} = value.summary;
Self {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
}
} }
} }
@ -736,3 +750,14 @@ fn get_parent_children_via(
}) })
.collect() .collect()
} }
fn next_room_to_traverse(
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
while stack.last().map_or(false, Vec::is_empty) {
stack.pop();
parents.pop_back();
}
stack.last_mut().and_then(Vec::pop)
}

View File

@ -8,7 +8,7 @@ use std::{
use conduit::{ use conduit::{
utils::{calculate_hash, MutexMap, MutexMapGuard}, utils::{calculate_hash, MutexMap, MutexMapGuard},
warn, Error, Result, warn, Error, PduEvent, Result,
}; };
use data::Data; use data::Data;
use ruma::{ use ruma::{
@ -23,19 +23,39 @@ use ruma::{
}; };
use super::state_compressor::CompressedStateEvent; use super::state_compressor::CompressedStateEvent;
use crate::{services, PduEvent}; use crate::{globals, rooms, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
pub mutex: RoomMutexMap, pub mutex: RoomMutexMap,
} }
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
spaces: Dep<rooms::spaces::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>; type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(args.db), db: Data::new(args.db),
mutex: RoomMutexMap::new(), mutex: RoomMutexMap::new(),
})) }))
@ -62,14 +82,13 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| { for event_id in statediffnew.iter().filter_map(|new| {
services() self.services
.rooms
.state_compressor .state_compressor
.parse_compressed_state_event(new) .parse_compressed_state_event(new)
.ok() .ok()
.map(|(_, id)| id) .map(|(_, id)| id)
}) { }) {
let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else { let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else {
continue; continue;
}; };
@ -94,7 +113,7 @@ impl Service {
continue; continue;
}; };
services().rooms.state_cache.update_membership( self.services.state_cache.update_membership(
room_id, room_id,
&user_id, &user_id,
membership_event, membership_event,
@ -105,8 +124,7 @@ impl Service {
)?; )?;
}, },
TimelineEventType::SpaceChild => { TimelineEventType::SpaceChild => {
services() self.services
.rooms
.spaces .spaces
.roomid_spacehierarchy_cache .roomid_spacehierarchy_cache
.lock() .lock()
@ -117,7 +135,7 @@ impl Service {
} }
} }
services().rooms.state_cache.update_joined_count(room_id)?; self.services.state_cache.update_joined_count(room_id)?;
self.db self.db
.set_room_state(room_id, shortstatehash, state_lock)?; .set_room_state(room_id, shortstatehash, state_lock)?;
@ -133,10 +151,7 @@ impl Service {
pub fn set_event_state( pub fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<u64> { ) -> Result<u64> {
let shorteventid = services() let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?;
.rooms
.short
.get_or_create_shorteventid(event_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
@ -147,20 +162,15 @@ impl Service {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
let (shortstatehash, already_existed) = services() let (shortstatehash, already_existed) = self
.rooms .services
.short .short
.get_or_create_shortstatehash(&state_hash)?; .get_or_create_shortstatehash(&state_hash)?;
if !already_existed { if !already_existed {
let states_parents = previous_shortstatehash.map_or_else( let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
|p| { |p| self.services.state_compressor.load_shortstatehash_info(p),
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
)?; )?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
@ -179,7 +189,7 @@ impl Service {
} else { } else {
(state_ids_compressed, Arc::new(HashSet::new())) (state_ids_compressed, Arc::new(HashSet::new()))
}; };
services().rooms.state_compressor.save_state_from_diff( self.services.state_compressor.save_state_from_diff(
shortstatehash, shortstatehash,
statediffnew, statediffnew,
statediffremoved, statediffremoved,
@ -199,8 +209,8 @@ impl Service {
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu), level = "debug")] #[tracing::instrument(skip(self, new_pdu), level = "debug")]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = services() let shorteventid = self
.rooms .services
.short .short
.get_or_create_shorteventid(&new_pdu.event_id)?; .get_or_create_shorteventid(&new_pdu.event_id)?;
@ -214,21 +224,16 @@ impl Service {
let states_parents = previous_shortstatehash.map_or_else( let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
#[inline] #[inline]
|p| { |p| self.services.state_compressor.load_shortstatehash_info(p),
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
)?; )?;
let shortstatekey = services() let shortstatekey = self
.rooms .services
.short .short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?;
let new = services() let new = self
.rooms .services
.state_compressor .state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?; .compress_state_event(shortstatekey, &new_pdu.event_id)?;
@ -246,7 +251,7 @@ impl Service {
} }
// TODO: statehash with deterministic inputs // TODO: statehash with deterministic inputs
let shortstatehash = services().globals.next_count()?; let shortstatehash = self.services.globals.next_count()?;
let mut statediffnew = HashSet::new(); let mut statediffnew = HashSet::new();
statediffnew.insert(new); statediffnew.insert(new);
@ -256,7 +261,7 @@ impl Service {
statediffremoved.insert(*replaces); statediffremoved.insert(*replaces);
} }
services().rooms.state_compressor.save_state_from_diff( self.services.state_compressor.save_state_from_diff(
shortstatehash, shortstatehash,
Arc::new(statediffnew), Arc::new(statediffnew),
Arc::new(statediffremoved), Arc::new(statediffremoved),
@ -275,22 +280,20 @@ impl Service {
let mut state = Vec::new(); let mut state = Vec::new();
// Add recommended events // Add recommended events
if let Some(e) = if let Some(e) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = services().rooms.state_accessor.room_state_get( if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id, &invite_event.room_id,
&StateEventType::RoomCanonicalAlias, &StateEventType::RoomCanonicalAlias,
"", "",
@ -298,22 +301,20 @@ impl Service {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = services().rooms.state_accessor.room_state_get( if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id, &invite_event.room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
invite_event.sender.as_str(), invite_event.sender.as_str(),
@ -339,8 +340,8 @@ impl Service {
/// Returns the room's version. /// Returns the room's version.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = services() let create_event = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?; .room_state_get(room_id, &StateEventType::RoomCreate, "")?;
@ -393,8 +394,7 @@ impl Service {
let mut sauthevents = auth_events let mut sauthevents = auth_events
.into_iter() .into_iter()
.filter_map(|(event_type, state_key)| { .filter_map(|(event_type, state_key)| {
services() self.services
.rooms
.short .short
.get_shortstatekey(&event_type.to_string().into(), &state_key) .get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok() .ok()
@ -403,8 +403,8 @@ impl Service {
}) })
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let full_state = services() let full_state = self
.rooms .services
.state_compressor .state_compressor
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
@ -414,16 +414,14 @@ impl Service {
Ok(full_state Ok(full_state
.iter() .iter()
.filter_map(|compressed| { .filter_map(|compressed| {
services() self.services
.rooms
.state_compressor .state_compressor
.parse_compressed_state_event(compressed) .parse_compressed_state_event(compressed)
.ok() .ok()
}) })
.filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id)))
.filter_map(|(k, event_id)| { .filter_map(|(k, event_id)| {
services() self.services
.rooms
.timeline .timeline
.get_pdu(&event_id) .get_pdu(&event_id)
.ok() .ok()

View File

@ -1,28 +1,43 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use conduit::{utils, Error, Result}; use conduit::{utils, Error, PduEvent, Result};
use database::{Database, Map}; use database::Map;
use ruma::{events::StateEventType, EventId, RoomId}; use ruma::{events::StateEventType, EventId, RoomId};
use crate::{services, PduEvent}; use crate::{rooms, Dep};
pub(super) struct Data { pub(super) struct Data {
eventid_shorteventid: Arc<Map>, eventid_shorteventid: Arc<Map>,
shorteventid_shortstatehash: Arc<Map>, shorteventid_shortstatehash: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
eventid_shorteventid: db["eventid_shorteventid"].clone(), eventid_shorteventid: db["eventid_shorteventid"].clone(),
shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
} }
} }
#[allow(unused_qualifications)] // async traits #[allow(unused_qualifications)] // async traits
pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services() let full_state = self
.rooms .services
.state_compressor .state_compressor
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
@ -31,8 +46,8 @@ impl Data {
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i: u8 = 0; let mut i: u8 = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let parsed = services() let parsed = self
.rooms .services
.state_compressor .state_compressor
.parse_compressed_state_event(compressed)?; .parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1); result.insert(parsed.0, parsed.1);
@ -49,8 +64,8 @@ impl Data {
pub(super) async fn state_full( pub(super) async fn state_full(
&self, shortstatehash: u64, &self, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services() let full_state = self
.rooms .services
.state_compressor .state_compressor
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
@ -60,11 +75,11 @@ impl Data {
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i: u8 = 0; let mut i: u8 = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let (_, eventid) = services() let (_, eventid) = self
.rooms .services
.state_compressor .state_compressor
.parse_compressed_state_event(compressed)?; .parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? {
result.insert( result.insert(
( (
pdu.kind.to_string().into(), pdu.kind.to_string().into(),
@ -92,15 +107,15 @@ impl Data {
pub(super) fn state_get_id( pub(super) fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services() let Some(shortstatekey) = self
.rooms .services
.short .short
.get_shortstatekey(event_type, state_key)? .get_shortstatekey(event_type, state_key)?
else { else {
return Ok(None); return Ok(None);
}; };
let full_state = services() let full_state = self
.rooms .services
.state_compressor .state_compressor
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
@ -110,8 +125,7 @@ impl Data {
.iter() .iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| { .and_then(|compressed| {
services() self.services
.rooms
.state_compressor .state_compressor
.parse_compressed_state_event(compressed) .parse_compressed_state_event(compressed)
.ok() .ok()
@ -125,7 +139,7 @@ impl Data {
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)? self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id))
} }
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
@ -149,7 +163,7 @@ impl Data {
pub(super) async fn room_state_full( pub(super) async fn room_state_full(
&self, room_id: &RoomId, &self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await self.state_full(current_shortstatehash).await
} else { } else {
Ok(HashMap::new()) Ok(HashMap::new())
@ -161,7 +175,7 @@ impl Data {
pub(super) fn room_state_get_id( pub(super) fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key) self.state_get_id(current_shortstatehash, event_type, state_key)
} else { } else {
Ok(None) Ok(None)
@ -173,7 +187,7 @@ impl Data {
pub(super) fn room_state_get( pub(super) fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key) self.state_get(current_shortstatehash, event_type, state_key)
} else { } else {
Ok(None) Ok(None)

View File

@ -6,7 +6,7 @@ use std::{
sync::{Arc, Mutex as StdMutex, Mutex}, sync::{Arc, Mutex as StdMutex, Mutex},
}; };
use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result}; use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result};
use data::Data; use data::Data;
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
@ -33,14 +33,20 @@ use ruma::{
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent}; use crate::{rooms, rooms::state::RoomMutexGuard, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>, pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>, pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
} }
struct Services {
state_cache: Dep<rooms::state_cache::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config; let config = &args.server.config;
@ -50,7 +56,11 @@ impl crate::Service for Service {
f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier; f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)), server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)), user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)),
})) }))
@ -164,8 +174,8 @@ impl Service {
}) })
.unwrap_or(HistoryVisibility::Shared); .unwrap_or(HistoryVisibility::Shared);
let mut current_server_members = services() let mut current_server_members = self
.rooms .services
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -212,7 +222,7 @@ impl Service {
return Ok(*visibility); return Ok(*visibility);
} }
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
@ -258,7 +268,7 @@ impl Service {
/// the room's history_visibility at that event's state. /// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id))] #[tracing::instrument(skip(self, user_id, room_id))]
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self let history_visibility = self
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
@ -342,8 +352,8 @@ impl Service {
redacts: None, redacts: None,
}; };
Ok(services() Ok(self
.rooms .services
.timeline .timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock) .create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.is_ok()) .is_ok())
@ -413,7 +423,7 @@ impl Service {
// Falling back on m.room.create to judge power level // Falling back on m.room.create to judge power level
if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? {
Ok(pdu.sender == sender Ok(pdu.sender == sender
|| if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
pdu.sender == sender pdu.sender == sender
} else { } else {
false false
@ -430,7 +440,7 @@ impl Service {
.map(|event: RoomPowerLevels| { .map(|event: RoomPowerLevels| {
event.user_can_redact_event_of_other(sender) event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender) || event.user_can_redact_own_event(sender)
&& if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
if federation { if federation {
pdu.sender.server_name() == sender.server_name() pdu.sender.server_name() == sender.server_name()
} else { } else {

View File

@ -4,7 +4,7 @@ use std::{
}; };
use conduit::{utils, Error, Result}; use conduit::{utils, Error, Result};
use database::{Database, Map}; use database::Map;
use itertools::Itertools; use itertools::Itertools;
use ruma::{ use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent}, events::{AnyStrippedStateEvent, AnySyncStateEvent},
@ -12,44 +12,55 @@ use ruma::{
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
use crate::{appservice::RegistrationInfo, services, user_is_local}; use crate::{appservice::RegistrationInfo, globals, user_is_local, users, Dep};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>; type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
pub(super) struct Data { pub(super) struct Data {
userroomid_joined: Arc<Map>,
roomuserid_joined: Arc<Map>,
userroomid_invitestate: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
userroomid_leftstate: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomid_invitedcount: Arc<Map>,
roomserverids: Arc<Map>,
serverroomids: Arc<Map>,
pub(super) appservice_in_room_cache: AppServiceInRoomCache, pub(super) appservice_in_room_cache: AppServiceInRoomCache,
roomid_invitedcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomserverids: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
userroomid_joined: db["userroomid_joined"].clone(),
roomuserid_joined: db["roomuserid_joined"].clone(),
userroomid_invitestate: db["userroomid_invitestate"].clone(),
roomuserid_invitecount: db["roomuserid_invitecount"].clone(),
userroomid_leftstate: db["userroomid_leftstate"].clone(),
roomuserid_leftcount: db["roomuserid_leftcount"].clone(),
roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(),
roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(),
roomid_joinedcount: db["roomid_joinedcount"].clone(),
roomid_invitedcount: db["roomid_invitedcount"].clone(),
roomserverids: db["roomserverids"].clone(),
serverroomids: db["serverroomids"].clone(),
appservice_in_room_cache: RwLock::new(HashMap::new()), appservice_in_room_cache: RwLock::new(HashMap::new()),
roomid_invitedcount: db["roomid_invitedcount"].clone(),
roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(),
roomid_joinedcount: db["roomid_joinedcount"].clone(),
roomserverids: db["roomserverids"].clone(),
roomuserid_invitecount: db["roomuserid_invitecount"].clone(),
roomuserid_joined: db["roomuserid_joined"].clone(),
roomuserid_leftcount: db["roomuserid_leftcount"].clone(),
roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(),
serverroomids: db["serverroomids"].clone(),
userroomid_invitestate: db["userroomid_invitestate"].clone(),
userroomid_joined: db["userroomid_joined"].clone(),
userroomid_leftstate: db["userroomid_leftstate"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
} }
} }
@ -100,7 +111,7 @@ impl Data {
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?; )?;
self.roomuserid_invitecount self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
@ -144,7 +155,7 @@ impl Data {
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO )?; // TODO
self.roomuserid_leftcount self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?; self.userroomid_invitestate.remove(&userroom_id)?;
@ -228,7 +239,7 @@ impl Data {
} else { } else {
let bridge_user_id = UserId::parse_with_server_name( let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(), appservice.registration.sender_localpart.as_str(),
services().globals.server_name(), self.services.globals.server_name(),
) )
.ok(); .ok();
@ -356,7 +367,7 @@ impl Data {
) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { ) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new( Box::new(
self.local_users_in_room(room_id) self.local_users_in_room(room_id)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)),
) )
} }

View File

@ -21,16 +21,28 @@ use ruma::{
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
use crate::{appservice::RegistrationInfo, services, user_is_local}; use crate::{account_data, appservice::RegistrationInfo, rooms, user_is_local, users, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
account_data: Dep<account_data::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(&args),
})) }))
} }
@ -54,18 +66,18 @@ impl Service {
// update // update
#[allow(clippy::collapsible_if)] #[allow(clippy::collapsible_if)]
if !user_is_local(user_id) { if !user_is_local(user_id) {
if !services().users.exists(user_id)? { if !self.services.users.exists(user_id)? {
services().users.create(user_id, None)?; self.services.users.create(user_id, None)?;
} }
/* /*
// Try to update our local copy of the user if ours does not match // Try to update our local copy of the user if ours does not match
if ((services().users.displayname(user_id)? != membership_event.displayname) if ((self.services.users.displayname(user_id)? != membership_event.displayname)
|| (services().users.avatar_url(user_id)? != membership_event.avatar_url) || (self.services.users.avatar_url(user_id)? != membership_event.avatar_url)
|| (services().users.blurhash(user_id)? != membership_event.blurhash)) || (self.services.users.blurhash(user_id)? != membership_event.blurhash))
&& (membership != MembershipState::Leave) && (membership != MembershipState::Leave)
{ {
let response = services() let response = self.services
.sending .sending
.send_federation_request( .send_federation_request(
user_id.server_name(), user_id.server_name(),
@ -76,9 +88,9 @@ impl Service {
) )
.await; .await;
services().users.set_displayname(user_id, response.displayname.clone()).await?; self.services.users.set_displayname(user_id, response.displayname.clone()).await?;
services().users.set_avatar_url(user_id, response.avatar_url).await?; self.services.users.set_avatar_url(user_id, response.avatar_url).await?;
services().users.set_blurhash(user_id, response.blurhash).await?; self.services.users.set_blurhash(user_id, response.blurhash).await?;
}; };
*/ */
} }
@ -91,8 +103,8 @@ impl Service {
self.db.mark_as_once_joined(user_id, room_id)?; self.db.mark_as_once_joined(user_id, room_id)?;
// Check if the room has a predecessor // Check if the room has a predecessor
if let Some(predecessor) = services() if let Some(predecessor) = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok()) .and_then(|create| serde_json::from_str(create.content.get()).ok())
@ -124,21 +136,23 @@ impl Service {
// .ok(); // .ok();
// Copy old tags to new room // Copy old tags to new room
if let Some(tag_event) = services() if let Some(tag_event) = self
.services
.account_data .account_data
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)?
.map(|event| { .map(|event| {
serde_json::from_str(event.get()) serde_json::from_str(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) { }) {
services() self.services
.account_data .account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?)
.ok(); .ok();
}; };
// Copy direct chat flag // Copy direct chat flag
if let Some(direct_event) = services() if let Some(direct_event) = self
.services
.account_data .account_data
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())?
.map(|event| { .map(|event| {
@ -156,7 +170,7 @@ impl Service {
} }
if room_ids_updated { if room_ids_updated {
services().account_data.update( self.services.account_data.update(
None, None,
user_id, user_id,
GlobalAccountDataEventType::Direct.to_string().into(), GlobalAccountDataEventType::Direct.to_string().into(),
@ -171,7 +185,8 @@ impl Service {
}, },
MembershipState::Invite => { MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver // We want to know if the sender is ignored by the receiver
let is_ignored = services() let is_ignored = self
.services
.account_data .account_data
.get( .get(
None, // Ignored users are in global account data None, // Ignored users are in global account data
@ -393,8 +408,8 @@ impl Service {
/// See <https://spec.matrix.org/v1.10/appendices/#routing> /// See <https://spec.matrix.org/v1.10/appendices/#routing>
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> { pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = services() let most_powerful_user_server = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|pdu| { .map(|pdu| {

View File

@ -13,7 +13,7 @@ use lru_cache::LruCache;
use ruma::{EventId, RoomId}; use ruma::{EventId, RoomId};
use self::data::StateDiff; use self::data::StateDiff;
use crate::services; use crate::{rooms, Dep};
type StateInfoLruCache = Mutex< type StateInfoLruCache = Mutex<
LruCache< LruCache<
@ -48,16 +48,25 @@ pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Service { pub struct Service {
db: Data, db: Data,
services: Services,
pub stateinfo_cache: StateInfoLruCache, pub stateinfo_cache: StateInfoLruCache,
} }
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config; let config = &args.server.config;
let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(args.db),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
},
stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)),
})) }))
} }
@ -124,8 +133,8 @@ impl Service {
pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> { pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec(); let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice( v.extend_from_slice(
&services() &self
.rooms .services
.short .short
.get_or_create_shorteventid(event_id)? .get_or_create_shorteventid(event_id)?
.to_be_bytes(), .to_be_bytes(),
@ -138,7 +147,7 @@ impl Service {
pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> { pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> {
Ok(( Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"), utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"),
services().rooms.short.get_eventid_from_short( self.services.short.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"), utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"),
)?, )?,
)) ))
@ -282,7 +291,7 @@ impl Service {
pub fn save_state( pub fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, &self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> HashSetCompressStateEvent { ) -> HashSetCompressStateEvent {
let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?;
let state_hash = utils::calculate_hash( let state_hash = utils::calculate_hash(
&new_state_ids_compressed &new_state_ids_compressed
@ -291,8 +300,8 @@ impl Service {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
let (new_shortstatehash, already_existed) = services() let (new_shortstatehash, already_existed) = self
.rooms .services
.short .short
.get_or_create_shortstatehash(&state_hash)?; .get_or_create_shortstatehash(&state_hash)?;

View File

@ -1,29 +1,40 @@
use std::{mem::size_of, sync::Arc}; use std::{mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, Result}; use conduit::{checked, utils, Error, PduEvent, Result};
use database::{Database, Map}; use database::Map;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{services, PduEvent}; use crate::{rooms, Dep};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>; type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
pub(super) struct Data { pub(super) struct Data {
threadid_userids: Arc<Map>, threadid_userids: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
timeline: Dep<rooms::timeline::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
threadid_userids: db["threadid_userids"].clone(), threadid_userids: db["threadid_userids"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
} }
} }
pub(super) fn threads_until<'a>( pub(super) fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> { ) -> PduEventIterResult<'a> {
let prefix = services() let prefix = self
.rooms .services
.short .short
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
.expect("room exists") .expect("room exists")
@ -40,8 +51,8 @@ impl Data {
.map(move |(pduid, _users)| { .map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..]) let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services() let mut pdu = self
.rooms .services
.timeline .timeline
.get_pdu_from_id(&pduid)? .get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;

View File

@ -2,7 +2,7 @@ mod data;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use conduit::{Error, Result}; use conduit::{Error, PduEvent, Result};
use data::Data; use data::Data;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads},
@ -11,16 +11,24 @@ use ruma::{
}; };
use serde_json::json; use serde_json::json;
use crate::{services, PduEvent}; use crate::{rooms, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
} }
struct Services {
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
})) }))
} }
@ -35,22 +43,22 @@ impl Service {
} }
pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
let root_id = &services() let root_id = self
.rooms .services
.timeline .timeline
.get_pdu_id(root_event_id)? .get_pdu_id(root_event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?;
let root_pdu = services() let root_pdu = self
.rooms .services
.timeline .timeline
.get_pdu_from_id(root_id)? .get_pdu_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
let mut root_pdu_json = services() let mut root_pdu_json = self
.rooms .services
.timeline .timeline
.get_pdu_json_from_id(root_id)? .get_pdu_json_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
if let CanonicalJsonValue::Object(unsigned) = root_pdu_json if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
@ -93,20 +101,19 @@ impl Service {
); );
} }
services() self.services
.rooms
.timeline .timeline
.replace_pdu(root_id, &root_pdu_json, &root_pdu)?; .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?;
} }
let mut users = Vec::new(); let mut users = Vec::new();
if let Some(userids) = self.db.get_participants(root_id)? { if let Some(userids) = self.db.get_participants(&root_id)? {
users.extend_from_slice(&userids); users.extend_from_slice(&userids);
} else { } else {
users.push(root_pdu.sender); users.push(root_pdu.sender);
} }
users.push(pdu.sender.clone()); users.push(pdu.sender.clone());
self.db.update_participants(root_id, &users) self.db.update_participants(&root_id, &users)
} }
} }

View File

@ -4,19 +4,25 @@ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use conduit::{checked, error, utils, Error, Result}; use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result};
use database::{Database, Map}; use database::{Database, Map};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{services, PduCount, PduEvent}; use crate::{rooms, Dep};
pub(super) struct Data { pub(super) struct Data {
eventid_outlierpdu: Arc<Map>,
eventid_pduid: Arc<Map>, eventid_pduid: Arc<Map>,
pduid_pdu: Arc<Map>, pduid_pdu: Arc<Map>,
eventid_outlierpdu: Arc<Map>,
userroomid_notificationcount: Arc<Map>,
userroomid_highlightcount: Arc<Map>, userroomid_highlightcount: Arc<Map>,
userroomid_notificationcount: Arc<Map>,
pub(super) lasttimelinecount_cache: LastTimelineCountCache, pub(super) lasttimelinecount_cache: LastTimelineCountCache,
pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
} }
type PdusIterItem = Result<(PduCount, PduEvent)>; type PdusIterItem = Result<(PduCount, PduEvent)>;
@ -24,14 +30,19 @@ type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>; type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>;
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
eventid_pduid: db["eventid_pduid"].clone(), eventid_pduid: db["eventid_pduid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(), pduid_pdu: db["pduid_pdu"].clone(),
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
userroomid_highlightcount: db["userroomid_highlightcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
lasttimelinecount_cache: Mutex::new(HashMap::new()), lasttimelinecount_cache: Mutex::new(HashMap::new()),
db: args.db.clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
} }
} }
@ -210,7 +221,7 @@ impl Data {
/// happened before the event with id `until` in reverse-chronological /// happened before the event with id `until` in reverse-chronological
/// order. /// order.
pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> { pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?; let (prefix, current) = self.count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
@ -232,7 +243,7 @@ impl Data {
} }
pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> { pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?; let (prefix, current) = self.count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
@ -277,6 +288,41 @@ impl Data {
.increment_batch(highlights_batch.iter().map(Vec::as_slice))?; .increment_batch(highlights_batch.iter().map(Vec::as_slice))?;
Ok(()) Ok(())
} }
pub(super) fn count_to_id(
&self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
@ -294,38 +340,3 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
Ok(PduCount::Normal(last_u64)) Ok(PduCount::Normal(last_u64))
} }
} }
pub(super) fn count_to_id(
room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View File

@ -7,11 +7,12 @@ use std::{
}; };
use conduit::{ use conduit::{
debug, error, info, utils, debug, error, info,
pdu::{EventHash, PduBuilder, PduCount, PduEvent},
utils,
utils::{MutexMap, MutexMapGuard}, utils::{MutexMap, MutexMapGuard},
validated, warn, Error, Result, validated, warn, Error, Result, Server,
}; };
use data::Data;
use itertools::Itertools; use itertools::Itertools;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation}, api::{client::error::ErrorKind, federation},
@ -37,11 +38,10 @@ use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use self::data::Data;
use crate::{ use crate::{
appservice::NamespaceRegex, account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms,
pdu::{EventHash, PduBuilder}, rooms::state_compressor::CompressedStateEvent, sending, server_is_ours, Dep,
rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent},
server_is_ours, services, PduCount, PduEvent,
}; };
// Update Relationships // Update Relationships
@ -67,17 +67,61 @@ struct ExtractBody {
} }
pub struct Service { pub struct Service {
services: Services,
db: Data, db: Data,
pub mutex_insert: RoomMutexMap, pub mutex_insert: RoomMutexMap,
} }
struct Services {
server: Arc<Server>,
account_data: Dep<account_data::Service>,
appservice: Dep<appservice::Service>,
admin: Dep<admin::Service>,
alias: Dep<rooms::alias::Service>,
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
pdu_metadata: Dep<rooms::pdu_metadata::Service>,
read_receipt: Dep<rooms::read_receipt::Service>,
sending: Dep<sending::Service>,
user: Dep<rooms::user::Service>,
pusher: Dep<pusher::Service>,
threads: Dep<rooms::threads::Service>,
search: Dep<rooms::search::Service>,
spaces: Dep<rooms::spaces::Service>,
event_handler: Dep<rooms::event_handler::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>; type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), services: Services {
server: args.server.clone(),
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<appservice::Service>("appservice"),
admin: args.depend::<admin::Service>("admin"),
alias: args.depend::<rooms::alias::Service>("rooms::alias"),
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
sending: args.depend::<sending::Service>("sending"),
user: args.depend::<rooms::user::Service>("rooms::user"),
pusher: args.depend::<pusher::Service>("pusher"),
threads: args.depend::<rooms::threads::Service>("rooms::threads"),
search: args.depend::<rooms::search::Service>("rooms::search"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
},
db: Data::new(&args),
mutex_insert: RoomMutexMap::new(), mutex_insert: RoomMutexMap::new(),
})) }))
} }
@ -217,10 +261,10 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
// Coalesce database writes for the remainder of this scope. // Coalesce database writes for the remainder of this scope.
let _cork = services().db.cork_and_flush(); let _cork = self.db.db.cork_and_flush();
let shortroomid = services() let shortroomid = self
.rooms .services
.short .short
.get_shortroomid(&pdu.room_id)? .get_shortroomid(&pdu.room_id)?
.expect("room exists"); .expect("room exists");
@ -233,14 +277,14 @@ impl Service {
.entry("unsigned".to_owned()) .entry("unsigned".to_owned())
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
{ {
if let Some(shortstatehash) = services() if let Some(shortstatehash) = self
.rooms .services
.state_accessor .state_accessor
.pdu_shortstatehash(&pdu.event_id) .pdu_shortstatehash(&pdu.event_id)
.unwrap() .unwrap()
{ {
if let Some(prev_state) = services() if let Some(prev_state) = self
.rooms .services
.state_accessor .state_accessor
.state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key)
.unwrap() .unwrap()
@ -270,30 +314,26 @@ impl Service {
} }
// We must keep track of all events that have been referenced. // We must keep track of all events that have been referenced.
services() self.services
.rooms
.pdu_metadata .pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services() self.services
.rooms
.state .state
.set_forward_extremities(&pdu.room_id, leaves, state_lock)?; .set_forward_extremities(&pdu.room_id, leaves, state_lock)?;
let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await;
let count1 = services().globals.next_count()?; let count1 = self.services.globals.next_count()?;
// Mark as read first so the sending client doesn't get a notification even if // Mark as read first so the sending client doesn't get a notification even if
// appending fails // appending fails
services() self.services
.rooms
.read_receipt .read_receipt
.private_read_set(&pdu.room_id, &pdu.sender, count1)?; .private_read_set(&pdu.room_id, &pdu.sender, count1)?;
services() self.services
.rooms
.user .user
.reset_notification_counts(&pdu.sender, &pdu.room_id)?; .reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = services().globals.next_count()?; let count2 = self.services.globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes()); pdu_id.extend_from_slice(&count2.to_be_bytes());
@ -303,8 +343,8 @@ impl Service {
drop(insert_lock); drop(insert_lock);
// See if the event matches any known pushers // See if the event matches any known pushers
let power_levels: RoomPowerLevelsEventContent = services() let power_levels: RoomPowerLevelsEventContent = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
@ -319,8 +359,8 @@ impl Service {
let mut notifies = Vec::new(); let mut notifies = Vec::new();
let mut highlights = Vec::new(); let mut highlights = Vec::new();
let mut push_target = services() let mut push_target = self
.rooms .services
.state_cache .state_cache
.active_local_users_in_room(&pdu.room_id) .active_local_users_in_room(&pdu.room_id)
.collect_vec(); .collect_vec();
@ -341,7 +381,8 @@ impl Service {
continue; continue;
} }
let rules_for_user = services() let rules_for_user = self
.services
.account_data .account_data
.get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())?
.map(|event| { .map(|event| {
@ -357,7 +398,7 @@ impl Service {
let mut notify = false; let mut notify = false;
for action in for action in
services() self.services
.pusher .pusher
.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)?
{ {
@ -378,8 +419,10 @@ impl Service {
highlights.push(user.clone()); highlights.push(user.clone());
} }
for push_key in services().pusher.get_pushkeys(user) { for push_key in self.services.pusher.get_pushkeys(user) {
services().sending.send_pdu_push(&pdu_id, user, push_key?)?; self.services
.sending
.send_pdu_push(&pdu_id, user, push_key?)?;
} }
} }
@ -390,11 +433,11 @@ impl Service {
TimelineEventType::RoomRedaction => { TimelineEventType::RoomRedaction => {
use RoomVersionId::*; use RoomVersionId::*;
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; let room_version_id = self.services.state.get_room_version(&pdu.room_id)?;
match room_version_id { match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts { if let Some(redact_id) = &pdu.redacts {
if services().rooms.state_accessor.user_can_redact( if self.services.state_accessor.user_can_redact(
redact_id, redact_id,
&pdu.sender, &pdu.sender,
&pdu.room_id, &pdu.room_id,
@ -412,7 +455,7 @@ impl Service {
})?; })?;
if let Some(redact_id) = &content.redacts { if let Some(redact_id) = &content.redacts {
if services().rooms.state_accessor.user_can_redact( if self.services.state_accessor.user_can_redact(
redact_id, redact_id,
&pdu.sender, &pdu.sender,
&pdu.room_id, &pdu.room_id,
@ -433,8 +476,7 @@ impl Service {
}, },
TimelineEventType::SpaceChild => { TimelineEventType::SpaceChild => {
if let Some(_state_key) = &pdu.state_key { if let Some(_state_key) = &pdu.state_key {
services() self.services
.rooms
.spaces .spaces
.roomid_spacehierarchy_cache .roomid_spacehierarchy_cache
.lock() .lock()
@ -455,7 +497,7 @@ impl Service {
let invite_state = match content.membership { let invite_state = match content.membership {
MembershipState::Invite => { MembershipState::Invite => {
let state = services().rooms.state.calculate_invite_state(pdu)?; let state = self.services.state.calculate_invite_state(pdu)?;
Some(state) Some(state)
}, },
_ => None, _ => None,
@ -463,7 +505,7 @@ impl Service {
// Update our membership info, we do this here incase a user is invited // Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth // and immediately leaves we need the DB to record the invite event for auth
services().rooms.state_cache.update_membership( self.services.state_cache.update_membership(
&pdu.room_id, &pdu.room_id,
&target_user_id, &target_user_id,
content, content,
@ -479,13 +521,12 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body { if let Some(body) = content.body {
services() self.services
.rooms
.search .search
.index_pdu(shortroomid, &pdu_id, &body)?; .index_pdu(shortroomid, &pdu_id, &body)?;
if services().admin.is_admin_command(pdu, &body).await { if self.services.admin.is_admin_command(pdu, &body).await {
services() self.services
.admin .admin
.command(body, Some((*pdu.event_id).into())) .command(body, Some((*pdu.event_id).into()))
.await; .await;
@ -497,8 +538,7 @@ impl Service {
if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) { if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) {
if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? {
services() self.services
.rooms
.pdu_metadata .pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?; .add_relation(PduCount::Normal(count2), related_pducount)?;
} }
@ -512,29 +552,25 @@ impl Service {
// We need to do it again here, because replies don't have // We need to do it again here, because replies don't have
// event_id as a top level field // event_id as a top level field
if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? {
services() self.services
.rooms
.pdu_metadata .pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?; .add_relation(PduCount::Normal(count2), related_pducount)?;
} }
}, },
Relation::Thread(thread) => { Relation::Thread(thread) => {
services() self.services.threads.add_to_thread(&thread.event_id, pdu)?;
.rooms
.threads
.add_to_thread(&thread.event_id, pdu)?;
}, },
_ => {}, // TODO: Aggregate other types _ => {}, // TODO: Aggregate other types
} }
} }
for appservice in services().appservice.read().await.values() { for appservice in self.services.appservice.read().await.values() {
if services() if self
.rooms .services
.state_cache .state_cache
.appservice_in_room(&pdu.room_id, appservice)? .appservice_in_room(&pdu.room_id, appservice)?
{ {
services() self.services
.sending .sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue; continue;
@ -550,7 +586,7 @@ impl Service {
{ {
let appservice_uid = appservice.registration.sender_localpart.as_str(); let appservice_uid = appservice.registration.sender_localpart.as_str();
if state_key_uid == appservice_uid { if state_key_uid == appservice_uid {
services() self.services
.sending .sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue; continue;
@ -567,8 +603,7 @@ impl Service {
.map_or(false, |state_key| users.is_match(state_key)) .map_or(false, |state_key| users.is_match(state_key))
}; };
let matching_aliases = |aliases: &NamespaceRegex| { let matching_aliases = |aliases: &NamespaceRegex| {
services() self.services
.rooms
.alias .alias
.local_aliases_for_room(&pdu.room_id) .local_aliases_for_room(&pdu.room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -579,7 +614,7 @@ impl Service {
|| appservice.rooms.is_match(pdu.room_id.as_str()) || appservice.rooms.is_match(pdu.room_id.as_str())
|| matching_users(&appservice.users) || matching_users(&appservice.users)
{ {
services() self.services
.sending .sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
} }
@ -603,8 +638,8 @@ impl Service {
redacts, redacts,
} = pdu_builder; } = pdu_builder;
let prev_events: Vec<_> = services() let prev_events: Vec<_> = self
.rooms .services
.state .state
.get_forward_extremities(room_id)? .get_forward_extremities(room_id)?
.into_iter() .into_iter()
@ -612,28 +647,23 @@ impl Service {
.collect(); .collect();
// If there was no create event yet, assume we are creating a room // If there was no create event yet, assume we are creating a room
let room_version_id = services() let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| {
.rooms if event_type == TimelineEventType::RoomCreate {
.state let content = serde_json::from_str::<RoomCreateEventContent>(content.get())
.get_room_version(room_id) .expect("Invalid content in RoomCreate pdu.");
.or_else(|_| { Ok(content.room_version)
if event_type == TimelineEventType::RoomCreate { } else {
let content = serde_json::from_str::<RoomCreateEventContent>(content.get()) Err(Error::InconsistentRoomState(
.expect("Invalid content in RoomCreate pdu."); "non-create event for room of unknown version",
Ok(content.room_version) room_id.to_owned(),
} else { ))
Err(Error::InconsistentRoomState( }
"non-create event for room of unknown version", })?;
room_id.to_owned(),
))
}
})?;
let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); let room_version = RoomVersion::new(&room_version_id).expect("room version is supported");
let auth_events = let auth_events =
services() self.services
.rooms
.state .state
.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?;
@ -649,8 +679,7 @@ impl Service {
if let Some(state_key) = &state_key { if let Some(state_key) = &state_key {
if let Some(prev_pdu) = if let Some(prev_pdu) =
services() self.services
.rooms
.state_accessor .state_accessor
.room_state_get(room_id, &event_type.to_string().into(), state_key)? .room_state_get(room_id, &event_type.to_string().into(), state_key)?
{ {
@ -730,12 +759,12 @@ impl Service {
// Add origin because synapse likes that (and it's required in the spec) // Add origin because synapse likes that (and it's required in the spec)
pdu_json.insert( pdu_json.insert(
"origin".to_owned(), "origin".to_owned(),
to_canonical_value(services().globals.server_name()).expect("server name is a valid CanonicalJsonValue"), to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
); );
match ruma::signatures::hash_and_sign_event( match ruma::signatures::hash_and_sign_event(
services().globals.server_name().as_str(), self.services.globals.server_name().as_str(),
services().globals.keypair(), self.services.globals.keypair(),
&mut pdu_json, &mut pdu_json,
&room_version_id, &room_version_id,
) { ) {
@ -763,8 +792,8 @@ impl Service {
); );
// Generate short event id // Generate short event id
let _shorteventid = services() let _shorteventid = self
.rooms .services
.short .short
.get_or_create_shorteventid(&pdu.event_id)?; .get_or_create_shorteventid(&pdu.event_id)?;
@ -783,7 +812,7 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Arc<EventId>> { ) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?;
if let Some(admin_room) = services().admin.get_admin_room()? { if let Some(admin_room) = self.services.admin.get_admin_room()? {
if admin_room == room_id { if admin_room == room_id {
match pdu.event_type() { match pdu.event_type() {
TimelineEventType::RoomEncryption => { TimelineEventType::RoomEncryption => {
@ -798,7 +827,7 @@ impl Service {
.state_key() .state_key()
.filter(|v| v.starts_with('@')) .filter(|v| v.starts_with('@'))
.unwrap_or(sender.as_str()); .unwrap_or(sender.as_str());
let server_user = &services().globals.server_user.to_string(); let server_user = &self.services.globals.server_user.to_string();
let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()) let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid content in pdu"))?; .map_err(|_| Error::bad_database("Invalid content in pdu"))?;
@ -812,8 +841,8 @@ impl Service {
)); ));
} }
let count = services() let count = self
.rooms .services
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -837,8 +866,8 @@ impl Service {
)); ));
} }
let count = services() let count = self
.rooms .services
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -861,15 +890,14 @@ impl Service {
// If redaction event is not authorized, do not append it to the timeline // If redaction event is not authorized, do not append it to the timeline
if pdu.kind == TimelineEventType::RoomRedaction { if pdu.kind == TimelineEventType::RoomRedaction {
use RoomVersionId::*; use RoomVersionId::*;
match services().rooms.state.get_room_version(&pdu.room_id)? { match self.services.state.get_room_version(&pdu.room_id)? {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts { if let Some(redact_id) = &pdu.redacts {
if !services().rooms.state_accessor.user_can_redact( if !self
redact_id, .services
&pdu.sender, .state_accessor
&pdu.room_id, .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)?
false, {
)? {
return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event."));
} }
}; };
@ -879,12 +907,11 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?;
if let Some(redact_id) = &content.redacts { if let Some(redact_id) = &content.redacts {
if !services().rooms.state_accessor.user_can_redact( if !self
redact_id, .services
&pdu.sender, .state_accessor
&pdu.room_id, .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)?
false, {
)? {
return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event."));
} }
} }
@ -895,7 +922,7 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in // We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't // time with the pdu without it's state. This is okay because append_pdu can't
// fail. // fail.
let statehashid = services().rooms.state.append_to_state(&pdu)?; let statehashid = self.services.state.append_to_state(&pdu)?;
let pdu_id = self let pdu_id = self
.append_pdu( .append_pdu(
@ -910,13 +937,12 @@ impl Service {
// We set the room state after inserting the pdu, so that we never have a moment // We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist // in time where events in the current room state do not exist
services() self.services
.rooms
.state .state
.set_room_state(room_id, statehashid, state_lock)?; .set_room_state(room_id, statehashid, state_lock)?;
let mut servers: HashSet<OwnedServerName> = services() let mut servers: HashSet<OwnedServerName> = self
.rooms .services
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -936,9 +962,9 @@ impl Service {
// Remove our server from the server list since it will be added to it by // Remove our server from the server list since it will be added to it by
// room_servers() and/or the if statement above // room_servers() and/or the if statement above
servers.remove(services().globals.server_name()); servers.remove(self.services.globals.server_name());
services() self.services
.sending .sending
.send_pdu_servers(servers.into_iter(), &pdu_id)?; .send_pdu_servers(servers.into_iter(), &pdu_id)?;
@ -960,18 +986,15 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in // We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't // time with the pdu without it's state. This is okay because append_pdu can't
// fail. // fail.
services() self.services
.rooms
.state .state
.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?;
if soft_fail { if soft_fail {
services() self.services
.rooms
.pdu_metadata .pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services() self.services
.rooms
.state .state
.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?;
return Ok(None); return Ok(None);
@ -1022,14 +1045,13 @@ impl Service {
if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) { if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) {
if let Some(body) = content.body { if let Some(body) = content.body {
services() self.services
.rooms
.search .search
.deindex_pdu(shortroomid, &pdu_id, &body)?; .deindex_pdu(shortroomid, &pdu_id, &body)?;
} }
} }
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; let room_version_id = self.services.state.get_room_version(&pdu.room_id)?;
pdu.redact(room_version_id, reason)?; pdu.redact(room_version_id, reason)?;
@ -1058,8 +1080,8 @@ impl Service {
return Ok(()); return Ok(());
} }
let power_levels: RoomPowerLevelsEventContent = services() let power_levels: RoomPowerLevelsEventContent = self
.rooms .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
@ -1077,8 +1099,8 @@ impl Service {
} }
}); });
let room_alias_servers = services() let room_alias_servers = self
.rooms .services
.alias .alias
.local_aliases_for_room(room_id) .local_aliases_for_room(room_id)
.filter_map(|alias| { .filter_map(|alias| {
@ -1090,14 +1112,13 @@ impl Service {
let servers = room_mods let servers = room_mods
.chain(room_alias_servers) .chain(room_alias_servers)
.chain(services().globals.config.trusted_servers.clone()) .chain(self.services.server.config.trusted_servers.clone())
.filter(|server_name| { .filter(|server_name| {
if server_is_ours(server_name) { if server_is_ours(server_name) {
return false; return false;
} }
services() self.services
.rooms
.state_cache .state_cache
.server_in_room(server_name, room_id) .server_in_room(server_name, room_id)
.unwrap_or(false) .unwrap_or(false)
@ -1105,7 +1126,8 @@ impl Service {
for backfill_server in servers { for backfill_server in servers {
info!("Asking {backfill_server} for backfill"); info!("Asking {backfill_server} for backfill");
let response = services() let response = self
.services
.sending .sending
.send_federation_request( .send_federation_request(
&backfill_server, &backfill_server,
@ -1141,11 +1163,11 @@ impl Service {
&self, origin: &ServerName, pdu: Box<RawJsonValue>, &self, origin: &ServerName, pdu: Box<RawJsonValue>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> { ) -> Result<()> {
let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?; let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?;
// Lock so we cannot backfill the same pdu twice at the same time // Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = services() let mutex_lock = self
.rooms .services
.event_handler .event_handler
.mutex_federation .mutex_federation
.lock(&room_id) .lock(&room_id)
@ -1158,14 +1180,12 @@ impl Service {
return Ok(()); return Ok(());
} }
services() self.services
.rooms
.event_handler .event_handler
.fetch_required_signing_keys([&value], pub_key_map) .fetch_required_signing_keys([&value], pub_key_map)
.await?; .await?;
services() self.services
.rooms
.event_handler .event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map)
.await?; .await?;
@ -1173,8 +1193,8 @@ impl Service {
let value = self.get_pdu_json(&event_id)?.expect("We just created it"); let value = self.get_pdu_json(&event_id)?.expect("We just created it");
let pdu = self.get_pdu(&event_id)?.expect("We just created it"); let pdu = self.get_pdu(&event_id)?.expect("We just created it");
let shortroomid = services() let shortroomid = self
.rooms .services
.short .short
.get_shortroomid(&room_id)? .get_shortroomid(&room_id)?
.expect("room exists"); .expect("room exists");
@ -1182,7 +1202,7 @@ impl Service {
let insert_lock = self.mutex_insert.lock(&room_id).await; let insert_lock = self.mutex_insert.lock(&room_id).await;
let max = u64::MAX; let max = u64::MAX;
let count = services().globals.next_count()?; let count = self.services.globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes());
@ -1197,8 +1217,7 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body { if let Some(body) = content.body {
services() self.services
.rooms
.search .search
.index_pdu(shortroomid, &pdu_id, &body)?; .index_pdu(shortroomid, &pdu_id, &body)?;
} }

View File

@ -1,6 +1,6 @@
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use conduit::{debug_info, trace, utils, Result}; use conduit::{debug_info, trace, utils, Result, Server};
use ruma::{ use ruma::{
api::federation::transactions::edu::{Edu, TypingContent}, api::federation::transactions::edu::{Edu, TypingContent},
events::SyncEphemeralRoomEvent, events::SyncEphemeralRoomEvent,
@ -8,19 +8,31 @@ use ruma::{
}; };
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, RwLock};
use crate::{services, user_is_local}; use crate::{globals, sending, user_is_local, Dep};
pub struct Service { pub struct Service {
pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>, // u64 is unix timestamp of timeout server: Arc<Server>,
pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>, /* timestamp of the last change to services: Services,
* typing /// u64 is unix timestamp of timeout
* users */ pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>,
/// timestamp of the last change to typing users
pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>,
pub typing_update_sender: broadcast::Sender<OwnedRoomId>, pub typing_update_sender: broadcast::Sender<OwnedRoomId>,
} }
struct Services {
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
server: args.server.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
},
typing: RwLock::new(BTreeMap::new()), typing: RwLock::new(BTreeMap::new()),
last_typing_update: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()),
typing_update_sender: broadcast::channel(100).0, typing_update_sender: broadcast::channel(100).0,
@ -45,14 +57,14 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!("receiver found what it was looking for and is no longer interested");
} }
// update federation // update federation
if user_is_local(user_id) { if user_is_local(user_id) {
Self::federation_send(room_id, user_id, true)?; self.federation_send(room_id, user_id, true)?;
} }
Ok(()) Ok(())
@ -71,14 +83,14 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!("receiver found what it was looking for and is no longer interested");
} }
// update federation // update federation
if user_is_local(user_id) { if user_is_local(user_id) {
Self::federation_send(room_id, user_id, false)?; self.federation_send(room_id, user_id, false)?;
} }
Ok(()) Ok(())
@ -126,7 +138,7 @@ impl Service {
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), services().globals.next_count()?); .insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() { if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested"); trace!("receiver found what it was looking for and is no longer interested");
} }
@ -134,7 +146,7 @@ impl Service {
// update federation // update federation
for user in removable { for user in removable {
if user_is_local(&user) { if user_is_local(&user) {
Self::federation_send(room_id, &user, false)?; self.federation_send(room_id, &user, false)?;
} }
} }
} }
@ -171,15 +183,15 @@ impl Service {
}) })
} }
fn federation_send(room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",); debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",);
if !services().globals.config.allow_outgoing_typing { if !self.server.config.allow_outgoing_typing {
return Ok(()); return Ok(());
} }
let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing)); let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing));
services() self.services
.sending .sending
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?;

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use conduit::{utils, Error, Result}; use conduit::{utils, Error, Result};
use database::{Database, Map}; use database::Map;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::services; use crate::{globals, rooms, Dep};
pub(super) struct Data { pub(super) struct Data {
userroomid_notificationcount: Arc<Map>, userroomid_notificationcount: Arc<Map>,
@ -12,16 +12,27 @@ pub(super) struct Data {
roomuserid_lastnotificationread: Arc<Map>, roomuserid_lastnotificationread: Arc<Map>,
roomsynctoken_shortstatehash: Arc<Map>, roomsynctoken_shortstatehash: Arc<Map>,
userroomid_joined: Arc<Map>, userroomid_joined: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
userroomid_notificationcount: db["userroomid_notificationcount"].clone(), userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
userroomid_highlightcount: db["userroomid_highlightcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit
roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(),
userroomid_joined: db["userroomid_joined"].clone(), userroomid_joined: db["userroomid_joined"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
},
} }
} }
@ -39,7 +50,7 @@ impl Data {
.insert(&userroom_id, &0_u64.to_be_bytes())?; .insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -87,8 +98,8 @@ impl Data {
pub(super) fn associate_token_shortstatehash( pub(super) fn associate_token_shortstatehash(
&self, room_id: &RoomId, token: u64, shortstatehash: u64, &self, room_id: &RoomId, token: u64, shortstatehash: u64,
) -> Result<()> { ) -> Result<()> {
let shortroomid = services() let shortroomid = self
.rooms .services
.short .short
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
.expect("room exists"); .expect("room exists");
@ -101,8 +112,8 @@ impl Data {
} }
pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services() let shortroomid = self
.rooms .services
.short .short
.get_shortroomid(room_id)? .get_shortroomid(room_id)?
.expect("room exists"); .expect("room exists");

View File

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use conduit::Result; use conduit::Result;
use data::Data;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use self::data::Data;
pub struct Service { pub struct Service {
db: Data, db: Data,
} }
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db), db: Data::new(&args),
})) }))
} }

View File

@ -1,16 +1,17 @@
use std::{fmt::Debug, mem}; use std::{fmt::Debug, mem};
use bytes::BytesMut; use bytes::BytesMut;
use conduit::{debug_error, trace, utils, warn, Error, Result};
use reqwest::Client;
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use tracing::{trace, warn};
use crate::{debug_error, services, utils, Error, Result};
/// Sends a request to an appservice /// Sends a request to an appservice
/// ///
/// Only returns Ok(None) if there is no url specified in the appservice /// Only returns Ok(None) if there is no url specified in the appservice
/// registration file /// registration file
pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Result<Option<T::IncomingResponse>> pub(crate) async fn send_request<T>(
client: &Client, registration: Registration, request: T,
) -> Result<Option<T::IncomingResponse>>
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
@ -48,15 +49,10 @@ where
let reqwest_request = reqwest::Request::try_from(http_request)?; let reqwest_request = reqwest::Request::try_from(http_request)?;
let mut response = services() let mut response = client.execute(reqwest_request).await.map_err(|e| {
.client warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id);
.appservice e
.execute(reqwest_request) })?;
.await
.map_err(|e| {
warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id);
e
})?;
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();

View File

@ -5,7 +5,7 @@ use database::{Database, Map};
use ruma::{ServerName, UserId}; use ruma::{ServerName, UserId};
use super::{Destination, SendingEvent}; use super::{Destination, SendingEvent};
use crate::services; use crate::{globals, Dep};
type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>; type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>;
type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>;
@ -15,15 +15,24 @@ pub struct Data {
servernameevent_data: Arc<Map>, servernameevent_data: Arc<Map>,
servername_educount: Arc<Map>, servername_educount: Arc<Map>,
pub(super) db: Arc<Database>, pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
servercurrentevent_data: db["servercurrentevent_data"].clone(), servercurrentevent_data: db["servercurrentevent_data"].clone(),
servernameevent_data: db["servernameevent_data"].clone(), servernameevent_data: db["servernameevent_data"].clone(),
servername_educount: db["servername_educount"].clone(), servername_educount: db["servername_educount"].clone(),
db, db: args.db.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
} }
} }
@ -78,7 +87,7 @@ impl Data {
if let SendingEvent::Pdu(value) = &event { if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value); key.extend_from_slice(value);
} else { } else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
} }
let value = if let SendingEvent::Edu(value) = &event { let value = if let SendingEvent::Edu(value) = &event {
&**value &**value

View File

@ -6,26 +6,39 @@ mod sender;
use std::{fmt::Debug, sync::Arc}; use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use conduit::{err, Result, Server}; use conduit::{err, warn, Result, Server};
use ruma::{ use ruma::{
api::{appservice::Registration, OutgoingRequest}, api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
pub use sender::convert_to_outgoing_federation_event;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::warn;
use crate::{server_is_ours, services}; use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_is_ours, users, Dep};
pub struct Service { pub struct Service {
pub db: data::Data,
server: Arc<Server>, server: Arc<Server>,
services: Services,
/// The state for a given state hash. pub db: data::Data,
sender: loole::Sender<Msg>, sender: loole::Sender<Msg>,
receiver: Mutex<loole::Receiver<Msg>>, receiver: Mutex<loole::Receiver<Msg>>,
} }
struct Services {
client: Dep<client::Service>,
globals: Dep<globals::Service>,
resolver: Dep<resolver::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
user: Dep<rooms::user::Service>,
users: Dep<users::Service>,
presence: Dep<presence::Service>,
read_receipt: Dep<rooms::read_receipt::Service>,
timeline: Dep<rooms::timeline::Service>,
account_data: Dep<account_data::Service>,
appservice: Dep<crate::appservice::Service>,
pusher: Dep<pusher::Service>,
}
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
struct Msg { struct Msg {
dest: Destination, dest: Destination,
@ -53,8 +66,23 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let (sender, receiver) = loole::unbounded(); let (sender, receiver) = loole::unbounded();
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: data::Data::new(args.db.clone()),
server: args.server.clone(), server: args.server.clone(),
services: Services {
client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
resolver: args.depend::<resolver::Service>("resolver"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
user: args.depend::<rooms::user::Service>("rooms::user"),
users: args.depend::<users::Service>("users"),
presence: args.depend::<presence::Service>("presence"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<crate::appservice::Service>("appservice"),
pusher: args.depend::<pusher::Service>("pusher"),
},
db: data::Data::new(&args),
sender, sender,
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
})) }))
@ -103,8 +131,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> {
let servers = services() let servers = self
.rooms .services
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -152,8 +180,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id, serialized), level = "debug")] #[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> { pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> {
let servers = services() let servers = self
.rooms .services
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -189,8 +217,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id), level = "debug")] #[tracing::instrument(skip(self, room_id), level = "debug")]
pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { pub fn flush_room(&self, room_id: &RoomId) -> Result<()> {
let servers = services() let servers = self
.rooms .services
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -213,13 +241,13 @@ impl Service {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, request), name = "request")] #[tracing::instrument(skip_all, name = "request")]
pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse> pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
let client = &services().client.federation; let client = &self.services.client.federation;
send::send(client, dest, request).await self.send(client, dest, request).await
} }
/// Sends a request to an appservice /// Sends a request to an appservice
@ -232,7 +260,8 @@ impl Service {
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
appservice::send_request(registration, request).await let client = &self.services.client.appservice;
appservice::send_request(client, registration, request).await
} }
/// Cleanup event data /// Cleanup event data

View File

@ -1,6 +1,8 @@
use std::{fmt::Debug, mem}; use std::{fmt::Debug, mem};
use conduit::Err; use conduit::{
debug, debug_error, debug_warn, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, Result,
};
use http::{header::AUTHORIZATION, HeaderValue}; use http::{header::AUTHORIZATION, HeaderValue};
use ipaddress::IPAddress; use ipaddress::IPAddress;
use reqwest::{Client, Method, Request, Response, Url}; use reqwest::{Client, Method, Request, Response, Url};
@ -13,75 +15,91 @@ use ruma::{
server_util::authorization::XMatrix, server_util::authorization::XMatrix,
ServerName, ServerName,
}; };
use tracing::{debug, trace};
use crate::{ use crate::{
debug_error, debug_warn, resolver, globals, resolver,
resolver::{actual::ActualDest, cache::CachedDest}, resolver::{actual::ActualDest, cache::CachedDest},
services, Error, Result,
}; };
#[tracing::instrument(skip_all, name = "send")] impl super::Service {
pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> #[tracing::instrument(skip(self, client, req), name = "send")]
where pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
T: OutgoingRequest + Debug + Send, where
{ T: OutgoingRequest + Debug + Send,
if !services().globals.allow_federation() { {
return Err!(Config("allow_federation", "Federation is disabled.")); if !self.server.config.allow_federation {
return Err!(Config("allow_federation", "Federation is disabled."));
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
let request = self.prepare::<T>(dest, &actual, req).await?;
self.execute::<T>(dest, &actual, request, client).await
} }
let actual = services().resolver.get_actual_dest(dest).await?; async fn execute<T>(
let request = prepare::<T>(dest, &actual, req).await?; &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
execute::<T>(client, dest, &actual, request).await ) -> Result<T::IncomingResponse>
} where
T: OutgoingRequest + Debug + Send,
{
let url = request.url().clone();
let method = request.method().clone();
async fn execute<T>( debug!(?method, ?url, "Sending request");
client: &Client, dest: &ServerName, actual: &ActualDest, request: Request, match client.execute(request).await {
) -> Result<T::IncomingResponse> Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
where Err(error) => handle_error::<T>(dest, actual, &method, &url, error),
T: OutgoingRequest + Debug + Send, }
{
let method = request.method().clone();
let url = request.url().clone();
debug!(
method = ?method,
url = ?url,
"Sending request",
);
match client.execute(request).await {
Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(dest, actual, &method, &url, e),
} }
}
async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request> async fn prepare<T>(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5];
const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY);
trace!("Preparing request"); trace!("Preparing request");
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SATIR, &VERSIONS)
.map_err(|_| Error::BadServerResponse("Invalid destination"))?;
let mut http_request = req sign_request::<T>(&self.services.globals, dest, &mut http_request);
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS)
.map_err(|_e| Error::BadServerResponse("Invalid destination"))?;
sign_request::<T>(dest, &mut http_request); let request = Request::try_from(http_request)?;
self.validate_url(request.url())?;
let request = Request::try_from(http_request)?; Ok(request)
validate_url(request.url())?; }
Ok(request) fn validate_url(&self, url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
self.services.resolver.validate_ip(&ip)?;
}
}
Ok(())
}
} }
async fn handle_response<T>( async fn handle_response<T>(
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
mut response: Response,
) -> Result<T::IncomingResponse> ) -> Result<T::IncomingResponse>
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
trace!("Received response from {} for {} with {}", actual.string, url, response.url());
let status = response.status(); let status = response.status();
trace!(
?status, ?method,
request_url = ?url,
response_url = ?response.url(),
"Received response from {}",
actual.string,
);
let mut http_response_builder = http::Response::builder() let mut http_response_builder = http::Response::builder()
.status(status) .status(status)
.version(response.version()); .version(response.version());
@ -92,11 +110,13 @@ where
.expect("http::response::Builder is usable"), .expect("http::response::Builder is usable"),
); );
trace!("Waiting for response body"); // TODO: handle timeout
let body = response.bytes().await.unwrap_or_else(|e| { trace!("Waiting for response body...");
debug_error!("server error {}", e); let body = response
Vec::new().into() .bytes()
}); // TODO: handle timeout .await
.inspect_err(inspect_debug_log)
.unwrap_or_else(|_| Vec::new().into());
let http_response = http_response_builder let http_response = http_response_builder
.body(body) .body(body)
@ -109,7 +129,7 @@ where
let response = T::IncomingResponse::try_from_http_response(http_response); let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached { if response.is_ok() && !actual.cached {
services().resolver.set_cached_destination( resolver.set_cached_destination(
dest.to_owned(), dest.to_owned(),
CachedDest { CachedDest {
dest: actual.dest.clone(), dest: actual.dest.clone(),
@ -120,7 +140,7 @@ where
} }
match response { match response {
Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")), Err(_) => Err(Error::BadServerResponse("Server returned bad 200 response.")),
Ok(response) => Ok(response), Ok(response) => Ok(response),
} }
} }
@ -150,7 +170,7 @@ where
Err(e.into()) Err(e.into())
} }
fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) fn sign_request<T>(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>)
where where
T: OutgoingRequest + Debug + Send, T: OutgoingRequest + Debug + Send,
{ {
@ -172,16 +192,12 @@ where
.to_string() .to_string()
.into(), .into(),
); );
req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); req_map.insert("origin".to_owned(), globals.server_name().as_str().into());
req_map.insert("destination".to_owned(), dest.as_str().into()); req_map.insert("destination".to_owned(), dest.as_str().into());
let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json( ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json)
services().globals.server_name().as_str(), .expect("our request json is what ruma expects");
services().globals.keypair(),
&mut req_json,
)
.expect("our request json is what ruma expects");
let req_json: serde_json::Map<String, serde_json::Value> = let req_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap();
@ -207,24 +223,8 @@ where
http_request.headers_mut().insert( http_request.headers_mut().insert(
AUTHORIZATION, AUTHORIZATION,
HeaderValue::from(&XMatrix::new( HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)),
services().globals.config.server_name.clone(),
dest.to_owned(),
key,
sig,
)),
); );
} }
} }
} }
fn validate_url(url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
resolver::actual::validate_ip(&ip)?;
}
}
Ok(())
}

View File

@ -6,7 +6,11 @@ use std::{
}; };
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; use conduit::{
debug, debug_warn, error, trace,
utils::{calculate_hash, math::continue_exponential_backoff_secs},
warn, Error, Result,
};
use federation::transactions::send_transaction_message; use federation::transactions::send_transaction_message;
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
@ -24,8 +28,8 @@ use ruma::{
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::time::sleep_until; use tokio::time::sleep_until;
use super::{appservice, send, Destination, Msg, SendingEvent, Service}; use super::{appservice, Destination, Msg, SendingEvent, Service};
use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; use crate::user_is_local;
#[derive(Debug)] #[derive(Debug)]
enum TransactionStatus { enum TransactionStatus {
@ -69,8 +73,8 @@ impl Service {
Ok(()) Ok(())
} }
fn handle_response( fn handle_response<'a>(
&self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) { ) {
match response { match response {
Ok(dest) => self.handle_response_ok(&dest, futures, statuses), Ok(dest) => self.handle_response_ok(&dest, futures, statuses),
@ -91,8 +95,8 @@ impl Service {
}); });
} }
fn handle_response_ok( fn handle_response_ok<'a>(
&self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) { ) {
let _cork = self.db.db.cork(); let _cork = self.db.db.cork();
self.db self.db
@ -113,24 +117,24 @@ impl Service {
.mark_as_active(&new_events) .mark_as_active(&new_events)
.expect("marked as active"); .expect("marked as active");
let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect();
futures.push(Box::pin(send_events(dest.clone(), new_events_vec))); futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec)));
} else { } else {
statuses.remove(dest); statuses.remove(dest);
} }
} }
fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let iv = vec![(msg.event, msg.queue_id)]; let iv = vec![(msg.event, msg.queue_id)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) {
if !events.is_empty() { if !events.is_empty() {
futures.push(Box::pin(send_events(msg.dest, events))); futures.push(Box::pin(self.send_events(msg.dest, events)));
} else { } else {
statuses.remove(&msg.dest); statuses.remove(&msg.dest);
} }
} }
} }
async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let now = Instant::now(); let now = Instant::now();
let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS); let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS);
let deadline = now.checked_add(timeout).unwrap_or(now); let deadline = now.checked_add(timeout).unwrap_or(now);
@ -148,7 +152,7 @@ impl Service {
debug_warn!("Leaving with {} unfinished requests...", futures.len()); debug_warn!("Leaving with {} unfinished requests...", futures.len());
} }
fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) {
@ -166,12 +170,12 @@ impl Service {
for (dest, events) in txns { for (dest, events) in txns {
if self.server.config.startup_netburst && !events.is_empty() { if self.server.config.startup_netburst && !events.is_empty() {
statuses.insert(dest.clone(), TransactionStatus::Running); statuses.insert(dest.clone(), TransactionStatus::Running);
futures.push(Box::pin(send_events(dest.clone(), events))); futures.push(Box::pin(self.send_events(dest.clone(), events)));
} }
} }
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all, level = "debug")]
fn select_events( fn select_events(
&self, &self,
dest: &Destination, dest: &Destination,
@ -218,7 +222,7 @@ impl Service {
Ok(Some(events)) Ok(Some(events))
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all, level = "debug")]
fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> { fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> {
let (mut allow, mut retry) = (true, false); let (mut allow, mut retry) = (true, false);
statuses statuses
@ -244,7 +248,7 @@ impl Service {
Ok((allow, retry)) Ok((allow, retry))
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all, level = "debug")]
fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
// u64: count of last edu // u64: count of last edu
let since = self.db.get_latest_educount(server_name)?; let since = self.db.get_latest_educount(server_name)?;
@ -252,11 +256,11 @@ impl Service {
let mut max_edu_count = since; let mut max_edu_count = since;
let mut device_list_changes = HashSet::new(); let mut device_list_changes = HashSet::new();
for room_id in services().rooms.state_cache.server_rooms(server_name) { for room_id in self.services.state_cache.server_rooms(server_name) {
let room_id = room_id?; let room_id = room_id?;
// Look for device list updates in this room // Look for device list updates in this room
device_list_changes.extend( device_list_changes.extend(
services() self.services
.users .users
.keys_changed(room_id.as_ref(), since, None) .keys_changed(room_id.as_ref(), since, None)
.filter_map(Result::ok) .filter_map(Result::ok)
@ -264,7 +268,7 @@ impl Service {
); );
if self.server.config.allow_outgoing_read_receipts if self.server.config.allow_outgoing_read_receipts
&& !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)?
{ {
break; break;
} }
@ -287,381 +291,390 @@ impl Service {
} }
if self.server.config.allow_outgoing_presence { if self.server.config.allow_outgoing_presence {
select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?;
} }
Ok((events, max_edu_count)) Ok((events, max_edu_count))
} }
}
/// Look for presence /// Look for presence
fn select_edus_presence( fn select_edus_presence(
server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> { ) -> Result<bool> {
// Look for presence updates for this server // Look for presence updates for this server
let mut presence_updates = Vec::new(); let mut presence_updates = Vec::new();
for (user_id, count, presence_bytes) in services().presence.presence_since(since) { for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) {
*max_edu_count = cmp::max(count, *max_edu_count); *max_edu_count = cmp::max(count, *max_edu_count);
if !user_is_local(&user_id) { if !user_is_local(&user_id) {
continue; continue;
} }
if !services() if !self
.rooms .services
.state_cache .state_cache
.server_sees_user(server_name, &user_id)? .server_sees_user(server_name, &user_id)?
{ {
continue; continue;
} }
let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; let presence_event = self
presence_updates.push(PresenceUpdate { .services
user_id, .presence
presence: presence_event.content.presence, .from_json_bytes_to_event(&presence_bytes, &user_id)?;
currently_active: presence_event.content.currently_active.unwrap_or(false), presence_updates.push(PresenceUpdate {
last_active_ago: presence_event
.content
.last_active_ago
.unwrap_or_else(|| uint!(0)),
status_msg: presence_event.content.status_msg,
});
if presence_updates.len() >= SELECT_EDU_LIMIT {
break;
}
}
if !presence_updates.is_empty() {
let presence_content = Edu::Presence(PresenceContent::new(presence_updates));
events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"));
}
Ok(true)
}
/// Look for read receipts in this room
fn select_edus_receipts(
room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in services()
.rooms
.read_receipt
.readreceipts_since(room_id, since)
{
let (user_id, count, read_receipt) = r?;
*max_edu_count = cmp::max(count, *max_edu_count);
if !user_is_local(&user_id) {
continue;
}
let event = serde_json::from_str(read_receipt.json().get())
.map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
let mut read = BTreeMap::new();
let (event_id, mut receipt) = r
.content
.0
.into_iter()
.next()
.expect("we only use one event per read receipt");
let receipt = receipt
.remove(&ReceiptType::Read)
.expect("our read receipts always set this")
.remove(&user_id)
.expect("our read receipts always have the user here");
read.insert(
user_id, user_id,
ReceiptData { presence: presence_event.content.presence,
data: receipt.clone(), currently_active: presence_event.content.currently_active.unwrap_or(false),
event_ids: vec![event_id.clone()], last_active_ago: presence_event
}, .content
); .last_active_ago
.unwrap_or_else(|| uint!(0)),
status_msg: presence_event.content.status_msg,
});
let receipt_map = ReceiptMap { if presence_updates.len() >= SELECT_EDU_LIMIT {
read, break;
};
let mut receipts = BTreeMap::new();
receipts.insert(room_id.to_owned(), receipt_map);
Edu::Receipt(ReceiptContent {
receipts,
})
} else {
Error::bad_database("Invalid event type in read_receipts");
continue;
};
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
if events.len() >= SELECT_EDU_LIMIT {
return Ok(false);
}
}
Ok(true)
}
async fn send_events(dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
//debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
Destination::Normal(ref server) => send_events_dest_normal(&dest, server, events).await,
Destination::Appservice(ref id) => send_events_dest_appservice(&dest, id, events).await,
Destination::Push(ref userid, ref pushkey) => send_events_dest_push(&dest, userid, pushkey, events).await,
}
}
#[tracing::instrument(skip(dest, events))]
async fn send_events_dest_appservice(dest: &Destination, id: &str, events: Vec<SendingEvent>) -> SendingResult {
let mut pdu_jsons = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdu_jsons.push(
services()
.rooms
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
)
})?
.to_room_event(),
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Appservices don't need EDUs (?) and flush only;
// no new content
},
}
}
//debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction");
match appservice::send_request(
services()
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Could not load registration from db."),
)
})?,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
},
)
.await
{
Ok(_) => Ok(dest.clone()),
Err(e) => Err((dest.clone(), e)),
}
}
#[tracing::instrument(skip(dest, events))]
async fn send_events_dest_push(
dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdus.push(
services()
.rooms
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Push] Event in servernameevent_data not found in db."),
)
})?,
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
// no new content
},
}
}
for pdu in pdus {
// Redacted events are not notification targets (we don't send push for them)
if let Some(unsigned) = &pdu.unsigned {
if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) {
if unsigned.get("redacted_because").is_some() {
continue;
}
} }
} }
let Some(pusher) = services() if !presence_updates.is_empty() {
.pusher let presence_content = Edu::Presence(PresenceContent::new(presence_updates));
.get_pusher(userid, pushkey) events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"));
.map_err(|e| (dest.clone(), e))? }
else {
continue;
};
let rules_for_user = services() Ok(true)
.account_data
.get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap_or_default()
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok())
.map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global);
let unread: UInt = services()
.rooms
.user
.notification_count(userid, &pdu.room_id)
.map_err(|e| (dest.clone(), e))?
.try_into()
.expect("notification count can't go that high");
let _response = services()
.pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu)
.await
.map(|_response| dest.clone())
.map_err(|e| (dest.clone(), e));
} }
Ok(dest.clone()) /// Look for read receipts in this room
} fn select_edus_receipts(
&self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in self
.services
.read_receipt
.readreceipts_since(room_id, since)
{
let (user_id, count, read_receipt) = r?;
*max_edu_count = cmp::max(count, *max_edu_count);
#[tracing::instrument(skip(dest, events), name = "")] if !user_is_local(&user_id) {
async fn send_events_dest_normal( continue;
dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>, }
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
.count(),
);
for event in &events { let event = serde_json::from_str(read_receipt.json().get())
match event { .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
SendingEvent::Pdu(pdu_id) => pdu_jsons.push(convert_to_outgoing_federation_event( let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
// TODO: check room version and remove event_id if needed let mut read = BTreeMap::new();
services()
.rooms let (event_id, mut receipt) = r
.timeline .content
.get_pdu_json_from_id(pdu_id) .0
.map_err(|e| (dest.clone(), e))? .into_iter()
.ok_or_else(|| { .next()
error!(?dest, ?server, ?pdu_id, "event not found"); .expect("we only use one event per read receipt");
( let receipt = receipt
dest.clone(), .remove(&ReceiptType::Read)
Error::bad_database("[Normal] Event in servernameevent_data not found in db."), .expect("our read receipts always set this")
) .remove(&user_id)
})?, .expect("our read receipts always have the user here");
)),
SendingEvent::Edu(edu) => { read.insert(
if let Ok(raw) = serde_json::from_slice(edu) { user_id,
edu_jsons.push(raw); ReceiptData {
} data: receipt.clone(),
}, event_ids: vec![event_id.clone()],
SendingEvent::Flush => { },
// flush only; no new content );
let receipt_map = ReceiptMap {
read,
};
let mut receipts = BTreeMap::new();
receipts.insert(room_id.to_owned(), receipt_map);
Edu::Receipt(ReceiptContent {
receipts,
})
} else {
Error::bad_database("Invalid event type in read_receipts");
continue;
};
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
if events.len() >= SELECT_EDU_LIMIT {
return Ok(false);
}
}
Ok(true)
}
async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
//debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await,
Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await,
Destination::Push(ref userid, ref pushkey) => {
self.send_events_dest_push(&dest, userid, pushkey, events)
.await
}, },
} }
} }
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty #[tracing::instrument(skip(self, dest, events), name = "appservice")]
// transaction"); async fn send_events_dest_appservice(
send::send( &self, dest: &Destination, id: &str, events: Vec<SendingEvent>,
&services().client.sender, ) -> SendingResult {
server, let mut pdu_jsons = Vec::new();
send_transaction_message::v1::Request {
origin: services().server.config.server_name.clone(), for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdu_jsons.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
)
})?
.to_room_event(),
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Appservices don't need EDUs (?) and flush only;
// no new content
},
}
}
//debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction");
let client = &self.services.client.appservice;
match appservice::send_request(
client,
self.services
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Could not load registration from db."),
)
})?,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
},
)
.await
{
Ok(_) => Ok(dest.clone()),
Err(e) => Err((dest.clone(), e)),
}
}
#[tracing::instrument(skip(self, dest, events), name = "push")]
async fn send_events_dest_push(
&self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdus.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Push] Event in servernameevent_data not found in db."),
)
})?,
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
// no new content
},
}
}
for pdu in pdus {
// Redacted events are not notification targets (we don't send push for them)
if let Some(unsigned) = &pdu.unsigned {
if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) {
if unsigned.get("redacted_because").is_some() {
continue;
}
}
}
let Some(pusher) = self
.services
.pusher
.get_pusher(userid, pushkey)
.map_err(|e| (dest.clone(), e))?
else {
continue;
};
let rules_for_user = self
.services
.account_data
.get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap_or_default()
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok())
.map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global);
let unread: UInt = self
.services
.user
.notification_count(userid, &pdu.room_id)
.map_err(|e| (dest.clone(), e))?
.try_into()
.expect("notification count can't go that high");
let _response = self
.services
.pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu)
.await
.map(|_response| dest.clone())
.map_err(|e| (dest.clone(), e));
}
Ok(dest.clone())
}
#[tracing::instrument(skip(self, dest, events), name = "", level = "debug")]
async fn send_events_dest_normal(
&self, dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
.count(),
);
for event in &events {
match event {
// TODO: check room version and remove event_id if needed
SendingEvent::Pdu(pdu_id) => pdu_jsons.push(
self.convert_to_outgoing_federation_event(
self.services
.timeline
.get_pdu_json_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
error!(?dest, ?server, ?pdu_id, "event not found");
(
dest.clone(),
Error::bad_database("[Normal] Event in servernameevent_data not found in db."),
)
})?,
),
),
SendingEvent::Edu(edu) => {
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
}
},
SendingEvent::Flush => {}, // flush only; no new content
}
}
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
));
let request = send_transaction_message::v1::Request {
origin: self.server.config.server_name.clone(),
pdus: pdu_jsons, pdus: pdu_jsons,
edus: edu_jsons, edus: edu_jsons,
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( transaction_id: transaction_id.into(),
&events };
let client = &self.services.client.sender;
self.send(client, server, request)
.await
.inspect(|response| {
response
.pdus
.iter() .iter()
.map(|e| match e { .filter(|(_, res)| res.is_err())
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, .for_each(|(pdu_id, res)| warn!("error for {pdu_id} from remote: {res:?}"));
SendingEvent::Flush => &[], })
}) .map(|_| dest.clone())
.collect::<Vec<_>>(), .map_err(|e| (dest.clone(), e))
))) }
.into(),
}, /// This does not return a full `Pdu` it is only to satisfy ruma's types.
) pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
.await if let Some(unsigned) = pdu_json
.map(|response| { .get_mut("unsigned")
for pdu in response.pdus { .and_then(|val| val.as_object_mut())
if pdu.1.is_err() { {
warn!("error for {} from remote: {:?}", pdu.0, pdu.1); unsigned.remove("transaction_id");
}
// room v3 and above removed the "event_id" field from remote PDU format
if let Some(room_id) = pdu_json
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(&room_id) {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),
},
Err(_) => _ = pdu_json.remove("event_id"),
} }
} else {
pdu_json.remove("event_id");
} }
dest.clone()
})
.map_err(|e| (dest.clone(), e))
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types. // TODO: another option would be to convert it to a canonical string to validate
#[tracing::instrument] // size and return a Result<Raw<...>>
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { // serde_json::from_str::<Raw<_>>(
if let Some(unsigned) = pdu_json // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
.get_mut("unsigned") // valid serde_json::Value"), )
.and_then(|val| val.as_object_mut()) // .expect("Raw::from_value always works")
{
unsigned.remove("transaction_id"); to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
} }
// room v3 and above removed the "event_id" field from remote PDU format
if let Some(room_id) = pdu_json
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match services().rooms.state.get_room_version(&room_id) {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),
},
Err(_) => _ = pdu_json.remove("event_id"),
}
} else {
pdu_json.remove("event_id");
}
// TODO: another option would be to convert it to a canonical string to validate
// size and return a Result<Raw<...>>
// serde_json::from_str::<Raw<_>>(
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
// valid serde_json::Value"), )
// .expect("Raw::from_value always works")
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
} }

View File

@ -3,7 +3,7 @@ use std::{
collections::BTreeMap, collections::BTreeMap,
fmt::Write, fmt::Write,
ops::Deref, ops::Deref,
sync::{Arc, OnceLock}, sync::{Arc, OnceLock, RwLock},
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -50,20 +50,20 @@ pub(crate) struct Args<'a> {
} }
/// Dep is a reference to a service used within another service. /// Dep is a reference to a service used within another service.
/// Circular-dependencies between services require this indirection to allow the /// Circular-dependencies between services require this indirection.
/// referenced service construction after the referencing service.
pub(crate) struct Dep<T> { pub(crate) struct Dep<T> {
dep: OnceLock<Arc<T>>, dep: OnceLock<Arc<T>>,
service: Arc<Map>, service: Arc<Map>,
name: &'static str, name: &'static str,
} }
pub(crate) type Map = BTreeMap<String, MapVal>; pub(crate) type Map = RwLock<BTreeMap<String, MapVal>>;
pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>);
impl<T: Any + Send + Sync> Deref for Dep<T> { impl<T: Send + Sync + 'static> Deref for Dep<T> {
type Target = Arc<T>; type Target = Arc<T>;
/// Dereference a dependency. The dependency must be ready or panics.
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.dep self.dep
.get_or_init(|| require::<T>(&self.service, self.name)) .get_or_init(|| require::<T>(&self.service, self.name))
@ -71,39 +71,61 @@ impl<T: Any + Send + Sync> Deref for Dep<T> {
} }
impl Args<'_> { impl Args<'_> {
pub(crate) fn depend_service<T: Any + Send + Sync>(&self, name: &'static str) -> Dep<T> { /// Create a lazy-reference to a service when constructing another Service.
pub(crate) fn depend<T: Send + Sync + 'static>(&self, name: &'static str) -> Dep<T> {
Dep::<T> { Dep::<T> {
dep: OnceLock::new(), dep: OnceLock::new(),
service: self.service.clone(), service: self.service.clone(),
name, name,
} }
} }
/// Create a reference immediately to a service when constructing another
/// Service. The other service must be constructed.
pub(crate) fn require<T: Send + Sync + 'static>(&self, name: &str) -> Arc<T> { require::<T>(self.service, name) }
} }
pub(crate) fn require<T: Any + Send + Sync>(map: &Map, name: &str) -> Arc<T> { /// Reference a Service by name. Panics if the Service does not exist or was
/// incorrectly cast.
pub(crate) fn require<T: Send + Sync + 'static>(map: &Map, name: &str) -> Arc<T> {
try_get::<T>(map, name) try_get::<T>(map, name)
.inspect_err(inspect_log) .inspect_err(inspect_log)
.expect("Failure to reference service required by another service.") .expect("Failure to reference service required by another service.")
} }
pub(crate) fn try_get<T: Any + Send + Sync>(map: &Map, name: &str) -> Result<Arc<T>> { /// Reference a Service by name. Returns Err if the Service does not exist or
map.get(name).map_or_else( /// was incorrectly cast.
|| Err!("Service {name:?} does not exist or has not been built yet."), pub(crate) fn try_get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Result<Arc<T>> {
|(_, s)| { map.read()
.expect("locked for reading")
.get(name)
.map_or_else(
|| Err!("Service {name:?} does not exist or has not been built yet."),
|(_, s)| {
s.clone()
.downcast::<T>()
.map_err(|_| err!("Service {name:?} must be correctly downcast."))
},
)
}
/// Reference a Service by name. Returns None if the Service does not exist, but
/// panics if incorrectly cast.
///
/// # Panics
/// Incorrect type is not a silent failure (None) as the type never has a reason
/// to be incorrect.
pub(crate) fn get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Option<Arc<T>> {
map.read()
.expect("locked for reading")
.get(name)
.map(|(_, s)| {
s.clone() s.clone()
.downcast::<T>() .downcast::<T>()
.map_err(|_| err!("Service {name:?} must be correctly downcast.")) .expect("Service must be correctly downcast.")
}, })
)
}
pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> {
map.get(name).map(|(_, s)| {
s.clone()
.downcast::<T>()
.expect("Service must be correctly downcast.")
})
} }
/// Utility for service implementations; see Service::name() in the trait.
#[inline] #[inline]
pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 }

View File

@ -1,11 +1,16 @@
use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; use std::{
any::Any,
collections::BTreeMap,
fmt::Write,
sync::{Arc, RwLock},
};
use conduit::{debug, debug_info, info, trace, Result, Server}; use conduit::{debug, debug_info, info, trace, Result, Server};
use database::Database; use database::Database;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::{ use crate::{
account_data, admin, appservice, client, globals, key_backups, account_data, admin, appservice, client, emergency, globals, key_backups,
manager::Manager, manager::Manager,
media, presence, pusher, resolver, rooms, sending, service, media, presence, pusher, resolver, rooms, sending, service,
service::{Args, Map, Service}, service::{Args, Map, Service},
@ -13,22 +18,23 @@ use crate::{
}; };
pub struct Services { pub struct Services {
pub resolver: Arc<resolver::Service>,
pub client: Arc<client::Service>,
pub globals: Arc<globals::Service>,
pub rooms: rooms::Service,
pub appservice: Arc<appservice::Service>,
pub pusher: Arc<pusher::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub users: Arc<users::Service>,
pub account_data: Arc<account_data::Service>, pub account_data: Arc<account_data::Service>,
pub presence: Arc<presence::Service>,
pub admin: Arc<admin::Service>, pub admin: Arc<admin::Service>,
pub appservice: Arc<appservice::Service>,
pub client: Arc<client::Service>,
pub emergency: Arc<emergency::Service>,
pub globals: Arc<globals::Service>,
pub key_backups: Arc<key_backups::Service>, pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>, pub media: Arc<media::Service>,
pub presence: Arc<presence::Service>,
pub pusher: Arc<pusher::Service>,
pub resolver: Arc<resolver::Service>,
pub rooms: rooms::Service,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub updates: Arc<updates::Service>, pub updates: Arc<updates::Service>,
pub users: Arc<users::Service>,
manager: Mutex<Option<Arc<Manager>>>, manager: Mutex<Option<Arc<Manager>>>,
pub(crate) service: Arc<Map>, pub(crate) service: Arc<Map>,
@ -36,37 +42,34 @@ pub struct Services {
pub db: Arc<Database>, pub db: Arc<Database>,
} }
macro_rules! build_service {
($map:ident, $server:ident, $db:ident, $tyname:ty) => {{
let built = <$tyname>::build(Args {
server: &$server,
db: &$db,
service: &$map,
})?;
Arc::get_mut(&mut $map)
.expect("must have mutable reference to services collection")
.insert(built.name().to_owned(), (built.clone(), built.clone()));
trace!("built service #{}: {:?}", $map.len(), built.name());
built
}};
}
impl Services { impl Services {
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> { pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> {
let mut service: Arc<Map> = Arc::new(BTreeMap::new()); let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new()));
macro_rules! build { macro_rules! build {
($srv:ty) => { ($tyname:ty) => {{
build_service!(service, server, db, $srv) let built = <$tyname>::build(Args {
}; db: &db,
server: &server,
service: &service,
})?;
add_service(&service, built.clone(), built.clone());
built
}};
} }
Ok(Self { Ok(Self {
globals: build!(globals::Service), account_data: build!(account_data::Service),
admin: build!(admin::Service),
appservice: build!(appservice::Service),
resolver: build!(resolver::Service), resolver: build!(resolver::Service),
client: build!(client::Service), client: build!(client::Service),
emergency: build!(emergency::Service),
globals: build!(globals::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
presence: build!(presence::Service),
pusher: build!(pusher::Service),
rooms: rooms::Service { rooms: rooms::Service {
alias: build!(rooms::alias::Service), alias: build!(rooms::alias::Service),
auth_chain: build!(rooms::auth_chain::Service), auth_chain: build!(rooms::auth_chain::Service),
@ -79,28 +82,22 @@ impl Services {
read_receipt: build!(rooms::read_receipt::Service), read_receipt: build!(rooms::read_receipt::Service),
search: build!(rooms::search::Service), search: build!(rooms::search::Service),
short: build!(rooms::short::Service), short: build!(rooms::short::Service),
spaces: build!(rooms::spaces::Service),
state: build!(rooms::state::Service), state: build!(rooms::state::Service),
state_accessor: build!(rooms::state_accessor::Service), state_accessor: build!(rooms::state_accessor::Service),
state_cache: build!(rooms::state_cache::Service), state_cache: build!(rooms::state_cache::Service),
state_compressor: build!(rooms::state_compressor::Service), state_compressor: build!(rooms::state_compressor::Service),
timeline: build!(rooms::timeline::Service),
threads: build!(rooms::threads::Service), threads: build!(rooms::threads::Service),
timeline: build!(rooms::timeline::Service),
typing: build!(rooms::typing::Service), typing: build!(rooms::typing::Service),
spaces: build!(rooms::spaces::Service),
user: build!(rooms::user::Service), user: build!(rooms::user::Service),
}, },
appservice: build!(appservice::Service), sending: build!(sending::Service),
pusher: build!(pusher::Service),
transaction_ids: build!(transaction_ids::Service), transaction_ids: build!(transaction_ids::Service),
uiaa: build!(uiaa::Service), uiaa: build!(uiaa::Service),
users: build!(users::Service),
account_data: build!(account_data::Service),
presence: build!(presence::Service),
admin: build!(admin::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
sending: build!(sending::Service),
updates: build!(updates::Service), updates: build!(updates::Service),
users: build!(users::Service),
manager: Mutex::new(None), manager: Mutex::new(None),
service, service,
server, server,
@ -111,7 +108,7 @@ impl Services {
pub(super) async fn start(&self) -> Result<()> { pub(super) async fn start(&self) -> Result<()> {
debug_info!("Starting services..."); debug_info!("Starting services...");
globals::migrations::migrations(&self.db, &self.server.config).await?; globals::migrations::migrations(self).await?;
self.manager self.manager
.lock() .lock()
.await .await
@ -144,7 +141,7 @@ impl Services {
} }
pub async fn clear_cache(&self) { pub async fn clear_cache(&self) {
for (service, ..) in self.service.values() { for (service, ..) in self.service.read().expect("locked for reading").values() {
service.clear_cache(); service.clear_cache();
} }
@ -159,7 +156,7 @@ impl Services {
pub async fn memory_usage(&self) -> Result<String> { pub async fn memory_usage(&self) -> Result<String> {
let mut out = String::new(); let mut out = String::new();
for (service, ..) in self.service.values() { for (service, ..) in self.service.read().expect("locked for reading").values() {
service.memory_usage(&mut out)?; service.memory_usage(&mut out)?;
} }
@ -179,23 +176,26 @@ impl Services {
fn interrupt(&self) { fn interrupt(&self) {
debug!("Interrupting services..."); debug!("Interrupting services...");
for (name, (service, ..)) in self.service.iter() { for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() {
trace!("Interrupting {name}"); trace!("Interrupting {name}");
service.interrupt(); service.interrupt();
} }
} }
pub fn try_get<T>(&self, name: &str) -> Result<Arc<T>> pub fn try_get<T: Send + Sync + 'static>(&self, name: &str) -> Result<Arc<T>> {
where
T: Any + Send + Sync,
{
service::try_get::<T>(&self.service, name) service::try_get::<T>(&self.service, name)
} }
pub fn get<T>(&self, name: &str) -> Option<Arc<T>> pub fn get<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> { service::get::<T>(&self.service, name) }
where }
T: Any + Send + Sync,
{ fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) {
service::get::<T>(&self.service, name) let name = s.name();
} let len = map.read().expect("locked for reading").len();
trace!("built service #{len}: {name:?}");
map.write()
.expect("locked for writing")
.insert(name.to_owned(), (s, a));
} }

View File

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc; use std::sync::Arc;
use conduit::{utils, utils::hash, Error, Result}; use conduit::{error, utils, utils::hash, Error, Result, Server};
use data::Data; use data::Data;
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -11,19 +11,30 @@ use ruma::{
}, },
CanonicalJsonValue, DeviceId, UserId, CanonicalJsonValue, DeviceId, UserId,
}; };
use tracing::error;
use crate::services; use crate::{globals, users, Dep};
pub const SESSION_ID_LENGTH: usize = 32; pub const SESSION_ID_LENGTH: usize = 32;
pub struct Service { pub struct Service {
server: Arc<Server>,
services: Services,
pub db: Data, pub db: Data,
} }
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
server: args.server.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db), db: Data::new(args.db),
})) }))
} }
@ -87,11 +98,11 @@ impl Service {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
}; };
let user_id = UserId::parse_with_server_name(username.clone(), services().globals.server_name()) let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
// Check if password is correct // Check if password is correct
if let Some(hash) = services().users.password_hash(&user_id)? { if let Some(hash) = self.services.users.password_hash(&user_id)? {
let hash_matches = hash::verify_password(password, &hash).is_ok(); let hash_matches = hash::verify_password(password, &hash).is_ok();
if !hash_matches { if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
@ -106,7 +117,7 @@ impl Service {
uiaainfo.completed.push(AuthType::Password); uiaainfo.completed.push(AuthType::Password);
}, },
AuthData::RegistrationToken(t) => { AuthData::RegistrationToken(t) => {
if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { if Some(t.token.trim()) == self.server.config.registration_token.as_deref() {
uiaainfo.completed.push(AuthType::RegistrationToken); uiaainfo.completed.push(AuthType::RegistrationToken);
} else { } else {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {

View File

@ -7,14 +7,20 @@ use ruma::events::room::message::RoomMessageEventContent;
use serde::Deserialize; use serde::Deserialize;
use tokio::{sync::Notify, time::interval}; use tokio::{sync::Notify, time::interval};
use crate::services; use crate::{admin, client, Dep};
pub struct Service { pub struct Service {
services: Services,
db: Arc<Map>, db: Arc<Map>,
interrupt: Notify, interrupt: Notify,
interval: Duration, interval: Duration,
} }
struct Services {
admin: Dep<admin::Service>,
client: Dep<client::Service>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct CheckForUpdatesResponse { struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>, updates: Vec<CheckForUpdatesResponseEntry>,
@ -35,6 +41,10 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: Services {
admin: args.depend::<admin::Service>("admin"),
client: args.depend::<client::Service>("client"),
},
db: args.db["global"].clone(), db: args.db["global"].clone(),
interrupt: Notify::new(), interrupt: Notify::new(),
interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL),
@ -63,7 +73,8 @@ impl crate::Service for Service {
impl Service { impl Service {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn handle_updates(&self) -> Result<()> { async fn handle_updates(&self) -> Result<()> {
let response = services() let response = self
.services
.client .client
.default .default
.get(CHECK_FOR_UPDATES_URL) .get(CHECK_FOR_UPDATES_URL)
@ -78,7 +89,7 @@ impl Service {
last_update_id = last_update_id.max(update.id); last_update_id = last_update_id.max(update.id);
if update.id > self.last_check_for_updates_id()? { if update.id > self.last_check_for_updates_id()? {
info!("{:#}", update.message); info!("{:#}", update.message);
services() self.services
.admin .admin
.send_message(RoomMessageEventContent::text_markdown(format!( .send_message(RoomMessageEventContent::text_markdown(format!(
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",

View File

@ -1,7 +1,7 @@
use std::{collections::BTreeMap, mem::size_of, sync::Arc}; use std::{collections::BTreeMap, mem::size_of, sync::Arc};
use conduit::{debug_info, err, utils, warn, Err, Error, Result}; use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server};
use database::{Database, Map}; use database::Map;
use ruma::{ use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
@ -11,52 +11,65 @@ use ruma::{
OwnedMxcUri, OwnedUserId, UInt, UserId, OwnedMxcUri, OwnedUserId, UInt, UserId,
}; };
use crate::{services, users::clean_signatures}; use crate::{globals, rooms, users::clean_signatures, Dep};
pub struct Data { pub struct Data {
userid_password: Arc<Map>, keychangeid_userid: Arc<Map>,
keyid_key: Arc<Map>,
onetimekeyid_onetimekeys: Arc<Map>,
openidtoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>, token_userdeviceid: Arc<Map>,
userid_displayname: Arc<Map>, userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
userid_avatarurl: Arc<Map>, userid_avatarurl: Arc<Map>,
userid_blurhash: Arc<Map>, userid_blurhash: Arc<Map>,
userid_devicelistversion: Arc<Map>, userid_devicelistversion: Arc<Map>,
userdeviceid_token: Arc<Map>, userid_displayname: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
onetimekeyid_onetimekeys: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>, userid_lastonetimekeyupdate: Arc<Map>,
keyid_key: Arc<Map>,
userid_masterkeyid: Arc<Map>, userid_masterkeyid: Arc<Map>,
userid_password: Arc<Map>,
userid_selfsigningkeyid: Arc<Map>, userid_selfsigningkeyid: Arc<Map>,
userid_usersigningkeyid: Arc<Map>, userid_usersigningkeyid: Arc<Map>,
openidtoken_expiresatuserid: Arc<Map>, services: Services,
keychangeid_userid: Arc<Map>, }
todeviceid_events: Arc<Map>,
userfilterid_filter: Arc<Map>, struct Services {
_db: Arc<Database>, server: Arc<Server>,
globals: Dep<globals::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
} }
impl Data { impl Data {
pub(super) fn new(db: Arc<Database>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self { Self {
userid_password: db["userid_password"].clone(), keychangeid_userid: db["keychangeid_userid"].clone(),
keyid_key: db["keyid_key"].clone(),
onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(),
openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(),
todeviceid_events: db["todeviceid_events"].clone(),
token_userdeviceid: db["token_userdeviceid"].clone(), token_userdeviceid: db["token_userdeviceid"].clone(),
userid_displayname: db["userid_displayname"].clone(), userdeviceid_metadata: db["userdeviceid_metadata"].clone(),
userdeviceid_token: db["userdeviceid_token"].clone(),
userfilterid_filter: db["userfilterid_filter"].clone(),
userid_avatarurl: db["userid_avatarurl"].clone(), userid_avatarurl: db["userid_avatarurl"].clone(),
userid_blurhash: db["userid_blurhash"].clone(), userid_blurhash: db["userid_blurhash"].clone(),
userid_devicelistversion: db["userid_devicelistversion"].clone(), userid_devicelistversion: db["userid_devicelistversion"].clone(),
userdeviceid_token: db["userdeviceid_token"].clone(), userid_displayname: db["userid_displayname"].clone(),
userdeviceid_metadata: db["userdeviceid_metadata"].clone(),
onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
keyid_key: db["keyid_key"].clone(),
userid_masterkeyid: db["userid_masterkeyid"].clone(), userid_masterkeyid: db["userid_masterkeyid"].clone(),
userid_password: db["userid_password"].clone(),
userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(),
openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), services: Services {
keychangeid_userid: db["keychangeid_userid"].clone(), server: args.server.clone(),
todeviceid_events: db["todeviceid_events"].clone(), globals: args.depend::<globals::Service>("globals"),
userfilterid_filter: db["userfilterid_filter"].clone(), state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
_db: db, state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
} }
} }
@ -377,7 +390,7 @@ impl Data {
)?; )?;
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -403,7 +416,7 @@ impl Data {
prefix.push(b':'); prefix.push(b':');
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
.scan_prefix(prefix) .scan_prefix(prefix)
@ -631,16 +644,16 @@ impl Data {
} }
pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes(); let count = self.services.globals.next_count()?.to_be_bytes();
for room_id in services() for room_id in self
.rooms .services
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
{ {
// Don't send key updates to unencrypted rooms // Don't send key updates to unencrypted rooms
if services() if self
.rooms .services
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none() .is_none()
@ -750,7 +763,7 @@ impl Data {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes()); key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new(); let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into()); json.insert("type".to_owned(), event_type.to_owned().into());
@ -916,7 +929,7 @@ impl Data {
pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> { pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
use std::num::Saturating as Sat; use std::num::Saturating as Sat;
let expires_in = services().globals.config.openid_token_ttl; let expires_in = self.services.server.config.openid_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
let mut value = expires_at.0.to_be_bytes().to_vec(); let mut value = expires_at.0.to_be_bytes().to_vec();

View File

@ -7,7 +7,6 @@ use std::{
}; };
use conduit::{Error, Result}; use conduit::{Error, Result};
use data::Data;
use ruma::{ use ruma::{
api::client::{ api::client::{
device::Device, device::Device,
@ -24,7 +23,8 @@ use ruma::{
UInt, UserId, UInt, UserId,
}; };
use crate::services; use self::data::Data;
use crate::{admin, rooms, Dep};
pub struct SlidingSyncCache { pub struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>, lists: BTreeMap<String, SyncRequestList>,
@ -36,14 +36,24 @@ pub struct SlidingSyncCache {
type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>; type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>;
pub struct Service { pub struct Service {
services: Services,
pub db: Data, pub db: Data,
pub connections: DbConnections, pub connections: DbConnections,
} }
struct Services {
admin: Dep<admin::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(args.db.clone()), services: Services {
admin: args.depend::<admin::Service>("admin"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
db: Data::new(&args),
connections: StdMutex::new(BTreeMap::new()), connections: StdMutex::new(BTreeMap::new()),
})) }))
} }
@ -247,11 +257,8 @@ impl Service {
/// Check if a user is an admin /// Check if a user is an admin
pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { pub fn is_admin(&self, user_id: &UserId) -> Result<bool> {
if let Some(admin_room_id) = services().admin.get_admin_room()? { if let Some(admin_room_id) = self.services.admin.get_admin_room()? {
services() self.services.state_cache.is_joined(user_id, &admin_room_id)
.rooms
.state_cache
.is_joined(user_id, &admin_room_id)
} else { } else {
Ok(false) Ok(false)
} }