Database Refactor

combine service/users data w/ mod unit

split sliding sync related out of service/users

instrument database entry points

remove increment crap from database interface

de-wrap all database get() calls

de-wrap all database insert() calls

de-wrap all database remove() calls

refactor database interface for async streaming

add query key serializer for database

implement Debug for result handle

add query deserializer for database

add deserialization trait for option handle

start a stream utils suite

de-wrap/asyncify/type-query count_one_time_keys()

de-wrap/asyncify users count

add admin query users command suite

de-wrap/asyncify users exists

de-wrap/partially asyncify user filter related

asyncify/de-wrap users device/keys related

asyncify/de-wrap user auth/misc related

asyncify/de-wrap users blurhash

asyncify/de-wrap account_data get; merge Data into Service

partial asyncify/de-wrap uiaa; merge Data into Service

partially asyncify/de-wrap transaction_ids get; merge Data into Service

partially asyncify/de-wrap key_backups; merge Data into Service

asyncify/de-wrap pusher service getters; merge Data into Service

asyncify/de-wrap rooms alias getters/some iterators

asyncify/de-wrap rooms directory getters/iterator

partially asyncify/de-wrap rooms lazy-loading

partially asyncify/de-wrap rooms metadata

asyncify/dewrap rooms outlier

asyncify/dewrap rooms pdu_metadata

dewrap/partially asyncify rooms read receipt

de-wrap rooms search service

de-wrap/partially asyncify rooms user service

partial de-wrap rooms state_compressor

de-wrap rooms state_cache

de-wrap room state et al

de-wrap rooms timeline service

additional users device/keys related

de-wrap/asyncify sender

asyncify services

refactor database to TryFuture/TryStream

refactor services for TryFuture/TryStream

asyncify api handlers

additional asyncification for admin module

abstract stream related; support reverse streams

additional stream conversions

asyncify state-res related

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-08-08 17:18:30 +00:00 committed by strawberry
parent 6001014078
commit 946ca364e0
203 changed files with 12202 additions and 10709 deletions

53
Cargo.lock generated
View File

@ -626,10 +626,11 @@ dependencies = [
"clap",
"conduit_api",
"conduit_core",
"conduit_database",
"conduit_macros",
"conduit_service",
"const-str",
"futures-util",
"futures",
"log",
"ruma",
"serde_json",
@ -652,7 +653,7 @@ dependencies = [
"conduit_database",
"conduit_service",
"const-str",
"futures-util",
"futures",
"hmac",
"http",
"http-body-util",
@ -689,6 +690,7 @@ dependencies = [
"cyborgtime",
"either",
"figment",
"futures",
"hardened_malloc-rs",
"http",
"http-body-util",
@ -726,8 +728,11 @@ version = "0.4.7"
dependencies = [
"conduit_core",
"const-str",
"futures",
"log",
"rust-rocksdb-uwu",
"serde",
"serde_json",
"tokio",
"tracing",
]
@ -784,7 +789,7 @@ dependencies = [
"conduit_core",
"conduit_database",
"const-str",
"futures-util",
"futures",
"hickory-resolver",
"http",
"image",
@ -1283,6 +1288,20 @@ dependencies = [
"new_debug_unreachable",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -1345,6 +1364,7 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
@ -2953,7 +2973,7 @@ dependencies = [
[[package]]
name = "ruma"
version = "0.10.1"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"assign",
"js_int",
@ -2975,7 +2995,7 @@ dependencies = [
[[package]]
name = "ruma-appservice-api"
version = "0.10.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"js_int",
"ruma-common",
@ -2987,7 +3007,7 @@ dependencies = [
[[package]]
name = "ruma-client-api"
version = "0.18.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"as_variant",
"assign",
@ -3010,7 +3030,7 @@ dependencies = [
[[package]]
name = "ruma-common"
version = "0.13.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"as_variant",
"base64 0.22.1",
@ -3040,7 +3060,7 @@ dependencies = [
[[package]]
name = "ruma-events"
version = "0.28.1"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"as_variant",
"indexmap 2.6.0",
@ -3064,7 +3084,7 @@ dependencies = [
[[package]]
name = "ruma-federation-api"
version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"bytes",
"http",
@ -3082,7 +3102,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers-validation"
version = "0.9.5"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"js_int",
"thiserror",
@ -3091,7 +3111,7 @@ dependencies = [
[[package]]
name = "ruma-identity-service-api"
version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"js_int",
"ruma-common",
@ -3101,7 +3121,7 @@ dependencies = [
[[package]]
name = "ruma-macros"
version = "0.13.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"cfg-if",
"once_cell",
@ -3117,7 +3137,7 @@ dependencies = [
[[package]]
name = "ruma-push-gateway-api"
version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"js_int",
"ruma-common",
@ -3129,7 +3149,7 @@ dependencies = [
[[package]]
name = "ruma-server-util"
version = "0.3.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"headers",
"http",
@ -3142,7 +3162,7 @@ dependencies = [
[[package]]
name = "ruma-signatures"
version = "0.15.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"base64 0.22.1",
"ed25519-dalek",
@ -3158,8 +3178,9 @@ dependencies = [
[[package]]
name = "ruma-state-res"
version = "0.11.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd"
source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [
"futures-util",
"itertools 0.12.1",
"js_int",
"ruma-common",

View File

@ -210,9 +210,10 @@ features = [
"string",
]
[workspace.dependencies.futures-util]
[workspace.dependencies.futures]
version = "0.3.30"
default-features = false
features = ["std"]
[workspace.dependencies.tokio]
version = "1.40.0"
@ -314,7 +315,7 @@ version = "0.1.2"
[workspace.dependencies.ruma]
git = "https://github.com/girlbossceo/ruwuma"
#branch = "conduwuit-changes"
rev = "9900d0676564883cfade556d6e8da2a2c9061efd"
rev = "e7db44989d68406393270d3a91815597385d3acb"
features = [
"compat",
"rand",
@ -463,7 +464,6 @@ version = "1.0.36"
[workspace.dependencies.proc-macro2]
version = "1.0.89"
#
# Patches
#
@ -828,6 +828,7 @@ missing_panics_doc = { level = "allow", priority = 1 }
module_name_repetitions = { level = "allow", priority = 1 }
no_effect_underscore_binding = { level = "allow", priority = 1 }
similar_names = { level = "allow", priority = 1 }
single_match_else = { level = "allow", priority = 1 }
struct_field_names = { level = "allow", priority = 1 }
unnecessary_wraps = { level = "allow", priority = 1 }
unused_async = { level = "allow", priority = 1 }

View File

@ -2,6 +2,6 @@ array-size-threshold = 4096
cognitive-complexity-threshold = 94 # TODO reduce me ALARA
excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5
future-size-threshold = 7745 # TODO reduce me ALARA
stack-size-threshold = 144000 # reduce me ALARA
stack-size-threshold = 196608 # reduce me ALARA
too-many-lines-threshold = 700 # TODO reduce me to <= 100
type-complexity-threshold = 250 # reduce me to ~200

View File

@ -29,10 +29,11 @@ release_max_log_level = [
clap.workspace = true
conduit-api.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
conduit-macros.workspace = true
conduit-service.workspace = true
const-str.workspace = true
futures-util.workspace = true
futures.workspace = true
log.workspace = true
ruma.workspace = true
serde_json.workspace = true

View File

@ -1,5 +1,6 @@
use conduit::Result;
use conduit_macros::implement;
use futures::StreamExt;
use ruma::events::room::message::RoomMessageEventContent;
use crate::Command;
@ -10,14 +11,12 @@ use crate::Command;
#[implement(Command, params = "<'_>")]
pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let results = self.services.users.db.iter();
let users = self.services.users.iter().collect::<Vec<_>>().await;
let query_time = timer.elapsed();
let users = results.collect::<Vec<_>>();
let total = users.len();
let err_count = users.iter().filter(|user| user.is_err()).count();
let ok_count = users.iter().filter(|user| user.is_ok()).count();
let err_count = users.iter().filter(|_user| false).count();
let ok_count = users.iter().filter(|_user| true).count();
let message = format!(
"Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \

View File

@ -7,6 +7,7 @@ use std::{
use api::client::validate_and_add_event_id;
use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result};
use futures::StreamExt;
use ruma::{
api::{client::error::ErrorKind, federation::event::get_room_state},
events::room::message::RoomMessageEventContent,
@ -27,7 +28,7 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent
#[admin_command]
pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? {
if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
@ -43,7 +44,8 @@ pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<Room
.auth_chain
.event_ids_iter(room_id, vec![event_id])
.await?
.count();
.count()
.await;
let elapsed = start.elapsed();
Ok(RoomMessageEventContent::text_plain(format!(
@ -91,13 +93,16 @@ pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessage
.services
.rooms
.timeline
.get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() {
.get_non_outlier_pdu_json(&event_id)
.await;
if pdu_json.is_err() {
outlier = true;
pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id)?;
pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id).await;
}
match pdu_json {
Some(json) => {
Ok(json) => {
let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json");
Ok(RoomMessageEventContent::notice_markdown(format!(
"{}\n```json\n{}\n```",
@ -109,7 +114,7 @@ pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessage
json_text
)))
},
None => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")),
Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")),
}
}
@ -157,7 +162,8 @@ pub(super) async fn get_remote_pdu_list(
.send_message(RoomMessageEventContent::text_plain(format!(
"Failed to get remote PDU, ignoring error: {e}"
)))
.await;
.await
.ok();
warn!("Failed to get remote PDU, ignoring error: {e}");
} else {
success_count = success_count.saturating_add(1);
@ -215,7 +221,9 @@ pub(super) async fn get_remote_pdu(
.services
.rooms
.event_handler
.parse_incoming_pdu(&response.pdu);
.parse_incoming_pdu(&response.pdu)
.await;
let (event_id, value, room_id) = match parsed_result {
Ok(t) => t,
Err(e) => {
@ -333,9 +341,12 @@ pub(super) async fn ping(&self, server: Box<ServerName>) -> Result<RoomMessageEv
#[admin_command]
pub(super) async fn force_device_list_updates(&self) -> Result<RoomMessageEventContent> {
// Force E2EE device list updates for all users
for user_id in self.services.users.iter().filter_map(Result::ok) {
self.services.users.mark_device_key_update(&user_id)?;
}
self.services
.users
.stream()
.for_each(|user_id| self.services.users.mark_device_key_update(user_id))
.await;
Ok(RoomMessageEventContent::text_plain(
"Marked all devices for all users as having new keys to update",
))
@ -470,7 +481,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo
.services
.rooms
.state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)?
.server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.",
@ -481,8 +493,9 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo
.services
.rooms
.timeline
.first_pdu_in_room(&room_id)?
.ok_or_else(|| Error::bad_database("Failed to find the first PDU in database"))?;
.first_pdu_in_room(&room_id)
.await
.map_err(|_| Error::bad_database("Failed to find the first PDU in database"))?;
Ok(RoomMessageEventContent::text_plain(format!("{first_pdu:?}")))
}
@ -494,7 +507,8 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro
.services
.rooms
.state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)?
.server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.",
@ -505,8 +519,9 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro
.services
.rooms
.timeline
.latest_pdu_in_room(&room_id)?
.ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?;
.latest_pdu_in_room(&room_id)
.await
.map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?;
Ok(RoomMessageEventContent::text_plain(format!("{latest_pdu:?}")))
}
@ -520,7 +535,8 @@ pub(super) async fn force_set_room_state_from_server(
.services
.rooms
.state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)?
.server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{
return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.",
@ -531,10 +547,11 @@ pub(super) async fn force_set_room_state_from_server(
.services
.rooms
.timeline
.latest_pdu_in_room(&room_id)?
.ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?;
.latest_pdu_in_room(&room_id)
.await
.map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?;
let room_version = self.services.rooms.state.get_room_version(&room_id)?;
let room_version = self.services.rooms.state.get_room_version(&room_id).await?;
let mut state: HashMap<u64, Arc<EventId>> = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new());
@ -554,13 +571,21 @@ pub(super) async fn force_set_room_state_from_server(
let mut events = Vec::with_capacity(remote_state_response.pdus.len());
for pdu in remote_state_response.pdus.clone() {
events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) {
Ok(t) => t,
Err(e) => {
warn!("Could not parse PDU, ignoring: {e}");
continue;
events.push(
match self
.services
.rooms
.event_handler
.parse_incoming_pdu(&pdu)
.await
{
Ok(t) => t,
Err(e) => {
warn!("Could not parse PDU, ignoring: {e}");
continue;
},
},
});
);
}
info!("Fetching required signing keys for all the state events we got");
@ -587,13 +612,16 @@ pub(super) async fn force_set_room_state_from_server(
self.services
.rooms
.outlier
.add_pdu_outlier(&event_id, &value)?;
.add_pdu_outlier(&event_id, &value);
if let Some(state_key) = &pdu.state_key {
let shortstatekey = self
.services
.rooms
.short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?;
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)
.await;
state.insert(shortstatekey, pdu.event_id.clone());
}
}
@ -611,7 +639,7 @@ pub(super) async fn force_set_room_state_from_server(
self.services
.rooms
.outlier
.add_pdu_outlier(&event_id, &value)?;
.add_pdu_outlier(&event_id, &value);
}
let new_room_state = self
@ -626,7 +654,8 @@ pub(super) async fn force_set_room_state_from_server(
.services
.rooms
.state_compressor
.save_state(room_id.clone().as_ref(), new_room_state)?;
.save_state(room_id.clone().as_ref(), new_room_state)
.await?;
let state_lock = self.services.rooms.state.mutex.lock(&room_id).await;
self.services
@ -642,7 +671,8 @@ pub(super) async fn force_set_room_state_from_server(
self.services
.rooms
.state_cache
.update_joined_count(&room_id)?;
.update_joined_count(&room_id)
.await;
drop(state_lock);
@ -656,7 +686,7 @@ pub(super) async fn get_signing_keys(
&self, server_name: Option<Box<ServerName>>, _cached: bool,
) -> Result<RoomMessageEventContent> {
let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into());
let signing_keys = self.services.globals.signing_keys_for(&server_name)?;
let signing_keys = self.services.globals.signing_keys_for(&server_name).await?;
Ok(RoomMessageEventContent::notice_markdown(format!(
"```rs\n{signing_keys:#?}\n```"
@ -674,7 +704,7 @@ pub(super) async fn get_verify_keys(
if cached {
writeln!(out, "| Key ID | VerifyKey |")?;
writeln!(out, "| --- | --- |")?;
for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? {
for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? {
writeln!(out, "| {key_id} | {verify_key:?} |")?;
}

View File

@ -1,19 +1,20 @@
use std::fmt::Write;
use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId};
use crate::{admin_command, escape_html, get_room_info};
#[admin_command]
pub(super) async fn disable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
self.services.rooms.metadata.disable_room(&room_id, true)?;
self.services.rooms.metadata.disable_room(&room_id, true);
Ok(RoomMessageEventContent::text_plain("Room disabled."))
}
#[admin_command]
pub(super) async fn enable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
self.services.rooms.metadata.disable_room(&room_id, false)?;
self.services.rooms.metadata.disable_room(&room_id, false);
Ok(RoomMessageEventContent::text_plain("Room enabled."))
}
@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result<
));
}
if !self.services.users.exists(&user_id)? {
if !self.services.users.exists(&user_id).await {
return Ok(RoomMessageEventContent::text_plain(
"Remote user does not exist in our database.",
));
@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result<
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.map(|room_id| get_room_info(self.services, &room_id))
.collect();
.then(|room_id| get_room_info(self.services, room_id))
.collect()
.await;
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));

View File

@ -36,7 +36,7 @@ pub(super) async fn delete(
let mut mxc_urls = Vec::with_capacity(4);
// parsing the PDU for any MXC URLs begins here
if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? {
if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
if let Some(content_key) = event_json.get("content") {
debug!("Event ID has \"content\".");
let content_obj = content_key.as_object();
@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server(
#[admin_command]
pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result<RoomMessageEventContent> {
let mxc: Mxc<'_> = mxc.as_str().try_into()?;
let metadata = self.services.media.get_metadata(&mxc);
let metadata = self.services.media.get_metadata(&mxc).await;
Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```")))
}

View File

@ -17,7 +17,7 @@ use conduit::{
utils::string::{collect_stream, common_prefix},
warn, Error, Result,
};
use futures_util::future::FutureExt;
use futures::future::FutureExt;
use ruma::{
events::{
relation::InReplyTo,

View File

@ -44,7 +44,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_
let timer = tokio::time::Instant::now();
let results = services
.account_data
.changes_since(room_id.as_deref(), &user_id, since)?;
.changes_since(room_id.as_deref(), &user_id, since)
.await?;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -59,7 +60,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_
let timer = tokio::time::Instant::now();
let results = services
.account_data
.get(room_id.as_deref(), &user_id, kind)?;
.get(room_id.as_deref(), &user_id, kind)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -29,7 +29,9 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_>
let results = services
.appservice
.db
.get_registration(appservice_id.as_ref());
.get_registration(appservice_id.as_ref())
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -38,7 +40,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_>
},
AppserviceCommand::All => {
let timer = tokio::time::Instant::now();
let results = services.appservice.all();
let results = services.appservice.all().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -29,7 +29,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
match subcommand {
GlobalsCommand::DatabaseVersion => {
let timer = tokio::time::Instant::now();
let results = services.globals.db.database_version();
let results = services.globals.db.database_version().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -47,7 +47,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
},
GlobalsCommand::LastCheckForUpdatesId => {
let timer = tokio::time::Instant::now();
let results = services.updates.last_check_for_updates_id();
let results = services.updates.last_check_for_updates_id().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -67,7 +67,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
origin,
} => {
let timer = tokio::time::Instant::now();
let results = services.globals.db.verify_keys_for(&origin);
let results = services.globals.db.verify_keys_for(&origin).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand;
use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, UserId};
use crate::Command;
@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>)
user_id,
} => {
let timer = tokio::time::Instant::now();
let results = services.presence.db.get_presence(&user_id)?;
let results = services.presence.db.get_presence(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -42,7 +43,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>)
} => {
let timer = tokio::time::Instant::now();
let results = services.presence.db.presence_since(since);
let presence_since: Vec<(_, _, _)> = results.collect();
let presence_since: Vec<(_, _, _)> = results.collect().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) ->
user_id,
} => {
let timer = tokio::time::Instant::now();
let results = services.pusher.get_pushers(&user_id)?;
let results = services.pusher.get_pushers(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand;
use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId};
use crate::Command;
@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
alias,
} => {
let timer = tokio::time::Instant::now();
let results = services.rooms.alias.resolve_local_alias(&alias);
let results = services.rooms.alias.resolve_local_alias(&alias).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -43,7 +44,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
} => {
let timer = tokio::time::Instant::now();
let results = services.rooms.alias.local_aliases_for_room(&room_id);
let aliases: Vec<_> = results.collect();
let aliases: Vec<_> = results.collect().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -52,8 +53,12 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
},
RoomAliasCommand::AllLocalAliases => {
let timer = tokio::time::Instant::now();
let results = services.rooms.alias.all_local_aliases();
let aliases: Vec<_> = results.collect();
let aliases = services
.rooms
.alias
.all_local_aliases()
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand;
use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId};
use crate::Command;
@ -86,7 +87,11 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let result = services.rooms.state_cache.server_in_room(&server, &room_id);
let result = services
.rooms
.state_cache
.server_in_room(&server, &room_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -97,7 +102,13 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.room_servers(&room_id).collect();
let results: Vec<_> = services
.rooms
.state_cache
.room_servers(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -108,7 +119,13 @@ pub(super) async fn process(
server,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.server_rooms(&server).collect();
let results: Vec<_> = services
.rooms
.state_cache
.server_rooms(&server)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -119,7 +136,13 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.room_members(&room_id).collect();
let results: Vec<_> = services
.rooms
.state_cache
.room_members(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -134,7 +157,9 @@ pub(super) async fn process(
.rooms
.state_cache
.local_users_in_room(&room_id)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -149,7 +174,9 @@ pub(super) async fn process(
.rooms
.state_cache
.active_local_users_in_room(&room_id)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -160,7 +187,7 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.room_joined_count(&room_id);
let results = services.rooms.state_cache.room_joined_count(&room_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -171,7 +198,11 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.room_invited_count(&room_id);
let results = services
.rooms
.state_cache
.room_invited_count(&room_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -182,11 +213,13 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services
let results: Vec<_> = services
.rooms
.state_cache
.room_useroncejoined(&room_id)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -197,11 +230,13 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services
let results: Vec<_> = services
.rooms
.state_cache
.room_members_invited(&room_id)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -216,7 +251,8 @@ pub(super) async fn process(
let results = services
.rooms
.state_cache
.get_invite_count(&room_id, &user_id);
.get_invite_count(&room_id, &user_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -231,7 +267,8 @@ pub(super) async fn process(
let results = services
.rooms
.state_cache
.get_left_count(&room_id, &user_id);
.get_left_count(&room_id, &user_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -242,7 +279,13 @@ pub(super) async fn process(
user_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_joined(&user_id).collect();
let results: Vec<_> = services
.rooms
.state_cache
.rooms_joined(&user_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -253,7 +296,12 @@ pub(super) async fn process(
user_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_invited(&user_id).collect();
let results: Vec<_> = services
.rooms
.state_cache
.rooms_invited(&user_id)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -264,7 +312,12 @@ pub(super) async fn process(
user_id,
} => {
let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_left(&user_id).collect();
let results: Vec<_> = services
.rooms
.state_cache
.rooms_left(&user_id)
.collect()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -276,7 +329,11 @@ pub(super) async fn process(
room_id,
} => {
let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.invite_state(&user_id, &room_id);
let results = services
.rooms
.state_cache
.invite_state(&user_id, &room_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand;
use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId};
use service::sending::Destination;
@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
SendingCommand::ActiveRequests => {
let timer = tokio::time::Instant::now();
let results = services.sending.db.active_requests();
let active_requests: Result<Vec<(_, _, _)>> = results.collect();
let active_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
},
};
let queued_requests = results.collect::<Result<Vec<(_, _)>>>();
let queued_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
},
};
let active_requests = results.collect::<Result<Vec<(_, _)>>>();
let active_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
server_name,
} => {
let timer = tokio::time::Instant::now();
let results = services.sending.db.get_latest_educount(&server_name);
let results = services.sending.db.get_latest_educount(&server_name).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,29 +1,344 @@
use clap::Subcommand;
use conduit::Result;
use ruma::events::room::message::RoomMessageEventContent;
use futures::stream::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId};
use crate::Command;
use crate::{admin_command, admin_command_dispatch};
#[admin_command_dispatch]
#[derive(Debug, Subcommand)]
/// All the getters and iterators from src/database/key_value/users.rs
pub(crate) enum UsersCommand {
Iter,
CountUsers,
IterUsers,
PasswordHash {
user_id: OwnedUserId,
},
ListDevices {
user_id: OwnedUserId,
},
ListDevicesMetadata {
user_id: OwnedUserId,
},
GetDeviceMetadata {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetDevicesVersion {
user_id: OwnedUserId,
},
CountOneTimeKeys {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetDeviceKeys {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetUserSigningKey {
user_id: OwnedUserId,
},
GetMasterKey {
user_id: OwnedUserId,
},
GetToDeviceEvents {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetLatestBackup {
user_id: OwnedUserId,
},
GetLatestBackupVersion {
user_id: OwnedUserId,
},
GetBackupAlgorithm {
user_id: OwnedUserId,
version: String,
},
GetAllBackups {
user_id: OwnedUserId,
version: String,
},
GetRoomBackups {
user_id: OwnedUserId,
version: String,
room_id: OwnedRoomId,
},
GetBackupSession {
user_id: OwnedUserId,
version: String,
room_id: OwnedRoomId,
session_id: String,
},
}
/// All the getters and iterators in key_value/users.rs
pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> {
let services = context.services;
#[admin_command]
async fn get_backup_session(
&self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String,
) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_session(&user_id, &version, &room_id, &session_id)
.await;
let query_time = timer.elapsed();
match subcommand {
UsersCommand::Iter => {
let timer = tokio::time::Instant::now();
let results = services.users.db.iter();
let users = results.collect::<Vec<_>>();
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```"
)))
},
}
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_room_backups(
&self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId,
) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_room(&user_id, &version, &room_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.key_backups.get_all(&user_id, &version).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_backup(&user_id, &version)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_latest_backup_version(&user_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.key_backups.get_latest_backup(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn iter_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result: Vec<OwnedUserId> = self.services.users.stream().map(Into::into).collect().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn count_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.count().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn password_hash(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.password_hash(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn list_devices(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let devices = self
.services
.users
.all_device_ids(&user_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```"
)))
}
#[admin_command]
async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let devices = self
.services
.users
.all_devices_metadata(&user_id)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```"
)))
}
#[admin_command]
async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let device = self
.services
.users
.get_device_metadata(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```"
)))
}
#[admin_command]
async fn get_devices_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let device = self.services.users.get_devicelist_version(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```"
)))
}
#[admin_command]
async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.count_one_time_keys(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_device_keys(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.get_user_signing_key(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_master_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_master_key(None, &user_id, &|_| true)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_to_device_events(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId,
) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_to_device_events(&user_id, &device_id)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use clap::Subcommand;
use conduit::Result;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::{escape_html, Command};
@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
force,
room_id,
..
} => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) {
(true, Ok(Some(id))) => match services
} => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) {
(true, Ok(id)) => match services
.rooms
.alias
.set_alias(&room_alias, &room_id, server_user)
@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
},
(false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!(
(false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!(
"Refusing to overwrite in use alias for {id}, use -f or --force to overwrite"
))),
(_, Ok(None)) => match services
(_, Err(_)) => match services
.rooms
.alias
.set_alias(&room_alias, &room_id, server_user)
@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
},
(_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
},
RoomAliasCommand::Remove {
..
} => match services.rooms.alias.resolve_local_alias(&room_alias) {
Ok(Some(id)) => match services
} => match services.rooms.alias.resolve_local_alias(&room_alias).await {
Ok(id) => match services
.rooms
.alias
.remove_alias(&room_alias, server_user)
@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
},
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
},
RoomAliasCommand::Which {
..
} => match services.rooms.alias.resolve_local_alias(&room_alias) {
Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))),
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
} => match services.rooms.alias.resolve_local_alias(&room_alias).await {
Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))),
Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
},
RoomAliasCommand::List {
..
@ -125,63 +123,59 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
.rooms
.alias
.local_aliases_for_room(&room_id)
.collect::<Result<Vec<_>, _>>();
match aliases {
Ok(aliases) => {
let plain_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "- {alias}").expect("should be able to write to string buffer");
output
});
.map(Into::into)
.collect::<Vec<OwnedRoomAliasId>>()
.await;
let html_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "<li>{}</li>", escape_html(alias.as_ref()))
.expect("should be able to write to string buffer");
output
});
let plain_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "- {alias}").expect("should be able to write to string buffer");
output
});
let plain = format!("Aliases for {room_id}:\n{plain_list}");
let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))),
}
let html_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "<li>{}</li>", escape_html(alias.as_ref()))
.expect("should be able to write to string buffer");
output
});
let plain = format!("Aliases for {room_id}:\n{plain_list}");
let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
} else {
let aliases = services
.rooms
.alias
.all_local_aliases()
.collect::<Result<Vec<_>, _>>();
match aliases {
Ok(aliases) => {
let server_name = services.globals.server_name();
let plain_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(output, "- `{alias}` -> #{id}:{server_name}")
.expect("should be able to write to string buffer");
output
});
.map(|(room_id, localpart)| (room_id.into(), localpart.into()))
.collect::<Vec<(OwnedRoomId, String)>>()
.await;
let html_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(
output,
"<li><code>{}</code> -> #{}:{}</li>",
escape_html(alias.as_ref()),
escape_html(id.as_ref()),
server_name
)
.expect("should be able to write to string buffer");
output
});
let server_name = services.globals.server_name();
let plain_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(output, "- `{alias}` -> #{id}:{server_name}")
.expect("should be able to write to string buffer");
output
});
let plain = format!("Aliases:\n{plain_list}");
let html = format!("Aliases:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))),
}
let html_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(
output,
"<li><code>{}</code> -> #{}:{}</li>",
escape_html(alias.as_ref()),
escape_html(id),
server_name
)
.expect("should be able to write to string buffer");
output
});
let plain = format!("Aliases:\n{plain_list}");
let html = format!("Aliases:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html))
}
},
}

View File

@ -1,11 +1,12 @@
use conduit::Result;
use ruma::events::room::message::RoomMessageEventContent;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use crate::{admin_command, get_room_info, PAGE_SIZE};
#[admin_command]
pub(super) async fn list_rooms(
&self, page: Option<usize>, exclude_disabled: bool, exclude_banned: bool, no_details: bool,
&self, page: Option<usize>, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool,
) -> Result<RoomMessageEventContent> {
// TODO: i know there's a way to do this with clap, but i can't seem to find it
let page = page.unwrap_or(1);
@ -14,37 +15,12 @@ pub(super) async fn list_rooms(
.rooms
.metadata
.iter_ids()
.filter_map(|room_id| {
room_id
.ok()
.filter(|room_id| {
if exclude_disabled
&& self
.services
.rooms
.metadata
.is_disabled(room_id)
.unwrap_or(false)
{
return false;
}
//.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await })
//.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await })
.then(|room_id| get_room_info(self.services, room_id))
.collect::<Vec<_>>()
.await;
if exclude_banned
&& self
.services
.rooms
.metadata
.is_banned(room_id)
.unwrap_or(false)
{
return false;
}
true
})
.map(|room_id| get_room_info(self.services, &room_id))
})
.collect::<Vec<_>>();
rooms.sort_by_key(|r| r.1);
rooms.reverse();
@ -74,3 +50,10 @@ pub(super) async fn list_rooms(
Ok(RoomMessageEventContent::notice_markdown(output_plain))
}
#[admin_command]
pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result<RoomMessageEventContent> {
let result = self.services.rooms.metadata.exists(&room_id).await;
Ok(RoomMessageEventContent::notice_markdown(format!("{result}")))
}

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use clap::Subcommand;
use conduit::Result;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId};
use crate::{escape_html, get_room_info, Command, PAGE_SIZE};
@ -31,15 +32,15 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_>
match command {
RoomDirectoryCommand::Publish {
room_id,
} => match services.rooms.directory.set_public(&room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))),
} => {
services.rooms.directory.set_public(&room_id);
Ok(RoomMessageEventContent::notice_plain("Room published"))
},
RoomDirectoryCommand::Unpublish {
room_id,
} => match services.rooms.directory.set_not_public(&room_id) {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))),
} => {
services.rooms.directory.set_not_public(&room_id);
Ok(RoomMessageEventContent::notice_plain("Room unpublished"))
},
RoomDirectoryCommand::List {
page,
@ -50,9 +51,10 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_>
.rooms
.directory
.public_rooms()
.filter_map(Result::ok)
.map(|id: OwnedRoomId| get_room_info(services, &id))
.collect::<Vec<_>>();
.then(|room_id| get_room_info(services, room_id))
.collect::<Vec<_>>()
.await;
rooms.sort_by_key(|r| r.1);
rooms.reverse();

View File

@ -1,5 +1,6 @@
use clap::Subcommand;
use conduit::Result;
use conduit::{utils::ReadyExt, Result};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId};
use crate::{admin_command, admin_command_dispatch};
@ -32,46 +33,42 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R
.rooms
.state_accessor
.get_name(&room_id)
.ok()
.flatten()
.unwrap_or_else(|| room_id.to_string());
.await
.unwrap_or_else(|_| room_id.to_string());
let members = self
let member_info: Vec<_> = self
.services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|member| {
.ready_filter(|user_id| {
if local_only {
member
.ok()
.filter(|user| self.services.globals.user_is_local(user))
self.services.globals.user_is_local(user_id)
} else {
member.ok()
true
}
});
let member_info = members
.into_iter()
.map(|user_id| {
(
user_id.clone(),
})
.filter_map(|user_id| async move {
let user_id = user_id.to_owned();
Some((
self.services
.users
.displayname(&user_id)
.unwrap_or(None)
.unwrap_or_else(|| user_id.to_string()),
)
.await
.unwrap_or_else(|_| user_id.to_string()),
user_id,
))
})
.collect::<Vec<_>>();
.collect()
.await;
let output_plain = format!(
"{} Members in Room \"{}\":\n```\n{}\n```",
member_info.len(),
room_name,
member_info
.iter()
.map(|(mxid, displayname)| format!("{mxid} | {displayname}"))
.into_iter()
.map(|(displayname, mxid)| format!("{mxid} | {displayname}"))
.collect::<Vec<_>>()
.join("\n")
);
@ -81,11 +78,12 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R
#[admin_command]
async fn view_room_topic(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
let Some(room_topic) = self
let Ok(room_topic) = self
.services
.rooms
.state_accessor
.get_room_topic(&room_id)?
.get_room_topic(&room_id)
.await
else {
return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set."));
};

View File

@ -6,6 +6,7 @@ mod moderation;
use clap::Subcommand;
use conduit::Result;
use ruma::OwnedRoomId;
use self::{
alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand,
@ -49,4 +50,9 @@ pub(super) enum RoomCommand {
#[command(subcommand)]
/// - Manage the room directory
Directory(RoomDirectoryCommand),
/// - Check if we know about a room
Exists {
room_id: OwnedRoomId,
},
}

View File

@ -1,6 +1,11 @@
use api::client::leave_room;
use clap::Subcommand;
use conduit::{debug, error, info, warn, Result};
use conduit::{
debug, error, info,
utils::{IterStream, ReadyExt},
warn, Result,
};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId};
use crate::{admin_command, admin_command_dispatch, get_room_info};
@ -76,7 +81,7 @@ async fn ban_room(
let admin_room_alias = &self.services.globals.admin_alias;
if let Some(admin_room_id) = self.services.admin.get_admin_room()? {
if let Ok(admin_room_id) = self.services.admin.get_admin_room().await {
if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) {
return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room."));
}
@ -95,7 +100,7 @@ async fn ban_room(
debug!("Room specified is a room ID, banning room ID");
self.services.rooms.metadata.ban_room(&room_id, true)?;
self.services.rooms.metadata.ban_room(&room_id, true);
room_id
} else if room.is_room_alias_id() {
@ -114,7 +119,13 @@ async fn ban_room(
get_alias_helper to fetch room ID remotely"
);
let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? {
let room_id = if let Ok(room_id) = self
.services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id
} else {
debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation");
@ -138,7 +149,7 @@ async fn ban_room(
}
};
self.services.rooms.metadata.ban_room(&room_id, true)?;
self.services.rooms.metadata.ban_room(&room_id, true);
room_id
} else {
@ -150,56 +161,40 @@ async fn ban_room(
debug!("Making all users leave the room {}", &room);
if force {
for local_user in self
let mut users = self
.services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
self.services.globals.user_is_local(local_user)
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would
// fail auth check)
&& (self.services.globals.user_is_local(local_user)
// since this is a force operation, assume user is an admin
// if somehow this fails
&& self.services
.users
.is_admin(local_user)
.unwrap_or(true))
})
}) {
.ready_filter(|user| self.services.globals.user_is_local(user))
.boxed();
while let Some(local_user) = users.next().await {
debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)",
&local_user, &room_id
"Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \
admins too)",
);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
warn!(%e, "Failed to leave room");
}
}
} else {
for local_user in self
let mut users = self
.services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote users
// who are in the admin room to the list of local users (would fail auth check)
&& (local_user.server_name()
== self.services.globals.server_name()
&& !self.services
.users
.is_admin(local_user)
.unwrap_or(false))
})
}) {
.ready_filter(|user| self.services.globals.user_is_local(user))
.boxed();
while let Some(local_user) = users.next().await {
if self.services.users.is_admin(local_user).await {
continue;
}
debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
error!(
"Error attempting to make local user {} leave room {} during room banning: {}",
&local_user, &room_id, e
@ -214,12 +209,14 @@ async fn ban_room(
}
// remove any local aliases, ignore errors
for ref local_alias in self
for local_alias in &self
.services
.rooms
.alias
.local_aliases_for_room(&room_id)
.filter_map(Result::ok)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await
{
_ = self
.services
@ -230,10 +227,10 @@ async fn ban_room(
}
// unpublish from room directory, ignore errors
_ = self.services.rooms.directory.set_not_public(&room_id);
self.services.rooms.directory.set_not_public(&room_id);
if disable_federation {
self.services.rooms.metadata.disable_room(&room_id, true)?;
self.services.rooms.metadata.disable_room(&room_id, true);
return Ok(RoomMessageEventContent::text_plain(
"Room banned, removed all our local users, and disabled incoming federation with room.",
));
@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
for &room in &rooms_s {
match <&RoomOrAliasId>::try_from(room) {
Ok(room_alias_or_id) => {
if let Some(admin_room_id) = self.services.admin.get_admin_room()? {
if let Ok(admin_room_id) = self.services.admin.get_admin_room().await {
if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) {
info!("User specified admin room in bulk ban list, ignoring");
continue;
@ -300,43 +297,48 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
if room_alias_or_id.is_room_alias_id() {
match RoomAliasId::parse(room_alias_or_id) {
Ok(room_alias) => {
let room_id =
if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? {
room_id
} else {
debug!(
"We don't have this room alias to a room ID locally, attempting to fetch room \
ID over federation"
);
let room_id = if let Ok(room_id) = self
.services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id
} else {
debug!(
"We don't have this room alias to a room ID locally, attempting to fetch room ID \
over federation"
);
match self
.services
.rooms
.alias
.resolve_alias(&room_alias, None)
.await
{
Ok((room_id, servers)) => {
debug!(
?room_id,
?servers,
"Got federation response fetching room ID for {room}",
);
room_id
},
Err(e) => {
// don't fail if force blocking
if force {
warn!("Failed to resolve room alias {room} to a room ID: {e}");
continue;
}
match self
.services
.rooms
.alias
.resolve_alias(&room_alias, None)
.await
{
Ok((room_id, servers)) => {
debug!(
?room_id,
?servers,
"Got federation response fetching room ID for {room}",
);
room_id
},
Err(e) => {
// don't fail if force blocking
if force {
warn!("Failed to resolve room alias {room} to a room ID: {e}");
continue;
}
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to resolve room alias {room} to a room ID: {e}"
)));
},
}
};
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to resolve room alias {room} to a room ID: {e}"
)));
},
}
};
room_ids.push(room_id);
},
@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
}
for room_id in room_ids {
if self
.services
.rooms
.metadata
.ban_room(&room_id, true)
.is_ok()
{
debug!("Banned {room_id} successfully");
room_ban_count = room_ban_count.saturating_add(1);
}
self.services.rooms.metadata.ban_room(&room_id, true);
debug!("Banned {room_id} successfully");
room_ban_count = room_ban_count.saturating_add(1);
debug!("Making all users leave the room {}", &room_id);
if force {
for local_user in self
let mut users = self
.services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote
// users who are in the admin room to the list of local
// users (would fail auth check)
&& (local_user.server_name()
== self.services.globals.server_name()
// since this is a force operation, assume user is an
// admin if somehow this fails
&& self.services
.users
.is_admin(local_user)
.unwrap_or(true))
})
}) {
.ready_filter(|user| self.services.globals.user_is_local(user))
.boxed();
while let Some(local_user) = users.next().await {
debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)",
&local_user, room_id
"Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \
admins too)",
);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
warn!(%e, "Failed to leave room");
}
}
} else {
for local_user in self
let mut users = self
.services
.rooms
.state_cache
.room_members(&room_id)
.filter_map(|user| {
user.ok().filter(|local_user| {
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote
// users who are in the admin room to the list of local
// users (would fail auth check)
&& (local_user.server_name()
== self.services.globals.server_name()
&& !self.services
.users
.is_admin(local_user)
.unwrap_or(false))
})
}) {
debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
.ready_filter(|user| self.services.globals.user_is_local(user))
.boxed();
while let Some(local_user) = users.next().await {
if self.services.users.is_admin(local_user).await {
continue;
}
debug!("Attempting leave for user {local_user} in room {room_id}");
if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
error!(
"Error attempting to make local user {} leave room {} during bulk room banning: {}",
&local_user, &room_id, e
"Error attempting to make local user {local_user} leave room {room_id} during bulk room \
banning: {e}",
);
return Ok(RoomMessageEventContent::text_plain(format!(
"Error attempting to make local user {} leave room {} during room banning (room is still \
banned but not removing any more users and not banning any more rooms): {}\nIf you would \
@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
}
// remove any local aliases, ignore errors
for ref local_alias in self
.services
self.services
.rooms
.alias
.local_aliases_for_room(&room_id)
.filter_map(Result::ok)
{
_ = self
.services
.rooms
.alias
.remove_alias(local_alias, &self.services.globals.server_user)
.await;
}
.map(ToOwned::to_owned)
.for_each(|local_alias| async move {
self.services
.rooms
.alias
.remove_alias(&local_alias, &self.services.globals.server_user)
.await
.ok();
})
.await;
// unpublish from room directory, ignore errors
_ = self.services.rooms.directory.set_not_public(&room_id);
self.services.rooms.directory.set_not_public(&room_id);
if disable_federation {
self.services.rooms.metadata.disable_room(&room_id, true)?;
self.services.rooms.metadata.disable_room(&room_id, true);
}
}
@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
debug!("Room specified is a room ID, unbanning room ID");
self.services.rooms.metadata.ban_room(&room_id, false)?;
self.services.rooms.metadata.ban_room(&room_id, false);
room_id
} else if room.is_room_alias_id() {
@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
get_alias_helper to fetch room ID remotely"
);
let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? {
let room_id = if let Ok(room_id) = self
.services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id
} else {
debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation");
@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
}
};
self.services.rooms.metadata.ban_room(&room_id, false)?;
self.services.rooms.metadata.ban_room(&room_id, false);
room_id
} else {
@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
};
if enable_federation {
self.services.rooms.metadata.disable_room(&room_id, false)?;
self.services.rooms.metadata.disable_room(&room_id, false);
return Ok(RoomMessageEventContent::text_plain("Room unbanned."));
}
@ -569,45 +555,42 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
#[admin_command]
async fn list_banned_rooms(&self, no_details: bool) -> Result<RoomMessageEventContent> {
let rooms = self
let room_ids = self
.services
.rooms
.metadata
.list_banned_rooms()
.collect::<Result<Vec<_>, _>>();
.map(Into::into)
.collect::<Vec<OwnedRoomId>>()
.await;
match rooms {
Ok(room_ids) => {
if room_ids.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No rooms are banned."));
}
let mut rooms = room_ids
.into_iter()
.map(|room_id| get_room_info(self.services, &room_id))
.collect::<Vec<_>>();
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let output_plain = format!(
"Rooms Banned ({}):\n```\n{}\n```",
rooms.len(),
rooms
.iter()
.map(|(id, members, name)| if no_details {
format!("{id}")
} else {
format!("{id}\tMembers: {members}\tName: {name}")
})
.collect::<Vec<_>>()
.join("\n")
);
Ok(RoomMessageEventContent::notice_markdown(output_plain))
},
Err(e) => {
error!("Failed to list banned rooms: {e}");
Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}")))
},
if room_ids.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No rooms are banned."));
}
let mut rooms = room_ids
.iter()
.stream()
.then(|room_id| get_room_info(self.services, room_id))
.collect::<Vec<_>>()
.await;
rooms.sort_by_key(|r| r.1);
rooms.reverse();
let output_plain = format!(
"Rooms Banned ({}):\n```\n{}\n```",
rooms.len(),
rooms
.iter()
.map(|(id, members, name)| if no_details {
format!("{id}")
} else {
format!("{id}\tMembers: {members}\tName: {name}")
})
.collect::<Vec<_>>()
.join("\n")
);
Ok(RoomMessageEventContent::notice_markdown(output_plain))
}

View File

@ -1,7 +1,9 @@
use std::{collections::BTreeMap, fmt::Write as _};
use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room};
use conduit::{error, info, utils, warn, PduBuilder, Result};
use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result};
use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname};
use futures::StreamExt;
use ruma::{
events::{
room::{
@ -25,16 +27,19 @@ const AUTO_GEN_PASSWORD_LENGTH: usize = 25;
#[admin_command]
pub(super) async fn list_users(&self) -> Result<RoomMessageEventContent> {
match self.services.users.list_local_users() {
Ok(users) => {
let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len());
plain_msg += users.join("\n").as_str();
plain_msg += "\n```";
let users = self
.services
.users
.list_local_users()
.map(ToString::to_string)
.collect::<Vec<_>>()
.await;
Ok(RoomMessageEventContent::notice_markdown(plain_msg))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())),
}
let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len());
plain_msg += users.join("\n").as_str();
plain_msg += "\n```";
Ok(RoomMessageEventContent::notice_markdown(plain_msg))
}
#[admin_command]
@ -42,7 +47,7 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
// Validate user id
let user_id = parse_local_user_id(self.services, &username)?;
if self.services.users.exists(&user_id)? {
if self.services.users.exists(&user_id).await {
return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists")));
}
@ -77,23 +82,25 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
self.services
.users
.set_displayname(&user_id, Some(displayname))
.await?;
.set_displayname(&user_id, Some(displayname));
// Initial account data
self.services.account_data.update(
None,
&user_id,
ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id),
},
})
.expect("to json value always works"),
)?;
self.services
.account_data
.update(
None,
&user_id,
ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id),
},
})
.expect("to json value always works"),
)
.await?;
if !self.services.globals.config.auto_join_rooms.is_empty() {
for room in &self.services.globals.config.auto_join_rooms {
@ -101,7 +108,8 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
.services
.rooms
.state_cache
.server_in_room(self.services.globals.server_name(), room)?
.server_in_room(self.services.globals.server_name(), room)
.await
{
warn!("Skipping room {room} to automatically join as we have never joined before.");
continue;
@ -135,13 +143,14 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
// if this account creation is from the CLI / --execute, invite the first user
// to admin room
if let Some(admin_room) = self.services.admin.get_admin_room()? {
if let Ok(admin_room) = self.services.admin.get_admin_room().await {
if self
.services
.rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1)
.room_joined_count(&admin_room)
.await
.is_ok_and(is_equal_to!(1))
{
self.services.admin.make_user_admin(&user_id).await?;
@ -167,7 +176,7 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) ->
));
}
self.services.users.deactivate_account(&user_id)?;
self.services.users.deactivate_account(&user_id).await?;
if !no_leave_rooms {
self.services
@ -175,17 +184,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) ->
.send_message(RoomMessageEventContent::text_plain(format!(
"Making {user_id} leave all rooms after deactivation..."
)))
.await;
.await
.ok();
let all_joined_rooms: Vec<OwnedRoomId> = self
.services
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
full_user_deactivate(self.services, &user_id, all_joined_rooms).await?;
full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?;
update_displayname(self.services, &user_id, None, &all_joined_rooms).await?;
update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?;
leave_all_rooms(self.services, &user_id).await;
}
Ok(RoomMessageEventContent::text_plain(format!(
@ -238,15 +252,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
let mut admins = Vec::new();
for username in usernames {
match parse_active_local_user_id(self.services, username) {
match parse_active_local_user_id(self.services, username).await {
Ok(user_id) => {
if self.services.users.is_admin(&user_id)? && !force {
if self.services.users.is_admin(&user_id).await && !force {
self.services
.admin
.send_message(RoomMessageEventContent::text_plain(format!(
"{username} is an admin and --force is not set, skipping over"
)))
.await;
.await
.ok();
admins.push(username);
continue;
}
@ -258,7 +273,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.send_message(RoomMessageEventContent::text_plain(format!(
"{username} is the server service account, skipping over"
)))
.await;
.await
.ok();
continue;
}
@ -270,7 +286,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.send_message(RoomMessageEventContent::text_plain(format!(
"{username} is not a valid username, skipping over: {e}"
)))
.await;
.await
.ok();
continue;
},
}
@ -279,7 +296,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
let mut deactivation_count: usize = 0;
for user_id in user_ids {
match self.services.users.deactivate_account(&user_id) {
match self.services.users.deactivate_account(&user_id).await {
Ok(()) => {
deactivation_count = deactivation_count.saturating_add(1);
if !no_leave_rooms {
@ -289,16 +306,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.collect();
full_user_deactivate(self.services, &user_id, all_joined_rooms).await?;
.map(Into::into)
.collect()
.await;
full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?;
update_displayname(self.services, &user_id, None, &all_joined_rooms)
.await
.ok();
update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms)
.await
.ok();
leave_all_rooms(self.services, &user_id).await;
}
},
Err(e) => {
self.services
.admin
.send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}")))
.await;
.await
.ok();
},
}
}
@ -326,9 +353,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result<RoomMess
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.map(|room_id| get_room_info(self.services, &room_id))
.collect();
.then(|room_id| get_room_info(self.services, room_id))
.collect()
.await;
if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));
@ -404,10 +431,9 @@ pub(super) async fn force_demote(
.services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
.as_ref()
.and_then(|event| serde_json::from_str(event.content.get()).ok()?)
.and_then(|content: RoomPowerLevelsEventContent| content.into());
.room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "")
.await
.ok();
let user_can_demote_self = room_power_levels
.as_ref()
@ -417,9 +443,9 @@ pub(super) async fn force_demote(
.services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.as_ref()
.is_some_and(|event| event.sender == user_id);
.room_state_get(&room_id, &StateEventType::RoomCreate, "")
.await
.is_ok_and(|event| event.sender == user_id);
if !user_can_demote_self {
return Ok(RoomMessageEventContent::notice_markdown(
@ -473,15 +499,16 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result<RoomMessag
pub(super) async fn put_room_tag(
&self, user_id: String, room_id: Box<RoomId>, tag: String,
) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?;
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self
.services
.account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?;
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else(
|| TagEvent {
|_| TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
},
@ -494,12 +521,15 @@ pub(super) async fn put_room_tag(
.tags
.insert(tag.clone().into(), TagInfo::new());
self.services.account_data.update(
Some(&room_id),
&user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)?;
self.services
.account_data
.update(
Some(&room_id),
&user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(RoomMessageEventContent::text_plain(format!(
"Successfully updated room account data for {user_id} and room {room_id} with tag {tag}"
@ -510,15 +540,16 @@ pub(super) async fn put_room_tag(
pub(super) async fn delete_room_tag(
&self, user_id: String, room_id: Box<RoomId>, tag: String,
) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?;
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self
.services
.account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?;
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else(
|| TagEvent {
|_| TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
},
@ -528,12 +559,15 @@ pub(super) async fn delete_room_tag(
tags_event.content.tags.remove(&tag.clone().into());
self.services.account_data.update(
Some(&room_id),
&user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)?;
self.services
.account_data
.update(
Some(&room_id),
&user_id,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(RoomMessageEventContent::text_plain(format!(
"Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}"
@ -542,15 +576,16 @@ pub(super) async fn delete_room_tag(
#[admin_command]
pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?;
let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self
.services
.account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?;
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let tags_event = event.map_or_else(
|| TagEvent {
|_| TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
},
@ -566,11 +601,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>)
#[admin_command]
pub(super) async fn redact_event(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let Some(event) = self
let Ok(event) = self
.services
.rooms
.timeline
.get_non_outlier_pdu(&event_id)?
.get_non_outlier_pdu(&event_id)
.await
else {
return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database."));
};

View File

@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String {
.replace('>', "&gt;")
}
pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) {
pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) {
(
id.into(),
room_id.into(),
services
.rooms
.state_cache
.room_joined_count(id)
.ok()
.flatten()
.room_joined_count(room_id)
.await
.unwrap_or(0),
services
.rooms
.state_accessor
.get_name(id)
.ok()
.flatten()
.unwrap_or_else(|| id.to_string()),
.get_name(room_id)
.await
.unwrap_or_else(|_| room_id.to_string()),
)
}
@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result<
}
/// Parses user ID that is an active (not guest or deactivated) local user
pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> {
pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> {
let user_id = parse_local_user_id(services, user_id)?;
if !services.users.exists(&user_id)? {
if !services.users.exists(&user_id).await {
return Err!("User {user_id:?} does not exist on this server.");
}
if services.users.is_deactivated(&user_id)? {
if services.users.is_deactivated(&user_id).await? {
return Err!("User {user_id:?} is deactivated.");
}

View File

@ -45,7 +45,7 @@ conduit-core.workspace = true
conduit-database.workspace = true
conduit-service.workspace = true
const-str.workspace = true
futures-util.workspace = true
futures.workspace = true
hmac.workspace = true
http.workspace = true
http-body-util.workspace = true

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result};
use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result};
use futures::{FutureExt, StreamExt};
use register::RegistrationKind;
use ruma::{
api::client::{
@ -55,7 +56,7 @@ pub(crate) async fn get_register_available_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
// Check if username is creative enough
if services.users.exists(&user_id)? {
if services.users.exists(&user_id).await {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
}
@ -125,7 +126,7 @@ pub(crate) async fn register_route(
// forbid guests from registering if there is not a real admin user yet. give
// generic user error.
if is_guest && services.users.count()? < 2 {
if is_guest && services.users.count().await < 2 {
warn!(
"Guest account attempted to register before a real admin user has been registered, rejecting \
registration. Guest's initial device name: {:?}",
@ -142,7 +143,7 @@ pub(crate) async fn register_route(
.filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
if services.users.exists(&proposed_user_id)? {
if services.users.exists(&proposed_user_id).await {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
}
@ -162,7 +163,7 @@ pub(crate) async fn register_route(
services.globals.server_name(),
)
.unwrap();
if !services.users.exists(&proposed_user_id)? {
if !services.users.exists(&proposed_user_id).await {
break proposed_user_id;
}
},
@ -210,12 +211,15 @@ pub(crate) async fn register_route(
if !skip_auth {
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services.uiaa.try_auth(
&UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"),
"".into(),
auth,
&uiaainfo,
)?;
let (worked, uiaainfo) = services
.uiaa
.try_auth(
&UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"),
"".into(),
auth,
&uiaainfo,
)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
@ -227,7 +231,7 @@ pub(crate) async fn register_route(
"".into(),
&uiaainfo,
&json,
)?;
);
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -255,21 +259,23 @@ pub(crate) async fn register_route(
services
.users
.set_displayname(&user_id, Some(displayname.clone()))
.await?;
.set_displayname(&user_id, Some(displayname.clone()));
// Initial account data
services.account_data.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)?;
services
.account_data
.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)
.await?;
// Inhibit login does not work for guests
if !is_guest && body.inhibit_login {
@ -294,13 +300,16 @@ pub(crate) async fn register_route(
let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account
services.users.create_device(
&user_id,
&device_id,
&token,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)?;
services
.users
.create_device(
&user_id,
&device_id,
&token,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
.await?;
debug_info!(%user_id, %device_id, "User account was created");
@ -318,7 +327,8 @@ pub(crate) async fn register_route(
"New user \"{user_id}\" registered on this server from IP {client} and device display name \
\"{device_display_name}\""
)))
.await;
.await
.ok();
}
} else {
info!("New user \"{user_id}\" registered on this server.");
@ -329,7 +339,8 @@ pub(crate) async fn register_route(
.send_message(RoomMessageEventContent::notice_plain(format!(
"New user \"{user_id}\" registered on this server from IP {client}"
)))
.await;
.await
.ok();
}
}
}
@ -346,7 +357,8 @@ pub(crate) async fn register_route(
"Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \
this server from IP {client}"
)))
.await;
.await
.ok();
}
} else {
#[allow(clippy::collapsible_else_if)]
@ -357,7 +369,8 @@ pub(crate) async fn register_route(
"Guest user \"{user_id}\" with no device display name registered on this server from IP \
{client}",
)))
.await;
.await
.ok();
}
}
}
@ -365,10 +378,15 @@ pub(crate) async fn register_route(
// If this is the first real user, grant them admin privileges except for guest
// users Note: the server user, @conduit:servername, is generated first
if !is_guest {
if let Some(admin_room) = services.admin.get_admin_room()? {
if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) {
if let Ok(admin_room) = services.admin.get_admin_room().await {
if services
.rooms
.state_cache
.room_joined_count(&admin_room)
.await
.is_ok_and(is_equal_to!(1))
{
services.admin.make_user_admin(&user_id).await?;
warn!("Granting {user_id} admin privileges as the first user");
}
}
@ -382,7 +400,8 @@ pub(crate) async fn register_route(
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), room)?
.server_in_room(services.globals.server_name(), room)
.await
{
warn!("Skipping room {room} to automatically join as we have never joined before.");
continue;
@ -398,6 +417,7 @@ pub(crate) async fn register_route(
None,
&body.appservice_info,
)
.boxed()
.await
{
// don't return this error so we don't fail registrations
@ -461,16 +481,20 @@ pub(crate) async fn change_password_route(
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
.try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
// Success!
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
.create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -482,14 +506,12 @@ pub(crate) async fn change_password_route(
if body.logout_devices {
// Logout all devices except the current one
for id in services
services
.users
.all_device_ids(sender_user)
.filter_map(Result::ok)
.filter(|id| id != sender_device)
{
services.users.remove_device(sender_user, &id)?;
}
.ready_filter(|id| id != sender_device)
.for_each(|id| services.users.remove_device(sender_user, id))
.await;
}
info!("User {sender_user} changed their password.");
@ -500,7 +522,8 @@ pub(crate) async fn change_password_route(
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} changed their password."
)))
.await;
.await
.ok();
}
Ok(change_password::v3::Response {})
@ -520,7 +543,7 @@ pub(crate) async fn whoami_route(
Ok(whoami::v3::Response {
user_id: sender_user.clone(),
device_id,
is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(),
is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(),
})
}
@ -561,7 +584,9 @@ pub(crate) async fn deactivate_route(
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
.try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
@ -570,7 +595,8 @@ pub(crate) async fn deactivate_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
.create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -581,10 +607,14 @@ pub(crate) async fn deactivate_route(
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
full_user_deactivate(&services, sender_user, all_joined_rooms).await?;
super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?;
super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?;
full_user_deactivate(&services, sender_user, &all_joined_rooms).await?;
info!("User {sender_user} deactivated their account.");
@ -594,7 +624,8 @@ pub(crate) async fn deactivate_route(
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} deactivated their account."
)))
.await;
.await
.ok();
}
Ok(deactivate::v3::Response {
@ -674,34 +705,27 @@ pub(crate) async fn check_registration_token_validity(
/// - Removing all profile data
/// - Leaving all rooms (and forgets all of them)
pub async fn full_user_deactivate(
services: &Services, user_id: &UserId, all_joined_rooms: Vec<OwnedRoomId>,
services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId],
) -> Result<()> {
services.users.deactivate_account(user_id)?;
services.users.deactivate_account(user_id).await?;
super::update_displayname(services, user_id, None, all_joined_rooms).await?;
super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?;
super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?;
super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?;
let all_profile_keys = services
services
.users
.all_profile_keys(user_id)
.filter_map(Result::ok);
for (profile_key, _profile_value) in all_profile_keys {
if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) {
warn!("Failed removing {user_id} profile key {profile_key}: {e}");
}
}
.ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None))
.await;
for room_id in all_joined_rooms {
let state_lock = services.rooms.state.mutex.lock(&room_id).await;
let state_lock = services.rooms.state.mutex.lock(room_id).await;
let room_power_levels = services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
.as_ref()
.and_then(|event| serde_json::from_str(event.content.get()).ok()?)
.and_then(|content: RoomPowerLevelsEventContent| content.into());
.room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "")
.await
.ok();
let user_can_demote_self = room_power_levels
.as_ref()
@ -710,9 +734,9 @@ pub async fn full_user_deactivate(
}) || services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.as_ref()
.is_some_and(|event| event.sender == user_id);
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
.is_ok_and(|event| event.sender == user_id);
if user_can_demote_self {
let mut power_levels_content = room_power_levels.unwrap_or_default();
@ -732,7 +756,7 @@ pub async fn full_user_deactivate(
timestamp: None,
},
user_id,
&room_id,
room_id,
&state_lock,
)
.await

View File

@ -1,11 +1,9 @@
use axum::extract::State;
use conduit::{debug, Error, Result};
use conduit::{debug, Err, Result};
use futures::StreamExt;
use rand::seq::SliceRandom;
use ruma::{
api::client::{
alias::{create_alias, delete_alias, get_alias},
error::ErrorKind,
},
api::client::alias::{create_alias, delete_alias, get_alias},
OwnedServerName, RoomAliasId, RoomId,
};
use service::Services;
@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route(
.forbidden_alias_names()
.is_match(body.room_alias.alias())
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden."));
return Err!(Request(Forbidden("Room alias is forbidden.")));
}
if services
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.is_some()
.resolve_local_alias(&body.room_alias)
.await
.is_ok()
{
return Err(Error::Conflict("Alias already exists."));
return Err!(Conflict("Alias already exists."));
}
services
@ -95,16 +94,16 @@ pub(crate) async fn get_alias_route(
.resolve_alias(&room_alias, servers.as_ref())
.await
else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found."));
return Err!(Request(NotFound("Room with alias not found.")));
};
let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers);
let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await;
debug!(?room_alias, ?room_id, "available servers: {servers:?}");
Ok(get_alias::v3::Response::new(room_id, servers))
}
fn room_available_servers(
async fn room_available_servers(
services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>,
) -> Vec<OwnedServerName> {
// find active servers in room state cache to suggest
@ -112,8 +111,9 @@ fn room_available_servers(
.rooms
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
// push any servers we want in the list already (e.g. responded remote alias
// servers, room alias server itself)

View File

@ -1,18 +1,16 @@
use axum::extract::State;
use conduit::{err, Err};
use ruma::{
api::client::{
backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
get_latest_backup_info, update_backup_version,
},
error::ErrorKind,
api::client::backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
get_latest_backup_info, update_backup_version,
},
UInt,
};
use crate::{Error, Result, Ruma};
use crate::{Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version`
///
@ -40,7 +38,8 @@ pub(crate) async fn update_backup_version_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services
.key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?;
.update_backup(sender_user, &body.version, &body.algorithm)
.await?;
Ok(update_backup_version::v3::Response {})
}
@ -55,14 +54,15 @@ pub(crate) async fn get_latest_backup_info_route(
let (version, algorithm) = services
.key_backups
.get_latest_backup(sender_user)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
.get_latest_backup(sender_user)
.await
.map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?;
Ok(get_latest_backup_info::v3::Response {
algorithm,
count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?)
count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &version)?,
etag: services.key_backups.get_etag(sender_user, &version).await,
version,
})
}
@ -76,18 +76,21 @@ pub(crate) async fn get_backup_info_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let algorithm = services
.key_backups
.get_backup(sender_user, &body.version)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
.get_backup(sender_user, &body.version)
.await
.map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?;
Ok(get_backup_info::v3::Response {
algorithm,
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
version: body.version.clone(),
})
}
@ -105,7 +108,8 @@ pub(crate) async fn delete_backup_version_route(
services
.key_backups
.delete_backup(sender_user, &body.version)?;
.delete_backup(sender_user, &body.version)
.await;
Ok(delete_backup_version::v3::Response {})
}
@ -123,34 +127,36 @@ pub(crate) async fn add_backup_keys_route(
) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version)
!= services
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
if services
.key_backups
.get_latest_backup_version(sender_user)
.await
.is_ok_and(|version| version != body.version)
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
return Err!(Request(InvalidParam(
"You may only manipulate the most recently created version of the backup."
)));
}
for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions {
services
.key_backups
.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
.add_key(sender_user, &body.version, room_id, session_id, key_data)
.await?;
}
}
Ok(add_backup_keys::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}
@ -167,32 +173,34 @@ pub(crate) async fn add_backup_keys_for_room_route(
) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version)
!= services
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
if services
.key_backups
.get_latest_backup_version(sender_user)
.await
.is_ok_and(|version| version != body.version)
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
return Err!(Request(InvalidParam(
"You may only manipulate the most recently created version of the backup."
)));
}
for (session_id, key_data) in &body.sessions {
services
.key_backups
.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)
.await?;
}
Ok(add_backup_keys_for_room::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}
@ -209,30 +217,32 @@ pub(crate) async fn add_backup_keys_for_session_route(
) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version)
!= services
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
if services
.key_backups
.get_latest_backup_version(sender_user)
.await
.is_ok_and(|version| version != body.version)
{
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.",
));
return Err!(Request(InvalidParam(
"You may only manipulate the most recently created version of the backup."
)));
}
services
.key_backups
.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)
.await?;
Ok(add_backup_keys_for_session::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}
@ -244,7 +254,10 @@ pub(crate) async fn get_backup_keys_route(
) -> Result<get_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let rooms = services.key_backups.get_all(sender_user, &body.version)?;
let rooms = services
.key_backups
.get_all(sender_user, &body.version)
.await;
Ok(get_backup_keys::v3::Response {
rooms,
@ -261,7 +274,8 @@ pub(crate) async fn get_backup_keys_for_room_route(
let sessions = services
.key_backups
.get_room(sender_user, &body.version, &body.room_id)?;
.get_room(sender_user, &body.version, &body.room_id)
.await;
Ok(get_backup_keys_for_room::v3::Response {
sessions,
@ -278,8 +292,9 @@ pub(crate) async fn get_backup_keys_for_session_route(
let key_data = services
.key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?;
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)
.await
.map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?;
Ok(get_backup_keys_for_session::v3::Response {
key_data,
@ -296,16 +311,19 @@ pub(crate) async fn delete_backup_keys_route(
services
.key_backups
.delete_all_keys(sender_user, &body.version)?;
.delete_all_keys(sender_user, &body.version)
.await;
Ok(delete_backup_keys::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}
@ -319,16 +337,19 @@ pub(crate) async fn delete_backup_keys_for_room_route(
services
.key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
.delete_room_keys(sender_user, &body.version, &body.room_id)
.await;
Ok(delete_backup_keys_for_room::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}
@ -342,15 +363,18 @@ pub(crate) async fn delete_backup_keys_for_session_route(
services
.key_backups
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)
.await;
Ok(delete_backup_keys_for_session::v3::Response {
count: (UInt::try_from(
services
.key_backups
.count_keys(sender_user, &body.version)?,
)
.expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &body.version)?,
count: services
.key_backups
.count_keys(sender_user, &body.version)
.await
.try_into()?,
etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
})
}

View File

@ -1,4 +1,5 @@
use axum::extract::State;
use conduit::err;
use ruma::{
api::client::{
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
@ -25,7 +26,8 @@ pub(crate) async fn set_global_account_data_route(
&body.sender_user,
&body.event_type.to_string(),
body.data.json(),
)?;
)
.await?;
Ok(set_global_account_data::v3::Response {})
}
@ -42,7 +44,8 @@ pub(crate) async fn set_room_account_data_route(
&body.sender_user,
&body.event_type.to_string(),
body.data.json(),
)?;
)
.await?;
Ok(set_room_account_data::v3::Response {})
}
@ -57,8 +60,9 @@ pub(crate) async fn get_global_account_data_route(
let event: Box<RawJsonValue> = services
.account_data
.get(None, sender_user, body.event_type.to_string().into())?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
.get(None, sender_user, body.event_type.to_string().into())
.await
.map_err(|_| err!(Request(NotFound("Data not found."))))?;
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -79,8 +83,9 @@ pub(crate) async fn get_room_account_data_route(
let event: Box<RawJsonValue> = services
.account_data
.get(Some(&body.room_id), sender_user, body.event_type.clone())?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
.get(Some(&body.room_id), sender_user, body.event_type.clone())
.await
.map_err(|_| err!(Request(NotFound("Data not found."))))?;
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -91,7 +96,7 @@ pub(crate) async fn get_room_account_data_route(
})
}
fn set_account_data(
async fn set_account_data(
services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str,
data: &RawJsonValue,
) -> Result<()> {
@ -100,15 +105,18 @@ fn set_account_data(
let data: serde_json::Value =
serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
services.account_data.update(
room_id,
sender_user,
event_type.into(),
&json!({
"type": event_type,
"content": data,
}),
)?;
services
.account_data
.update(
room_id,
sender_user,
event_type.into(),
&json!({
"type": event_type,
"content": data,
}),
)
.await?;
Ok(())
}

View File

@ -1,13 +1,14 @@
use std::collections::HashSet;
use axum::extract::State;
use conduit::{err, error, Err};
use futures::StreamExt;
use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
api::client::{context::get_context, filter::LazyLoadOptions},
events::StateEventType,
};
use tracing::error;
use crate::{Error, Result, Ruma};
use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
///
@ -35,34 +36,33 @@ pub(crate) async fn get_context_route(
let base_token = services
.rooms
.timeline
.get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
.get_pdu_count(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("Base event id not found."))))?;
let base_event = services
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?;
.get_pdu(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("Base event not found."))))?;
let room_id = base_event.room_id.clone();
let room_id = &base_event.room_id;
if !services
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &body.event_id)?
.user_can_see_event(sender_user, room_id, &body.event_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this event.",
));
return Err!(Request(Forbidden("You don't have permission to view this event.")));
}
if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user,
sender_device,
&room_id,
&base_event.sender,
)? || lazy_load_send_redundant
if !services
.rooms
.lazy_loading
.lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender)
.await || lazy_load_send_redundant
{
lazy_loaded.insert(base_event.sender.as_str().to_owned());
}
@ -75,25 +75,26 @@ pub(crate) async fn get_context_route(
let events_before: Vec<_> = services
.rooms
.timeline
.pdus_until(sender_user, &room_id, base_token)?
.pdus_until(sender_user, room_id, base_token)
.await?
.take(limit / 2)
.filter_map(Result::ok) // Remove buggy events
.filter(|(_, pdu)| {
.filter_map(|(count, pdu)| async move {
services
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false)
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.await
.then_some((count, pdu))
})
.collect();
.collect()
.await;
for (_, event) in &events_before {
if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user,
sender_device,
&room_id,
&event.sender,
)? || lazy_load_send_redundant
if !services
.rooms
.lazy_loading
.lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
.await || lazy_load_send_redundant
{
lazy_loaded.insert(event.sender.as_str().to_owned());
}
@ -111,25 +112,26 @@ pub(crate) async fn get_context_route(
let events_after: Vec<_> = services
.rooms
.timeline
.pdus_after(sender_user, &room_id, base_token)?
.pdus_after(sender_user, room_id, base_token)
.await?
.take(limit / 2)
.filter_map(Result::ok) // Remove buggy events
.filter(|(_, pdu)| {
.filter_map(|(count, pdu)| async move {
services
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id)
.unwrap_or(false)
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.await
.then_some((count, pdu))
})
.collect();
.collect()
.await;
for (_, event) in &events_after {
if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user,
sender_device,
&room_id,
&event.sender,
)? || lazy_load_send_redundant
if !services
.rooms
.lazy_loading
.lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
.await || lazy_load_send_redundant
{
lazy_loaded.insert(event.sender.as_str().to_owned());
}
@ -142,12 +144,14 @@ pub(crate) async fn get_context_route(
events_after
.last()
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
)?
)
.await
.map_or(
services
.rooms
.state
.get_room_shortstatehash(&room_id)?
.get_room_shortstatehash(room_id)
.await
.expect("All rooms have state"),
|hash| hash,
);
@ -156,7 +160,8 @@ pub(crate) async fn get_context_route(
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?;
.await
.map_err(|e| err!(Database("State not found: {e}")))?;
let end_token = events_after
.last()
@ -173,18 +178,19 @@ pub(crate) async fn get_context_route(
let (event_type, state_key) = services
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
.get_statekey_from_short(shortstatekey)
.await?;
if event_type != StateEventType::RoomMember {
let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id);
let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else {
error!("Pdu in state not found: {id}");
continue;
};
state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id);
let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else {
error!("Pdu in state not found: {id}");
continue;
};

View File

@ -1,4 +1,6 @@
use axum::extract::State;
use conduit::{err, Err};
use futures::StreamExt;
use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
error::ErrorKind,
@ -19,8 +21,8 @@ pub(crate) async fn get_devices_route(
let devices: Vec<device::Device> = services
.users
.all_devices_metadata(sender_user)
.filter_map(Result::ok) // Filter out buggy devices
.collect();
.collect()
.await;
Ok(get_devices::v3::Response {
devices,
@ -37,8 +39,9 @@ pub(crate) async fn get_device_route(
let device = services
.users
.get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
.get_device_metadata(sender_user, &body.body.device_id)
.await
.map_err(|_| err!(Request(NotFound("Device not found."))))?;
Ok(get_device::v3::Response {
device,
@ -55,14 +58,16 @@ pub(crate) async fn update_device_route(
let mut device = services
.users
.get_device_metadata(sender_user, &body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
.get_device_metadata(sender_user, &body.device_id)
.await
.map_err(|_| err!(Request(NotFound("Device not found."))))?;
device.display_name.clone_from(&body.display_name);
services
.users
.update_device_metadata(sender_user, &body.device_id, &device)?;
.update_device_metadata(sender_user, &body.device_id, &device)
.await?;
Ok(update_device::v3::Response {})
}
@ -97,22 +102,28 @@ pub(crate) async fn delete_device_route(
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
.try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
return Err!(Uiaa(uiaainfo));
}
// Success!
} else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo));
.create(sender_user, sender_device, &uiaainfo, &json);
return Err!(Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
return Err!(Request(NotJson("Not json.")));
}
services.users.remove_device(sender_user, &body.device_id)?;
services
.users
.remove_device(sender_user, &body.device_id)
.await;
Ok(delete_device::v3::Response {})
}
@ -149,7 +160,9 @@ pub(crate) async fn delete_devices_route(
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
.try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
@ -158,14 +171,15 @@ pub(crate) async fn delete_devices_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
.create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
for device_id in &body.devices {
services.users.remove_device(sender_user, device_id)?;
services.users.remove_device(sender_user, device_id).await;
}
Ok(delete_devices::v3::Response {})

View File

@ -1,6 +1,7 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{err, info, warn, Err, Error, Result};
use conduit::{info, warn, Err, Error, Result};
use futures::{StreamExt, TryFutureExt};
use ruma::{
api::{
client::{
@ -18,7 +19,7 @@ use ruma::{
},
StateEventType,
},
uint, RoomId, ServerName, UInt, UserId,
uint, OwnedRoomId, RoomId, ServerName, UInt, UserId,
};
use service::Services;
@ -119,16 +120,22 @@ pub(crate) async fn set_room_visibility_route(
) -> Result<set_room_visibility::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services.rooms.metadata.exists(&body.room_id)? {
if !services.rooms.metadata.exists(&body.room_id).await {
// Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
}
if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() {
if services
.users
.is_deactivated(sender_user)
.await
.unwrap_or(false)
&& body.appservice_info.is_none()
{
return Err!(Request(Forbidden("Guests cannot publish to room directories")));
}
if !user_can_publish_room(&services, sender_user, &body.room_id)? {
if !user_can_publish_room(&services, sender_user, &body.room_id).await? {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"User is not allowed to publish this room",
@ -138,7 +145,7 @@ pub(crate) async fn set_room_visibility_route(
match &body.visibility {
room::Visibility::Public => {
if services.globals.config.lockdown_public_room_directory
&& !services.users.is_admin(sender_user)?
&& !services.users.is_admin(sender_user).await
&& body.appservice_info.is_none()
{
info!(
@ -164,7 +171,7 @@ pub(crate) async fn set_room_visibility_route(
));
}
services.rooms.directory.set_public(&body.room_id)?;
services.rooms.directory.set_public(&body.room_id);
if services.globals.config.admin_room_notices {
services
@ -174,7 +181,7 @@ pub(crate) async fn set_room_visibility_route(
}
info!("{sender_user} made {0} public to the room directory", body.room_id);
},
room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?,
room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id),
_ => {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
@ -192,13 +199,13 @@ pub(crate) async fn set_room_visibility_route(
pub(crate) async fn get_room_visibility_route(
State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>,
) -> Result<get_room_visibility::v3::Response> {
if !services.rooms.metadata.exists(&body.room_id)? {
if !services.rooms.metadata.exists(&body.room_id).await {
// Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
}
Ok(get_room_visibility::v3::Response {
visibility: if services.rooms.directory.is_public_room(&body.room_id)? {
visibility: if services.rooms.directory.is_public_room(&body.room_id).await {
room::Visibility::Public
} else {
room::Visibility::Private
@ -257,101 +264,41 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}
}
let mut all_rooms: Vec<_> = services
let mut all_rooms: Vec<PublicRoomsChunk> = services
.rooms
.directory
.public_rooms()
.map(|room_id| {
let room_id = room_id?;
let chunk = PublicRoomsChunk {
canonical_alias: services
.rooms
.state_accessor
.get_canonical_alias(&room_id)?,
name: services.rooms.state_accessor.get_name(&room_id)?,
num_joined_members: services
.rooms
.state_cache
.room_joined_count(&room_id)?
.unwrap_or_else(|| {
warn!("Room {} has no member count", room_id);
0
})
.try_into()
.expect("user count should not be that big"),
topic: services
.rooms
.state_accessor
.get_room_topic(&room_id)
.unwrap_or(None),
world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?,
guest_can_join: services
.rooms
.state_accessor
.guest_can_join(&room_id)?,
avatar_url: services
.rooms
.state_accessor
.get_avatar(&room_id)?
.into_option()
.unwrap_or_default()
.url,
join_rule: services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => Some(PublicRoomJoinRule::Public),
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
_ => None,
})
.map_err(|e| {
err!(Database(error!("Invalid room join rule event in database: {e}")))
})
})
.transpose()?
.flatten()
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)?,
room_id,
};
Ok(chunk)
})
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
.filter(|chunk| {
.map(ToOwned::to_owned)
.then(|room_id| public_rooms_chunk(services, room_id))
.filter_map(|chunk| async move {
if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
if let Some(name) = &chunk.name {
if name.as_str().to_lowercase().contains(&query) {
return true;
return Some(chunk);
}
}
if let Some(topic) = &chunk.topic {
if topic.to_lowercase().contains(&query) {
return true;
return Some(chunk);
}
}
if let Some(canonical_alias) = &chunk.canonical_alias {
if canonical_alias.as_str().to_lowercase().contains(&query) {
return true;
return Some(chunk);
}
}
false
} else {
// No search term
true
return None;
}
// No search term
Some(chunk)
})
// We need to collect all, so we can sort by member count
.collect();
.collect()
.await;
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
@ -394,22 +341,23 @@ pub(crate) async fn get_public_rooms_filtered_helper(
/// Check whether the user can publish to the room directory via power levels of
/// room history visibility event or room creator
fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
if let Some(event) = services
async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
if let Ok(event) = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")
.await
{
serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels"))
.map(|content: RoomPowerLevelsEventContent| {
RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility)
})
} else if let Some(event) =
services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
} else if let Ok(event) = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
{
Ok(event.sender == user_id)
} else {
@ -419,3 +367,61 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId
));
}
}
async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk {
PublicRoomsChunk {
canonical_alias: services
.rooms
.state_accessor
.get_canonical_alias(&room_id)
.await
.ok(),
name: services.rooms.state_accessor.get_name(&room_id).await.ok(),
num_joined_members: services
.rooms
.state_cache
.room_joined_count(&room_id)
.await
.unwrap_or(0)
.try_into()
.expect("joined count overflows ruma UInt"),
topic: services
.rooms
.state_accessor
.get_room_topic(&room_id)
.await
.ok(),
world_readable: services
.rooms
.state_accessor
.is_world_readable(&room_id)
.await,
guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await,
avatar_url: services
.rooms
.state_accessor
.get_avatar(&room_id)
.await
.into_option()
.unwrap_or_default()
.url,
join_rule: services
.rooms
.state_accessor
.room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "")
.map_ok(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => PublicRoomJoinRule::Public,
JoinRule::Knock => PublicRoomJoinRule::Knock,
_ => "invite".into(),
})
.await
.unwrap_or_default(),
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)
.await
.ok(),
room_id,
}
}

View File

@ -1,10 +1,8 @@
use axum::extract::State;
use ruma::api::client::{
error::ErrorKind,
filter::{create_filter, get_filter},
};
use conduit::err;
use ruma::api::client::filter::{create_filter, get_filter};
use crate::{Error, Result, Ruma};
use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
///
@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route(
State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>,
) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found."));
};
Ok(get_filter::v3::Response::new(filter))
services
.users
.get_filter(sender_user, &body.filter_id)
.await
.map(get_filter::v3::Response::new)
.map_err(|_| err!(Request(NotFound("Filter not found."))))
}
/// # `PUT /_matrix/client/r0/user/{userId}/filter`
@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route(
State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>,
) -> Result<create_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(create_filter::v3::Response::new(
services.users.create_filter(sender_user, &body.filter)?,
))
let filter_id = services.users.create_filter(sender_user, &body.filter);
Ok(create_filter::v3::Response::new(filter_id))
}

View File

@ -4,8 +4,8 @@ use std::{
};
use axum::extract::State;
use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result};
use futures_util::{stream::FuturesUnordered, StreamExt};
use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result};
use futures::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::{
client::{
@ -21,7 +21,10 @@ use ruma::{
use serde_json::json;
use super::SESSION_ID_LENGTH;
use crate::{service::Services, Ruma};
use crate::{
service::{users::parse_master_key, Services},
Ruma,
};
/// # `POST /_matrix/client/r0/keys/upload`
///
@ -39,7 +42,8 @@ pub(crate) async fn upload_keys_route(
for (key_key, key_value) in &body.one_time_keys {
services
.users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
.add_one_time_key(sender_user, sender_device, key_key, key_value)
.await?;
}
if let Some(device_keys) = &body.device_keys {
@ -47,19 +51,22 @@ pub(crate) async fn upload_keys_route(
// This check is needed to assure that signatures are kept
if services
.users
.get_device_keys(sender_user, sender_device)?
.is_none()
.get_device_keys(sender_user, sender_device)
.await
.is_err()
{
services
.users
.add_device_keys(sender_user, sender_device, device_keys)?;
.add_device_keys(sender_user, sender_device, device_keys)
.await;
}
}
Ok(upload_keys::v3::Response {
one_time_key_counts: services
.users
.count_one_time_keys(sender_user, sender_device)?,
.count_one_time_keys(sender_user, sender_device)
.await,
})
}
@ -120,7 +127,9 @@ pub(crate) async fn upload_signing_keys_route(
if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
.try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked {
return Err(Error::Uiaa(uiaainfo));
}
@ -129,20 +138,24 @@ pub(crate) async fn upload_signing_keys_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
.create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo));
} else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
if let Some(master_key) = &body.master_key {
services.users.add_cross_signing_keys(
sender_user,
master_key,
&body.self_signing_key,
&body.user_signing_key,
true, // notify so that other users see the new keys
)?;
services
.users
.add_cross_signing_keys(
sender_user,
master_key,
&body.self_signing_key,
&body.user_signing_key,
true, // notify so that other users see the new keys
)
.await?;
}
Ok(upload_signing_keys::v3::Response {})
@ -179,9 +192,11 @@ pub(crate) async fn upload_signatures_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
.to_owned(),
);
services
.users
.sign_key(user_id, key_id, signature, sender_user)?;
.sign_key(user_id, key_id, signature, sender_user)
.await?;
}
}
}
@ -204,56 +219,51 @@ pub(crate) async fn get_key_changes_route(
let mut device_list_updates = HashSet::new();
let from = body
.from
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?;
let to = body
.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?;
device_list_updates.extend(
services
.users
.keys_changed(
sender_user.as_str(),
body.from
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
)
.filter_map(Result::ok),
.keys_changed(sender_user.as_str(), from, Some(to))
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
for room_id in services
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
{
let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed();
while let Some(room_id) = rooms_joined.next().await {
device_list_updates.extend(
services
.users
.keys_changed(
room_id.as_ref(),
body.from
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
)
.filter_map(Result::ok),
.keys_changed(room_id.as_ref(), from, Some(to))
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
}
Ok(get_key_changes::v3::Response {
changed: device_list_updates.into_iter().collect(),
left: Vec::new(), // TODO
})
}
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
pub(crate) async fn get_keys_helper<F>(
services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
allowed_signatures: F, include_display_names: bool,
) -> Result<get_keys::v3::Response> {
) -> Result<get_keys::v3::Response>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let mut master_keys = BTreeMap::new();
let mut self_signing_keys = BTreeMap::new();
let mut user_signing_keys = BTreeMap::new();
@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
if device_ids.is_empty() {
let mut container = BTreeMap::new();
for device_id in services.users.all_device_ids(user_id) {
let device_id = device_id?;
if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? {
let mut devices = services.users.all_device_ids(user_id).boxed();
while let Some(device_id) = devices.next().await {
if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await {
let metadata = services
.users
.get_device_metadata(user_id, &device_id)?
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
.get_device_metadata(user_id, device_id)
.await
.map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
.map_err(|_| err!(Database("invalid device keys in database")))?;
container.insert(device_id, keys);
container.insert(device_id.to_owned(), keys);
}
}
device_keys.insert(user_id.to_owned(), container);
} else {
for device_id in device_ids {
let mut container = BTreeMap::new();
if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? {
if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await {
let metadata = services
.users
.get_device_metadata(user_id, device_id)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to get keys for nonexistent device.",
))?;
.get_device_metadata(user_id, device_id)
.await
.map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
.map_err(|_| err!(Database("invalid device keys in database")))?;
container.insert(device_id.to_owned(), keys);
}
device_keys.insert(user_id.to_owned(), container);
}
}
if let Some(master_key) = services
if let Ok(master_key) = services
.users
.get_master_key(sender_user, user_id, &allowed_signatures)?
.get_master_key(sender_user, user_id, &allowed_signatures)
.await
{
master_keys.insert(user_id.to_owned(), master_key);
}
if let Some(self_signing_key) =
services
.users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
if let Ok(self_signing_key) = services
.users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)
.await
{
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
}
if Some(user_id) == sender_user {
if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? {
if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await {
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
}
}
@ -386,23 +400,26 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
while let Some((server, response)) = futures.next().await {
if let Ok(Ok(response)) = response {
for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?;
let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) =
services
.users
.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
if let Ok(our_master_key) = services
.users
.get_key(&master_key_id, sender_user, &user, &allowed_signatures)
.await
{
let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?;
let (_, our_master_key) = parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures);
}
let json = serde_json::to_value(master_key).expect("to_value always works");
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
services.users.add_cross_signing_keys(
&user, &raw, &None, &None,
false, /* Dont notify. A notification would trigger another key request resulting in an
* endless loop */
)?;
services
.users
.add_cross_signing_keys(
&user, &raw, &None, &None,
false, /* Dont notify. A notification would trigger another key request resulting in an
* endless loop */
)
.await?;
master_keys.insert(user.clone(), raw);
}
@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = services
if let Ok(one_time_keys) = services
.users
.take_one_time_key(user_id, device_id, key_algorithm)?
.take_one_time_key(user_id, device_id, key_algorithm)
.await
{
let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1);

View File

@ -11,9 +11,10 @@ use conduit::{
debug, debug_error, debug_warn, err, error, info,
pdu::{gen_event_id_canonical_json, PduBuilder},
trace, utils,
utils::math::continue_exponential_backoff_secs,
utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt},
warn, Err, Error, PduEvent, Result,
};
use futures::{FutureExt, StreamExt};
use ruma::{
api::{
client::{
@ -55,9 +56,9 @@ async fn banned_room_check(
services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>,
client_ip: IpAddr,
) -> Result<()> {
if !services.users.is_admin(user_id)? {
if !services.users.is_admin(user_id).await {
if let Some(room_id) = room_id {
if services.rooms.metadata.is_banned(room_id)?
if services.rooms.metadata.is_banned(room_id).await
|| services
.globals
.config
@ -79,23 +80,22 @@ async fn banned_room_check(
"Automatically deactivating user {user_id} due to attempted banned room join from IP \
{client_ip}"
)))
.await;
.await
.ok();
}
let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
full_user_deactivate(services, user_id, all_joined_rooms).await?;
full_user_deactivate(services, user_id, &all_joined_rooms).await?;
}
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"This room is banned on this homeserver.",
));
return Err!(Request(Forbidden("This room is banned on this homeserver.")));
}
} else if let Some(server_name) = server_name {
if services
@ -119,23 +119,22 @@ async fn banned_room_check(
"Automatically deactivating user {user_id} due to attempted banned room join from IP \
{client_ip}"
)))
.await;
.await
.ok();
}
let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
full_user_deactivate(services, user_id, all_joined_rooms).await?;
full_user_deactivate(services, user_id, &all_joined_rooms).await?;
}
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"This remote server is banned on this homeserver.",
));
return Err!(Request(Forbidden("This remote server is banned on this homeserver.")));
}
}
}
@ -172,14 +171,16 @@ pub(crate) async fn join_room_by_id_route(
.rooms
.state_cache
.servers_invite_via(&body.room_id)
.filter_map(Result::ok)
.collect::<Vec<_>>();
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
servers.extend(
services
.rooms
.state_cache
.invite_state(sender_user, &body.room_id)?
.invite_state(sender_user, &body.room_id)
.await
.unwrap_or_default()
.iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -202,6 +203,7 @@ pub(crate) async fn join_room_by_id_route(
body.third_party_signed.as_ref(),
&body.appservice_info,
)
.boxed()
.await
}
@ -233,14 +235,17 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.rooms
.state_cache
.servers_invite_via(&room_id)
.filter_map(Result::ok),
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
servers.extend(
services
.rooms
.state_cache
.invite_state(sender_user, &room_id)?
.invite_state(sender_user, &room_id)
.await
.unwrap_or_default()
.iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -270,19 +275,23 @@ pub(crate) async fn join_room_by_id_or_alias_route(
if let Some(pre_servers) = &mut pre_servers {
servers.append(pre_servers);
}
servers.extend(
services
.rooms
.state_cache
.servers_invite_via(&room_id)
.filter_map(Result::ok),
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
);
servers.extend(
services
.rooms
.state_cache
.invite_state(sender_user, &room_id)?
.invite_state(sender_user, &room_id)
.await
.unwrap_or_default()
.iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -305,6 +314,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
body.third_party_signed.as_ref(),
appservice_info,
)
.boxed()
.await?;
Ok(join_room_by_id_or_alias::v3::Response {
@ -337,7 +347,7 @@ pub(crate) async fn invite_user_route(
) -> Result<invite_user::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() {
if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() {
info!(
"User {sender_user} is not an admin and attempted to send an invite to room {}",
&body.room_id
@ -375,15 +385,13 @@ pub(crate) async fn kick_user_route(
services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())?
.ok_or(Error::BadRequest(
ErrorKind::BadState,
"Cannot kick member that's not in the room.",
))?
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.await
.map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid member event in database."))?;
.map_err(|_| err!(Database("Invalid member event in database.")))?;
event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason);
@ -421,10 +429,13 @@ pub(crate) async fn ban_user_route(
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let blurhash = services.users.blurhash(&body.user_id).await.ok();
let event = services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())?
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.await
.map_or(
Ok(RoomMemberEventContent {
membership: MembershipState::Ban,
@ -432,7 +443,7 @@ pub(crate) async fn ban_user_route(
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(),
blurhash: blurhash.clone(),
reason: body.reason.clone(),
join_authorized_via_users_server: None,
}),
@ -442,12 +453,12 @@ pub(crate) async fn ban_user_route(
membership: MembershipState::Ban,
displayname: None,
avatar_url: None,
blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(),
blurhash: blurhash.clone(),
reason: body.reason.clone(),
join_authorized_via_users_server: None,
..event
})
.map_err(|_| Error::bad_database("Invalid member event in database."))
.map_err(|e| err!(Database("Invalid member event in database: {e:?}")))
},
)?;
@ -488,12 +499,13 @@ pub(crate) async fn unban_user_route(
services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())?
.ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))?
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.await
.map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid member event in database."))?;
.map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?;
event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason);
@ -539,18 +551,16 @@ pub(crate) async fn forget_room_route(
if services
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
.is_joined(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::Unknown,
"You must leave the room before forgetting it",
));
return Err!(Request(Unknown("You must leave the room before forgetting it")));
}
services
.rooms
.state_cache
.forget(&body.room_id, sender_user)?;
.forget(&body.room_id, sender_user);
Ok(forget_room::v3::Response::new())
}
@ -568,8 +578,9 @@ pub(crate) async fn joined_rooms_route(
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
.collect(),
.map(ToOwned::to_owned)
.collect()
.await,
})
}
@ -587,12 +598,10 @@ pub(crate) async fn get_member_events_route(
if !services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
.user_can_see_state_events(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
return Err!(Request(Forbidden("You don't have permission to view this room.")));
}
Ok(get_member_events::v3::Response {
@ -622,30 +631,27 @@ pub(crate) async fn joined_members_route(
if !services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
.user_can_see_state_events(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
return Err!(Request(Forbidden("You don't have permission to view this room.")));
}
let joined: BTreeMap<OwnedUserId, RoomMember> = services
.rooms
.state_cache
.room_members(&body.room_id)
.filter_map(|user| {
let user = user.ok()?;
Some((
user.clone(),
.then(|user| async move {
(
user.to_owned(),
RoomMember {
display_name: services.users.displayname(&user).unwrap_or_default(),
avatar_url: services.users.avatar_url(&user).unwrap_or_default(),
display_name: services.users.displayname(user).await.ok(),
avatar_url: services.users.avatar_url(user).await.ok(),
},
))
)
})
.collect();
.collect()
.await;
Ok(joined_members::v3::Response {
joined,
@ -658,13 +664,23 @@ pub async fn join_room_by_id_helper(
) -> Result<join_room_by_id::v3::Response> {
let state_lock = services.rooms.state.mutex.lock(room_id).await;
let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none();
let user_is_guest = services
.users
.is_deactivated(sender_user)
.await
.unwrap_or(false)
&& appservice_info.is_none();
if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest {
if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await {
return Err!(Request(Forbidden("Guests are not allowed to join this room")));
}
if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) {
if services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
debug_warn!("{sender_user} is already joined in {room_id}");
return Ok(join_room_by_id::v3::Response {
room_id: room_id.into(),
@ -674,15 +690,17 @@ pub async fn join_room_by_id_helper(
if services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), room_id)?
|| servers.is_empty()
.server_in_room(services.globals.server_name(), room_id)
.await || servers.is_empty()
|| (servers.len() == 1 && services.globals.server_is_ours(&servers[0]))
{
join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock)
.boxed()
.await
} else {
// Ask a remote server if we are not participating in this room
join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock)
.boxed()
.await
}
}
@ -739,11 +757,11 @@ async fn join_room_by_id_helper_remote(
"content".to_owned(),
to_canonical_value(RoomMemberEventContent {
membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?,
avatar_url: services.users.avatar_url(sender_user)?,
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None,
third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?,
blurhash: services.users.blurhash(sender_user).await.ok(),
reason,
join_authorized_via_users_server: join_authorized_via_users_server.clone(),
})
@ -791,10 +809,11 @@ async fn join_room_by_id_helper_remote(
federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
omit_members: false,
pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false,
.convert_to_outgoing_federation_event(join_event.clone())
.await,
},
)
.await?;
@ -864,7 +883,11 @@ async fn join_room_by_id_helper_remote(
}
}
services.rooms.short.get_or_create_shortroomid(room_id)?;
services
.rooms
.short
.get_or_create_shortroomid(room_id)
.await;
info!("Parsing join event");
let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone())
@ -895,12 +918,13 @@ async fn join_room_by_id_helper_remote(
err!(BadServerResponse("Invalid PDU in send_join response: {e:?}"))
})?;
services.rooms.outlier.add_pdu_outlier(&event_id, &value)?;
services.rooms.outlier.add_pdu_outlier(&event_id, &value);
if let Some(state_key) = &pdu.state_key {
let shortstatekey = services
.rooms
.short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?;
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)
.await;
state.insert(shortstatekey, pdu.event_id.clone());
}
}
@ -916,50 +940,53 @@ async fn join_room_by_id_helper_remote(
continue;
};
services.rooms.outlier.add_pdu_outlier(&event_id, &value)?;
services.rooms.outlier.add_pdu_outlier(&event_id, &value);
}
debug!("Running send_join auth check");
let fetch_state = &state;
let state_fetch = |k: &'static StateEventType, s: String| async move {
let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?;
let event_id = fetch_state.get(&shortstatekey)?;
services.rooms.timeline.get_pdu(event_id).await.ok()
};
let auth_check = state_res::event_auth::auth_check(
&state_res::RoomVersion::new(&room_version_id).expect("room version is supported"),
&parsed_join_pdu,
None::<PduEvent>, // TODO: third party invite
|k, s| {
services
.rooms
.timeline
.get_pdu(
state.get(
&services
.rooms
.short
.get_or_create_shortstatekey(&k.to_string().into(), s)
.ok()?,
)?,
)
.ok()?
},
None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()),
)
.map_err(|e| {
warn!("Auth check failed: {e}");
Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")
})?;
.await
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
if !auth_check {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed"));
return Err!(Request(Forbidden("Auth check failed")));
}
info!("Saving state from send_join");
let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state(
room_id,
Arc::new(
state
.into_iter()
.map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id))
.collect::<Result<_>>()?,
),
)?;
let (statehash_before_join, new, removed) = services
.rooms
.state_compressor
.save_state(
room_id,
Arc::new(
state
.into_iter()
.stream()
.then(|(k, id)| async move {
services
.rooms
.state_compressor
.compress_state_event(k, &id)
.await
})
.collect()
.await,
),
)
.await?;
services
.rooms
@ -968,12 +995,20 @@ async fn join_room_by_id_helper_remote(
.await?;
info!("Updating joined counts for new room");
services.rooms.state_cache.update_joined_count(room_id)?;
services
.rooms
.state_cache
.update_joined_count(room_id)
.await;
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?;
let statehash_after_join = services
.rooms
.state
.append_to_state(&parsed_join_pdu)
.await?;
info!("Appending new room join event");
services
@ -993,7 +1028,7 @@ async fn join_room_by_id_helper_remote(
services
.rooms
.state
.set_room_state(room_id, statehash_after_join, &state_lock)?;
.set_room_state(room_id, statehash_after_join, &state_lock);
Ok(join_room_by_id::v3::Response::new(room_id.to_owned()))
}
@ -1005,23 +1040,15 @@ async fn join_room_by_id_helper_local(
) -> Result<join_room_by_id::v3::Response> {
debug!("We can join locally");
let join_rules_event = services
let join_rules_event_content = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
})
})
.transpose()?;
.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
.map(|content: RoomJoinRulesEventContent| content);
let restriction_rooms = match join_rules_event_content {
Some(RoomJoinRulesEventContent {
Ok(RoomJoinRulesEventContent {
join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted),
}) => restricted
.allow
@ -1034,29 +1061,34 @@ async fn join_room_by_id_helper_local(
_ => Vec::new(),
};
let local_members = services
let local_members: Vec<_> = services
.rooms
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.filter(|user| services.globals.user_is_local(user))
.collect::<Vec<OwnedUserId>>();
.ready_filter(|user| services.globals.user_is_local(user))
.map(ToOwned::to_owned)
.collect()
.await;
let mut join_authorized_via_users_server: Option<OwnedUserId> = None;
if restriction_rooms.iter().any(|restriction_room_id| {
services
.rooms
.state_cache
.is_joined(sender_user, restriction_room_id)
.unwrap_or(false)
}) {
if restriction_rooms
.iter()
.stream()
.any(|restriction_room_id| {
services
.rooms
.state_cache
.is_joined(sender_user, restriction_room_id)
})
.await
{
for user in local_members {
if services
.rooms
.state_accessor
.user_can_invite(room_id, &user, sender_user, &state_lock)
.unwrap_or(false)
.await
{
join_authorized_via_users_server = Some(user);
break;
@ -1066,11 +1098,11 @@ async fn join_room_by_id_helper_local(
let event = RoomMemberEventContent {
membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?,
avatar_url: services.users.avatar_url(sender_user)?,
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None,
third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?,
blurhash: services.users.blurhash(sender_user).await.ok(),
reason: reason.clone(),
join_authorized_via_users_server,
};
@ -1144,11 +1176,11 @@ async fn join_room_by_id_helper_local(
"content".to_owned(),
to_canonical_value(RoomMemberEventContent {
membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?,
avatar_url: services.users.avatar_url(sender_user)?,
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None,
third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?,
blurhash: services.users.blurhash(sender_user).await.ok(),
reason,
join_authorized_via_users_server,
})
@ -1195,10 +1227,11 @@ async fn join_room_by_id_helper_local(
federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
omit_members: false,
pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false,
.convert_to_outgoing_federation_event(join_event.clone())
.await,
},
)
.await?;
@ -1369,7 +1402,7 @@ pub(crate) async fn invite_helper(
services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>,
is_direct: bool,
) -> Result<()> {
if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() {
if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() {
info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}");
return Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -1381,7 +1414,7 @@ pub(crate) async fn invite_helper(
let (pdu, pdu_json, invite_room_state) = {
let state_lock = services.rooms.state.mutex.lock(room_id).await;
let content = to_raw_value(&RoomMemberEventContent {
avatar_url: services.users.avatar_url(user_id)?,
avatar_url: services.users.avatar_url(user_id).await.ok(),
displayname: None,
is_direct: Some(is_direct),
membership: MembershipState::Invite,
@ -1392,28 +1425,32 @@ pub(crate) async fn invite_helper(
})
.expect("member event is valid value");
let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
sender_user,
room_id,
&state_lock,
)?;
let (pdu, pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
sender_user,
room_id,
&state_lock,
)
.await?;
let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?;
let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?;
drop(state_lock);
(pdu, pdu_json, invite_room_state)
};
let room_version_id = services.rooms.state.get_room_version(room_id)?;
let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let response = services
.sending
@ -1425,9 +1462,15 @@ pub(crate) async fn invite_helper(
room_version: room_version_id.clone(),
event: services
.sending
.convert_to_outgoing_federation_event(pdu_json.clone()),
.convert_to_outgoing_federation_event(pdu_json.clone())
.await,
invite_room_state,
via: services.rooms.state_cache.servers_route_via(room_id).ok(),
via: services
.rooms
.state_cache
.servers_route_via(room_id)
.await
.ok(),
},
)
.await?;
@ -1478,11 +1521,16 @@ pub(crate) async fn invite_helper(
"Could not accept incoming PDU as timeline event.",
))?;
services.sending.send_pdu_room(room_id, &pdu_id)?;
services.sending.send_pdu_room(room_id, &pdu_id).await?;
return Ok(());
}
if !services.rooms.state_cache.is_joined(sender_user, room_id)? {
if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
@ -1499,11 +1547,11 @@ pub(crate) async fn invite_helper(
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite,
displayname: services.users.displayname(user_id)?,
avatar_url: services.users.avatar_url(user_id)?,
displayname: services.users.displayname(user_id).await.ok(),
avatar_url: services.users.avatar_url(user_id).await.ok(),
is_direct: Some(is_direct),
third_party_invite: None,
blurhash: services.users.blurhash(user_id)?,
blurhash: services.users.blurhash(user_id).await.ok(),
reason,
join_authorized_via_users_server: None,
})
@ -1531,36 +1579,37 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) {
.rooms
.state_cache
.rooms_joined(user_id)
.map(ToOwned::to_owned)
.chain(
services
.rooms
.state_cache
.rooms_invited(user_id)
.map(|t| t.map(|(r, _)| r)),
.map(|(r, _)| r),
)
.collect::<Vec<_>>();
.collect::<Vec<_>>()
.await;
for room_id in all_rooms {
let Ok(room_id) = room_id else {
continue;
};
// ignore errors
if let Err(e) = leave_room(services, user_id, &room_id, None).await {
warn!(%room_id, %user_id, %e, "Failed to leave room");
}
if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) {
warn!(%room_id, %user_id, %e, "Failed to forget room");
}
services.rooms.state_cache.forget(&room_id, user_id);
}
}
pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> {
//use conduit::utils::stream::OptionStream;
use futures::TryFutureExt;
// Ask a remote server if we don't have this room
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), room_id)?
.server_in_room(services.globals.server_name(), room_id)
.await
{
if let Err(e) = remote_leave_room(services, user_id, room_id).await {
warn!("Failed to leave room {} remotely: {}", user_id, e);
@ -1570,34 +1619,42 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId,
let last_state = services
.rooms
.state_cache
.invite_state(user_id, room_id)?
.map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?;
.invite_state(user_id, room_id)
.map_err(|_| services.rooms.state_cache.left_state(user_id, room_id))
.await
.ok();
// We always drop the invite, we can't rely on other servers
services.rooms.state_cache.update_membership(
room_id,
user_id,
RoomMemberEventContent::new(MembershipState::Leave),
user_id,
last_state,
None,
true,
)?;
services
.rooms
.state_cache
.update_membership(
room_id,
user_id,
RoomMemberEventContent::new(MembershipState::Leave),
user_id,
last_state,
None,
true,
)
.await?;
} else {
let state_lock = services.rooms.state.mutex.lock(room_id).await;
let member_event =
services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?;
let member_event = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.await;
// Fix for broken rooms
let member_event = match member_event {
None => {
error!("Trying to leave a room you are not a member of.");
let Ok(member_event) = member_event else {
error!("Trying to leave a room you are not a member of.");
services.rooms.state_cache.update_membership(
services
.rooms
.state_cache
.update_membership(
room_id,
user_id,
RoomMemberEventContent::new(MembershipState::Leave),
@ -1605,16 +1662,14 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId,
None,
None,
true,
)?;
return Ok(());
},
Some(e) => e,
)
.await?;
return Ok(());
};
let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| {
error!("Invalid room member event in database: {}", e);
Error::bad_database("Invalid member event in database.")
})?;
let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get())
.map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?;
event.membership = MembershipState::Leave;
event.reason = reason;
@ -1647,15 +1702,17 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
let invite_state = services
.rooms
.state_cache
.invite_state(user_id, room_id)?
.ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?;
.invite_state(user_id, room_id)
.await
.map_err(|_| err!(Request(BadState("User is not invited."))))?;
let mut servers: HashSet<OwnedServerName> = services
.rooms
.state_cache
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
servers.extend(
invite_state
@ -1760,7 +1817,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
event_id,
pdu: services
.sending
.convert_to_outgoing_federation_event(leave_event.clone()),
.convert_to_outgoing_federation_event(leave_event.clone())
.await,
},
)
.await?;

View File

@ -1,7 +1,8 @@
use std::collections::{BTreeMap, HashSet};
use axum::extract::State;
use conduit::PduCount;
use conduit::{err, utils::ReadyExt, Err, PduCount};
use futures::{FutureExt, StreamExt};
use ruma::{
api::client::{
error::ErrorKind,
@ -9,13 +10,14 @@ use ruma::{
message::{get_message_events, send_message_event},
},
events::{MessageLikeEventType, StateEventType},
RoomId, UserId,
UserId,
};
use serde_json::{from_str, Value};
use service::rooms::timeline::PdusIterItem;
use crate::{
service::{pdu::PduBuilder, Services},
utils, Error, PduEvent, Result, Ruma,
utils, Error, Result, Ruma,
};
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
@ -30,79 +32,78 @@ use crate::{
pub(crate) async fn send_message_event_route(
State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref();
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let appservice_info = body.appservice_info.as_ref();
// Forbid m.room.encrypted if encryption is disabled
if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled"));
return Err!(Request(Forbidden("Encryption has been disabled")));
}
if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Room call invites are not allowed in public rooms",
));
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
if body.event_type == MessageLikeEventType::CallInvite
&& services.rooms.directory.is_public_room(&body.room_id).await
{
return Err!(Request(Forbidden("Room call invites are not allowed in public rooms")));
}
// Check if this is a new transaction id
if let Some(response) = services
if let Ok(response) = services
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)?
.existing_txnid(sender_user, sender_device, &body.txn_id)
.await
{
// The client might have sent a txnid of the /sendToDevice endpoint
// This txnid has no response associated with it
if response.is_empty() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to use txn id already used for an incompatible endpoint.",
));
return Err!(Request(InvalidParam(
"Tried to use txn id already used for an incompatible endpoint."
)));
}
let event_id = utils::string_from_bytes(&response)
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
return Ok(send_message_event::v3::Response {
event_id,
event_id: utils::string_from_bytes(&response)
.map(TryInto::try_into)
.map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??,
});
}
let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let content = from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?;
let event_id = services
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: body.event_type.to_string().into(),
content: from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
content,
unsigned: Some(unsigned),
state_key: None,
redacts: None,
timestamp: if body.appservice_info.is_some() {
body.timestamp
} else {
None
},
timestamp: appservice_info.and(body.timestamp),
},
sender_user,
&body.room_id,
&state_lock,
)
.await?;
.await
.map(|event_id| (*event_id).to_owned())?;
services
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes());
drop(state_lock);
Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
Ok(send_message_event::v3::Response {
event_id,
})
}
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
@ -117,8 +118,12 @@ pub(crate) async fn get_message_events_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
let room_id = &body.room_id;
let filter = &body.filter;
let limit = usize::try_from(body.limit).unwrap_or(10).min(100);
let from = match body.from.as_ref() {
Some(from) => PduCount::try_from_string(from)?,
None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
@ -133,30 +138,25 @@ pub(crate) async fn get_message_events_route(
services
.rooms
.lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)
.await?;
let limit = usize::try_from(body.limit).unwrap_or(10).min(100);
let next_token;
.lazy_load_confirm_delivery(sender_user, sender_device, room_id, from);
let mut resp = get_message_events::v3::Response::new();
let mut lazy_loaded = HashSet::new();
let next_token;
match body.dir {
ruma::api::Direction::Forward => {
let events_after: Vec<_> = services
let events_after: Vec<PdusIterItem> = services
.rooms
.timeline
.pdus_after(sender_user, &body.room_id, from)?
.filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)
})
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.pdus_after(sender_user, room_id, from)
.await?
.ready_filter_map(|item| contains_url_filter(item, filter))
.filter_map(|item| visibility_filter(&services, item, sender_user))
.ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to`
.take(limit)
.collect();
.collect()
.boxed()
.await;
for (_, event) in &events_after {
/* TODO: Remove the not "element_hacks" check when these are resolved:
@ -164,16 +164,18 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034
*/
if !cfg!(feature = "element_hacks")
&& !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user,
sender_device,
&body.room_id,
&event.sender,
)? {
&& !services
.rooms
.lazy_loading
.lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
.await
{
lazy_loaded.insert(event.sender.clone());
}
lazy_loaded.insert(event.sender.clone());
if cfg!(features = "element_hacks") {
lazy_loaded.insert(event.sender.clone());
}
}
next_token = events_after.last().map(|(count, _)| count).copied();
@ -191,17 +193,22 @@ pub(crate) async fn get_message_events_route(
services
.rooms
.timeline
.backfill_if_required(&body.room_id, from)
.backfill_if_required(room_id, from)
.boxed()
.await?;
let events_before: Vec<_> = services
let events_before: Vec<PdusIterItem> = services
.rooms
.timeline
.pdus_until(sender_user, &body.room_id, from)?
.filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)})
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.pdus_until(sender_user, room_id, from)
.await?
.ready_filter_map(|item| contains_url_filter(item, filter))
.filter_map(|item| visibility_filter(&services, item, sender_user))
.ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to`
.take(limit)
.collect();
.collect()
.boxed()
.await;
for (_, event) in &events_before {
/* TODO: Remove the not "element_hacks" check when these are resolved:
@ -209,16 +216,18 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034
*/
if !cfg!(feature = "element_hacks")
&& !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user,
sender_device,
&body.room_id,
&event.sender,
)? {
&& !services
.rooms
.lazy_loading
.lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
.await
{
lazy_loaded.insert(event.sender.clone());
}
lazy_loaded.insert(event.sender.clone());
if cfg!(features = "element_hacks") {
lazy_loaded.insert(event.sender.clone());
}
}
next_token = events_before.last().map(|(count, _)| count).copied();
@ -236,11 +245,11 @@ pub(crate) async fn get_message_events_route(
resp.state = Vec::new();
for ll_id in &lazy_loaded {
if let Some(member_event) =
services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())?
if let Ok(member_event) = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str())
.await
{
resp.state.push(member_event.to_state_event());
}
@ -249,34 +258,43 @@ pub(crate) async fn get_message_events_route(
// remove the feature check when we are sure clients like element can handle it
if !cfg!(feature = "element_hacks") {
if let Some(next_token) = next_token {
services
.rooms
.lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token)
.await;
services.rooms.lazy_loading.lazy_load_mark_sent(
sender_user,
sender_device,
room_id,
lazy_loaded,
next_token,
);
}
}
Ok(resp)
}
fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool {
async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option<PdusIterItem> {
let (_, pdu) = &item;
services
.rooms
.state_accessor
.user_can_see_event(user_id, room_id, &pdu.event_id)
.unwrap_or(false)
.user_can_see_event(user_id, &pdu.room_id, &pdu.event_id)
.await
.then_some(item)
}
fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool {
fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option<PdusIterItem> {
let (_, pdu) = &item;
if filter.url_filter.is_none() {
return true;
return Some(item);
}
let content: Value = from_str(pdu.content.get()).unwrap();
match filter.url_filter {
let res = match filter.url_filter {
Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(),
Some(UrlFilter::EventsWithUrl) => content["url"].is_string(),
None => true,
}
};
res.then_some(item)
}

View File

@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route(
services
.presence
.set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?;
.set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())
.await?;
Ok(set_presence::v3::Response {})
}
@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route(
let mut presence_event = None;
for _room_id in services
let has_shared_rooms = services
.rooms
.user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{
if let Some(presence) = services.presence.get_presence(&body.user_id)? {
.has_shared_rooms(sender_user, &body.user_id)
.await;
if has_shared_rooms {
if let Ok(presence) = services.presence.get_presence(&body.user_id).await {
presence_event = Some(presence);
break;
}
}

View File

@ -1,5 +1,10 @@
use axum::extract::State;
use conduit::{pdu::PduBuilder, warn, Err, Error, Result};
use conduit::{
pdu::PduBuilder,
utils::{stream::TryIgnore, IterStream},
warn, Err, Error, Result,
};
use futures::{StreamExt, TryStreamExt};
use ruma::{
api::{
client::{
@ -35,16 +40,18 @@ pub(crate) async fn set_displayname_route(
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?;
update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?;
if services.globals.allow_local_presence() {
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await?;
}
Ok(set_display_name::v3::Response {})
@ -72,22 +79,19 @@ pub(crate) async fn get_displayname_route(
)
.await
{
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?;
}
services
.users
.set_displayname(&body.user_id, response.displayname.clone())
.await?;
.set_displayname(&body.user_id, response.displayname.clone());
services
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
.set_avatar_url(&body.user_id, response.avatar_url.clone());
services
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
.set_blurhash(&body.user_id, response.blurhash.clone());
return Ok(get_display_name::v3::Response {
displayname: response.displayname,
@ -95,14 +99,14 @@ pub(crate) async fn get_displayname_route(
}
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over
// federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
}
Ok(get_display_name::v3::Response {
displayname: services.users.displayname(&body.user_id)?,
displayname: services.users.displayname(&body.user_id).await.ok(),
})
}
@ -124,15 +128,16 @@ pub(crate) async fn set_avatar_url_route(
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
update_avatar_url(
&services,
&body.user_id,
body.avatar_url.clone(),
body.blurhash.clone(),
all_joined_rooms,
&all_joined_rooms,
)
.await?;
@ -140,7 +145,9 @@ pub(crate) async fn set_avatar_url_route(
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await
.ok();
}
Ok(set_avatar_url::v3::Response {})
@ -168,22 +175,21 @@ pub(crate) async fn get_avatar_url_route(
)
.await
{
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?;
}
services
.users
.set_displayname(&body.user_id, response.displayname.clone())
.await?;
.set_displayname(&body.user_id, response.displayname.clone());
services
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
.set_avatar_url(&body.user_id, response.avatar_url.clone());
services
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
.set_blurhash(&body.user_id, response.blurhash.clone());
return Ok(get_avatar_url::v3::Response {
avatar_url: response.avatar_url,
@ -192,15 +198,15 @@ pub(crate) async fn get_avatar_url_route(
}
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over
// federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
}
Ok(get_avatar_url::v3::Response {
avatar_url: services.users.avatar_url(&body.user_id)?,
blurhash: services.users.blurhash(&body.user_id)?,
avatar_url: services.users.avatar_url(&body.user_id).await.ok(),
blurhash: services.users.blurhash(&body.user_id).await.ok(),
})
}
@ -226,31 +232,30 @@ pub(crate) async fn get_profile_route(
)
.await
{
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?;
}
services
.users
.set_displayname(&body.user_id, response.displayname.clone())
.await?;
.set_displayname(&body.user_id, response.displayname.clone());
services
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
.set_avatar_url(&body.user_id, response.avatar_url.clone());
services
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
.set_blurhash(&body.user_id, response.blurhash.clone());
services
.users
.set_timezone(&body.user_id, response.tz.clone())
.await?;
.set_timezone(&body.user_id, response.tz.clone());
for (profile_key, profile_key_value) in &response.custom_profile_fields {
services
.users
.set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?;
.set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()));
}
return Ok(get_profile::v3::Response {
@ -263,104 +268,93 @@ pub(crate) async fn get_profile_route(
}
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over
// federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
}
Ok(get_profile::v3::Response {
avatar_url: services.users.avatar_url(&body.user_id)?,
blurhash: services.users.blurhash(&body.user_id)?,
displayname: services.users.displayname(&body.user_id)?,
tz: services.users.timezone(&body.user_id)?,
avatar_url: services.users.avatar_url(&body.user_id).await.ok(),
blurhash: services.users.blurhash(&body.user_id).await.ok(),
displayname: services.users.displayname(&body.user_id).await.ok(),
tz: services.users.timezone(&body.user_id).await.ok(),
custom_profile_fields: services
.users
.all_profile_keys(&body.user_id)
.filter_map(Result::ok)
.collect(),
.collect()
.await,
})
}
pub async fn update_displayname(
services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>,
services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: &[OwnedRoomId],
) -> Result<()> {
let current_display_name = services.users.displayname(user_id).unwrap_or_default();
let current_display_name = services.users.displayname(user_id).await.ok();
if displayname == current_display_name {
return Ok(());
}
services
.users
.set_displayname(user_id, displayname.clone())
.await?;
services.users.set_displayname(user_id, displayname.clone());
// Send a new join membership event into all joined rooms
let all_joined_rooms: Vec<_> = all_joined_rooms
.iter()
.map(|room_id| {
Ok::<_, Error>((
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
displayname: displayname.clone(),
join_authorized_via_users_server: None,
..serde_json::from_str(
services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.ok_or_else(|| {
Error::bad_database("Tried to send display name update for user not in the room.")
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
},
room_id,
))
})
.filter_map(Result::ok)
.collect();
let mut joined_rooms = Vec::new();
for room_id in all_joined_rooms {
let Ok(event) = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
else {
continue;
};
update_all_rooms(services, all_joined_rooms, user_id).await;
let pdu = PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
displayname: displayname.clone(),
join_authorized_via_users_server: None,
..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.")
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
timestamp: None,
};
joined_rooms.push((pdu, room_id));
}
update_all_rooms(services, joined_rooms, user_id).await;
Ok(())
}
pub async fn update_avatar_url(
services: &Services, user_id: &UserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>,
all_joined_rooms: Vec<OwnedRoomId>,
all_joined_rooms: &[OwnedRoomId],
) -> Result<()> {
let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default();
let current_blurhash = services.users.blurhash(user_id).unwrap_or_default();
let current_avatar_url = services.users.avatar_url(user_id).await.ok();
let current_blurhash = services.users.blurhash(user_id).await.ok();
if current_avatar_url == avatar_url && current_blurhash == blurhash {
return Ok(());
}
services
.users
.set_avatar_url(user_id, avatar_url.clone())
.await?;
services
.users
.set_blurhash(user_id, blurhash.clone())
.await?;
services.users.set_avatar_url(user_id, avatar_url.clone());
services.users.set_blurhash(user_id, blurhash.clone());
// Send a new join membership event into all joined rooms
let avatar_url = &avatar_url;
let blurhash = &blurhash;
let all_joined_rooms: Vec<_> = all_joined_rooms
.iter()
.map(|room_id| {
Ok::<_, Error>((
.try_stream()
.and_then(|room_id: &OwnedRoomId| async move {
Ok((
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
@ -371,8 +365,9 @@ pub async fn update_avatar_url(
services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.ok_or_else(|| {
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
.map_err(|_| {
Error::bad_database("Tried to send avatar URL update for user not in the room.")
})?
.content
@ -389,8 +384,9 @@ pub async fn update_avatar_url(
room_id,
))
})
.filter_map(Result::ok)
.collect();
.ignore_err()
.collect()
.await;
update_all_rooms(services, all_joined_rooms, user_id).await;

View File

@ -29,41 +29,37 @@ pub(crate) async fn get_pushrules_all_route(
let global_ruleset: Ruleset;
let Ok(event) =
services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
else {
// push rules event doesn't exist, create it and return default
return recreate_push_rules_and_return(&services, sender_user);
};
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await;
if let Some(event) = event {
let value = serde_json::from_str::<CanonicalJsonObject>(event.get())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
let Some(content_value) = value.get("content") else {
// user somehow has a push rule event with no content key, recreate it and
// return server default silently
return recreate_push_rules_and_return(&services, sender_user);
};
if content_value.to_string().is_empty() {
// user somehow has a push rule event with empty content, recreate it and return
// server default silently
return recreate_push_rules_and_return(&services, sender_user);
}
let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.clone().into())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
global_ruleset = account_data_content.global;
} else {
let Ok(event) = event else {
// user somehow has non-existent push rule event. recreate it and return server
// default silently
return recreate_push_rules_and_return(&services, sender_user);
return recreate_push_rules_and_return(&services, sender_user).await;
};
let value = serde_json::from_str::<CanonicalJsonObject>(event.get())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
let Some(content_value) = value.get("content") else {
// user somehow has a push rule event with no content key, recreate it and
// return server default silently
return recreate_push_rules_and_return(&services, sender_user).await;
};
if content_value.to_string().is_empty() {
// user somehow has a push rule event with empty content, recreate it and return
// server default silently
return recreate_push_rules_and_return(&services, sender_user).await;
}
let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.clone().into())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
global_ruleset = account_data_content.global;
Ok(get_pushrules_all::v3::Response {
global: global_ruleset,
})
@ -79,8 +75,9 @@ pub(crate) async fn get_pushrule_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -118,8 +115,9 @@ pub(crate) async fn set_pushrule_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -155,12 +153,15 @@ pub(crate) async fn set_pushrule_route(
return Err(err);
}
services.account_data.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule::v3::Response {})
}
@ -182,8 +183,9 @@ pub(crate) async fn get_pushrule_actions_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -217,8 +219,9 @@ pub(crate) async fn set_pushrule_actions_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -232,12 +235,15 @@ pub(crate) async fn set_pushrule_actions_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
}
services.account_data.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule_actions::v3::Response {})
}
@ -259,8 +265,9 @@ pub(crate) async fn get_pushrule_enabled_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -293,8 +300,9 @@ pub(crate) async fn set_pushrule_enabled_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -308,12 +316,15 @@ pub(crate) async fn set_pushrule_enabled_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
}
services.account_data.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(set_pushrule_enabled::v3::Response {})
}
@ -335,8 +346,9 @@ pub(crate) async fn delete_pushrule_route(
let event = services
.account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -357,12 +369,15 @@ pub(crate) async fn delete_pushrule_route(
return Err(err);
}
services.account_data.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)?;
services
.account_data
.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
Ok(delete_pushrule::v3::Response {})
}
@ -376,7 +391,7 @@ pub(crate) async fn get_pushers_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_pushers::v3::Response {
pushers: services.pusher.get_pushers(sender_user)?,
pushers: services.pusher.get_pushers(sender_user).await,
})
}
@ -390,27 +405,30 @@ pub(crate) async fn set_pushers_route(
) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services.pusher.set_pusher(sender_user, &body.action)?;
services.pusher.set_pusher(sender_user, &body.action);
Ok(set_pusher::v3::Response::default())
}
/// user somehow has bad push rules, these must always exist per spec.
/// so recreate it and return server default silently
fn recreate_push_rules_and_return(
async fn recreate_push_rules_and_return(
services: &Services, sender_user: &ruma::UserId,
) -> Result<get_pushrules_all::v3::Response> {
services.account_data.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent {
content: PushRulesEventContent {
global: Ruleset::server_default(sender_user),
},
})
.expect("to json always works"),
)?;
services
.account_data
.update(
None,
sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(PushRulesEvent {
content: PushRulesEventContent {
global: Ruleset::server_default(sender_user),
},
})
.expect("to json always works"),
)
.await?;
Ok(get_pushrules_all::v3::Response {
global: Ruleset::server_default(sender_user),

View File

@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route(
event_id: fully_read.clone(),
},
};
services.account_data.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)?;
services
.account_data
.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)
.await?;
}
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services
.rooms
.user
.reset_notification_counts(sender_user, &body.room_id)?;
.reset_notification_counts(sender_user, &body.room_id);
}
if let Some(event) = &body.private_read_receipt {
let count = services
.rooms
.timeline
.get_pdu_count(event)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
.get_pdu_count(event)
.await
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
let count = match count {
PduCount::Backfilled(_) => {
return Err(Error::BadRequest(
@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route(
services
.rooms
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
.private_read_set(&body.room_id, sender_user, count);
}
if let Some(event) = &body.read_receipt {
@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route(
let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts);
services.rooms.read_receipt.readreceipt_update(
sender_user,
&body.room_id,
&ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)?;
services
.rooms
.read_receipt
.readreceipt_update(
sender_user,
&body.room_id,
&ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)
.await;
}
Ok(set_read_marker::v3::Response {})
@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route(
services
.rooms
.user
.reset_notification_counts(sender_user, &body.room_id)?;
.reset_notification_counts(sender_user, &body.room_id);
}
match body.receipt_type {
@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route(
event_id: body.event_id.clone(),
},
};
services.account_data.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)?;
services
.account_data
.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"),
)
.await?;
},
create_receipt::v3::ReceiptType::Read => {
let mut user_receipts = BTreeMap::new();
@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route(
let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.clone(), receipts);
services.rooms.read_receipt.readreceipt_update(
sender_user,
&body.room_id,
&ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)?;
services
.rooms
.read_receipt
.readreceipt_update(
sender_user,
&body.room_id,
&ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(),
},
)
.await;
},
create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services
.rooms
.timeline
.get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
.get_pdu_count(&body.event_id)
.await
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
let count = match count {
PduCount::Backfilled(_) => {
return Err(Error::BadRequest(
@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route(
services
.rooms
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
.private_read_set(&body.room_id, sender_user, count);
},
_ => return Err(Error::bad_database("Unsupported receipt type")),
}

View File

@ -9,20 +9,24 @@ use crate::{Result, Ruma};
pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let res = services.rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
&Some(body.event_type.clone()),
&Some(body.rel_type.clone()),
&body.from,
&body.to,
&body.limit,
body.recurse,
body.dir,
)?;
let res = services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
body.event_type.clone().into(),
body.rel_type.clone().into(),
body.from.as_ref(),
body.to.as_ref(),
body.limit,
body.recurse,
body.dir,
)
.await?;
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
chunk: res.chunk,
@ -36,20 +40,24 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
pub(crate) async fn get_relating_events_with_rel_type_route(
State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let res = services.rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
&None,
&Some(body.rel_type.clone()),
&body.from,
&body.to,
&body.limit,
body.recurse,
body.dir,
)?;
let res = services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
body.rel_type.clone().into(),
body.from.as_ref(),
body.to.as_ref(),
body.limit,
body.recurse,
body.dir,
)
.await?;
Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk,
@ -63,18 +71,22 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
pub(crate) async fn get_relating_events_route(
State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>,
) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_user = body.sender_user.as_deref().expect("user is authenticated");
services.rooms.pdu_metadata.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
&None,
&None,
&body.from,
&body.to,
&body.limit,
body.recurse,
body.dir,
)
services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user,
&body.room_id,
&body.event_id,
None,
None,
body.from.as_ref(),
body.to.as_ref(),
body.limit,
body.recurse,
body.dir,
)
.await
}

View File

@ -1,6 +1,7 @@
use std::time::Duration;
use axum::extract::State;
use conduit::{utils::ReadyExt, Err};
use rand::Rng;
use ruma::{
api::client::{error::ErrorKind, room::report_content},
@ -34,11 +35,8 @@ pub(crate) async fn report_event_route(
delay_response().await;
// check if we know about the reported event ID or if it's invalid
let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else {
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Event ID is not known to us or Event ID is invalid",
));
let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else {
return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid")));
};
is_report_valid(
@ -49,7 +47,8 @@ pub(crate) async fn report_event_route(
&body.reason,
body.score,
&pdu,
)?;
)
.await?;
// send admin room message that we received the report with an @room ping for
// urgency
@ -81,7 +80,8 @@ pub(crate) async fn report_event_route(
HtmlEscape(body.reason.as_deref().unwrap_or(""))
),
))
.await;
.await
.ok();
Ok(report_content::v3::Response {})
}
@ -92,7 +92,7 @@ pub(crate) async fn report_event_route(
/// check if score is in valid range
/// check if report reasoning is less than or equal to 750 characters
/// check if reporting user is in the reporting room
fn is_report_valid(
async fn is_report_valid(
services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>,
score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>,
) -> Result<()> {
@ -123,8 +123,8 @@ fn is_report_valid(
.rooms
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.any(|user_id| user_id == *sender_user)
.ready_any(|user_id| user_id == sender_user)
.await
{
return Err(Error::BadRequest(
ErrorKind::NotFound,

View File

@ -2,6 +2,7 @@ use std::{cmp::max, collections::BTreeMap};
use axum::extract::State;
use conduit::{debug_info, debug_warn, err, Err};
use futures::{FutureExt, StreamExt};
use ruma::{
api::client::{
error::ErrorKind,
@ -74,7 +75,7 @@ pub(crate) async fn create_room_route(
if !services.globals.allow_room_creation()
&& body.appservice_info.is_none()
&& !services.users.is_admin(sender_user)?
&& !services.users.is_admin(sender_user).await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled."));
}
@ -86,7 +87,7 @@ pub(crate) async fn create_room_route(
};
// check if room ID doesn't already exist instead of erroring on auth check
if services.rooms.short.get_shortroomid(&room_id)?.is_some() {
if services.rooms.short.get_shortroomid(&room_id).await.is_ok() {
return Err(Error::BadRequest(
ErrorKind::RoomInUse,
"Room with that custom room ID already exists",
@ -95,7 +96,7 @@ pub(crate) async fn create_room_route(
if body.visibility == room::Visibility::Public
&& services.globals.config.lockdown_public_room_directory
&& !services.users.is_admin(sender_user)?
&& !services.users.is_admin(sender_user).await
&& body.appservice_info.is_none()
{
info!(
@ -118,7 +119,11 @@ pub(crate) async fn create_room_route(
return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed")));
}
let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?;
let _short_id = services
.rooms
.short
.get_or_create_shortroomid(&room_id)
.await;
let state_lock = services.rooms.state.mutex.lock(&room_id).await;
let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name {
@ -218,6 +223,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 2. Let the room creator join
@ -229,11 +235,11 @@ pub(crate) async fn create_room_route(
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?,
avatar_url: services.users.avatar_url(sender_user)?,
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: Some(body.is_direct),
third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?,
blurhash: services.users.blurhash(sender_user).await.ok(),
reason: None,
join_authorized_via_users_server: None,
})
@ -247,6 +253,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 3. Power levels
@ -284,6 +291,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 4. Canonical room alias
@ -308,6 +316,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
}
@ -335,6 +344,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 5.2 History Visibility
@ -355,6 +365,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 5.3 Guest Access
@ -378,6 +389,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
// 6. Events listed in initial_state
@ -410,6 +422,7 @@ pub(crate) async fn create_room_route(
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.boxed()
.await?;
}
@ -432,6 +445,7 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
}
@ -455,13 +469,17 @@ pub(crate) async fn create_room_route(
&room_id,
&state_lock,
)
.boxed()
.await?;
}
// 8. Events implied by invite (and TODO: invite_3pid)
drop(state_lock);
for user_id in &body.invite {
if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await {
if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct)
.boxed()
.await
{
warn!(%e, "Failed to send invite");
}
}
@ -475,7 +493,7 @@ pub(crate) async fn create_room_route(
}
if body.visibility == room::Visibility::Public {
services.rooms.directory.set_public(&room_id)?;
services.rooms.directory.set_public(&room_id);
if services.globals.config.admin_room_notices {
services
@ -505,13 +523,15 @@ pub(crate) async fn get_room_event_route(
let event = services
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?;
.get_pdu(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?;
if !services
.rooms
.state_accessor
.user_can_see_event(sender_user, &event.room_id, &body.event_id)?
.user_can_see_event(sender_user, &event.room_id, &body.event_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -541,7 +561,8 @@ pub(crate) async fn get_room_aliases_route(
if !services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
.user_can_see_state_events(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -554,8 +575,9 @@ pub(crate) async fn get_room_aliases_route(
.rooms
.alias
.local_aliases_for_room(&body.room_id)
.filter_map(Result::ok)
.collect(),
.map(ToOwned::to_owned)
.collect()
.await,
})
}
@ -591,7 +613,8 @@ pub(crate) async fn upgrade_room_route(
let _short_id = services
.rooms
.short
.get_or_create_shortroomid(&replacement_room)?;
.get_or_create_shortroomid(&replacement_room)
.await;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
@ -629,12 +652,12 @@ pub(crate) async fn upgrade_room_route(
services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")
.await
.map_err(|_| err!(Database("Found room without m.room.create event.")))?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
)?;
// Use the m.room.tombstone event as the predecessor
let predecessor = Some(ruma::events::room::create::PreviousRoom::new(
@ -714,11 +737,11 @@ pub(crate) async fn upgrade_room_route(
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?,
avatar_url: services.users.avatar_url(sender_user)?,
displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None,
third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?,
blurhash: services.users.blurhash(sender_user).await.ok(),
reason: None,
join_authorized_via_users_server: None,
})
@ -739,10 +762,11 @@ pub(crate) async fn upgrade_room_route(
let event_content = match services
.rooms
.state_accessor
.room_state_get(&body.room_id, event_type, "")?
.room_state_get(&body.room_id, event_type, "")
.await
{
Some(v) => v.content.clone(),
None => continue, // Skipping missing events.
Ok(v) => v.content.clone(),
Err(_) => continue, // Skipping missing events.
};
services
@ -765,21 +789,23 @@ pub(crate) async fn upgrade_room_route(
}
// Moves any local aliases to the new room
for alias in services
let mut local_aliases = services
.rooms
.alias
.local_aliases_for_room(&body.room_id)
.filter_map(Result::ok)
{
.boxed();
while let Some(alias) = local_aliases.next().await {
services
.rooms
.alias
.remove_alias(&alias, sender_user)
.remove_alias(alias, sender_user)
.await?;
services
.rooms
.alias
.set_alias(&alias, &replacement_room, sender_user)?;
.set_alias(alias, &replacement_room, sender_user)?;
}
// Get the old room power levels
@ -787,12 +813,12 @@ pub(crate) async fn upgrade_room_route(
services
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")
.await
.map_err(|_| err!(Database("Found room without m.room.create event.")))?
.content
.get(),
)
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
)?;
// Setting events_default and invite to the greater of 50 and users_default + 1
let new_level = max(
@ -800,9 +826,7 @@ pub(crate) async fn upgrade_room_route(
power_levels_event_content
.users_default
.checked_add(int!(1))
.ok_or_else(|| {
Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid")
})?,
.ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?,
);
power_levels_event_content.events_default = new_level;
power_levels_event_content.invite = new_level;
@ -921,8 +945,9 @@ async fn room_alias_check(
if services
.rooms
.alias
.resolve_local_alias(&full_room_alias)?
.is_some()
.resolve_local_alias(&full_room_alias)
.await
.is_ok()
{
return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists."));
}

View File

@ -1,6 +1,12 @@
use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{
debug,
utils::{IterStream, ReadyExt},
Err,
};
use futures::{FutureExt, StreamExt};
use ruma::{
api::client::{
error::ErrorKind,
@ -13,7 +19,6 @@ use ruma::{
serde::Raw,
uint, OwnedRoomId,
};
use tracing::debug;
use crate::{Error, Result, Ruma};
@ -32,14 +37,17 @@ pub(crate) async fn search_events_route(
let filter = &search_criteria.filter;
let include_state = &search_criteria.include_state;
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
let room_ids = if let Some(room_ids) = &filter.rooms {
room_ids.clone()
} else {
services
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
.map(ToOwned::to_owned)
.collect()
});
.await
};
// Use limit or else 10, with maximum 100
let limit: usize = filter
@ -53,18 +61,21 @@ pub(crate) async fn search_events_route(
if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids {
if !services.rooms.state_cache.is_joined(sender_user, room_id)? {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err!(Request(Forbidden("You don't have permission to view this room.")));
}
// check if sender_user can see state events
if services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, room_id)?
.user_can_see_state_events(sender_user, room_id)
.await
{
let room_state = services
.rooms
@ -87,10 +98,15 @@ pub(crate) async fn search_events_route(
}
}
let mut searches = Vec::new();
let mut search_vecs = Vec::new();
for room_id in &room_ids {
if !services.rooms.state_cache.is_joined(sender_user, room_id)? {
if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
@ -100,12 +116,18 @@ pub(crate) async fn search_events_route(
if let Some(search) = services
.rooms
.search
.search_pdus(room_id, &search_criteria.search_term)?
.search_pdus(room_id, &search_criteria.search_term)
.await
{
searches.push(search.0.peekable());
search_vecs.push(search.0);
}
}
let mut searches: Vec<_> = search_vecs
.iter()
.map(|vec| vec.iter().peekable())
.collect();
let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) {
Some(Ok(s)) => s,
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
@ -118,8 +140,8 @@ pub(crate) async fn search_events_route(
for _ in 0..next_batch {
if let Some(s) = searches
.iter_mut()
.map(|s| (s.peek().cloned(), s))
.max_by_key(|(peek, _)| peek.clone())
.map(|s| (s.peek().copied(), s))
.max_by_key(|(peek, _)| *peek)
.and_then(|(_, i)| i.next())
{
results.push(s);
@ -127,42 +149,38 @@ pub(crate) async fn search_events_route(
}
let results: Vec<_> = results
.iter()
.into_iter()
.skip(skip)
.filter_map(|result| {
.stream()
.filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok))
.ready_filter(|pdu| !pdu.is_redacted())
.filter_map(|pdu| async move {
services
.rooms
.timeline
.get_pdu_from_id(result)
.ok()?
.filter(|pdu| {
!pdu.is_redacted()
&& services
.rooms
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.unwrap_or(false)
})
.map(|pdu| pdu.to_room_event())
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.await
.then_some(pdu)
})
.map(|result| {
Ok::<_, Error>(SearchResult {
context: EventContextResult {
end: None,
events_after: Vec::new(),
events_before: Vec::new(),
profile_info: BTreeMap::new(),
start: None,
},
rank: None,
result: Some(result),
})
})
.filter_map(Result::ok)
.take(limit)
.collect();
.map(|pdu| pdu.to_room_event())
.map(|result| SearchResult {
context: EventContextResult {
end: None,
events_after: Vec::new(),
events_before: Vec::new(),
profile_info: BTreeMap::new(),
start: None,
},
rank: None,
result: Some(result),
})
.collect()
.boxed()
.await;
let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some());
let next_batch = more_unloaded_results.then(|| next_batch.to_string());
Ok(search_events::v3::Response::new(ResultCategories {

View File

@ -1,5 +1,7 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{debug, err, info, utils::ReadyExt, warn, Err};
use futures::StreamExt;
use ruma::{
api::client::{
error::ErrorKind,
@ -19,7 +21,6 @@ use ruma::{
UserId,
};
use serde::Deserialize;
use tracing::{debug, info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{utils, utils::hash, Error, Result, Ruma};
@ -79,21 +80,22 @@ pub(crate) async fn login_route(
UserId::parse(user)
} else {
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type."));
return Err!(Request(Forbidden("Bad login type.")));
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
let hash = services
.users
.password_hash(&user_id)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?;
.password_hash(&user_id)
.await
.map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?;
if hash.is_empty() {
return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated"));
return Err!(Request(UserDeactivated("The user has been deactivated")));
}
if hash::verify_password(password, &hash).is_err() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."));
return Err!(Request(Forbidden("Wrong username or password.")));
}
user_id
@ -112,15 +114,12 @@ pub(crate) async fn login_route(
let username = token.claims.sub.to_lowercase();
UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| {
warn!("Failed to parse username from user logging in: {e}");
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?
UserId::parse_with_server_name(username, services.globals.server_name())
.map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))?
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).",
));
return Err!(Request(Unknown(
"Token login is not supported (server has no jwt decoding key)."
)));
}
},
#[allow(deprecated)]
@ -169,23 +168,32 @@ pub(crate) async fn login_route(
let token = utils::random_string(TOKEN_LENGTH);
// Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
let device_exists = if body.device_id.is_some() {
services
.users
.all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
});
.ready_any(|v| v == device_id)
.await
} else {
false
};
if device_exists {
services.users.set_token(&user_id, &device_id, &token)?;
services
.users
.set_token(&user_id, &device_id, &token)
.await?;
} else {
services.users.create_device(
&user_id,
&device_id,
&token,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)?;
services
.users
.create_device(
&user_id,
&device_id,
&token,
body.initial_device_display_name.clone(),
Some(client.to_string()),
)
.await?;
}
// send client well-known if specified so the client knows to reconfigure itself
@ -228,10 +236,13 @@ pub(crate) async fn logout_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
services.users.remove_device(sender_user, sender_device)?;
services
.users
.remove_device(sender_user, sender_device)
.await;
// send device list update for user after logout
services.users.mark_device_key_update(sender_user)?;
services.users.mark_device_key_update(sender_user).await;
Ok(logout::v3::Response::new())
}
@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route(
) -> Result<logout_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for device_id in services.users.all_device_ids(sender_user).flatten() {
services.users.remove_device(sender_user, &device_id)?;
}
services
.users
.all_device_ids(sender_user)
.for_each(|device_id| services.users.remove_device(sender_user, device_id))
.await;
// send device list update for user after logout
services.users.mark_device_key_update(sender_user)?;
services.users.mark_device_key_update(sender_user).await;
Ok(logout_all::v3::Response::new())
}

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use axum::extract::State;
use conduit::{debug_info, error, pdu::PduBuilder, Error, Result};
use conduit::{err, error, pdu::PduBuilder, Err, Error, Result};
use ruma::{
api::client::{
error::ErrorKind,
@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route(
if !services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
.user_can_see_state_events(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view the room state.",
));
return Err!(Request(Forbidden("You don't have permission to view the room state.")));
}
Ok(get_state_events::v3::Response {
@ -120,22 +118,25 @@ pub(crate) async fn get_state_events_for_key_route(
if !services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
.user_can_see_state_events(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view the room state.",
));
return Err!(Request(Forbidden("You don't have permission to view the room state.")));
}
let event = services
.rooms
.state_accessor
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
.ok_or_else(|| {
debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
.room_state_get(&body.room_id, &body.event_type, &body.state_key)
.await
.map_err(|_| {
err!(Request(NotFound(error!(
room_id = ?body.room_id,
event_type = ?body.event_type,
"State event not found in room.",
))))
})?;
if body
.format
.as_ref()
@ -204,7 +205,7 @@ async fn send_state_event_for_key_helper(
async fn allowed_to_send_state_event(
services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>,
) -> Result<()> {
) -> Result {
match event_type {
// Forbid m.room.encryption if encryption is disabled
StateEventType::RoomEncryption => {
@ -214,7 +215,7 @@ async fn allowed_to_send_state_event(
},
// admin room is a sensitive room, it should not ever be made public
StateEventType::RoomJoinRules => {
if let Some(admin_room_id) = services.admin.get_admin_room()? {
if let Ok(admin_room_id) = services.admin.get_admin_room().await {
if admin_room_id == room_id {
if let Ok(join_rule) = serde_json::from_str::<RoomJoinRulesEventContent>(json.json().get()) {
if join_rule.join_rule == JoinRule::Public {
@ -229,7 +230,7 @@ async fn allowed_to_send_state_event(
},
// admin room is a sensitive room, it should not ever be made world readable
StateEventType::RoomHistoryVisibility => {
if let Some(admin_room_id) = services.admin.get_admin_room()? {
if let Ok(admin_room_id) = services.admin.get_admin_room().await {
if admin_room_id == room_id {
if let Ok(visibility_content) =
serde_json::from_str::<RoomHistoryVisibilityEventContent>(json.json().get())
@ -254,23 +255,27 @@ async fn allowed_to_send_state_event(
}
for alias in aliases {
if !services.globals.server_is_ours(alias.server_name())
|| services
.rooms
.alias
.resolve_local_alias(&alias)?
.filter(|room| room == room_id) // Make sure it's the right room
.is_none()
if !services.globals.server_is_ours(alias.server_name()) {
return Err!(Request(Forbidden("canonical_alias must be for this server")));
}
if !services
.rooms
.alias
.resolve_local_alias(&alias)
.await
.is_ok_and(|room| room == room_id)
// Make sure it's the right room
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You are only allowed to send canonical_alias events when its aliases already exist",
));
return Err!(Request(Forbidden(
"You are only allowed to send canonical_alias events when its aliases already exist"
)));
}
}
}
},
_ => (),
}
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@ -23,10 +23,11 @@ pub(crate) async fn update_tag_route(
let event = services
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else(
|| {
|_| {
Ok(TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
@ -41,12 +42,15 @@ pub(crate) async fn update_tag_route(
.tags
.insert(body.tag.clone().into(), body.tag_info.clone());
services.account_data.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)?;
services
.account_data
.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(create_tag::v3::Response {})
}
@ -63,10 +67,11 @@ pub(crate) async fn delete_tag_route(
let event = services
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else(
|| {
|_| {
Ok(TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),
@ -78,12 +83,15 @@ pub(crate) async fn delete_tag_route(
tags_event.content.tags.remove(&body.tag.clone().into());
services.account_data.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)?;
services
.account_data
.update(
Some(&body.room_id),
sender_user,
RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"),
)
.await?;
Ok(delete_tag::v3::Response {})
}
@ -100,10 +108,11 @@ pub(crate) async fn get_tags_route(
let event = services
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let tags_event = event.map_or_else(
|| {
|_| {
Ok(TagEvent {
content: TagEventContent {
tags: BTreeMap::new(),

View File

@ -1,4 +1,6 @@
use axum::extract::State;
use conduit::PduEvent;
use futures::StreamExt;
use ruma::{
api::client::{error::ErrorKind, threads::get_threads},
uint,
@ -27,20 +29,23 @@ pub(crate) async fn get_threads_route(
u64::MAX
};
let threads = services
let room_id = &body.room_id;
let threads: Vec<(u64, PduEvent)> = services
.rooms
.threads
.threads_until(sender_user, &body.room_id, from, &body.include)?
.threads_until(sender_user, &body.room_id, from, &body.include)
.await?
.take(limit)
.filter_map(Result::ok)
.filter(|(_, pdu)| {
.filter_map(|(count, pdu)| async move {
services
.rooms
.state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id)
.unwrap_or(false)
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.await
.then_some((count, pdu))
})
.collect::<Vec<_>>();
.collect()
.await;
let next_batch = threads.last().map(|(count, _)| count.to_string());

View File

@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{Error, Result};
use futures::StreamExt;
use ruma::{
api::{
client::{error::ErrorKind, to_device::send_event_to_device},
@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route(
// Check if this is a new transaction id
if services
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)?
.is_some()
.existing_txnid(sender_user, sender_device, &body.txn_id)
.await
.is_ok()
{
return Ok(send_event_to_device::v3::Response {});
}
@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route(
continue;
}
let event_type = &body.event_type.to_string();
let event = event
.deserialize_as()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?;
match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services.users.add_to_device_event(
sender_user,
target_user_id,
target_device_id,
&body.event_type.to_string(),
event
.deserialize_as()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?;
services
.users
.add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event)
.await;
},
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services.users.all_device_ids(target_user_id) {
services.users.add_to_device_event(
sender_user,
target_user_id,
&target_device_id?,
&body.event_type.to_string(),
event
.deserialize_as()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?;
}
let (event_type, event) = (&event_type, &event);
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event(
sender_user,
target_user_id,
target_device_id,
event_type,
event.clone(),
)
})
.await;
},
}
}
@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route(
// Save transaction id with empty data
services
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
.add_txnid(sender_user, sender_device, &body.txn_id, &[]);
Ok(send_event_to_device::v3::Response {})
}

View File

@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route(
if !services
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
.is_joined(sender_user, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room."));
}

View File

@ -2,7 +2,8 @@ use std::collections::BTreeMap;
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{warn, Err};
use conduit::Err;
use futures::StreamExt;
use ruma::{
api::{
client::{
@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route(
));
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
return Ok(mutual_rooms::unstable::Response {
joined: vec![],
next_batch_token: None,
@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route(
let mutual_rooms: Vec<OwnedRoomId> = services
.rooms
.user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
.filter_map(Result::ok)
.collect();
.get_shared_rooms(sender_user, &body.user_id)
.map(ToOwned::to_owned)
.collect()
.await;
Ok(mutual_rooms::unstable::Response {
joined: mutual_rooms,
@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary(
let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?;
if !services.rooms.metadata.exists(&room_id)? {
if !services.rooms.metadata.exists(&room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"));
}
@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary(
.rooms
.state_accessor
.is_world_readable(&room_id)
.unwrap_or(false)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary(
.rooms
.state_accessor
.get_canonical_alias(&room_id)
.unwrap_or(None),
.await
.ok(),
avatar_url: services
.rooms
.state_accessor
.get_avatar(&room_id)?
.get_avatar(&room_id)
.await
.into_option()
.unwrap_or_default()
.url,
guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?,
name: services
.rooms
.state_accessor
.get_name(&room_id)
.unwrap_or(None),
guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await,
name: services.rooms.state_accessor.get_name(&room_id).await.ok(),
num_joined_members: services
.rooms
.state_cache
.room_joined_count(&room_id)
.unwrap_or_default()
.unwrap_or_else(|| {
warn!("Room {room_id} has no member count");
0
})
.try_into()
.expect("user count should not be that big"),
.await
.unwrap_or(0)
.try_into()?,
topic: services
.rooms
.state_accessor
.get_room_topic(&room_id)
.unwrap_or(None),
.await
.ok(),
world_readable: services
.rooms
.state_accessor
.is_world_readable(&room_id)
.unwrap_or(false),
join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0,
room_type: services.rooms.state_accessor.get_room_type(&room_id)?,
room_version: Some(services.rooms.state.get_room_version(&room_id)?),
.await,
join_rule: services
.rooms
.state_accessor
.get_join_rule(&room_id)
.await
.unwrap_or_default()
.0,
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)
.await
.ok(),
room_version: services.rooms.state.get_room_version(&room_id).await.ok(),
membership: if let Some(sender_user) = sender_user {
services
.rooms
.state_accessor
.get_member(&room_id, sender_user)?
.map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership))
.get_member(&room_id, sender_user)
.await
.map_or_else(|_| MembershipState::Leave, |content| content.membership)
.into()
} else {
None
},
@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary(
.rooms
.state_accessor
.get_room_encryption(&room_id)
.unwrap_or_else(|_e| None),
.await
.ok(),
})
}
@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route(
return Err!(Request(Forbidden("You cannot update the profile of another user")));
}
services.users.set_timezone(&body.user_id, None).await?;
services.users.set_timezone(&body.user_id, None);
if services.globals.allow_local_presence() {
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await?;
}
Ok(delete_timezone_key::unstable::Response {})
@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route(
return Err!(Request(Forbidden("You cannot update the profile of another user")));
}
services
.users
.set_timezone(&body.user_id, body.tz.clone())
.await?;
services.users.set_timezone(&body.user_id, body.tz.clone());
if services.globals.allow_local_presence() {
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await?;
}
Ok(set_timezone_key::unstable::Response {})
@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route(
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?;
update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?;
} else if body.key == "avatar_url" {
let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string());
@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route(
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?;
update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?;
} else {
services
.users
.set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?;
.set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()));
}
if services.globals.allow_local_presence() {
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await?;
}
Ok(set_profile_key::unstable::Response {})
@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route(
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
update_displayname(&services, &body.user_id, None, all_joined_rooms).await?;
update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?;
} else if body.key == "avatar_url" {
let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms
.state_cache
.rooms_joined(&body.user_id)
.filter_map(Result::ok)
.collect();
.map(Into::into)
.collect()
.await;
update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?;
update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?;
} else {
services
.users
.set_profile_key(&body.user_id, &body.key, None)?;
.set_profile_key(&body.user_id, &body.key, None);
}
if services.globals.allow_local_presence() {
// Presence update
services
.presence
.ping_presence(&body.user_id, &PresenceState::Online)?;
.ping_presence(&body.user_id, &PresenceState::Online)
.await?;
}
Ok(delete_profile_key::unstable::Response {})
@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route(
)
.await
{
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?;
}
services
.users
.set_displayname(&body.user_id, response.displayname.clone())
.await?;
.set_displayname(&body.user_id, response.displayname.clone());
services
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
.set_avatar_url(&body.user_id, response.avatar_url.clone());
services
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
.set_blurhash(&body.user_id, response.blurhash.clone());
services
.users
.set_timezone(&body.user_id, response.tz.clone())
.await?;
.set_timezone(&body.user_id, response.tz.clone());
return Ok(get_timezone_key::unstable::Response {
tz: response.tz,
@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route(
}
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over
// federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
}
Ok(get_timezone_key::unstable::Response {
tz: services.users.timezone(&body.user_id)?,
tz: services.users.timezone(&body.user_id).await.ok(),
})
}
@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route(
)
.await
{
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?;
}
services
.users
.set_displayname(&body.user_id, response.displayname.clone())
.await?;
.set_displayname(&body.user_id, response.displayname.clone());
services
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
.set_avatar_url(&body.user_id, response.avatar_url.clone());
services
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
.set_blurhash(&body.user_id, response.blurhash.clone());
services
.users
.set_timezone(&body.user_id, response.tz.clone())
.await?;
.set_timezone(&body.user_id, response.tz.clone());
if let Some(value) = response.custom_profile_fields.get(&body.key) {
profile_key_value.insert(body.key.clone(), value.clone());
services
.users
.set_profile_key(&body.user_id, &body.key, Some(value.clone()))?;
.set_profile_key(&body.user_id, &body.key, Some(value.clone()));
} else {
return Err!(Request(NotFound("The requested profile key does not exist.")));
}
@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route(
}
}
if !services.users.exists(&body.user_id)? {
if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over
// federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
return Err!(Request(NotFound("Profile was not found.")));
}
if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? {
if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await {
profile_key_value.insert(body.key.clone(), value);
} else {
return Err!(Request(NotFound("The requested profile key does not exist.")));

View File

@ -1,6 +1,7 @@
use std::collections::BTreeMap;
use axum::{extract::State, response::IntoResponse, Json};
use futures::StreamExt;
use ruma::api::client::{
discovery::{
discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo},
@ -173,7 +174,7 @@ pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> {
/// homeserver. Endpoint is disabled if federation is disabled for privacy. This
/// only includes active users (not deactivated, no guests, etc)
pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> {
let user_count = services.users.list_local_users()?.len();
let user_count = services.users.list_local_users().count().await;
Ok(Json(serde_json::json!({
"count": user_count

View File

@ -1,4 +1,5 @@
use axum::extract::State;
use futures::{pin_mut, StreamExt};
use ruma::{
api::client::user_directory::search_users,
events::{
@ -21,14 +22,12 @@ pub(crate) async fn search_users_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10
let mut users = services.users.iter().filter_map(|user_id| {
let users = services.users.stream().filter_map(|user_id| async {
// Filter out buggy users (they should not exist, but you never know...)
let user_id = user_id.ok()?;
let user = search_users::v3::User {
user_id: user_id.clone(),
display_name: services.users.displayname(&user_id).ok()?,
avatar_url: services.users.avatar_url(&user_id).ok()?,
user_id: user_id.to_owned(),
display_name: services.users.displayname(user_id).await.ok(),
avatar_url: services.users.avatar_url(user_id).await.ok(),
};
let user_id_matches = user
@ -56,20 +55,19 @@ pub(crate) async fn search_users_route(
let user_is_in_public_rooms = services
.rooms
.state_cache
.rooms_joined(&user_id)
.filter_map(Result::ok)
.any(|room| {
.rooms_joined(&user.user_id)
.any(|room| async move {
services
.rooms
.state_accessor
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
.room_state_get(room, &StateEventType::RoomJoinRules, "")
.await
.map_or(false, |event| {
event.map_or(false, |event| {
serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
})
serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
})
});
})
.await;
if user_is_in_public_rooms {
user_visible = true;
@ -77,25 +75,22 @@ pub(crate) async fn search_users_route(
let user_is_in_shared_rooms = services
.rooms
.user
.get_shared_rooms(vec![sender_user.clone(), user_id])
.ok()?
.next()
.is_some();
.has_shared_rooms(sender_user, &user.user_id)
.await;
if user_is_in_shared_rooms {
user_visible = true;
}
}
if !user_visible {
return None;
}
Some(user)
user_visible.then_some(user)
});
let results = users.by_ref().take(limit).collect();
let limited = users.next().is_some();
pin_mut!(users);
let limited = users.by_ref().next().await.is_some();
let results = users.take(limit).collect().await;
Ok(search_users::v3::Response {
results,

View File

@ -22,101 +22,101 @@ use crate::{client, server};
pub fn build(router: Router<State>, server: &Server) -> Router<State> {
let config = &server.config;
let mut router = router
.ruma_route(client::get_timezone_key_route)
.ruma_route(client::get_profile_key_route)
.ruma_route(client::set_profile_key_route)
.ruma_route(client::delete_profile_key_route)
.ruma_route(client::set_timezone_key_route)
.ruma_route(client::delete_timezone_key_route)
.ruma_route(client::appservice_ping)
.ruma_route(client::get_supported_versions_route)
.ruma_route(client::get_register_available_route)
.ruma_route(client::register_route)
.ruma_route(client::get_login_types_route)
.ruma_route(client::login_route)
.ruma_route(client::whoami_route)
.ruma_route(client::logout_route)
.ruma_route(client::logout_all_route)
.ruma_route(client::change_password_route)
.ruma_route(client::deactivate_route)
.ruma_route(client::third_party_route)
.ruma_route(client::request_3pid_management_token_via_email_route)
.ruma_route(client::request_3pid_management_token_via_msisdn_route)
.ruma_route(client::check_registration_token_validity)
.ruma_route(client::get_capabilities_route)
.ruma_route(client::get_pushrules_all_route)
.ruma_route(client::set_pushrule_route)
.ruma_route(client::get_pushrule_route)
.ruma_route(client::set_pushrule_enabled_route)
.ruma_route(client::get_pushrule_enabled_route)
.ruma_route(client::get_pushrule_actions_route)
.ruma_route(client::set_pushrule_actions_route)
.ruma_route(client::delete_pushrule_route)
.ruma_route(client::get_room_event_route)
.ruma_route(client::get_room_aliases_route)
.ruma_route(client::get_filter_route)
.ruma_route(client::create_filter_route)
.ruma_route(client::create_openid_token_route)
.ruma_route(client::set_global_account_data_route)
.ruma_route(client::set_room_account_data_route)
.ruma_route(client::get_global_account_data_route)
.ruma_route(client::get_room_account_data_route)
.ruma_route(client::set_displayname_route)
.ruma_route(client::get_displayname_route)
.ruma_route(client::set_avatar_url_route)
.ruma_route(client::get_avatar_url_route)
.ruma_route(client::get_profile_route)
.ruma_route(client::set_presence_route)
.ruma_route(client::get_presence_route)
.ruma_route(client::upload_keys_route)
.ruma_route(client::get_keys_route)
.ruma_route(client::claim_keys_route)
.ruma_route(client::create_backup_version_route)
.ruma_route(client::update_backup_version_route)
.ruma_route(client::delete_backup_version_route)
.ruma_route(client::get_latest_backup_info_route)
.ruma_route(client::get_backup_info_route)
.ruma_route(client::add_backup_keys_route)
.ruma_route(client::add_backup_keys_for_room_route)
.ruma_route(client::add_backup_keys_for_session_route)
.ruma_route(client::delete_backup_keys_for_room_route)
.ruma_route(client::delete_backup_keys_for_session_route)
.ruma_route(client::delete_backup_keys_route)
.ruma_route(client::get_backup_keys_for_room_route)
.ruma_route(client::get_backup_keys_for_session_route)
.ruma_route(client::get_backup_keys_route)
.ruma_route(client::set_read_marker_route)
.ruma_route(client::create_receipt_route)
.ruma_route(client::create_typing_event_route)
.ruma_route(client::create_room_route)
.ruma_route(client::redact_event_route)
.ruma_route(client::report_event_route)
.ruma_route(client::create_alias_route)
.ruma_route(client::delete_alias_route)
.ruma_route(client::get_alias_route)
.ruma_route(client::join_room_by_id_route)
.ruma_route(client::join_room_by_id_or_alias_route)
.ruma_route(client::joined_members_route)
.ruma_route(client::leave_room_route)
.ruma_route(client::forget_room_route)
.ruma_route(client::joined_rooms_route)
.ruma_route(client::kick_user_route)
.ruma_route(client::ban_user_route)
.ruma_route(client::unban_user_route)
.ruma_route(client::invite_user_route)
.ruma_route(client::set_room_visibility_route)
.ruma_route(client::get_room_visibility_route)
.ruma_route(client::get_public_rooms_route)
.ruma_route(client::get_public_rooms_filtered_route)
.ruma_route(client::search_users_route)
.ruma_route(client::get_member_events_route)
.ruma_route(client::get_protocols_route)
.ruma_route(&client::get_timezone_key_route)
.ruma_route(&client::get_profile_key_route)
.ruma_route(&client::set_profile_key_route)
.ruma_route(&client::delete_profile_key_route)
.ruma_route(&client::set_timezone_key_route)
.ruma_route(&client::delete_timezone_key_route)
.ruma_route(&client::appservice_ping)
.ruma_route(&client::get_supported_versions_route)
.ruma_route(&client::get_register_available_route)
.ruma_route(&client::register_route)
.ruma_route(&client::get_login_types_route)
.ruma_route(&client::login_route)
.ruma_route(&client::whoami_route)
.ruma_route(&client::logout_route)
.ruma_route(&client::logout_all_route)
.ruma_route(&client::change_password_route)
.ruma_route(&client::deactivate_route)
.ruma_route(&client::third_party_route)
.ruma_route(&client::request_3pid_management_token_via_email_route)
.ruma_route(&client::request_3pid_management_token_via_msisdn_route)
.ruma_route(&client::check_registration_token_validity)
.ruma_route(&client::get_capabilities_route)
.ruma_route(&client::get_pushrules_all_route)
.ruma_route(&client::set_pushrule_route)
.ruma_route(&client::get_pushrule_route)
.ruma_route(&client::set_pushrule_enabled_route)
.ruma_route(&client::get_pushrule_enabled_route)
.ruma_route(&client::get_pushrule_actions_route)
.ruma_route(&client::set_pushrule_actions_route)
.ruma_route(&client::delete_pushrule_route)
.ruma_route(&client::get_room_event_route)
.ruma_route(&client::get_room_aliases_route)
.ruma_route(&client::get_filter_route)
.ruma_route(&client::create_filter_route)
.ruma_route(&client::create_openid_token_route)
.ruma_route(&client::set_global_account_data_route)
.ruma_route(&client::set_room_account_data_route)
.ruma_route(&client::get_global_account_data_route)
.ruma_route(&client::get_room_account_data_route)
.ruma_route(&client::set_displayname_route)
.ruma_route(&client::get_displayname_route)
.ruma_route(&client::set_avatar_url_route)
.ruma_route(&client::get_avatar_url_route)
.ruma_route(&client::get_profile_route)
.ruma_route(&client::set_presence_route)
.ruma_route(&client::get_presence_route)
.ruma_route(&client::upload_keys_route)
.ruma_route(&client::get_keys_route)
.ruma_route(&client::claim_keys_route)
.ruma_route(&client::create_backup_version_route)
.ruma_route(&client::update_backup_version_route)
.ruma_route(&client::delete_backup_version_route)
.ruma_route(&client::get_latest_backup_info_route)
.ruma_route(&client::get_backup_info_route)
.ruma_route(&client::add_backup_keys_route)
.ruma_route(&client::add_backup_keys_for_room_route)
.ruma_route(&client::add_backup_keys_for_session_route)
.ruma_route(&client::delete_backup_keys_for_room_route)
.ruma_route(&client::delete_backup_keys_for_session_route)
.ruma_route(&client::delete_backup_keys_route)
.ruma_route(&client::get_backup_keys_for_room_route)
.ruma_route(&client::get_backup_keys_for_session_route)
.ruma_route(&client::get_backup_keys_route)
.ruma_route(&client::set_read_marker_route)
.ruma_route(&client::create_receipt_route)
.ruma_route(&client::create_typing_event_route)
.ruma_route(&client::create_room_route)
.ruma_route(&client::redact_event_route)
.ruma_route(&client::report_event_route)
.ruma_route(&client::create_alias_route)
.ruma_route(&client::delete_alias_route)
.ruma_route(&client::get_alias_route)
.ruma_route(&client::join_room_by_id_route)
.ruma_route(&client::join_room_by_id_or_alias_route)
.ruma_route(&client::joined_members_route)
.ruma_route(&client::leave_room_route)
.ruma_route(&client::forget_room_route)
.ruma_route(&client::joined_rooms_route)
.ruma_route(&client::kick_user_route)
.ruma_route(&client::ban_user_route)
.ruma_route(&client::unban_user_route)
.ruma_route(&client::invite_user_route)
.ruma_route(&client::set_room_visibility_route)
.ruma_route(&client::get_room_visibility_route)
.ruma_route(&client::get_public_rooms_route)
.ruma_route(&client::get_public_rooms_filtered_route)
.ruma_route(&client::search_users_route)
.ruma_route(&client::get_member_events_route)
.ruma_route(&client::get_protocols_route)
.route("/_matrix/client/unstable/thirdparty/protocols",
get(client::get_protocols_route_unstable))
.ruma_route(client::send_message_event_route)
.ruma_route(client::send_state_event_for_key_route)
.ruma_route(client::get_state_events_route)
.ruma_route(client::get_state_events_for_key_route)
.ruma_route(&client::send_message_event_route)
.ruma_route(&client::send_state_event_for_key_route)
.ruma_route(&client::get_state_events_route)
.ruma_route(&client::get_state_events_for_key_route)
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route
.route(
@ -140,46 +140,46 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
get(client::get_state_events_for_empty_key_route)
.put(client::send_state_event_for_empty_key_route),
)
.ruma_route(client::sync_events_route)
.ruma_route(client::sync_events_v4_route)
.ruma_route(client::get_context_route)
.ruma_route(client::get_message_events_route)
.ruma_route(client::search_events_route)
.ruma_route(client::turn_server_route)
.ruma_route(client::send_event_to_device_route)
.ruma_route(client::create_content_route)
.ruma_route(client::get_content_thumbnail_route)
.ruma_route(client::get_content_route)
.ruma_route(client::get_content_as_filename_route)
.ruma_route(client::get_media_preview_route)
.ruma_route(client::get_media_config_route)
.ruma_route(client::get_devices_route)
.ruma_route(client::get_device_route)
.ruma_route(client::update_device_route)
.ruma_route(client::delete_device_route)
.ruma_route(client::delete_devices_route)
.ruma_route(client::get_tags_route)
.ruma_route(client::update_tag_route)
.ruma_route(client::delete_tag_route)
.ruma_route(client::upload_signing_keys_route)
.ruma_route(client::upload_signatures_route)
.ruma_route(client::get_key_changes_route)
.ruma_route(client::get_pushers_route)
.ruma_route(client::set_pushers_route)
.ruma_route(client::upgrade_room_route)
.ruma_route(client::get_threads_route)
.ruma_route(client::get_relating_events_with_rel_type_and_event_type_route)
.ruma_route(client::get_relating_events_with_rel_type_route)
.ruma_route(client::get_relating_events_route)
.ruma_route(client::get_hierarchy_route)
.ruma_route(client::get_mutual_rooms_route)
.ruma_route(client::get_room_summary)
.ruma_route(&client::sync_events_route)
.ruma_route(&client::sync_events_v4_route)
.ruma_route(&client::get_context_route)
.ruma_route(&client::get_message_events_route)
.ruma_route(&client::search_events_route)
.ruma_route(&client::turn_server_route)
.ruma_route(&client::send_event_to_device_route)
.ruma_route(&client::create_content_route)
.ruma_route(&client::get_content_thumbnail_route)
.ruma_route(&client::get_content_route)
.ruma_route(&client::get_content_as_filename_route)
.ruma_route(&client::get_media_preview_route)
.ruma_route(&client::get_media_config_route)
.ruma_route(&client::get_devices_route)
.ruma_route(&client::get_device_route)
.ruma_route(&client::update_device_route)
.ruma_route(&client::delete_device_route)
.ruma_route(&client::delete_devices_route)
.ruma_route(&client::get_tags_route)
.ruma_route(&client::update_tag_route)
.ruma_route(&client::delete_tag_route)
.ruma_route(&client::upload_signing_keys_route)
.ruma_route(&client::upload_signatures_route)
.ruma_route(&client::get_key_changes_route)
.ruma_route(&client::get_pushers_route)
.ruma_route(&client::set_pushers_route)
.ruma_route(&client::upgrade_room_route)
.ruma_route(&client::get_threads_route)
.ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route)
.ruma_route(&client::get_relating_events_with_rel_type_route)
.ruma_route(&client::get_relating_events_route)
.ruma_route(&client::get_hierarchy_route)
.ruma_route(&client::get_mutual_rooms_route)
.ruma_route(&client::get_room_summary)
.route(
"/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary",
get(client::get_room_summary_legacy)
)
.ruma_route(client::well_known_support)
.ruma_route(client::well_known_client)
.ruma_route(&client::well_known_support)
.ruma_route(&client::well_known_client)
.route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
@ -187,35 +187,35 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
if config.allow_federation {
router = router
.ruma_route(server::get_server_version_route)
.ruma_route(&server::get_server_version_route)
.route("/_matrix/key/v2/server", get(server::get_server_keys_route))
.route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route))
.ruma_route(server::get_public_rooms_route)
.ruma_route(server::get_public_rooms_filtered_route)
.ruma_route(server::send_transaction_message_route)
.ruma_route(server::get_event_route)
.ruma_route(server::get_backfill_route)
.ruma_route(server::get_missing_events_route)
.ruma_route(server::get_event_authorization_route)
.ruma_route(server::get_room_state_route)
.ruma_route(server::get_room_state_ids_route)
.ruma_route(server::create_leave_event_template_route)
.ruma_route(server::create_leave_event_v1_route)
.ruma_route(server::create_leave_event_v2_route)
.ruma_route(server::create_join_event_template_route)
.ruma_route(server::create_join_event_v1_route)
.ruma_route(server::create_join_event_v2_route)
.ruma_route(server::create_invite_route)
.ruma_route(server::get_devices_route)
.ruma_route(server::get_room_information_route)
.ruma_route(server::get_profile_information_route)
.ruma_route(server::get_keys_route)
.ruma_route(server::claim_keys_route)
.ruma_route(server::get_openid_userinfo_route)
.ruma_route(server::get_hierarchy_route)
.ruma_route(server::well_known_server)
.ruma_route(server::get_content_route)
.ruma_route(server::get_content_thumbnail_route)
.ruma_route(&server::get_public_rooms_route)
.ruma_route(&server::get_public_rooms_filtered_route)
.ruma_route(&server::send_transaction_message_route)
.ruma_route(&server::get_event_route)
.ruma_route(&server::get_backfill_route)
.ruma_route(&server::get_missing_events_route)
.ruma_route(&server::get_event_authorization_route)
.ruma_route(&server::get_room_state_route)
.ruma_route(&server::get_room_state_ids_route)
.ruma_route(&server::create_leave_event_template_route)
.ruma_route(&server::create_leave_event_v1_route)
.ruma_route(&server::create_leave_event_v2_route)
.ruma_route(&server::create_join_event_template_route)
.ruma_route(&server::create_join_event_v1_route)
.ruma_route(&server::create_join_event_v2_route)
.ruma_route(&server::create_invite_route)
.ruma_route(&server::get_devices_route)
.ruma_route(&server::get_room_information_route)
.ruma_route(&server::get_profile_information_route)
.ruma_route(&server::get_keys_route)
.ruma_route(&server::claim_keys_route)
.ruma_route(&server::get_openid_userinfo_route)
.ruma_route(&server::get_hierarchy_route)
.ruma_route(&server::well_known_server)
.ruma_route(&server::get_content_route)
.ruma_route(&server::get_content_thumbnail_route)
.route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count));
} else {
router = router
@ -227,11 +227,11 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
if config.allow_legacy_media {
router = router
.ruma_route(client::get_media_config_legacy_route)
.ruma_route(client::get_media_preview_legacy_route)
.ruma_route(client::get_content_legacy_route)
.ruma_route(client::get_content_as_filename_legacy_route)
.ruma_route(client::get_content_thumbnail_legacy_route)
.ruma_route(&client::get_media_config_legacy_route)
.ruma_route(&client::get_media_preview_legacy_route)
.ruma_route(&client::get_content_legacy_route)
.ruma_route(&client::get_content_as_filename_legacy_route)
.ruma_route(&client::get_content_thumbnail_legacy_route)
.route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route))
.route("/_matrix/media/v1/upload", post(client::create_content_legacy_route))
.route(

View File

@ -10,7 +10,10 @@ use super::{auth, auth::Auth, request, request::Request};
use crate::{service::appservice::RegistrationInfo, State};
/// Extractor for Ruma request structs
pub(crate) struct Args<T> {
pub(crate) struct Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
/// Request struct body
pub(crate) body: T,
@ -38,7 +41,7 @@ pub(crate) struct Args<T> {
#[async_trait]
impl<T> FromRequest<State, Body> for Args<T>
where
T: IncomingRequest,
T: IncomingRequest + Send + Sync + 'static,
{
type Rejection = Error;
@ -57,7 +60,10 @@ where
}
}
impl<T> Deref for Args<T> {
impl<T> Deref for Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
type Target = T;
fn deref(&self) -> &Self::Target { &self.body }
@ -67,7 +73,7 @@ fn make_body<T>(
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
) -> Result<T>
where
T: IncomingRequest,
T: IncomingRequest + Send + Sync + 'static,
{
let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body {
let user_id = auth.sender_user.clone().unwrap_or_else(|| {
@ -77,15 +83,13 @@ where
let uiaa_request = json_body
.get("auth")
.and_then(|auth| auth.as_object())
.and_then(CanonicalJsonValue::as_object)
.and_then(|auth| auth.get("session"))
.and_then(|session| session.as_str())
.and_then(CanonicalJsonValue::as_str)
.and_then(|session| {
services.uiaa.get_uiaa_request(
&user_id,
&auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()),
session,
)
services
.uiaa
.get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
});
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {

View File

@ -44,8 +44,8 @@ pub(super) async fn auth(
let token = if let Some(token) = token {
if let Some(reg_info) = services.appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info))
} else if let Some((user_id, device_id)) = services.users.find_from_token(token)? {
Token::User((user_id, OwnedDeviceId::from(device_id)))
} else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await {
Token::User((user_id, device_id))
} else {
Token::Invalid
}
@ -98,7 +98,7 @@ pub(super) async fn auth(
))
}
},
(AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?),
(AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?),
(AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => {
Ok(Auth {
origin: None,
@ -150,7 +150,7 @@ pub(super) async fn auth(
}
}
fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> {
async fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> {
let user_id = request
.query
.user_id
@ -170,7 +170,7 @@ fn auth_appservice(services: &Services, request: &Request, info: Box<Registratio
return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace."));
}
if !services.users.exists(&user_id)? {
if !services.users.exists(&user_id).await {
return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist."));
}

View File

@ -1,5 +1,3 @@
use std::future::Future;
use axum::{
extract::FromRequestParts,
response::IntoResponse,
@ -7,19 +5,25 @@ use axum::{
Router,
};
use conduit::Result;
use futures::{Future, TryFutureExt};
use http::Method;
use ruma::api::IncomingRequest;
use super::{Ruma, RumaResponse, State};
pub(in super::super) trait RumaHandler<T> {
fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State>;
fn add_routes(&'static self, router: Router<State>) -> Router<State>;
}
pub(in super::super) trait RouterExt {
fn ruma_route<H, T>(self, handler: H) -> Self
fn ruma_route<H, T>(self, handler: &'static H) -> Self
where
H: RumaHandler<T>;
}
impl RouterExt for Router<State> {
fn ruma_route<H, T>(self, handler: H) -> Self
fn ruma_route<H, T>(self, handler: &'static H) -> Self
where
H: RumaHandler<T>,
{
@ -27,34 +31,28 @@ impl RouterExt for Router<State> {
}
}
pub(in super::super) trait RumaHandler<T> {
fn add_routes(&self, router: Router<State>) -> Router<State>;
fn add_route(&self, router: Router<State>, path: &str) -> Router<State>;
}
macro_rules! ruma_handler {
( $($tx:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<Req, Ret, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun
impl<Err, Req, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun
where
Req: IncomingRequest + Send + 'static,
Ret: IntoResponse,
Fut: Future<Output = Result<Req::OutgoingResponse, Ret>> + Send,
Fun: FnOnce($($tx,)* Ruma<Req>,) -> Fut + Clone + Send + Sync + 'static,
$( $tx: FromRequestParts<State> + Send + 'static, )*
Fun: Fn($($tx,)* Ruma<Req>,) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Req::OutgoingResponse, Err>> + Send,
Req: IncomingRequest + Send + Sync,
Err: IntoResponse + Send,
<Req as IncomingRequest>::OutgoingResponse: Send,
$( $tx: FromRequestParts<State> + Send + Sync + 'static, )*
{
fn add_routes(&self, router: Router<State>) -> Router<State> {
fn add_routes(&'static self, router: Router<State>) -> Router<State> {
Req::METADATA
.history
.all_paths()
.fold(router, |router, path| self.add_route(router, path))
}
fn add_route(&self, router: Router<State>, path: &str) -> Router<State> {
let handle = self.clone();
fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State> {
let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse);
let method = method_to_filter(&Req::METADATA.method);
let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) };
router.route(path, on(method, action))
}
}

View File

@ -5,13 +5,18 @@ use http::StatusCode;
use http_body_util::Full;
use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse};
pub(crate) struct RumaResponse<T>(pub(crate) T);
pub(crate) struct RumaResponse<T>(pub(crate) T)
where
T: OutgoingResponse;
impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { Self(t.into()) }
}
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
impl<T> IntoResponse for RumaResponse<T>
where
T: OutgoingResponse,
{
fn into_response(self) -> Response {
self.0
.try_into_http_response::<BytesMut>()

View File

@ -1,9 +1,13 @@
use std::cmp;
use axum::extract::State;
use conduit::{Error, Result};
use ruma::{
api::{client::error::ErrorKind, federation::backfill::get_backfill},
uint, user_id, MilliSecondsSinceUnixEpoch,
use conduit::{
is_equal_to,
utils::{IterStream, ReadyExt},
Err, PduCount, Result,
};
use futures::{FutureExt, StreamExt};
use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch};
use crate::Ruma;
@ -19,27 +23,35 @@ pub(crate) async fn get_backfill_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.rooms
.state_accessor
.is_world_readable(&body.room_id)?
&& !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)?
.is_world_readable(&body.room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
return Err!(Request(Forbidden("Server is not in room.")));
}
let until = body
.v
.iter()
.map(|event_id| services.rooms.timeline.get_pdu_count(event_id))
.filter_map(|r| r.ok().flatten())
.max()
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?;
.stream()
.filter_map(|event_id| {
services
.rooms
.timeline
.get_pdu_count(event_id)
.map(Result::ok)
})
.ready_fold(PduCount::Backfilled(0), cmp::max)
.await;
let limit = body
.limit
@ -47,31 +59,37 @@ pub(crate) async fn get_backfill_route(
.try_into()
.expect("UInt could not be converted to usize");
let all_events = services
let pdus = services
.rooms
.timeline
.pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)?
.take(limit);
.pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)
.await?
.take(limit)
.filter_map(|(_, pdu)| async move {
if !services
.rooms
.state_accessor
.server_can_see_event(origin, &pdu.room_id, &pdu.event_id)
.await
.is_ok_and(is_equal_to!(true))
{
return None;
}
let events = all_events
.filter_map(Result::ok)
.filter(|(_, e)| {
matches!(
services
.rooms
.state_accessor
.server_can_see_event(origin, &e.room_id, &e.event_id,),
Ok(true),
)
services
.rooms
.timeline
.get_pdu_json(&pdu.event_id)
.await
.ok()
})
.map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id))
.filter_map(|r| r.ok().flatten())
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect();
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect()
.await;
Ok(get_backfill::v1::Response {
origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus: events,
pdus,
})
}

View File

@ -1,9 +1,6 @@
use axum::extract::State;
use conduit::{Error, Result};
use ruma::{
api::{client::error::ErrorKind, federation::event::get_event},
MilliSecondsSinceUnixEpoch, RoomId,
};
use conduit::{err, Err, Result};
use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId};
use crate::Ruma;
@ -21,34 +18,46 @@ pub(crate) async fn get_event_route(
let event = services
.rooms
.timeline
.get_pdu_json(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
.get_pdu_json(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("Event not found."))))?;
let room_id_str = event
.get("room_id")
.and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database."))?;
.ok_or_else(|| err!(Database("Invalid event in database.")))?;
let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
<&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?;
if !services.rooms.state_accessor.is_world_readable(room_id)?
&& !services.rooms.state_cache.server_in_room(origin, room_id)?
if !services
.rooms
.state_accessor
.is_world_readable(room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
return Err!(Request(Forbidden("Server is not in room.")));
}
if !services
.rooms
.state_accessor
.server_can_see_event(origin, room_id, &body.event_id)?
.server_can_see_event(origin, room_id, &body.event_id)
.await?
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event."));
return Err!(Request(Forbidden("Server is not allowed to see event.")));
}
Ok(get_event::v1::Response {
origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdu: services.sending.convert_to_outgoing_federation_event(event),
pdu: services
.sending
.convert_to_outgoing_federation_event(event)
.await,
})
}

View File

@ -2,6 +2,7 @@ use std::sync::Arc;
use axum::extract::State;
use conduit::{Error, Result};
use futures::StreamExt;
use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
RoomId,
@ -22,16 +23,18 @@ pub(crate) async fn get_event_authorization_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.rooms
.state_accessor
.is_world_readable(&body.room_id)?
&& !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)?
.is_world_readable(&body.room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
}
@ -39,8 +42,9 @@ pub(crate) async fn get_event_authorization_route(
let event = services
.rooms
.timeline
.get_pdu_json(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
.get_pdu_json(&body.event_id)
.await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
let room_id_str = event
.get("room_id")
@ -50,16 +54,17 @@ pub(crate) async fn get_event_authorization_route(
let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
let auth_chain_ids = services
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
.await?;
.await?
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect()
.await;
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
auth_chain,
})
}

View File

@ -18,16 +18,18 @@ pub(crate) async fn get_missing_events_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.rooms
.state_accessor
.is_world_readable(&body.room_id)?
&& !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)?
.is_world_readable(&body.room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room"));
}
@ -43,7 +45,12 @@ pub(crate) async fn get_missing_events_route(
let mut i: usize = 0;
while i < queued_events.len() && events.len() < limit {
if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? {
if let Ok(pdu) = services
.rooms
.timeline
.get_pdu_json(&queued_events[i])
.await
{
let room_id_str = pdu
.get("room_id")
.and_then(|val| val.as_str())
@ -64,7 +71,8 @@ pub(crate) async fn get_missing_events_route(
if !services
.rooms
.state_accessor
.server_can_see_event(origin, &body.room_id, &queued_events[i])?
.server_can_see_event(origin, &body.room_id, &queued_events[i])
.await?
{
i = i.saturating_add(1);
continue;
@ -81,7 +89,12 @@ pub(crate) async fn get_missing_events_route(
)
.map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?,
);
events.push(services.sending.convert_to_outgoing_federation_event(pdu));
events.push(
services
.sending
.convert_to_outgoing_federation_event(pdu)
.await,
);
}
i = i.saturating_add(1);
}

View File

@ -12,7 +12,7 @@ pub(crate) async fn get_hierarchy_route(
) -> Result<get_hierarchy::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated");
if services.rooms.metadata.exists(&body.room_id)? {
if services.rooms.metadata.exists(&body.room_id).await {
services
.rooms
.spaces

View File

@ -24,7 +24,8 @@ pub(crate) async fn create_invite_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.globals
@ -98,7 +99,8 @@ pub(crate) async fn create_invite_route(
services
.rooms
.event_handler
.acl_check(invited_user.server_name(), &body.room_id)?;
.acl_check(invited_user.server_name(), &body.room_id)
.await?;
ruma::signatures::hash_and_sign_event(
services.globals.server_name().as_str(),
@ -128,14 +130,14 @@ pub(crate) async fn create_invite_route(
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?;
if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? {
if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"This room is banned on this homeserver.",
));
}
if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? {
if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"This server does not allow room invites.",
@ -159,22 +161,28 @@ pub(crate) async fn create_invite_route(
if !services
.rooms
.state_cache
.server_in_room(services.globals.server_name(), &body.room_id)?
.server_in_room(services.globals.server_name(), &body.room_id)
.await
{
services.rooms.state_cache.update_membership(
&body.room_id,
&invited_user,
RoomMemberEventContent::new(MembershipState::Invite),
&sender,
Some(invite_state),
body.via.clone(),
true,
)?;
services
.rooms
.state_cache
.update_membership(
&body.room_id,
&invited_user,
RoomMemberEventContent::new(MembershipState::Invite),
&sender,
Some(invite_state),
body.via.clone(),
true,
)
.await?;
}
Ok(create_invite::v2::Response {
event: services
.sending
.convert_to_outgoing_federation_event(signed_event),
.convert_to_outgoing_federation_event(signed_event)
.await,
})
}

View File

@ -1,4 +1,6 @@
use axum::extract::State;
use conduit::utils::{IterStream, ReadyExt};
use futures::StreamExt;
use ruma::{
api::{client::error::ErrorKind, federation::membership::prepare_join_event},
events::{
@ -24,7 +26,7 @@ use crate::{
pub(crate) async fn create_join_event_template_route(
State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>,
) -> Result<prepare_join_event::v1::Response> {
if !services.rooms.metadata.exists(&body.room_id)? {
if !services.rooms.metadata.exists(&body.room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
}
@ -40,7 +42,8 @@ pub(crate) async fn create_join_event_template_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if services
.globals
@ -73,7 +76,7 @@ pub(crate) async fn create_join_event_template_route(
}
}
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
@ -81,22 +84,24 @@ pub(crate) async fn create_join_event_template_route(
.rooms
.state_cache
.is_left(&body.user_id, &body.room_id)
.unwrap_or(true))
&& user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)?
.await)
&& user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await?
{
let auth_user = services
.rooms
.state_cache
.room_members(&body.room_id)
.filter_map(Result::ok)
.filter(|user| user.server_name() == services.globals.server_name())
.find(|user| {
.ready_filter(|user| user.server_name() == services.globals.server_name())
.filter(|user| {
services
.rooms
.state_accessor
.user_can_invite(&body.room_id, user, &body.user_id, &state_lock)
.unwrap_or(false)
});
})
.boxed()
.next()
.await
.map(ToOwned::to_owned);
if auth_user.is_some() {
auth_user
@ -110,7 +115,7 @@ pub(crate) async fn create_join_event_template_route(
None
};
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
if !body.ver.contains(&room_version_id) {
return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion {
@ -132,19 +137,23 @@ pub(crate) async fn create_join_event_template_route(
})
.expect("member event is valid value");
let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(body.user_id.to_string()),
redacts: None,
timestamp: None,
},
&body.user_id,
&body.room_id,
&state_lock,
)?;
let (_pdu, mut pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(body.user_id.to_string()),
redacts: None,
timestamp: None,
},
&body.user_id,
&body.room_id,
&state_lock,
)
.await?;
drop(state_lock);
@ -161,7 +170,7 @@ pub(crate) async fn create_join_event_template_route(
/// This doesn't check the current user's membership. This should be done
/// externally, either by using the state cache or attempting to authorize the
/// event.
pub(crate) fn user_can_perform_restricted_join(
pub(crate) async fn user_can_perform_restricted_join(
services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId,
) -> Result<bool> {
use RoomVersionId::*;
@ -169,18 +178,15 @@ pub(crate) fn user_can_perform_restricted_join(
let join_rules_event = services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")
.await;
let Some(join_rules_event_content) = join_rules_event
.as_ref()
.map(|join_rules_event| {
serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event in database: {e}");
Error::bad_database("Invalid join rules event in database")
})
let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| {
serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event in database: {e}");
Error::bad_database("Invalid join rules event in database")
})
.transpose()?
else {
}) else {
return Ok(false);
};
@ -201,13 +207,10 @@ pub(crate) fn user_can_perform_restricted_join(
None
}
})
.any(|m| {
services
.rooms
.state_cache
.is_joined(user_id, &m.room_id)
.unwrap_or(false)
}) {
.stream()
.any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id))
.await
{
Ok(true)
} else {
Err(Error::BadRequest(

View File

@ -18,7 +18,7 @@ use crate::{service::pdu::PduBuilder, Ruma};
pub(crate) async fn create_leave_event_template_route(
State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>,
) -> Result<prepare_leave_event::v1::Response> {
if !services.rooms.metadata.exists(&body.room_id)? {
if !services.rooms.metadata.exists(&body.room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
}
@ -34,9 +34,10 @@ pub(crate) async fn create_leave_event_template_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let content = to_raw_value(&RoomMemberEventContent {
avatar_url: None,
@ -50,19 +51,23 @@ pub(crate) async fn create_leave_event_template_route(
})
.expect("member event is valid value");
let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(body.user_id.to_string()),
redacts: None,
timestamp: None,
},
&body.user_id,
&body.room_id,
&state_lock,
)?;
let (_pdu, mut pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content,
unsigned: None,
state_key: Some(body.user_id.to_string()),
redacts: None,
timestamp: None,
},
&body.user_id,
&body.room_id,
&state_lock,
)
.await?;
drop(state_lock);

View File

@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route(
State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>,
) -> Result<get_openid_userinfo::v1::Response> {
Ok(get_openid_userinfo::v1::Response::new(
services.users.find_from_openid_token(&body.access_token)?,
services
.users
.find_from_openid_token(&body.access_token)
.await?,
))
}

View File

@ -1,7 +1,8 @@
use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{Error, Result};
use conduit::{err, Error, Result};
use futures::StreamExt;
use get_profile_information::v1::ProfileField;
use rand::seq::SliceRandom;
use ruma::{
@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route(
let room_id = services
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?;
.resolve_local_alias(&body.room_alias)
.await
.map_err(|_| err!(Request(NotFound("Room alias not found."))))?;
let mut servers: Vec<OwnedServerName> = services
.rooms
.state_cache
.room_servers(&room_id)
.filter_map(Result::ok)
.collect();
.map(ToOwned::to_owned)
.collect()
.await;
servers.sort_unstable();
servers.dedup();
@ -82,30 +85,31 @@ pub(crate) async fn get_profile_information_route(
match &body.field {
Some(ProfileField::DisplayName) => {
displayname = services.users.displayname(&body.user_id)?;
displayname = services.users.displayname(&body.user_id).await.ok();
},
Some(ProfileField::AvatarUrl) => {
avatar_url = services.users.avatar_url(&body.user_id)?;
blurhash = services.users.blurhash(&body.user_id)?;
avatar_url = services.users.avatar_url(&body.user_id).await.ok();
blurhash = services.users.blurhash(&body.user_id).await.ok();
},
Some(custom_field) => {
if let Some(value) = services
if let Ok(value) = services
.users
.profile_key(&body.user_id, custom_field.as_str())?
.profile_key(&body.user_id, custom_field.as_str())
.await
{
custom_profile_fields.insert(custom_field.to_string(), value);
}
},
None => {
displayname = services.users.displayname(&body.user_id)?;
avatar_url = services.users.avatar_url(&body.user_id)?;
blurhash = services.users.blurhash(&body.user_id)?;
tz = services.users.timezone(&body.user_id)?;
displayname = services.users.displayname(&body.user_id).await.ok();
avatar_url = services.users.avatar_url(&body.user_id).await.ok();
blurhash = services.users.blurhash(&body.user_id).await.ok();
tz = services.users.timezone(&body.user_id).await.ok();
custom_profile_fields = services
.users
.all_profile_keys(&body.user_id)
.filter_map(Result::ok)
.collect();
.collect()
.await;
},
}

View File

@ -2,7 +2,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use conduit::{debug, debug_warn, err, trace, warn, Err};
use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result};
use futures::StreamExt;
use ruma::{
api::{
client::error::ErrorKind,
@ -23,10 +24,13 @@ use tokio::sync::RwLock;
use crate::{
services::Services,
utils::{self},
Error, Result, Ruma,
Ruma,
};
type ResolvedMap = BTreeMap<OwnedEventId, Result<(), Error>>;
const PDU_LIMIT: usize = 50;
const EDU_LIMIT: usize = 100;
type ResolvedMap = BTreeMap<OwnedEventId, Result<()>>;
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
@ -44,12 +48,16 @@ pub(crate) async fn send_transaction_message_route(
)));
}
if body.pdus.len() > 50_usize {
return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction")));
if body.pdus.len() > PDU_LIMIT {
return Err!(Request(Forbidden(
"Not allowed to send more than {PDU_LIMIT} PDUs in one transaction"
)));
}
if body.edus.len() > 100_usize {
return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction")));
if body.edus.len() > EDU_LIMIT {
return Err!(Request(Forbidden(
"Not allowed to send more than {EDU_LIMIT} EDUs in one transaction"
)));
}
let txn_start_time = Instant::now();
@ -62,8 +70,8 @@ pub(crate) async fn send_transaction_message_route(
"Starting txn",
);
let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?;
handle_edus(&services, &client, &body, origin).await?;
let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await;
handle_edus(&services, &client, &body, origin).await;
debug!(
pdus = ?body.pdus.len(),
@ -85,10 +93,10 @@ pub(crate) async fn send_transaction_message_route(
async fn handle_pdus(
services: &Services, _client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName,
txn_start_time: &Instant,
) -> Result<ResolvedMap> {
) -> ResolvedMap {
let mut parsed_pdus = Vec::with_capacity(body.pdus.len());
for pdu in &body.pdus {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await {
Ok(t) => t,
Err(e) => {
debug_warn!("Could not parse PDU: {e}");
@ -151,38 +159,34 @@ async fn handle_pdus(
}
}
Ok(resolved_map)
resolved_map
}
async fn handle_edus(
services: &Services, client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName,
) -> Result<()> {
) {
for edu in body
.edus
.iter()
.filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok())
{
match edu {
Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?,
Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?,
Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?,
Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?,
Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?,
Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?,
Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await,
Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await,
Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await,
Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await,
Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await,
Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await,
Edu::_Custom(ref _custom) => {
debug_warn!(?body.edus, "received custom/unknown EDU");
},
}
}
Ok(())
}
async fn handle_edu_presence(
services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent,
) -> Result<()> {
async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) {
if !services.globals.allow_incoming_presence() {
return Ok(());
return;
}
for update in presence.push {
@ -194,23 +198,24 @@ async fn handle_edu_presence(
continue;
}
services.presence.set_presence(
&update.user_id,
&update.presence,
Some(update.currently_active),
Some(update.last_active_ago),
update.status_msg.clone(),
)?;
services
.presence
.set_presence(
&update.user_id,
&update.presence,
Some(update.currently_active),
Some(update.last_active_ago),
update.status_msg.clone(),
)
.await
.log_err()
.ok();
}
Ok(())
}
async fn handle_edu_receipt(
services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent,
) -> Result<()> {
async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) {
if !services.globals.allow_incoming_read_receipts() {
return Ok(());
return;
}
for (room_id, room_updates) in receipt.receipts {
@ -218,6 +223,7 @@ async fn handle_edu_receipt(
.rooms
.event_handler
.acl_check(origin, &room_id)
.await
.is_err()
{
debug_warn!(
@ -240,8 +246,8 @@ async fn handle_edu_receipt(
.rooms
.state_cache
.room_members(&room_id)
.filter_map(Result::ok)
.any(|member| member.server_name() == user_id.server_name())
.ready_any(|member| member.server_name() == user_id.server_name())
.await
{
for event_id in &user_updates.event_ids {
let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]);
@ -255,7 +261,8 @@ async fn handle_edu_receipt(
services
.rooms
.read_receipt
.readreceipt_update(&user_id, &room_id, &event)?;
.readreceipt_update(&user_id, &room_id, &event)
.await;
}
} else {
debug_warn!(
@ -266,15 +273,11 @@ async fn handle_edu_receipt(
}
}
}
Ok(())
}
async fn handle_edu_typing(
services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent,
) -> Result<()> {
async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) {
if !services.globals.config.allow_incoming_typing {
return Ok(());
return;
}
if typing.user_id.server_name() != origin {
@ -282,26 +285,28 @@ async fn handle_edu_typing(
%typing.user_id, %origin,
"received typing EDU for user not belonging to origin"
);
return Ok(());
return;
}
if services
.rooms
.event_handler
.acl_check(typing.user_id.server_name(), &typing.room_id)
.await
.is_err()
{
debug_warn!(
%typing.user_id, %typing.room_id, %origin,
"received typing EDU for ACL'd user's server"
);
return Ok(());
return;
}
if services
.rooms
.state_cache
.is_joined(&typing.user_id, &typing.room_id)?
.is_joined(&typing.user_id, &typing.room_id)
.await
{
if typing.typing {
let timeout = utils::millis_since_unix_epoch().saturating_add(
@ -315,28 +320,29 @@ async fn handle_edu_typing(
.rooms
.typing
.typing_add(&typing.user_id, &typing.room_id, timeout)
.await?;
.await
.log_err()
.ok();
} else {
services
.rooms
.typing
.typing_remove(&typing.user_id, &typing.room_id)
.await?;
.await
.log_err()
.ok();
}
} else {
debug_warn!(
%typing.user_id, %typing.room_id, %origin,
"received typing EDU for user not in room"
);
return Ok(());
}
Ok(())
}
async fn handle_edu_device_list_update(
services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent,
) -> Result<()> {
) {
let DeviceListUpdateContent {
user_id,
..
@ -347,17 +353,15 @@ async fn handle_edu_device_list_update(
%user_id, %origin,
"received device list update EDU for user not belonging to origin"
);
return Ok(());
return;
}
services.users.mark_device_key_update(&user_id)?;
Ok(())
services.users.mark_device_key_update(&user_id).await;
}
async fn handle_edu_direct_to_device(
services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent,
) -> Result<()> {
) {
let DirectDeviceContent {
sender,
ev_type,
@ -370,45 +374,52 @@ async fn handle_edu_direct_to_device(
%sender, %origin,
"received direct to device EDU for user not belonging to origin"
);
return Ok(());
return;
}
// Check if this is a new transaction id
if services
.transaction_ids
.existing_txnid(&sender, None, &message_id)?
.is_some()
.existing_txnid(&sender, None, &message_id)
.await
.is_ok()
{
return Ok(());
return;
}
for (target_user_id, map) in &messages {
for (target_device_id_maybe, event) in map {
let Ok(event) = event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))
else {
continue;
};
let ev_type = ev_type.to_string();
match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services.users.add_to_device_event(
&sender,
target_user_id,
target_device_id,
&ev_type.to_string(),
event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?,
)?;
services
.users
.add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event)
.await;
},
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services.users.all_device_ids(target_user_id) {
services.users.add_to_device_event(
&sender,
target_user_id,
&target_device_id?,
&ev_type.to_string(),
event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?,
)?;
}
let (sender, ev_type, event) = (&sender, &ev_type, &event);
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event(
sender,
target_user_id,
target_device_id,
ev_type,
event.clone(),
)
})
.await;
},
}
}
@ -417,14 +428,12 @@ async fn handle_edu_direct_to_device(
// Save transaction id with empty data
services
.transaction_ids
.add_txnid(&sender, None, &message_id, &[])?;
Ok(())
.add_txnid(&sender, None, &message_id, &[]);
}
async fn handle_edu_signing_key_update(
services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent,
) -> Result<()> {
) {
let SigningKeyUpdateContent {
user_id,
master_key,
@ -436,14 +445,15 @@ async fn handle_edu_signing_key_update(
%user_id, %origin,
"received signing key update EDU from server that does not belong to user's server"
);
return Ok(());
return;
}
if let Some(master_key) = master_key {
services
.users
.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?;
.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)
.await
.log_err()
.ok();
}
Ok(())
}

View File

@ -3,7 +3,8 @@
use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result};
use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{
api::{client::error::ErrorKind, federation::membership::create_join_event},
events::{
@ -22,27 +23,32 @@ use crate::Ruma;
async fn create_join_event(
services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<create_join_event::v1::RoomState> {
if !services.rooms.metadata.exists(room_id)? {
if !services.rooms.metadata.exists(room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
}
// ACL check origin server
services.rooms.event_handler.acl_check(origin, room_id)?;
services
.rooms
.event_handler
.acl_check(origin, room_id)
.await?;
// We need to return the state prior to joining, let's keep a reference to that
// here
let shortstatehash = services
.rooms
.state
.get_room_shortstatehash(room_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?;
.get_room_shortstatehash(room_id)
.await
.map_err(|_| err!(Request(NotFound("Event state not found."))))?;
let pub_key_map = RwLock::new(BTreeMap::new());
// let mut auth_cache = EventMap::new();
// We do not add the event_id field to the pdu here because of signature and
// hashes checks
let room_version_id = services.rooms.state.get_room_version(room_id)?;
let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
@ -97,7 +103,8 @@ async fn create_join_event(
services
.rooms
.event_handler
.acl_check(sender.server_name(), room_id)?;
.acl_check(sender.server_name(), room_id)
.await?;
// check if origin server is trying to send for another server
if sender.server_name() != origin {
@ -126,7 +133,9 @@ async fn create_join_event(
if content
.join_authorized_via_users_server
.is_some_and(|user| services.globals.user_is_local(&user))
&& super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default()
&& super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id)
.await
.unwrap_or_default()
{
ruma::signatures::hash_and_sign_event(
services.globals.server_name().as_str(),
@ -158,12 +167,14 @@ async fn create_join_event(
.mutex_federation
.lock(room_id)
.await;
let pdu_id: Vec<u8> = services
.rooms
.event_handler
.handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map)
.await?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?;
drop(mutex_lock);
let state_ids = services
@ -171,29 +182,43 @@ async fn create_join_event(
.state_accessor
.state_full_ids(shortstatehash)
.await?;
let auth_chain_ids = services
let state = state_ids
.iter()
.try_stream()
.and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id))
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?;
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(room_id, state_ids.values().cloned().collect())
.await?
.map(Ok)
.and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await })
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?;
services.sending.send_pdu_room(room_id, &pdu_id)?;
services.sending.send_pdu_room(room_id, &pdu_id).await?;
Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten())
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
state: state_ids
.iter()
.filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten())
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
auth_chain,
state,
// Event field is required if the room version supports restricted join rules.
event: Some(
to_raw_value(&CanonicalJsonValue::Object(value))
.expect("To raw json should not fail since only change was adding signature"),
),
event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(),
})
}

View File

@ -3,7 +3,7 @@
use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{Error, Result};
use conduit::{utils::ReadyExt, Error, Result};
use ruma::{
api::{client::error::ErrorKind, federation::membership::create_leave_event},
events::{
@ -49,18 +49,22 @@ pub(crate) async fn create_leave_event_v2_route(
async fn create_leave_event(
services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<()> {
if !services.rooms.metadata.exists(room_id)? {
if !services.rooms.metadata.exists(room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
}
// ACL check origin
services.rooms.event_handler.acl_check(origin, room_id)?;
services
.rooms
.event_handler
.acl_check(origin, room_id)
.await?;
let pub_key_map = RwLock::new(BTreeMap::new());
// We do not add the event_id field to the pdu here because of signature and
// hashes checks
let room_version_id = services.rooms.state.get_room_version(room_id)?;
let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
return Err(Error::BadRequest(
@ -114,7 +118,8 @@ async fn create_leave_event(
services
.rooms
.event_handler
.acl_check(sender.server_name(), room_id)?;
.acl_check(sender.server_name(), room_id)
.await?;
if sender.server_name() != origin {
return Err(Error::BadRequest(
@ -173,10 +178,9 @@ async fn create_leave_event(
.rooms
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.filter(|server| !services.globals.server_is_ours(server));
.ready_filter(|server| !services.globals.server_is_ours(server));
services.sending.send_pdu_servers(servers, &pdu_id)?;
services.sending.send_pdu_servers(servers, &pdu_id).await?;
Ok(())
}

View File

@ -1,8 +1,9 @@
use std::sync::Arc;
use axum::extract::State;
use conduit::{Error, Result};
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state};
use conduit::{err, result::LogErr, utils::IterStream, Err, Result};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::api::federation::event::get_room_state;
use crate::Ruma;
@ -17,56 +18,66 @@ pub(crate) async fn get_room_state_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.rooms
.state_accessor
.is_world_readable(&body.room_id)?
&& !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)?
.is_world_readable(&body.room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
return Err!(Request(Forbidden("Server is not in room.")));
}
let shortstatehash = services
.rooms
.state_accessor
.pdu_shortstatehash(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?;
.pdu_shortstatehash(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("PDU state not found."))))?;
let pdus = services
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?
.into_values()
.map(|id| {
.await
.log_err()
.map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))?
.values()
.try_stream()
.and_then(|id| services.rooms.timeline.get_pdu_json(id))
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.collect();
.try_collect()
.await?;
let auth_chain_ids = services
let auth_chain = services
.rooms
.auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?
.map(Ok)
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?;
Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids
.filter_map(|id| {
services
.rooms
.timeline
.get_pdu_json(&id)
.ok()?
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
})
.collect(),
auth_chain,
pdus,
})
}

View File

@ -1,9 +1,11 @@
use std::sync::Arc;
use axum::extract::State;
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids};
use conduit::{err, Err};
use futures::StreamExt;
use ruma::api::federation::event::get_room_state_ids;
use crate::{Error, Result, Ruma};
use crate::{Result, Ruma};
/// # `GET /_matrix/federation/v1/state_ids/{roomId}`
///
@ -17,31 +19,35 @@ pub(crate) async fn get_room_state_ids_route(
services
.rooms
.event_handler
.acl_check(origin, &body.room_id)?;
.acl_check(origin, &body.room_id)
.await?;
if !services
.rooms
.state_accessor
.is_world_readable(&body.room_id)?
&& !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)?
.is_world_readable(&body.room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, &body.room_id)
.await
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
return Err!(Request(Forbidden("Server is not in room.")));
}
let shortstatehash = services
.rooms
.state_accessor
.pdu_shortstatehash(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?;
.pdu_shortstatehash(&body.event_id)
.await
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
let pdu_ids = services
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?
.await
.map_err(|_| err!(Request(NotFound("State ids not found"))))?
.into_values()
.map(|id| (*id).to_owned())
.collect();
@ -50,10 +56,13 @@ pub(crate) async fn get_room_state_ids_route(
.rooms
.auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?;
.await?
.map(|id| (*id).to_owned())
.collect()
.await;
Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
auth_chain_ids,
pdu_ids,
})
}

View File

@ -1,5 +1,6 @@
use axum::extract::State;
use conduit::{Error, Result};
use futures::{FutureExt, StreamExt, TryFutureExt};
use ruma::api::{
client::error::ErrorKind,
federation::{
@ -28,41 +29,51 @@ pub(crate) async fn get_devices_route(
let origin = body.origin.as_ref().expect("server is authenticated");
let user_id = &body.user_id;
Ok(get_devices::v1::Response {
user_id: body.user_id.clone(),
user_id: user_id.clone(),
stream_id: services
.users
.get_devicelist_version(&body.user_id)?
.get_devicelist_version(user_id)
.await
.unwrap_or(0)
.try_into()
.expect("version will not grow that large"),
.try_into()?,
devices: services
.users
.all_devices_metadata(&body.user_id)
.filter_map(Result::ok)
.filter_map(|metadata| {
let device_id_string = metadata.device_id.as_str().to_owned();
.all_devices_metadata(user_id)
.filter_map(|metadata| async move {
let device_id = metadata.device_id.clone();
let device_id_clone = device_id.clone();
let device_id_string = device_id.as_str().to_owned();
let device_display_name = if services.globals.allow_device_name_federation() {
metadata.display_name
metadata.display_name.clone()
} else {
Some(device_id_string)
};
Some(UserDevice {
keys: services
.users
.get_device_keys(&body.user_id, &metadata.device_id)
.ok()??,
device_id: metadata.device_id,
device_display_name,
})
services
.users
.get_device_keys(user_id, &device_id_clone)
.map_ok(|keys| UserDevice {
device_id,
keys,
device_display_name,
})
.map(Result::ok)
.await
})
.collect(),
.collect()
.await,
master_key: services
.users
.get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?,
.get_master_key(None, &body.user_id, &|u| u.server_name() == origin)
.await
.ok(),
self_signing_key: services
.users
.get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?,
.get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)
.await
.ok(),
})
}

View File

@ -67,6 +67,7 @@ ctor.workspace = true
cyborgtime.workspace = true
either.workspace = true
figment.workspace = true
futures.workspace = true
http-body-util.workspace = true
http.workspace = true
image.workspace = true

View File

@ -86,7 +86,7 @@ pub enum Error {
#[error("There was a problem with the '{0}' directive in your configuration: {1}")]
Config(&'static str, Cow<'static, str>),
#[error("{0}")]
Conflict(&'static str), // This is only needed for when a room alias already exists
Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists
#[error(transparent)]
ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError),
#[error("{0}")]
@ -107,6 +107,8 @@ pub enum Error {
Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode),
#[error(transparent)]
Ruma(#[from] ruma::api::client::error::Error),
#[error(transparent)]
StateRes(#[from] ruma::state_res::Error),
#[error("uiaa")]
Uiaa(ruma::api::client::uiaa::UiaaInfo),

View File

@ -3,8 +3,6 @@ mod count;
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
pub use builder::PduBuilder;
pub use count::PduCount;
use ruma::{
canonical_json::redact_content_in_place,
events::{
@ -23,7 +21,8 @@ use serde_json::{
value::{to_raw_value, RawValue as RawJsonValue},
};
use crate::{err, warn, Error};
pub use self::{builder::PduBuilder, count::PduCount};
use crate::{err, warn, Error, Result};
#[derive(Deserialize)]
struct ExtractRedactedBecause {
@ -65,11 +64,12 @@ pub struct PduEvent {
impl PduEvent {
#[tracing::instrument(skip(self), level = "debug")]
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> {
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> {
self.unsigned = None;
let mut content = serde_json::from_str(self.content.get())
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
.map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?;
@ -98,31 +98,38 @@ impl PduEvent {
unsigned.redacted_because.is_some()
}
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
if let Some(unsigned) = &self.unsigned {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
unsigned.remove("transaction_id");
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
}
pub fn remove_transaction_id(&mut self) -> Result<()> {
let Some(unsigned) = &self.unsigned else {
return Ok(());
};
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?;
unsigned.remove("transaction_id");
self.unsigned = to_raw_value(&unsigned)
.map(Some)
.expect("unsigned is valid");
Ok(())
}
pub fn add_age(&mut self) -> crate::Result<()> {
pub fn add_age(&mut self) -> Result<()> {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned
.as_ref()
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
.map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?;
// deliberately allowing for the possibility of negative age
let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into();
let then: i128 = self.origin_server_ts.into();
let this_age = now.saturating_sub(then);
unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap());
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid"));
self.unsigned = to_raw_value(&unsigned)
.map(Some)
.expect("unsigned is valid");
Ok(())
}
@ -369,9 +376,9 @@ impl state_res::Event for PduEvent {
fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) }
fn prev_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.prev_events.iter() }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) }
fn auth_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.auth_events.iter() }
fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
}
@ -395,7 +402,7 @@ impl Ord for PduEvent {
/// CanonicalJsonValue>`.
pub fn gen_event_id_canonical_json(
pdu: &RawJsonValue, room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
.map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?;

View File

@ -1,18 +1,14 @@
use std::fmt;
use std::fmt::Debug;
use tracing::Level;
use super::{DebugInspect, Result};
use crate::error;
pub trait LogDebugErr<T, E>
where
E: fmt::Debug,
{
pub trait LogDebugErr<T, E: Debug> {
#[must_use]
fn err_debug_log(self, level: Level) -> Self;
#[inline]
#[must_use]
fn log_debug_err(self) -> Self
where
@ -22,15 +18,9 @@ where
}
}
impl<T, E> LogDebugErr<T, E> for Result<T, E>
where
E: fmt::Debug,
{
impl<T, E: Debug> LogDebugErr<T, E> for Result<T, E> {
#[inline]
fn err_debug_log(self, level: Level) -> Self
where
Self: Sized,
{
fn err_debug_log(self, level: Level) -> Self {
self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level))
}
}

View File

@ -1,18 +1,14 @@
use std::fmt;
use std::fmt::Display;
use tracing::Level;
use super::Result;
use crate::error;
pub trait LogErr<T, E>
where
E: fmt::Display,
{
pub trait LogErr<T, E: Display> {
#[must_use]
fn err_log(self, level: Level) -> Self;
#[inline]
#[must_use]
fn log_err(self) -> Self
where
@ -22,15 +18,7 @@ where
}
}
impl<T, E> LogErr<T, E> for Result<T, E>
where
E: fmt::Display,
{
impl<T, E: Display> LogErr<T, E> for Result<T, E> {
#[inline]
fn err_log(self, level: Level) -> Self
where
Self: Sized,
{
self.inspect_err(|error| error::inspect_log_level(&error, level))
}
fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) }
}

View File

@ -1,25 +0,0 @@
use std::cmp::Ordering;
#[allow(clippy::impl_trait_in_params)]
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering,
) -> Option<impl Iterator<Item = Vec<u8>>> {
let first_iterator = iterators.next()?;
let mut other_iterators = iterators.map(Iterator::peekable).collect::<Vec<_>>();
Some(first_iterator.filter(move |target| {
other_iterators.iter_mut().all(|it| {
while let Some(element) = it.peek() {
match check_order(element, target) {
Ordering::Greater => return false, // We went too far
Ordering::Equal => return true, // Element is in both iters
Ordering::Less => {
// Keep searching
it.next();
},
}
}
false
})
}))
}

View File

@ -1,4 +1,3 @@
pub mod algorithm;
pub mod bytes;
pub mod content_disposition;
pub mod debug;
@ -9,25 +8,30 @@ pub mod json;
pub mod math;
pub mod mutex_map;
pub mod rand;
pub mod set;
pub mod stream;
pub mod string;
pub mod sys;
mod tests;
pub mod time;
pub use ::conduit_macros::implement;
pub use ::ctor::{ctor, dtor};
pub use algorithm::common_elements;
pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8};
pub use conduit_macros::implement;
pub use debug::slice_truncated as debug_slice_truncated;
pub use hash::calculate_hash;
pub use html::Escape as HtmlEscape;
pub use json::{deserialize_from_str, to_canonical_object};
pub use math::clamp;
pub use mutex_map::{Guard as MutexMapGuard, MutexMap};
pub use rand::string as random_string;
pub use string::{str_from_bytes, string_from_bytes};
pub use sys::available_parallelism;
pub use time::now_millis as millis_since_unix_epoch;
pub use self::{
bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8},
debug::slice_truncated as debug_slice_truncated,
hash::calculate_hash,
html::Escape as HtmlEscape,
json::{deserialize_from_str, to_canonical_object},
math::clamp,
mutex_map::{Guard as MutexMapGuard, MutexMap},
rand::string as random_string,
stream::{IterStream, ReadyExt, TryReadyExt},
string::{str_from_bytes, string_from_bytes},
sys::available_parallelism,
time::now_millis as millis_since_unix_epoch,
};
#[inline]
pub fn exchange<T>(state: &mut T, source: T) -> T { std::mem::replace(state, source) }

47
src/core/utils/set.rs Normal file
View File

@ -0,0 +1,47 @@
use std::cmp::{Eq, Ord};
use crate::{is_equal_to, is_less_than};
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs do not have to
/// be sorted. If inputs are sorted a more optimized function is available in
/// this suite and should be used.
pub fn intersection<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq + Send,
{
input.next().into_iter().flat_map(move |first| {
let input = input.clone();
first.filter(move |targ| {
input
.clone()
.all(|mut other| other.any(is_equal_to!(*targ)))
})
})
}
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs must be sorted.
pub fn intersection_sorted<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq + Ord + Send,
{
input.next().into_iter().flat_map(move |first| {
let mut input = input.clone().collect::<Vec<_>>();
first.filter(move |targ| {
input.iter_mut().all(|it| {
it.by_ref()
.skip_while(is_less_than!(targ))
.peekable()
.peek()
.is_some_and(is_equal_to!(targ))
})
})
})
}

View File

@ -0,0 +1,20 @@
use std::clone::Clone;
use futures::{stream::Map, Stream, StreamExt};
pub trait Cloned<'a, T, S>
where
S: Stream<Item = &'a T>,
T: Clone + 'a,
{
fn cloned(self) -> Map<S, fn(&T) -> T>;
}
impl<'a, T, S> Cloned<'a, T, S> for S
where
S: Stream<Item = &'a T>,
T: Clone + 'a,
{
#[inline]
fn cloned(self) -> Map<S, fn(&T) -> T> { self.map(Clone::clone) }
}

View File

@ -0,0 +1,17 @@
use futures::{Stream, StreamExt, TryStream};
use crate::Result;
pub trait TryExpect<'a, Item> {
fn expect_ok(self) -> impl Stream<Item = Item> + Send + 'a;
}
impl<'a, T, Item> TryExpect<'a, Item> for T
where
T: Stream<Item = Result<Item>> + TryStream + Send + 'a,
{
#[inline]
fn expect_ok(self: T) -> impl Stream<Item = Item> + Send + 'a {
self.map(|res| res.expect("stream expectation failure"))
}
}

View File

@ -0,0 +1,21 @@
use futures::{future::ready, Stream, StreamExt, TryStream};
use crate::{Error, Result};
pub trait TryIgnore<'a, Item> {
fn ignore_err(self) -> impl Stream<Item = Item> + Send + 'a;
fn ignore_ok(self) -> impl Stream<Item = Error> + Send + 'a;
}
impl<'a, T, Item> TryIgnore<'a, Item> for T
where
T: Stream<Item = Result<Item>> + TryStream + Send + 'a,
Item: Send + 'a,
{
#[inline]
fn ignore_err(self: T) -> impl Stream<Item = Item> + Send + 'a { self.filter_map(|res| ready(res.ok())) }
#[inline]
fn ignore_ok(self: T) -> impl Stream<Item = Error> + Send + 'a { self.filter_map(|res| ready(res.err())) }
}

View File

@ -0,0 +1,27 @@
use futures::{
stream,
stream::{Stream, TryStream},
StreamExt,
};
pub trait IterStream<I: IntoIterator + Send> {
/// Convert an Iterator into a Stream
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send;
/// Convert an Iterator into a TryStream
fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send;
}
impl<I> IterStream<I> for I
where
I: IntoIterator + Send,
<I as IntoIterator>::IntoIter: Send,
{
#[inline]
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send { stream::iter(self) }
#[inline]
fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send {
self.stream().map(Ok)
}
}

View File

@ -0,0 +1,13 @@
mod cloned;
mod expect;
mod ignore;
mod iter_stream;
mod ready;
mod try_ready;
pub use cloned::Cloned;
pub use expect::TryExpect;
pub use ignore::TryIgnore;
pub use iter_stream::IterStream;
pub use ready::ReadyExt;
pub use try_ready::TryReadyExt;

View File

@ -0,0 +1,109 @@
//! Synchronous combinator extensions to futures::Stream
use futures::{
future::{ready, Ready},
stream::{Any, Filter, FilterMap, Fold, ForEach, SkipWhile, Stream, StreamExt, TakeWhile},
};
/// Synchronous combinators to augment futures::StreamExt. Most Stream
/// combinators take asynchronous arguments, but often only simple predicates
/// are required to steer a Stream like an Iterator. This suite provides a
/// convenience to reduce boilerplate by de-cluttering non-async predicates.
///
/// This interface is not necessarily complete; feel free to add as-needed.
pub trait ReadyExt<Item, S>
where
S: Stream<Item = Item> + Send + ?Sized,
Self: Stream + Send + Sized,
{
fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(S::Item) -> Ready<bool>>
where
F: Fn(S::Item) -> bool;
fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(S::Item) -> Ready<Option<U>>>
where
F: Fn(S::Item) -> Option<U>;
fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, S::Item) -> Ready<T>>
where
F: Fn(T, S::Item) -> T;
fn ready_for_each<F>(self, f: F) -> ForEach<Self, Ready<()>, impl FnMut(S::Item) -> Ready<()>>
where
F: FnMut(S::Item);
fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
}
impl<Item, S> ReadyExt<Item, S> for S
where
S: Stream<Item = Item> + Send + ?Sized,
Self: Stream + Send + Sized,
{
#[inline]
fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(S::Item) -> Ready<bool>>
where
F: Fn(S::Item) -> bool,
{
self.any(move |t| ready(f(t)))
}
#[inline]
fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.filter(move |t| ready(f(t)))
}
#[inline]
fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(S::Item) -> Ready<Option<U>>>
where
F: Fn(S::Item) -> Option<U>,
{
self.filter_map(move |t| ready(f(t)))
}
#[inline]
fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, S::Item) -> Ready<T>>
where
F: Fn(T, S::Item) -> T,
{
self.fold(init, move |a, t| ready(f(a, t)))
}
#[inline]
#[allow(clippy::unit_arg)]
fn ready_for_each<F>(self, mut f: F) -> ForEach<Self, Ready<()>, impl FnMut(S::Item) -> Ready<()>>
where
F: FnMut(S::Item),
{
self.for_each(move |t| ready(f(t)))
}
#[inline]
fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.take_while(move |t| ready(f(t)))
}
#[inline]
fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.skip_while(move |t| ready(f(t)))
}
}

View File

@ -0,0 +1,35 @@
//! Synchronous combinator extensions to futures::TryStream
use futures::{
future::{ready, Ready},
stream::{AndThen, TryStream, TryStreamExt},
};
use crate::Result;
/// Synchronous combinators to augment futures::TryStreamExt.
///
/// This interface is not necessarily complete; feel free to add as-needed.
pub trait TryReadyExt<T, E, S>
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>>
where
F: Fn(S::Ok) -> Result<U, E>;
}
impl<T, E, S> TryReadyExt<T, E, S> for S
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
#[inline]
fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>>
where
F: Fn(S::Ok) -> Result<U, E>,
{
self.and_then(move |t| ready(f(t)))
}
}

View File

@ -107,3 +107,133 @@ async fn mutex_map_contend() {
tokio::try_join!(join_b, join_a).expect("joined");
assert!(map.is_empty(), "Must be empty");
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_none() {
use utils::set::intersection;
let a: [&str; 0] = [];
let b: [&str; 0] = [];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let a: [&str; 0] = [];
let b = ["abc", "def"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let i = [b.iter(), a.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let i = [a.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let a = ["foo", "bar", "baz"];
let b = ["def", "hij", "klm", "nop"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_all() {
use utils::set::intersection;
let a = ["foo"];
let b = ["foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["foo", "bar"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "bar"].iter()));
let i = [b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let a = ["foo", "bar", "baz"];
let b = ["baz", "foo", "bar"];
let c = ["bar", "baz", "foo"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "bar", "baz"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_some() {
use utils::set::intersection;
let a = ["foo"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let i = [b.iter(), a.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["abcdef", "foo", "hijkl", "abc"];
let b = ["hij", "bar", "baz", "abc", "foo"];
let c = ["abc", "xyz", "foo", "ghi"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "abc"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_sorted_some() {
use utils::set::intersection_sorted;
let a = ["bar"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar"].iter()));
let i = [b.iter(), a.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar"].iter()));
let a = ["aaa", "ccc", "eee", "ggg"];
let b = ["aaa", "bbb", "ccc", "ddd", "eee"];
let c = ["bbb", "ccc", "eee", "fff"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["ccc", "eee"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_sorted_all() {
use utils::set::intersection_sorted;
let a = ["foo"];
let b = ["foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["bar", "foo"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let i = [b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let a = ["bar", "baz", "foo"];
let b = ["bar", "baz", "foo"];
let c = ["bar", "baz", "foo"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "baz", "foo"].iter()));
}

View File

@ -37,8 +37,11 @@ zstd_compression = [
[dependencies]
conduit-core.workspace = true
const-str.workspace = true
futures.workspace = true
log.workspace = true
rust-rocksdb.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio.workspace = true
tracing.workspace = true

View File

@ -37,7 +37,7 @@ impl Database {
pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) }
#[inline]
pub fn iter_maps(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + '_ { self.map.iter() }
pub fn iter_maps(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + Send + '_ { self.map.iter() }
}
impl Index<&str> for Database {

261
src/database/de.rs Normal file
View File

@ -0,0 +1,261 @@
use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result};
use serde::{
de,
de::{DeserializeSeed, Visitor},
Deserialize,
};
pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result<T>
where
T: Deserialize<'a>,
{
let mut deserializer = Deserializer {
buf,
pos: 0,
};
T::deserialize(&mut deserializer).debug_inspect(|_| {
deserializer
.finished()
.expect("deserialization failed to consume trailing bytes");
})
}
pub(crate) struct Deserializer<'de> {
buf: &'de [u8],
pos: usize,
}
/// Directive to ignore a record. This type can be used to skip deserialization
/// until the next separator is found.
#[derive(Debug, Deserialize)]
pub struct Ignore;
impl<'de> Deserializer<'de> {
const SEP: u8 = b'\xFF';
fn finished(&self) -> Result<()> {
let pos = self.pos;
let len = self.buf.len();
let parsed = &self.buf[0..pos];
let unparsed = &self.buf[pos..];
let remain = checked!(len - pos)?;
let trailing_sep = remain == 1 && unparsed[0] == Self::SEP;
(remain == 0 || trailing_sep)
.then_some(())
.ok_or(err!(SerdeDe(
"{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}",
)))
}
#[inline]
fn record_next(&mut self) -> &'de [u8] {
self.buf[self.pos..]
.split(|b| *b == Deserializer::SEP)
.inspect(|record| self.inc_pos(record.len()))
.next()
.expect("remainder of buf even if SEP was not found")
}
#[inline]
fn record_trail(&mut self) -> &'de [u8] {
let record = &self.buf[self.pos..];
self.inc_pos(record.len());
record
}
#[inline]
fn record_start(&mut self) {
let started = self.pos != 0;
debug_assert!(
!started || self.buf[self.pos] == Self::SEP,
"Missing expected record separator at current position"
);
self.inc_pos(started.into());
}
#[inline]
fn inc_pos(&mut self, n: usize) {
self.pos = self.pos.saturating_add(n);
debug_assert!(self.pos <= self.buf.len(), "pos out of range");
}
}
impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Map not implemented")
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_struct<V>(
self, _name: &'static str, _fields: &'static [&'static str], _visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Struct not implemented")
}
fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match name {
"Ignore" => self.record_next(),
_ => unimplemented!("Unrecognized deserialization Directive {name:?}"),
};
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Newtype Struct not implemented")
}
fn deserialize_enum<V>(
self, _name: &'static str, _variants: &'static [&'static str], _visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Enum not implemented")
}
fn deserialize_option<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Option not implemented")
}
fn deserialize_bool<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize bool not implemented")
}
fn deserialize_i8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i8 not implemented")
}
fn deserialize_i16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i16 not implemented")
}
fn deserialize_i32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i32 not implemented")
}
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let bytes: [u8; size_of::<i64>()] = self.buf[self.pos..].try_into()?;
self.pos = self.pos.saturating_add(size_of::<i64>());
visitor.visit_i64(i64::from_be_bytes(bytes))
}
fn deserialize_u8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u8 not implemented")
}
fn deserialize_u16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u16 not implemented")
}
fn deserialize_u32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u32 not implemented")
}
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let bytes: [u8; size_of::<u64>()] = self.buf[self.pos..].try_into()?;
self.pos = self.pos.saturating_add(size_of::<u64>());
visitor.visit_u64(u64::from_be_bytes(bytes))
}
fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize f32 not implemented")
}
fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize f64 not implemented")
}
fn deserialize_char<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize char not implemented")
}
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = string::str_from_bytes(input)?;
visitor.visit_borrowed_str(out)
}
fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = string::string_from_bytes(input)?;
visitor.visit_string(out)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_trail();
visitor.visit_borrowed_bytes(input)
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Byte Buf not implemented")
}
fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Unit Struct not implemented")
}
fn deserialize_identifier<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Identifier not implemented")
}
fn deserialize_ignored_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Ignored Any not implemented")
}
fn deserialize_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize any not implemented")
}
}
impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.pos >= self.buf.len() {
return Ok(None);
}
self.record_start();
seed.deserialize(&mut **self).map(Some)
}
}

View File

@ -0,0 +1,34 @@
use std::convert::identity;
use conduit::Result;
use serde::Deserialize;
pub trait Deserialized {
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>;
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>;
#[inline]
fn deserialized<T>(self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
Self: Sized,
{
self.map_de(identity::<T>)
}
#[inline]
fn deserialized_json<T>(self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
Self: Sized,
{
self.map_json(identity::<T>)
}
}

View File

@ -106,7 +106,7 @@ impl Engine {
}))
}
#[tracing::instrument(skip(self))]
#[tracing::instrument(skip(self), level = "trace")]
pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> {
let mut cfs = self.cfs.lock().expect("locked");
if !cfs.contains(name) {

View File

@ -1,6 +1,10 @@
use std::ops::Deref;
use std::{fmt, fmt::Debug, ops::Deref};
use conduit::Result;
use rocksdb::DBPinnableSlice;
use serde::{Deserialize, Serialize, Serializer};
use crate::{keyval::deserialize_val, Deserialized, Slice};
pub struct Handle<'a> {
val: DBPinnableSlice<'a>,
@ -14,14 +18,91 @@ impl<'a> From<DBPinnableSlice<'a>> for Handle<'a> {
}
}
impl Debug for Handle<'_> {
fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result {
let val: &Slice = self;
let ptr = val.as_ptr();
let len = val.len();
write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}")
}
}
impl Serialize for Handle<'_> {
#[inline]
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes: &Slice = self;
serializer.serialize_bytes(bytes)
}
}
impl Deref for Handle<'_> {
type Target = [u8];
type Target = Slice;
#[inline]
fn deref(&self) -> &Self::Target { &self.val }
}
impl AsRef<[u8]> for Handle<'_> {
impl AsRef<Slice> for Handle<'_> {
#[inline]
fn as_ref(&self) -> &[u8] { &self.val }
fn as_ref(&self) -> &Slice { &self.val }
}
impl Deserialized for Result<Handle<'_>> {
#[inline]
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self?.map_json(f)
}
#[inline]
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self?.map_de(f)
}
}
impl<'a> Deserialized for Result<&'a Handle<'a>> {
#[inline]
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self.and_then(|handle| handle.map_json(f))
}
#[inline]
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self.and_then(|handle| handle.map_de(f))
}
}
impl<'a> Deserialized for &'a Handle<'a> {
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
serde_json::from_slice::<T>(self.as_ref())
.map_err(Into::into)
.map(f)
}
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
deserialize_val(self.as_ref()).map(f)
}
}

View File

@ -1,110 +0,0 @@
use std::{iter::FusedIterator, sync::Arc};
use conduit::Result;
use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions};
use crate::{
engine::Db,
result,
slice::{OwnedKeyVal, OwnedKeyValPair},
Engine,
};
type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>;
struct State<'cursor> {
cursor: Cursor<'cursor>,
direction: Direction,
valid: bool,
init: bool,
}
impl<'cursor> State<'cursor> {
pub(crate) fn new(
db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>,
) -> Self {
let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts);
let direction = into_direction(mode);
let valid = seek_init(&mut cursor, mode);
Self {
cursor,
direction,
valid,
init: true,
}
}
}
pub struct Iter<'cursor> {
state: State<'cursor>,
}
impl<'cursor> Iter<'cursor> {
pub(crate) fn new(
db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>,
) -> Self {
Self {
state: State::new(db, cf, opts, mode),
}
}
}
impl Iterator for Iter<'_> {
type Item = OwnedKeyValPair;
fn next(&mut self) -> Option<Self::Item> {
if !self.state.init && self.state.valid {
seek_next(&mut self.state.cursor, self.state.direction);
} else if self.state.init {
self.state.init = false;
}
self.state
.cursor
.item()
.map(OwnedKeyVal::from)
.map(OwnedKeyVal::to_tuple)
.or_else(|| {
when_invalid(&mut self.state).expect("iterator invalidated due to error");
None
})
}
}
impl FusedIterator for Iter<'_> {}
fn when_invalid(state: &mut State<'_>) -> Result<()> {
state.valid = false;
result(state.cursor.status())
}
fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) {
match direction {
Direction::Forward => cursor.next(),
Direction::Reverse => cursor.prev(),
}
}
fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool {
use Direction::{Forward, Reverse};
use IteratorMode::{End, From, Start};
match mode {
Start => cursor.seek_to_first(),
End => cursor.seek_to_last(),
From(key, Forward) => cursor.seek(key),
From(key, Reverse) => cursor.seek_for_prev(key),
};
cursor.valid()
}
fn into_direction(mode: &IteratorMode<'_>) -> Direction {
use Direction::{Forward, Reverse};
use IteratorMode::{End, From, Start};
match mode {
Start | From(_, Forward) => Forward,
End | From(_, Reverse) => Reverse,
}
}

Some files were not shown because too many files have changed in this diff Show More