Database Refactor

combine service/users data w/ mod unit

split sliding sync related out of service/users

instrument database entry points

remove increment crap from database interface

de-wrap all database get() calls

de-wrap all database insert() calls

de-wrap all database remove() calls

refactor database interface for async streaming

add query key serializer for database

implement Debug for result handle

add query deserializer for database

add deserialization trait for option handle

start a stream utils suite

de-wrap/asyncify/type-query count_one_time_keys()

de-wrap/asyncify users count

add admin query users command suite

de-wrap/asyncify users exists

de-wrap/partially asyncify user filter related

asyncify/de-wrap users device/keys related

asyncify/de-wrap user auth/misc related

asyncify/de-wrap users blurhash

asyncify/de-wrap account_data get; merge Data into Service

partial asyncify/de-wrap uiaa; merge Data into Service

partially asyncify/de-wrap transaction_ids get; merge Data into Service

partially asyncify/de-wrap key_backups; merge Data into Service

asyncify/de-wrap pusher service getters; merge Data into Service

asyncify/de-wrap rooms alias getters/some iterators

asyncify/de-wrap rooms directory getters/iterator

partially asyncify/de-wrap rooms lazy-loading

partially asyncify/de-wrap rooms metadata

asyncify/dewrap rooms outlier

asyncify/dewrap rooms pdu_metadata

dewrap/partially asyncify rooms read receipt

de-wrap rooms search service

de-wrap/partially asyncify rooms user service

partial de-wrap rooms state_compressor

de-wrap rooms state_cache

de-wrap room state et al

de-wrap rooms timeline service

additional users device/keys related

de-wrap/asyncify sender

asyncify services

refactor database to TryFuture/TryStream

refactor services for TryFuture/TryStream

asyncify api handlers

additional asyncification for admin module

abstract stream related; support reverse streams

additional stream conversions

asyncify state-res related

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-08-08 17:18:30 +00:00 committed by strawberry
parent 6001014078
commit 946ca364e0
203 changed files with 12202 additions and 10709 deletions

53
Cargo.lock generated
View File

@ -626,10 +626,11 @@ dependencies = [
"clap", "clap",
"conduit_api", "conduit_api",
"conduit_core", "conduit_core",
"conduit_database",
"conduit_macros", "conduit_macros",
"conduit_service", "conduit_service",
"const-str", "const-str",
"futures-util", "futures",
"log", "log",
"ruma", "ruma",
"serde_json", "serde_json",
@ -652,7 +653,7 @@ dependencies = [
"conduit_database", "conduit_database",
"conduit_service", "conduit_service",
"const-str", "const-str",
"futures-util", "futures",
"hmac", "hmac",
"http", "http",
"http-body-util", "http-body-util",
@ -689,6 +690,7 @@ dependencies = [
"cyborgtime", "cyborgtime",
"either", "either",
"figment", "figment",
"futures",
"hardened_malloc-rs", "hardened_malloc-rs",
"http", "http",
"http-body-util", "http-body-util",
@ -726,8 +728,11 @@ version = "0.4.7"
dependencies = [ dependencies = [
"conduit_core", "conduit_core",
"const-str", "const-str",
"futures",
"log", "log",
"rust-rocksdb-uwu", "rust-rocksdb-uwu",
"serde",
"serde_json",
"tokio", "tokio",
"tracing", "tracing",
] ]
@ -784,7 +789,7 @@ dependencies = [
"conduit_core", "conduit_core",
"conduit_database", "conduit_database",
"const-str", "const-str",
"futures-util", "futures",
"hickory-resolver", "hickory-resolver",
"http", "http",
"image", "image",
@ -1283,6 +1288,20 @@ dependencies = [
"new_debug_unreachable", "new_debug_unreachable",
] ]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@ -1345,6 +1364,7 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro", "futures-macro",
@ -2953,7 +2973,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.10.1" version = "0.10.1"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"assign", "assign",
"js_int", "js_int",
@ -2975,7 +2995,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.10.0" version = "0.10.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -2987,7 +3007,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.18.0" version = "0.18.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"assign", "assign",
@ -3010,7 +3030,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.13.0" version = "0.13.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"base64 0.22.1", "base64 0.22.1",
@ -3040,7 +3060,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.28.1" version = "0.28.1"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"indexmap 2.6.0", "indexmap 2.6.0",
@ -3064,7 +3084,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"bytes", "bytes",
"http", "http",
@ -3082,7 +3102,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.9.5" version = "0.9.5"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"js_int", "js_int",
"thiserror", "thiserror",
@ -3091,7 +3111,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -3101,7 +3121,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-macros" name = "ruma-macros"
version = "0.13.0" version = "0.13.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"once_cell", "once_cell",
@ -3117,7 +3137,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -3129,7 +3149,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-server-util" name = "ruma-server-util"
version = "0.3.0" version = "0.3.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"headers", "headers",
"http", "http",
@ -3142,7 +3162,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.15.0" version = "0.15.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"ed25519-dalek", "ed25519-dalek",
@ -3158,8 +3178,9 @@ dependencies = [
[[package]] [[package]]
name = "ruma-state-res" name = "ruma-state-res"
version = "0.11.0" version = "0.11.0"
source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" source = "git+https://github.com/girlbossceo/ruwuma?rev=e7db44989d68406393270d3a91815597385d3acb#e7db44989d68406393270d3a91815597385d3acb"
dependencies = [ dependencies = [
"futures-util",
"itertools 0.12.1", "itertools 0.12.1",
"js_int", "js_int",
"ruma-common", "ruma-common",

View File

@ -210,9 +210,10 @@ features = [
"string", "string",
] ]
[workspace.dependencies.futures-util] [workspace.dependencies.futures]
version = "0.3.30" version = "0.3.30"
default-features = false default-features = false
features = ["std"]
[workspace.dependencies.tokio] [workspace.dependencies.tokio]
version = "1.40.0" version = "1.40.0"
@ -314,7 +315,7 @@ version = "0.1.2"
[workspace.dependencies.ruma] [workspace.dependencies.ruma]
git = "https://github.com/girlbossceo/ruwuma" git = "https://github.com/girlbossceo/ruwuma"
#branch = "conduwuit-changes" #branch = "conduwuit-changes"
rev = "9900d0676564883cfade556d6e8da2a2c9061efd" rev = "e7db44989d68406393270d3a91815597385d3acb"
features = [ features = [
"compat", "compat",
"rand", "rand",
@ -463,7 +464,6 @@ version = "1.0.36"
[workspace.dependencies.proc-macro2] [workspace.dependencies.proc-macro2]
version = "1.0.89" version = "1.0.89"
# #
# Patches # Patches
# #
@ -828,6 +828,7 @@ missing_panics_doc = { level = "allow", priority = 1 }
module_name_repetitions = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 }
no_effect_underscore_binding = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 }
similar_names = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 }
single_match_else = { level = "allow", priority = 1 }
struct_field_names = { level = "allow", priority = 1 } struct_field_names = { level = "allow", priority = 1 }
unnecessary_wraps = { level = "allow", priority = 1 } unnecessary_wraps = { level = "allow", priority = 1 }
unused_async = { level = "allow", priority = 1 } unused_async = { level = "allow", priority = 1 }

View File

@ -2,6 +2,6 @@ array-size-threshold = 4096
cognitive-complexity-threshold = 94 # TODO reduce me ALARA cognitive-complexity-threshold = 94 # TODO reduce me ALARA
excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5
future-size-threshold = 7745 # TODO reduce me ALARA future-size-threshold = 7745 # TODO reduce me ALARA
stack-size-threshold = 144000 # reduce me ALARA stack-size-threshold = 196608 # reduce me ALARA
too-many-lines-threshold = 700 # TODO reduce me to <= 100 too-many-lines-threshold = 700 # TODO reduce me to <= 100
type-complexity-threshold = 250 # reduce me to ~200 type-complexity-threshold = 250 # reduce me to ~200

View File

@ -29,10 +29,11 @@ release_max_log_level = [
clap.workspace = true clap.workspace = true
conduit-api.workspace = true conduit-api.workspace = true
conduit-core.workspace = true conduit-core.workspace = true
conduit-database.workspace = true
conduit-macros.workspace = true conduit-macros.workspace = true
conduit-service.workspace = true conduit-service.workspace = true
const-str.workspace = true const-str.workspace = true
futures-util.workspace = true futures.workspace = true
log.workspace = true log.workspace = true
ruma.workspace = true ruma.workspace = true
serde_json.workspace = true serde_json.workspace = true

View File

@ -1,5 +1,6 @@
use conduit::Result; use conduit::Result;
use conduit_macros::implement; use conduit_macros::implement;
use futures::StreamExt;
use ruma::events::room::message::RoomMessageEventContent; use ruma::events::room::message::RoomMessageEventContent;
use crate::Command; use crate::Command;
@ -10,14 +11,12 @@ use crate::Command;
#[implement(Command, params = "<'_>")] #[implement(Command, params = "<'_>")]
pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> { pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = self.services.users.db.iter(); let users = self.services.users.iter().collect::<Vec<_>>().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
let users = results.collect::<Vec<_>>();
let total = users.len(); let total = users.len();
let err_count = users.iter().filter(|user| user.is_err()).count(); let err_count = users.iter().filter(|_user| false).count();
let ok_count = users.iter().filter(|user| user.is_ok()).count(); let ok_count = users.iter().filter(|_user| true).count();
let message = format!( let message = format!(
"Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \ "Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \

View File

@ -7,6 +7,7 @@ use std::{
use api::client::validate_and_add_event_id; use api::client::validate_and_add_event_id;
use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result};
use futures::StreamExt;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::event::get_room_state}, api::{client::error::ErrorKind, federation::event::get_room_state},
events::room::message::RoomMessageEventContent, events::room::message::RoomMessageEventContent,
@ -27,7 +28,7 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent
#[admin_command] #[admin_command]
pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id); let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
@ -43,7 +44,8 @@ pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<Room
.auth_chain .auth_chain
.event_ids_iter(room_id, vec![event_id]) .event_ids_iter(room_id, vec![event_id])
.await? .await?
.count(); .count()
.await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
@ -91,13 +93,16 @@ pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessage
.services .services
.rooms .rooms
.timeline .timeline
.get_non_outlier_pdu_json(&event_id)?; .get_non_outlier_pdu_json(&event_id)
if pdu_json.is_none() { .await;
if pdu_json.is_err() {
outlier = true; outlier = true;
pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id)?; pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id).await;
} }
match pdu_json { match pdu_json {
Some(json) => { Ok(json) => {
let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json");
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
"{}\n```json\n{}\n```", "{}\n```json\n{}\n```",
@ -109,7 +114,7 @@ pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessage
json_text json_text
))) )))
}, },
None => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")),
} }
} }
@ -157,7 +162,8 @@ pub(super) async fn get_remote_pdu_list(
.send_message(RoomMessageEventContent::text_plain(format!( .send_message(RoomMessageEventContent::text_plain(format!(
"Failed to get remote PDU, ignoring error: {e}" "Failed to get remote PDU, ignoring error: {e}"
))) )))
.await; .await
.ok();
warn!("Failed to get remote PDU, ignoring error: {e}"); warn!("Failed to get remote PDU, ignoring error: {e}");
} else { } else {
success_count = success_count.saturating_add(1); success_count = success_count.saturating_add(1);
@ -215,7 +221,9 @@ pub(super) async fn get_remote_pdu(
.services .services
.rooms .rooms
.event_handler .event_handler
.parse_incoming_pdu(&response.pdu); .parse_incoming_pdu(&response.pdu)
.await;
let (event_id, value, room_id) = match parsed_result { let (event_id, value, room_id) = match parsed_result {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
@ -333,9 +341,12 @@ pub(super) async fn ping(&self, server: Box<ServerName>) -> Result<RoomMessageEv
#[admin_command] #[admin_command]
pub(super) async fn force_device_list_updates(&self) -> Result<RoomMessageEventContent> { pub(super) async fn force_device_list_updates(&self) -> Result<RoomMessageEventContent> {
// Force E2EE device list updates for all users // Force E2EE device list updates for all users
for user_id in self.services.users.iter().filter_map(Result::ok) { self.services
self.services.users.mark_device_key_update(&user_id)?; .users
} .stream()
.for_each(|user_id| self.services.users.mark_device_key_update(user_id))
.await;
Ok(RoomMessageEventContent::text_plain( Ok(RoomMessageEventContent::text_plain(
"Marked all devices for all users as having new keys to update", "Marked all devices for all users as having new keys to update",
)) ))
@ -470,7 +481,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo
.services .services
.rooms .rooms
.state_cache .state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)? .server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{ {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.", "We are not participating in the room / we don't know about the room ID.",
@ -481,8 +493,9 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo
.services .services
.rooms .rooms
.timeline .timeline
.first_pdu_in_room(&room_id)? .first_pdu_in_room(&room_id)
.ok_or_else(|| Error::bad_database("Failed to find the first PDU in database"))?; .await
.map_err(|_| Error::bad_database("Failed to find the first PDU in database"))?;
Ok(RoomMessageEventContent::text_plain(format!("{first_pdu:?}"))) Ok(RoomMessageEventContent::text_plain(format!("{first_pdu:?}")))
} }
@ -494,7 +507,8 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro
.services .services
.rooms .rooms
.state_cache .state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)? .server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{ {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.", "We are not participating in the room / we don't know about the room ID.",
@ -505,8 +519,9 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro
.services .services
.rooms .rooms
.timeline .timeline
.latest_pdu_in_room(&room_id)? .latest_pdu_in_room(&room_id)
.ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?; .await
.map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?;
Ok(RoomMessageEventContent::text_plain(format!("{latest_pdu:?}"))) Ok(RoomMessageEventContent::text_plain(format!("{latest_pdu:?}")))
} }
@ -520,7 +535,8 @@ pub(super) async fn force_set_room_state_from_server(
.services .services
.rooms .rooms
.state_cache .state_cache
.server_in_room(&self.services.globals.config.server_name, &room_id)? .server_in_room(&self.services.globals.config.server_name, &room_id)
.await
{ {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
"We are not participating in the room / we don't know about the room ID.", "We are not participating in the room / we don't know about the room ID.",
@ -531,10 +547,11 @@ pub(super) async fn force_set_room_state_from_server(
.services .services
.rooms .rooms
.timeline .timeline
.latest_pdu_in_room(&room_id)? .latest_pdu_in_room(&room_id)
.ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?; .await
.map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?;
let room_version = self.services.rooms.state.get_room_version(&room_id)?; let room_version = self.services.rooms.state.get_room_version(&room_id).await?;
let mut state: HashMap<u64, Arc<EventId>> = HashMap::new(); let mut state: HashMap<u64, Arc<EventId>> = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
@ -554,13 +571,21 @@ pub(super) async fn force_set_room_state_from_server(
let mut events = Vec::with_capacity(remote_state_response.pdus.len()); let mut events = Vec::with_capacity(remote_state_response.pdus.len());
for pdu in remote_state_response.pdus.clone() { for pdu in remote_state_response.pdus.clone() {
events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { events.push(
match self
.services
.rooms
.event_handler
.parse_incoming_pdu(&pdu)
.await
{
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
warn!("Could not parse PDU, ignoring: {e}"); warn!("Could not parse PDU, ignoring: {e}");
continue; continue;
}, },
}); },
);
} }
info!("Fetching required signing keys for all the state events we got"); info!("Fetching required signing keys for all the state events we got");
@ -587,13 +612,16 @@ pub(super) async fn force_set_room_state_from_server(
self.services self.services
.rooms .rooms
.outlier .outlier
.add_pdu_outlier(&event_id, &value)?; .add_pdu_outlier(&event_id, &value);
if let Some(state_key) = &pdu.state_key { if let Some(state_key) = &pdu.state_key {
let shortstatekey = self let shortstatekey = self
.services .services
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)
.await;
state.insert(shortstatekey, pdu.event_id.clone()); state.insert(shortstatekey, pdu.event_id.clone());
} }
} }
@ -611,7 +639,7 @@ pub(super) async fn force_set_room_state_from_server(
self.services self.services
.rooms .rooms
.outlier .outlier
.add_pdu_outlier(&event_id, &value)?; .add_pdu_outlier(&event_id, &value);
} }
let new_room_state = self let new_room_state = self
@ -626,7 +654,8 @@ pub(super) async fn force_set_room_state_from_server(
.services .services
.rooms .rooms
.state_compressor .state_compressor
.save_state(room_id.clone().as_ref(), new_room_state)?; .save_state(room_id.clone().as_ref(), new_room_state)
.await?;
let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; let state_lock = self.services.rooms.state.mutex.lock(&room_id).await;
self.services self.services
@ -642,7 +671,8 @@ pub(super) async fn force_set_room_state_from_server(
self.services self.services
.rooms .rooms
.state_cache .state_cache
.update_joined_count(&room_id)?; .update_joined_count(&room_id)
.await;
drop(state_lock); drop(state_lock);
@ -656,7 +686,7 @@ pub(super) async fn get_signing_keys(
&self, server_name: Option<Box<ServerName>>, _cached: bool, &self, server_name: Option<Box<ServerName>>, _cached: bool,
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into());
let signing_keys = self.services.globals.signing_keys_for(&server_name)?; let signing_keys = self.services.globals.signing_keys_for(&server_name).await?;
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
"```rs\n{signing_keys:#?}\n```" "```rs\n{signing_keys:#?}\n```"
@ -674,7 +704,7 @@ pub(super) async fn get_verify_keys(
if cached { if cached {
writeln!(out, "| Key ID | VerifyKey |")?; writeln!(out, "| Key ID | VerifyKey |")?;
writeln!(out, "| --- | --- |")?; writeln!(out, "| --- | --- |")?;
for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name).await? {
writeln!(out, "| {key_id} | {verify_key:?} |")?; writeln!(out, "| {key_id} | {verify_key:?} |")?;
} }

View File

@ -1,19 +1,20 @@
use std::fmt::Write; use std::fmt::Write;
use conduit::Result; use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId};
use crate::{admin_command, escape_html, get_room_info}; use crate::{admin_command, escape_html, get_room_info};
#[admin_command] #[admin_command]
pub(super) async fn disable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { pub(super) async fn disable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
self.services.rooms.metadata.disable_room(&room_id, true)?; self.services.rooms.metadata.disable_room(&room_id, true);
Ok(RoomMessageEventContent::text_plain("Room disabled.")) Ok(RoomMessageEventContent::text_plain("Room disabled."))
} }
#[admin_command] #[admin_command]
pub(super) async fn enable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { pub(super) async fn enable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
self.services.rooms.metadata.disable_room(&room_id, false)?; self.services.rooms.metadata.disable_room(&room_id, false);
Ok(RoomMessageEventContent::text_plain("Room enabled.")) Ok(RoomMessageEventContent::text_plain("Room enabled."))
} }
@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result<
)); ));
} }
if !self.services.users.exists(&user_id)? { if !self.services.users.exists(&user_id).await {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
"Remote user does not exist in our database.", "Remote user does not exist in our database.",
)); ));
@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result<
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .then(|room_id| get_room_info(self.services, room_id))
.map(|room_id| get_room_info(self.services, &room_id)) .collect()
.collect(); .await;
if rooms.is_empty() { if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));

View File

@ -36,7 +36,7 @@ pub(super) async fn delete(
let mut mxc_urls = Vec::with_capacity(4); let mut mxc_urls = Vec::with_capacity(4);
// parsing the PDU for any MXC URLs begins here // parsing the PDU for any MXC URLs begins here
if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await {
if let Some(content_key) = event_json.get("content") { if let Some(content_key) = event_json.get("content") {
debug!("Event ID has \"content\"."); debug!("Event ID has \"content\".");
let content_obj = content_key.as_object(); let content_obj = content_key.as_object();
@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server(
#[admin_command] #[admin_command]
pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result<RoomMessageEventContent> { pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result<RoomMessageEventContent> {
let mxc: Mxc<'_> = mxc.as_str().try_into()?; let mxc: Mxc<'_> = mxc.as_str().try_into()?;
let metadata = self.services.media.get_metadata(&mxc); let metadata = self.services.media.get_metadata(&mxc).await;
Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```"))) Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```")))
} }

View File

@ -17,7 +17,7 @@ use conduit::{
utils::string::{collect_stream, common_prefix}, utils::string::{collect_stream, common_prefix},
warn, Error, Result, warn, Error, Result,
}; };
use futures_util::future::FutureExt; use futures::future::FutureExt;
use ruma::{ use ruma::{
events::{ events::{
relation::InReplyTo, relation::InReplyTo,

View File

@ -44,7 +44,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services let results = services
.account_data .account_data
.changes_since(room_id.as_deref(), &user_id, since)?; .changes_since(room_id.as_deref(), &user_id, since)
.await?;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -59,7 +60,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services let results = services
.account_data .account_data
.get(room_id.as_deref(), &user_id, kind)?; .get(room_id.as_deref(), &user_id, kind)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -29,7 +29,9 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_>
let results = services let results = services
.appservice .appservice
.db .db
.get_registration(appservice_id.as_ref()); .get_registration(appservice_id.as_ref())
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -38,7 +40,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_>
}, },
AppserviceCommand::All => { AppserviceCommand::All => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.appservice.all(); let results = services.appservice.all().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -29,7 +29,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
match subcommand { match subcommand {
GlobalsCommand::DatabaseVersion => { GlobalsCommand::DatabaseVersion => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.globals.db.database_version(); let results = services.globals.db.database_version().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -47,7 +47,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
}, },
GlobalsCommand::LastCheckForUpdatesId => { GlobalsCommand::LastCheckForUpdatesId => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.updates.last_check_for_updates_id(); let results = services.updates.last_check_for_updates_id().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -67,7 +67,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -
origin, origin,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.globals.db.verify_keys_for(&origin); let results = services.globals.db.verify_keys_for(&origin).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, UserId}; use ruma::{events::room::message::RoomMessageEventContent, UserId};
use crate::Command; use crate::Command;
@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>)
user_id, user_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.presence.db.get_presence(&user_id)?; let results = services.presence.db.get_presence(&user_id).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -42,7 +43,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>)
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.presence.db.presence_since(since); let results = services.presence.db.presence_since(since);
let presence_since: Vec<(_, _, _)> = results.collect(); let presence_since: Vec<(_, _, _)> = results.collect().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) ->
user_id, user_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.pusher.get_pushers(&user_id)?; let results = services.pusher.get_pushers(&user_id).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId};
use crate::Command; use crate::Command;
@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
alias, alias,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.alias.resolve_local_alias(&alias); let results = services.rooms.alias.resolve_local_alias(&alias).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -43,7 +44,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.alias.local_aliases_for_room(&room_id); let results = services.rooms.alias.local_aliases_for_room(&room_id);
let aliases: Vec<_> = results.collect(); let aliases: Vec<_> = results.collect().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -52,8 +53,12 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>)
}, },
RoomAliasCommand::AllLocalAliases => { RoomAliasCommand::AllLocalAliases => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.alias.all_local_aliases(); let aliases = services
let aliases: Vec<_> = results.collect(); .rooms
.alias
.all_local_aliases()
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId};
use crate::Command; use crate::Command;
@ -86,7 +87,11 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let result = services.rooms.state_cache.server_in_room(&server, &room_id); let result = services
.rooms
.state_cache
.server_in_room(&server, &room_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -97,7 +102,13 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.room_servers(&room_id).collect(); let results: Vec<_> = services
.rooms
.state_cache
.room_servers(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -108,7 +119,13 @@ pub(super) async fn process(
server, server,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.server_rooms(&server).collect(); let results: Vec<_> = services
.rooms
.state_cache
.server_rooms(&server)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -119,7 +136,13 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.room_members(&room_id).collect(); let results: Vec<_> = services
.rooms
.state_cache
.room_members(&room_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -134,7 +157,9 @@ pub(super) async fn process(
.rooms .rooms
.state_cache .state_cache
.local_users_in_room(&room_id) .local_users_in_room(&room_id)
.collect(); .map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -149,7 +174,9 @@ pub(super) async fn process(
.rooms .rooms
.state_cache .state_cache
.active_local_users_in_room(&room_id) .active_local_users_in_room(&room_id)
.collect(); .map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -160,7 +187,7 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.room_joined_count(&room_id); let results = services.rooms.state_cache.room_joined_count(&room_id).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -171,7 +198,11 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.room_invited_count(&room_id); let results = services
.rooms
.state_cache
.room_invited_count(&room_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -182,11 +213,13 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services let results: Vec<_> = services
.rooms .rooms
.state_cache .state_cache
.room_useroncejoined(&room_id) .room_useroncejoined(&room_id)
.collect(); .map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -197,11 +230,13 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services let results: Vec<_> = services
.rooms .rooms
.state_cache .state_cache
.room_members_invited(&room_id) .room_members_invited(&room_id)
.collect(); .map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -216,7 +251,8 @@ pub(super) async fn process(
let results = services let results = services
.rooms .rooms
.state_cache .state_cache
.get_invite_count(&room_id, &user_id); .get_invite_count(&room_id, &user_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -231,7 +267,8 @@ pub(super) async fn process(
let results = services let results = services
.rooms .rooms
.state_cache .state_cache
.get_left_count(&room_id, &user_id); .get_left_count(&room_id, &user_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -242,7 +279,13 @@ pub(super) async fn process(
user_id, user_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_joined(&user_id).collect(); let results: Vec<_> = services
.rooms
.state_cache
.rooms_joined(&user_id)
.map(ToOwned::to_owned)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -253,7 +296,12 @@ pub(super) async fn process(
user_id, user_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_invited(&user_id).collect(); let results: Vec<_> = services
.rooms
.state_cache
.rooms_invited(&user_id)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -264,7 +312,12 @@ pub(super) async fn process(
user_id, user_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results: Result<Vec<_>> = services.rooms.state_cache.rooms_left(&user_id).collect(); let results: Vec<_> = services
.rooms
.state_cache
.rooms_left(&user_id)
.collect()
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -276,7 +329,11 @@ pub(super) async fn process(
room_id, room_id,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.rooms.state_cache.invite_state(&user_id, &room_id); let results = services
.rooms
.state_cache
.invite_state(&user_id, &room_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,5 +1,6 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId};
use service::sending::Destination; use service::sending::Destination;
@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
SendingCommand::ActiveRequests => { SendingCommand::ActiveRequests => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.sending.db.active_requests(); let results = services.sending.db.active_requests();
let active_requests: Result<Vec<(_, _, _)>> = results.collect(); let active_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
}, },
}; };
let queued_requests = results.collect::<Result<Vec<(_, _)>>>(); let queued_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
}, },
}; };
let active_requests = results.collect::<Result<Vec<(_, _)>>>(); let active_requests = results.collect::<Vec<_>>().await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -
server_name, server_name,
} => { } => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.sending.db.get_latest_educount(&server_name); let results = services.sending.db.get_latest_educount(&server_name).await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(

View File

@ -1,29 +1,344 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use ruma::events::room::message::RoomMessageEventContent; use futures::stream::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId};
use crate::Command; use crate::{admin_command, admin_command_dispatch};
#[admin_command_dispatch]
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
/// All the getters and iterators from src/database/key_value/users.rs /// All the getters and iterators from src/database/key_value/users.rs
pub(crate) enum UsersCommand { pub(crate) enum UsersCommand {
Iter, CountUsers,
IterUsers,
PasswordHash {
user_id: OwnedUserId,
},
ListDevices {
user_id: OwnedUserId,
},
ListDevicesMetadata {
user_id: OwnedUserId,
},
GetDeviceMetadata {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetDevicesVersion {
user_id: OwnedUserId,
},
CountOneTimeKeys {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetDeviceKeys {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetUserSigningKey {
user_id: OwnedUserId,
},
GetMasterKey {
user_id: OwnedUserId,
},
GetToDeviceEvents {
user_id: OwnedUserId,
device_id: OwnedDeviceId,
},
GetLatestBackup {
user_id: OwnedUserId,
},
GetLatestBackupVersion {
user_id: OwnedUserId,
},
GetBackupAlgorithm {
user_id: OwnedUserId,
version: String,
},
GetAllBackups {
user_id: OwnedUserId,
version: String,
},
GetRoomBackups {
user_id: OwnedUserId,
version: String,
room_id: OwnedRoomId,
},
GetBackupSession {
user_id: OwnedUserId,
version: String,
room_id: OwnedRoomId,
session_id: String,
},
} }
/// All the getters and iterators in key_value/users.rs #[admin_command]
pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { async fn get_backup_session(
let services = context.services; &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String,
) -> Result<RoomMessageEventContent> {
match subcommand {
UsersCommand::Iter => {
let timer = tokio::time::Instant::now(); let timer = tokio::time::Instant::now();
let results = services.users.db.iter(); let result = self
let users = results.collect::<Vec<_>>(); .services
.key_backups
.get_session(&user_id, &version, &room_id, &session_id)
.await;
let query_time = timer.elapsed(); let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```" "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_room_backups(
&self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId,
) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_room(&user_id, &version, &room_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.key_backups.get_all(&user_id, &version).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_backup(&user_id, &version)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.key_backups
.get_latest_backup_version(&user_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.key_backups.get_latest_backup(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn iter_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result: Vec<OwnedUserId> = self.services.users.stream().map(Into::into).collect().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn count_users(&self) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.count().await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn password_hash(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.password_hash(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn list_devices(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let devices = self
.services
.users
.all_device_ids(&user_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```"
)))
}
#[admin_command]
async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let devices = self
.services
.users
.all_devices_metadata(&user_id)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```"
)))
}
#[admin_command]
async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let device = self
.services
.users
.get_device_metadata(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```"
)))
}
#[admin_command]
async fn get_devices_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let device = self.services.users.get_devicelist_version(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```"
)))
}
#[admin_command]
async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.count_one_time_keys(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_device_keys(&user_id, &device_id)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self.services.users.get_user_signing_key(&user_id).await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_master_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_master_key(None, &user_id, &|_| true)
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
)))
}
#[admin_command]
async fn get_to_device_events(
&self, user_id: OwnedUserId, device_id: OwnedDeviceId,
) -> Result<RoomMessageEventContent> {
let timer = tokio::time::Instant::now();
let result = self
.services
.users
.get_to_device_events(&user_id, &device_id)
.collect::<Vec<_>>()
.await;
let query_time = timer.elapsed();
Ok(RoomMessageEventContent::notice_markdown(format!(
"Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```"
))) )))
},
}
} }

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
use crate::{escape_html, Command}; use crate::{escape_html, Command};
@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
force, force,
room_id, room_id,
.. ..
} => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { } => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) {
(true, Ok(Some(id))) => match services (true, Ok(id)) => match services
.rooms .rooms
.alias .alias
.set_alias(&room_alias, &room_id, server_user) .set_alias(&room_alias, &room_id, server_user)
@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
))), ))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
}, },
(false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( (false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!(
"Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite"
))), ))),
(_, Ok(None)) => match services (_, Err(_)) => match services
.rooms .rooms
.alias .alias
.set_alias(&room_alias, &room_id, server_user) .set_alias(&room_alias, &room_id, server_user)
@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")), Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
}, },
(_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
}, },
RoomAliasCommand::Remove { RoomAliasCommand::Remove {
.. ..
} => match services.rooms.alias.resolve_local_alias(&room_alias) { } => match services.rooms.alias.resolve_local_alias(&room_alias).await {
Ok(Some(id)) => match services Ok(id) => match services
.rooms .rooms
.alias .alias
.remove_alias(&room_alias, server_user) .remove_alias(&room_alias, server_user)
@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))), Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))),
}, },
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
}, },
RoomAliasCommand::Which { RoomAliasCommand::Which {
.. ..
} => match services.rooms.alias.resolve_local_alias(&room_alias) { } => match services.rooms.alias.resolve_local_alias(&room_alias).await {
Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))),
Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")),
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))),
}, },
RoomAliasCommand::List { RoomAliasCommand::List {
.. ..
@ -125,9 +123,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
.rooms .rooms
.alias .alias
.local_aliases_for_room(&room_id) .local_aliases_for_room(&room_id)
.collect::<Result<Vec<_>, _>>(); .map(Into::into)
match aliases { .collect::<Vec<OwnedRoomAliasId>>()
Ok(aliases) => { .await;
let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { let plain_list = aliases.iter().fold(String::new(), |mut output, alias| {
writeln!(output, "- {alias}").expect("should be able to write to string buffer"); writeln!(output, "- {alias}").expect("should be able to write to string buffer");
output output
@ -142,17 +141,15 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
let plain = format!("Aliases for {room_id}:\n{plain_list}"); let plain = format!("Aliases for {room_id}:\n{plain_list}");
let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>"); let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html)) Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))),
}
} else { } else {
let aliases = services let aliases = services
.rooms .rooms
.alias .alias
.all_local_aliases() .all_local_aliases()
.collect::<Result<Vec<_>, _>>(); .map(|(room_id, localpart)| (room_id.into(), localpart.into()))
match aliases { .collect::<Vec<(OwnedRoomId, String)>>()
Ok(aliases) => { .await;
let server_name = services.globals.server_name(); let server_name = services.globals.server_name();
let plain_list = aliases let plain_list = aliases
.iter() .iter()
@ -169,7 +166,7 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
output, output,
"<li><code>{}</code> -> #{}:{}</li>", "<li><code>{}</code> -> #{}:{}</li>",
escape_html(alias.as_ref()), escape_html(alias.as_ref()),
escape_html(id.as_ref()), escape_html(id),
server_name server_name
) )
.expect("should be able to write to string buffer"); .expect("should be able to write to string buffer");
@ -179,9 +176,6 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) ->
let plain = format!("Aliases:\n{plain_list}"); let plain = format!("Aliases:\n{plain_list}");
let html = format!("Aliases:\n<ul>{html_list}</ul>"); let html = format!("Aliases:\n<ul>{html_list}</ul>");
Ok(RoomMessageEventContent::text_html(plain, html)) Ok(RoomMessageEventContent::text_html(plain, html))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))),
}
} }
}, },
} }

View File

@ -1,11 +1,12 @@
use conduit::Result; use conduit::Result;
use ruma::events::room::message::RoomMessageEventContent; use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use crate::{admin_command, get_room_info, PAGE_SIZE}; use crate::{admin_command, get_room_info, PAGE_SIZE};
#[admin_command] #[admin_command]
pub(super) async fn list_rooms( pub(super) async fn list_rooms(
&self, page: Option<usize>, exclude_disabled: bool, exclude_banned: bool, no_details: bool, &self, page: Option<usize>, _exclude_disabled: bool, _exclude_banned: bool, no_details: bool,
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
// TODO: i know there's a way to do this with clap, but i can't seem to find it // TODO: i know there's a way to do this with clap, but i can't seem to find it
let page = page.unwrap_or(1); let page = page.unwrap_or(1);
@ -14,37 +15,12 @@ pub(super) async fn list_rooms(
.rooms .rooms
.metadata .metadata
.iter_ids() .iter_ids()
.filter_map(|room_id| { //.filter(|room_id| async { !exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await })
room_id //.filter(|room_id| async { !exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await })
.ok() .then(|room_id| get_room_info(self.services, room_id))
.filter(|room_id| { .collect::<Vec<_>>()
if exclude_disabled .await;
&& self
.services
.rooms
.metadata
.is_disabled(room_id)
.unwrap_or(false)
{
return false;
}
if exclude_banned
&& self
.services
.rooms
.metadata
.is_banned(room_id)
.unwrap_or(false)
{
return false;
}
true
})
.map(|room_id| get_room_info(self.services, &room_id))
})
.collect::<Vec<_>>();
rooms.sort_by_key(|r| r.1); rooms.sort_by_key(|r| r.1);
rooms.reverse(); rooms.reverse();
@ -74,3 +50,10 @@ pub(super) async fn list_rooms(
Ok(RoomMessageEventContent::notice_markdown(output_plain)) Ok(RoomMessageEventContent::notice_markdown(output_plain))
} }
#[admin_command]
pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result<RoomMessageEventContent> {
let result = self.services.rooms.metadata.exists(&room_id).await;
Ok(RoomMessageEventContent::notice_markdown(format!("{result}")))
}

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId};
use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; use crate::{escape_html, get_room_info, Command, PAGE_SIZE};
@ -31,15 +32,15 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_>
match command { match command {
RoomDirectoryCommand::Publish { RoomDirectoryCommand::Publish {
room_id, room_id,
} => match services.rooms.directory.set_public(&room_id) { } => {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), services.rooms.directory.set_public(&room_id);
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), Ok(RoomMessageEventContent::notice_plain("Room published"))
}, },
RoomDirectoryCommand::Unpublish { RoomDirectoryCommand::Unpublish {
room_id, room_id,
} => match services.rooms.directory.set_not_public(&room_id) { } => {
Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), services.rooms.directory.set_not_public(&room_id);
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), Ok(RoomMessageEventContent::notice_plain("Room unpublished"))
}, },
RoomDirectoryCommand::List { RoomDirectoryCommand::List {
page, page,
@ -50,9 +51,10 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_>
.rooms .rooms
.directory .directory
.public_rooms() .public_rooms()
.filter_map(Result::ok) .then(|room_id| get_room_info(services, room_id))
.map(|id: OwnedRoomId| get_room_info(services, &id)) .collect::<Vec<_>>()
.collect::<Vec<_>>(); .await;
rooms.sort_by_key(|r| r.1); rooms.sort_by_key(|r| r.1);
rooms.reverse(); rooms.reverse();

View File

@ -1,5 +1,6 @@
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::{utils::ReadyExt, Result};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use ruma::{events::room::message::RoomMessageEventContent, RoomId};
use crate::{admin_command, admin_command_dispatch}; use crate::{admin_command, admin_command_dispatch};
@ -32,46 +33,42 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R
.rooms .rooms
.state_accessor .state_accessor
.get_name(&room_id) .get_name(&room_id)
.ok() .await
.flatten() .unwrap_or_else(|_| room_id.to_string());
.unwrap_or_else(|| room_id.to_string());
let members = self let member_info: Vec<_> = self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(|member| { .ready_filter(|user_id| {
if local_only { if local_only {
member self.services.globals.user_is_local(user_id)
.ok()
.filter(|user| self.services.globals.user_is_local(user))
} else { } else {
member.ok() true
} }
}); })
.filter_map(|user_id| async move {
let member_info = members let user_id = user_id.to_owned();
.into_iter() Some((
.map(|user_id| {
(
user_id.clone(),
self.services self.services
.users .users
.displayname(&user_id) .displayname(&user_id)
.unwrap_or(None) .await
.unwrap_or_else(|| user_id.to_string()), .unwrap_or_else(|_| user_id.to_string()),
) user_id,
))
}) })
.collect::<Vec<_>>(); .collect()
.await;
let output_plain = format!( let output_plain = format!(
"{} Members in Room \"{}\":\n```\n{}\n```", "{} Members in Room \"{}\":\n```\n{}\n```",
member_info.len(), member_info.len(),
room_name, room_name,
member_info member_info
.iter() .into_iter()
.map(|(mxid, displayname)| format!("{mxid} | {displayname}")) .map(|(displayname, mxid)| format!("{mxid} | {displayname}"))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n") .join("\n")
); );
@ -81,11 +78,12 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R
#[admin_command] #[admin_command]
async fn view_room_topic(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { async fn view_room_topic(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
let Some(room_topic) = self let Ok(room_topic) = self
.services .services
.rooms .rooms
.state_accessor .state_accessor
.get_room_topic(&room_id)? .get_room_topic(&room_id)
.await
else { else {
return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set."));
}; };

View File

@ -6,6 +6,7 @@ mod moderation;
use clap::Subcommand; use clap::Subcommand;
use conduit::Result; use conduit::Result;
use ruma::OwnedRoomId;
use self::{ use self::{
alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand,
@ -49,4 +50,9 @@ pub(super) enum RoomCommand {
#[command(subcommand)] #[command(subcommand)]
/// - Manage the room directory /// - Manage the room directory
Directory(RoomDirectoryCommand), Directory(RoomDirectoryCommand),
/// - Check if we know about a room
Exists {
room_id: OwnedRoomId,
},
} }

View File

@ -1,6 +1,11 @@
use api::client::leave_room; use api::client::leave_room;
use clap::Subcommand; use clap::Subcommand;
use conduit::{debug, error, info, warn, Result}; use conduit::{
debug, error, info,
utils::{IterStream, ReadyExt},
warn, Result,
};
use futures::StreamExt;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId};
use crate::{admin_command, admin_command_dispatch, get_room_info}; use crate::{admin_command, admin_command_dispatch, get_room_info};
@ -76,7 +81,7 @@ async fn ban_room(
let admin_room_alias = &self.services.globals.admin_alias; let admin_room_alias = &self.services.globals.admin_alias;
if let Some(admin_room_id) = self.services.admin.get_admin_room()? { if let Ok(admin_room_id) = self.services.admin.get_admin_room().await {
if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) {
return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room."));
} }
@ -95,7 +100,7 @@ async fn ban_room(
debug!("Room specified is a room ID, banning room ID"); debug!("Room specified is a room ID, banning room ID");
self.services.rooms.metadata.ban_room(&room_id, true)?; self.services.rooms.metadata.ban_room(&room_id, true);
room_id room_id
} else if room.is_room_alias_id() { } else if room.is_room_alias_id() {
@ -114,7 +119,13 @@ async fn ban_room(
get_alias_helper to fetch room ID remotely" get_alias_helper to fetch room ID remotely"
); );
let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { let room_id = if let Ok(room_id) = self
.services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id room_id
} else { } else {
debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation");
@ -138,7 +149,7 @@ async fn ban_room(
} }
}; };
self.services.rooms.metadata.ban_room(&room_id, true)?; self.services.rooms.metadata.ban_room(&room_id, true);
room_id room_id
} else { } else {
@ -150,56 +161,40 @@ async fn ban_room(
debug!("Making all users leave the room {}", &room); debug!("Making all users leave the room {}", &room);
if force { if force {
for local_user in self let mut users = self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(|user| { .ready_filter(|user| self.services.globals.user_is_local(user))
user.ok().filter(|local_user| { .boxed();
self.services.globals.user_is_local(local_user)
// additional wrapped check here is to avoid adding remote users while let Some(local_user) = users.next().await {
// who are in the admin room to the list of local users (would
// fail auth check)
&& (self.services.globals.user_is_local(local_user)
// since this is a force operation, assume user is an admin
// if somehow this fails
&& self.services
.users
.is_admin(local_user)
.unwrap_or(true))
})
}) {
debug!( debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \
&local_user, &room_id admins too)",
); );
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
warn!(%e, "Failed to leave room"); warn!(%e, "Failed to leave room");
} }
} }
} else { } else {
for local_user in self let mut users = self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(|user| { .ready_filter(|user| self.services.globals.user_is_local(user))
user.ok().filter(|local_user| { .boxed();
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote users while let Some(local_user) = users.next().await {
// who are in the admin room to the list of local users (would fail auth check) if self.services.users.is_admin(local_user).await {
&& (local_user.server_name() continue;
== self.services.globals.server_name() }
&& !self.services
.users
.is_admin(local_user)
.unwrap_or(false))
})
}) {
debug!("Attempting leave for user {} in room {}", &local_user, &room_id); debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
error!( error!(
"Error attempting to make local user {} leave room {} during room banning: {}", "Error attempting to make local user {} leave room {} during room banning: {}",
&local_user, &room_id, e &local_user, &room_id, e
@ -214,12 +209,14 @@ async fn ban_room(
} }
// remove any local aliases, ignore errors // remove any local aliases, ignore errors
for ref local_alias in self for local_alias in &self
.services .services
.rooms .rooms
.alias .alias
.local_aliases_for_room(&room_id) .local_aliases_for_room(&room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await
{ {
_ = self _ = self
.services .services
@ -230,10 +227,10 @@ async fn ban_room(
} }
// unpublish from room directory, ignore errors // unpublish from room directory, ignore errors
_ = self.services.rooms.directory.set_not_public(&room_id); self.services.rooms.directory.set_not_public(&room_id);
if disable_federation { if disable_federation {
self.services.rooms.metadata.disable_room(&room_id, true)?; self.services.rooms.metadata.disable_room(&room_id, true);
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
"Room banned, removed all our local users, and disabled incoming federation with room.", "Room banned, removed all our local users, and disabled incoming federation with room.",
)); ));
@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
for &room in &rooms_s { for &room in &rooms_s {
match <&RoomOrAliasId>::try_from(room) { match <&RoomOrAliasId>::try_from(room) {
Ok(room_alias_or_id) => { Ok(room_alias_or_id) => {
if let Some(admin_room_id) = self.services.admin.get_admin_room()? { if let Ok(admin_room_id) = self.services.admin.get_admin_room().await {
if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) {
info!("User specified admin room in bulk ban list, ignoring"); info!("User specified admin room in bulk ban list, ignoring");
continue; continue;
@ -300,13 +297,18 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
if room_alias_or_id.is_room_alias_id() { if room_alias_or_id.is_room_alias_id() {
match RoomAliasId::parse(room_alias_or_id) { match RoomAliasId::parse(room_alias_or_id) {
Ok(room_alias) => { Ok(room_alias) => {
let room_id = let room_id = if let Ok(room_id) = self
if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { .services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id room_id
} else { } else {
debug!( debug!(
"We don't have this room alias to a room ID locally, attempting to fetch room \ "We don't have this room alias to a room ID locally, attempting to fetch room ID \
ID over federation" over federation"
); );
match self match self
@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
} }
for room_id in room_ids { for room_id in room_ids {
if self self.services.rooms.metadata.ban_room(&room_id, true);
.services
.rooms
.metadata
.ban_room(&room_id, true)
.is_ok()
{
debug!("Banned {room_id} successfully"); debug!("Banned {room_id} successfully");
room_ban_count = room_ban_count.saturating_add(1); room_ban_count = room_ban_count.saturating_add(1);
}
debug!("Making all users leave the room {}", &room_id); debug!("Making all users leave the room {}", &room_id);
if force { if force {
for local_user in self let mut users = self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(|user| { .ready_filter(|user| self.services.globals.user_is_local(user))
user.ok().filter(|local_user| { .boxed();
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote while let Some(local_user) = users.next().await {
// users who are in the admin room to the list of local
// users (would fail auth check)
&& (local_user.server_name()
== self.services.globals.server_name()
// since this is a force operation, assume user is an
// admin if somehow this fails
&& self.services
.users
.is_admin(local_user)
.unwrap_or(true))
})
}) {
debug!( debug!(
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \
&local_user, room_id admins too)",
); );
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
warn!(%e, "Failed to leave room"); warn!(%e, "Failed to leave room");
} }
} }
} else { } else {
for local_user in self let mut users = self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(|user| { .ready_filter(|user| self.services.globals.user_is_local(user))
user.ok().filter(|local_user| { .boxed();
local_user.server_name() == self.services.globals.server_name()
// additional wrapped check here is to avoid adding remote while let Some(local_user) = users.next().await {
// users who are in the admin room to the list of local if self.services.users.is_admin(local_user).await {
// users (would fail auth check) continue;
&& (local_user.server_name() }
== self.services.globals.server_name()
&& !self.services debug!("Attempting leave for user {local_user} in room {room_id}");
.users if let Err(e) = leave_room(self.services, local_user, &room_id, None).await {
.is_admin(local_user)
.unwrap_or(false))
})
}) {
debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await {
error!( error!(
"Error attempting to make local user {} leave room {} during bulk room banning: {}", "Error attempting to make local user {local_user} leave room {room_id} during bulk room \
&local_user, &room_id, e banning: {e}",
); );
return Ok(RoomMessageEventContent::text_plain(format!( return Ok(RoomMessageEventContent::text_plain(format!(
"Error attempting to make local user {} leave room {} during room banning (room is still \ "Error attempting to make local user {} leave room {} during room banning (room is still \
banned but not removing any more users and not banning any more rooms): {}\nIf you would \ banned but not removing any more users and not banning any more rooms): {}\nIf you would \
@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu
} }
// remove any local aliases, ignore errors // remove any local aliases, ignore errors
for ref local_alias in self self.services
.services
.rooms .rooms
.alias .alias
.local_aliases_for_room(&room_id) .local_aliases_for_room(&room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
{ .for_each(|local_alias| async move {
_ = self self.services
.services
.rooms .rooms
.alias .alias
.remove_alias(local_alias, &self.services.globals.server_user) .remove_alias(&local_alias, &self.services.globals.server_user)
.await
.ok();
})
.await; .await;
}
// unpublish from room directory, ignore errors // unpublish from room directory, ignore errors
_ = self.services.rooms.directory.set_not_public(&room_id); self.services.rooms.directory.set_not_public(&room_id);
if disable_federation { if disable_federation {
self.services.rooms.metadata.disable_room(&room_id, true)?; self.services.rooms.metadata.disable_room(&room_id, true);
} }
} }
@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
debug!("Room specified is a room ID, unbanning room ID"); debug!("Room specified is a room ID, unbanning room ID");
self.services.rooms.metadata.ban_room(&room_id, false)?; self.services.rooms.metadata.ban_room(&room_id, false);
room_id room_id
} else if room.is_room_alias_id() { } else if room.is_room_alias_id() {
@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
get_alias_helper to fetch room ID remotely" get_alias_helper to fetch room ID remotely"
); );
let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { let room_id = if let Ok(room_id) = self
.services
.rooms
.alias
.resolve_local_alias(&room_alias)
.await
{
room_id room_id
} else { } else {
debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation");
@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
} }
}; };
self.services.rooms.metadata.ban_room(&room_id, false)?; self.services.rooms.metadata.ban_room(&room_id, false);
room_id room_id
} else { } else {
@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
}; };
if enable_federation { if enable_federation {
self.services.rooms.metadata.disable_room(&room_id, false)?; self.services.rooms.metadata.disable_room(&room_id, false);
return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); return Ok(RoomMessageEventContent::text_plain("Room unbanned."));
} }
@ -569,23 +555,26 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) ->
#[admin_command] #[admin_command]
async fn list_banned_rooms(&self, no_details: bool) -> Result<RoomMessageEventContent> { async fn list_banned_rooms(&self, no_details: bool) -> Result<RoomMessageEventContent> {
let rooms = self let room_ids = self
.services .services
.rooms .rooms
.metadata .metadata
.list_banned_rooms() .list_banned_rooms()
.collect::<Result<Vec<_>, _>>(); .map(Into::into)
.collect::<Vec<OwnedRoomId>>()
.await;
match rooms {
Ok(room_ids) => {
if room_ids.is_empty() { if room_ids.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); return Ok(RoomMessageEventContent::text_plain("No rooms are banned."));
} }
let mut rooms = room_ids let mut rooms = room_ids
.into_iter() .iter()
.map(|room_id| get_room_info(self.services, &room_id)) .stream()
.collect::<Vec<_>>(); .then(|room_id| get_room_info(self.services, room_id))
.collect::<Vec<_>>()
.await;
rooms.sort_by_key(|r| r.1); rooms.sort_by_key(|r| r.1);
rooms.reverse(); rooms.reverse();
@ -604,10 +593,4 @@ async fn list_banned_rooms(&self, no_details: bool) -> Result<RoomMessageEventCo
); );
Ok(RoomMessageEventContent::notice_markdown(output_plain)) Ok(RoomMessageEventContent::notice_markdown(output_plain))
},
Err(e) => {
error!("Failed to list banned rooms: {e}");
Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}")))
},
}
} }

View File

@ -1,7 +1,9 @@
use std::{collections::BTreeMap, fmt::Write as _}; use std::{collections::BTreeMap, fmt::Write as _};
use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room};
use conduit::{error, info, utils, warn, PduBuilder, Result}; use conduit::{error, info, is_equal_to, utils, warn, PduBuilder, Result};
use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname};
use futures::StreamExt;
use ruma::{ use ruma::{
events::{ events::{
room::{ room::{
@ -25,16 +27,19 @@ const AUTO_GEN_PASSWORD_LENGTH: usize = 25;
#[admin_command] #[admin_command]
pub(super) async fn list_users(&self) -> Result<RoomMessageEventContent> { pub(super) async fn list_users(&self) -> Result<RoomMessageEventContent> {
match self.services.users.list_local_users() { let users = self
Ok(users) => { .services
.users
.list_local_users()
.map(ToString::to_string)
.collect::<Vec<_>>()
.await;
let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len());
plain_msg += users.join("\n").as_str(); plain_msg += users.join("\n").as_str();
plain_msg += "\n```"; plain_msg += "\n```";
Ok(RoomMessageEventContent::notice_markdown(plain_msg)) Ok(RoomMessageEventContent::notice_markdown(plain_msg))
},
Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())),
}
} }
#[admin_command] #[admin_command]
@ -42,7 +47,7 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
// Validate user id // Validate user id
let user_id = parse_local_user_id(self.services, &username)?; let user_id = parse_local_user_id(self.services, &username)?;
if self.services.users.exists(&user_id)? { if self.services.users.exists(&user_id).await {
return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists"))); return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists")));
} }
@ -77,11 +82,12 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
self.services self.services
.users .users
.set_displayname(&user_id, Some(displayname)) .set_displayname(&user_id, Some(displayname));
.await?;
// Initial account data // Initial account data
self.services.account_data.update( self.services
.account_data
.update(
None, None,
&user_id, &user_id,
ruma::events::GlobalAccountDataEventType::PushRules ruma::events::GlobalAccountDataEventType::PushRules
@ -93,7 +99,8 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
}, },
}) })
.expect("to json value always works"), .expect("to json value always works"),
)?; )
.await?;
if !self.services.globals.config.auto_join_rooms.is_empty() { if !self.services.globals.config.auto_join_rooms.is_empty() {
for room in &self.services.globals.config.auto_join_rooms { for room in &self.services.globals.config.auto_join_rooms {
@ -101,7 +108,8 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
.services .services
.rooms .rooms
.state_cache .state_cache
.server_in_room(self.services.globals.server_name(), room)? .server_in_room(self.services.globals.server_name(), room)
.await
{ {
warn!("Skipping room {room} to automatically join as we have never joined before."); warn!("Skipping room {room} to automatically join as we have never joined before.");
continue; continue;
@ -135,13 +143,14 @@ pub(super) async fn create_user(&self, username: String, password: Option<String
// if this account creation is from the CLI / --execute, invite the first user // if this account creation is from the CLI / --execute, invite the first user
// to admin room // to admin room
if let Some(admin_room) = self.services.admin.get_admin_room()? { if let Ok(admin_room) = self.services.admin.get_admin_room().await {
if self if self
.services .services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(&admin_room)? .room_joined_count(&admin_room)
== Some(1) .await
.is_ok_and(is_equal_to!(1))
{ {
self.services.admin.make_user_admin(&user_id).await?; self.services.admin.make_user_admin(&user_id).await?;
@ -167,7 +176,7 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) ->
)); ));
} }
self.services.users.deactivate_account(&user_id)?; self.services.users.deactivate_account(&user_id).await?;
if !no_leave_rooms { if !no_leave_rooms {
self.services self.services
@ -175,17 +184,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) ->
.send_message(RoomMessageEventContent::text_plain(format!( .send_message(RoomMessageEventContent::text_plain(format!(
"Making {user_id} leave all rooms after deactivation..." "Making {user_id} leave all rooms after deactivation..."
))) )))
.await; .await
.ok();
let all_joined_rooms: Vec<OwnedRoomId> = self let all_joined_rooms: Vec<OwnedRoomId> = self
.services .services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?;
update_displayname(self.services, &user_id, None, &all_joined_rooms).await?;
update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?;
leave_all_rooms(self.services, &user_id).await;
} }
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
@ -238,15 +252,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
let mut admins = Vec::new(); let mut admins = Vec::new();
for username in usernames { for username in usernames {
match parse_active_local_user_id(self.services, username) { match parse_active_local_user_id(self.services, username).await {
Ok(user_id) => { Ok(user_id) => {
if self.services.users.is_admin(&user_id)? && !force { if self.services.users.is_admin(&user_id).await && !force {
self.services self.services
.admin .admin
.send_message(RoomMessageEventContent::text_plain(format!( .send_message(RoomMessageEventContent::text_plain(format!(
"{username} is an admin and --force is not set, skipping over" "{username} is an admin and --force is not set, skipping over"
))) )))
.await; .await
.ok();
admins.push(username); admins.push(username);
continue; continue;
} }
@ -258,7 +273,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.send_message(RoomMessageEventContent::text_plain(format!( .send_message(RoomMessageEventContent::text_plain(format!(
"{username} is the server service account, skipping over" "{username} is the server service account, skipping over"
))) )))
.await; .await
.ok();
continue; continue;
} }
@ -270,7 +286,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.send_message(RoomMessageEventContent::text_plain(format!( .send_message(RoomMessageEventContent::text_plain(format!(
"{username} is not a valid username, skipping over: {e}" "{username} is not a valid username, skipping over: {e}"
))) )))
.await; .await
.ok();
continue; continue;
}, },
} }
@ -279,7 +296,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
let mut deactivation_count: usize = 0; let mut deactivation_count: usize = 0;
for user_id in user_ids { for user_id in user_ids {
match self.services.users.deactivate_account(&user_id) { match self.services.users.deactivate_account(&user_id).await {
Ok(()) => { Ok(()) => {
deactivation_count = deactivation_count.saturating_add(1); deactivation_count = deactivation_count.saturating_add(1);
if !no_leave_rooms { if !no_leave_rooms {
@ -289,16 +306,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) ->
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; .await;
full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?;
update_displayname(self.services, &user_id, None, &all_joined_rooms)
.await
.ok();
update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms)
.await
.ok();
leave_all_rooms(self.services, &user_id).await;
} }
}, },
Err(e) => { Err(e) => {
self.services self.services
.admin .admin
.send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}")))
.await; .await
.ok();
}, },
} }
} }
@ -326,9 +353,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result<RoomMess
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .then(|room_id| get_room_info(self.services, room_id))
.map(|room_id| get_room_info(self.services, &room_id)) .collect()
.collect(); .await;
if rooms.is_empty() { if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); return Ok(RoomMessageEventContent::text_plain("User is not in any rooms."));
@ -404,10 +431,9 @@ pub(super) async fn force_demote(
.services .services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "")
.as_ref() .await
.and_then(|event| serde_json::from_str(event.content.get()).ok()?) .ok();
.and_then(|content: RoomPowerLevelsEventContent| content.into());
let user_can_demote_self = room_power_levels let user_can_demote_self = room_power_levels
.as_ref() .as_ref()
@ -417,9 +443,9 @@ pub(super) async fn force_demote(
.services .services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")? .room_state_get(&room_id, &StateEventType::RoomCreate, "")
.as_ref() .await
.is_some_and(|event| event.sender == user_id); .is_ok_and(|event| event.sender == user_id);
if !user_can_demote_self { if !user_can_demote_self {
return Ok(RoomMessageEventContent::notice_markdown( return Ok(RoomMessageEventContent::notice_markdown(
@ -473,15 +499,16 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result<RoomMessag
pub(super) async fn put_room_tag( pub(super) async fn put_room_tag(
&self, user_id: String, room_id: Box<RoomId>, tag: String, &self, user_id: String, room_id: Box<RoomId>, tag: String,
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?; let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self let event = self
.services .services
.account_data .account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| TagEvent { |_| TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),
}, },
@ -494,12 +521,15 @@ pub(super) async fn put_room_tag(
.tags .tags
.insert(tag.clone().into(), TagInfo::new()); .insert(tag.clone().into(), TagInfo::new());
self.services.account_data.update( self.services
.account_data
.update(
Some(&room_id), Some(&room_id),
&user_id, &user_id,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )
.await?;
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
"Successfully updated room account data for {user_id} and room {room_id} with tag {tag}" "Successfully updated room account data for {user_id} and room {room_id} with tag {tag}"
@ -510,15 +540,16 @@ pub(super) async fn put_room_tag(
pub(super) async fn delete_room_tag( pub(super) async fn delete_room_tag(
&self, user_id: String, room_id: Box<RoomId>, tag: String, &self, user_id: String, room_id: Box<RoomId>, tag: String,
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?; let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self let event = self
.services .services
.account_data .account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| TagEvent { |_| TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),
}, },
@ -528,12 +559,15 @@ pub(super) async fn delete_room_tag(
tags_event.content.tags.remove(&tag.clone().into()); tags_event.content.tags.remove(&tag.clone().into());
self.services.account_data.update( self.services
.account_data
.update(
Some(&room_id), Some(&room_id),
&user_id, &user_id,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )
.await?;
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
"Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}" "Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}"
@ -542,15 +576,16 @@ pub(super) async fn delete_room_tag(
#[admin_command] #[admin_command]
pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
let user_id = parse_active_local_user_id(self.services, &user_id)?; let user_id = parse_active_local_user_id(self.services, &user_id).await?;
let event = self let event = self
.services .services
.account_data .account_data
.get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)
.await;
let tags_event = event.map_or_else( let tags_event = event.map_or_else(
|| TagEvent { |_| TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),
}, },
@ -566,11 +601,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>)
#[admin_command] #[admin_command]
pub(super) async fn redact_event(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { pub(super) async fn redact_event(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let Some(event) = self let Ok(event) = self
.services .services
.rooms .rooms
.timeline .timeline
.get_non_outlier_pdu(&event_id)? .get_non_outlier_pdu(&event_id)
.await
else { else {
return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database.")); return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database."));
}; };

View File

@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String {
.replace('>', "&gt;") .replace('>', "&gt;")
} }
pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) {
( (
id.into(), room_id.into(),
services services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(id) .room_joined_count(room_id)
.ok() .await
.flatten()
.unwrap_or(0), .unwrap_or(0),
services services
.rooms .rooms
.state_accessor .state_accessor
.get_name(id) .get_name(room_id)
.ok() .await
.flatten() .unwrap_or_else(|_| room_id.to_string()),
.unwrap_or_else(|| id.to_string()),
) )
} }
@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result<
} }
/// Parses user ID that is an active (not guest or deactivated) local user /// Parses user ID that is an active (not guest or deactivated) local user
pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> {
let user_id = parse_local_user_id(services, user_id)?; let user_id = parse_local_user_id(services, user_id)?;
if !services.users.exists(&user_id)? { if !services.users.exists(&user_id).await {
return Err!("User {user_id:?} does not exist on this server."); return Err!("User {user_id:?} does not exist on this server.");
} }
if services.users.is_deactivated(&user_id)? { if services.users.is_deactivated(&user_id).await? {
return Err!("User {user_id:?} is deactivated."); return Err!("User {user_id:?} is deactivated.");
} }

View File

@ -45,7 +45,7 @@ conduit-core.workspace = true
conduit-database.workspace = true conduit-database.workspace = true
conduit-service.workspace = true conduit-service.workspace = true
const-str.workspace = true const-str.workspace = true
futures-util.workspace = true futures.workspace = true
hmac.workspace = true hmac.workspace = true
http.workspace = true http.workspace = true
http-body-util.workspace = true http-body-util.workspace = true

View File

@ -2,7 +2,8 @@ use std::fmt::Write;
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result}; use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result};
use futures::{FutureExt, StreamExt};
use register::RegistrationKind; use register::RegistrationKind;
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -55,7 +56,7 @@ pub(crate) async fn get_register_available_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
// Check if username is creative enough // Check if username is creative enough
if services.users.exists(&user_id)? { if services.users.exists(&user_id).await {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
@ -125,7 +126,7 @@ pub(crate) async fn register_route(
// forbid guests from registering if there is not a real admin user yet. give // forbid guests from registering if there is not a real admin user yet. give
// generic user error. // generic user error.
if is_guest && services.users.count()? < 2 { if is_guest && services.users.count().await < 2 {
warn!( warn!(
"Guest account attempted to register before a real admin user has been registered, rejecting \ "Guest account attempted to register before a real admin user has been registered, rejecting \
registration. Guest's initial device name: {:?}", registration. Guest's initial device name: {:?}",
@ -142,7 +143,7 @@ pub(crate) async fn register_route(
.filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
if services.users.exists(&proposed_user_id)? { if services.users.exists(&proposed_user_id).await {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
@ -162,7 +163,7 @@ pub(crate) async fn register_route(
services.globals.server_name(), services.globals.server_name(),
) )
.unwrap(); .unwrap();
if !services.users.exists(&proposed_user_id)? { if !services.users.exists(&proposed_user_id).await {
break proposed_user_id; break proposed_user_id;
} }
}, },
@ -210,12 +211,15 @@ pub(crate) async fn register_route(
if !skip_auth { if !skip_auth {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services.uiaa.try_auth( let (worked, uiaainfo) = services
.uiaa
.try_auth(
&UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"),
"".into(), "".into(),
auth, auth,
&uiaainfo, &uiaainfo,
)?; )
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
@ -227,7 +231,7 @@ pub(crate) async fn register_route(
"".into(), "".into(),
&uiaainfo, &uiaainfo,
&json, &json,
)?; );
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -255,11 +259,12 @@ pub(crate) async fn register_route(
services services
.users .users
.set_displayname(&user_id, Some(displayname.clone())) .set_displayname(&user_id, Some(displayname.clone()));
.await?;
// Initial account data // Initial account data
services.account_data.update( services
.account_data
.update(
None, None,
&user_id, &user_id,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -269,7 +274,8 @@ pub(crate) async fn register_route(
}, },
}) })
.expect("to json always works"), .expect("to json always works"),
)?; )
.await?;
// Inhibit login does not work for guests // Inhibit login does not work for guests
if !is_guest && body.inhibit_login { if !is_guest && body.inhibit_login {
@ -294,13 +300,16 @@ pub(crate) async fn register_route(
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
services.users.create_device( services
.users
.create_device(
&user_id, &user_id,
&device_id, &device_id,
&token, &token,
body.initial_device_display_name.clone(), body.initial_device_display_name.clone(),
Some(client.to_string()), Some(client.to_string()),
)?; )
.await?;
debug_info!(%user_id, %device_id, "User account was created"); debug_info!(%user_id, %device_id, "User account was created");
@ -318,7 +327,8 @@ pub(crate) async fn register_route(
"New user \"{user_id}\" registered on this server from IP {client} and device display name \ "New user \"{user_id}\" registered on this server from IP {client} and device display name \
\"{device_display_name}\"" \"{device_display_name}\""
))) )))
.await; .await
.ok();
} }
} else { } else {
info!("New user \"{user_id}\" registered on this server."); info!("New user \"{user_id}\" registered on this server.");
@ -329,7 +339,8 @@ pub(crate) async fn register_route(
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"New user \"{user_id}\" registered on this server from IP {client}" "New user \"{user_id}\" registered on this server from IP {client}"
))) )))
.await; .await
.ok();
} }
} }
} }
@ -346,7 +357,8 @@ pub(crate) async fn register_route(
"Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \ "Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \
this server from IP {client}" this server from IP {client}"
))) )))
.await; .await
.ok();
} }
} else { } else {
#[allow(clippy::collapsible_else_if)] #[allow(clippy::collapsible_else_if)]
@ -357,7 +369,8 @@ pub(crate) async fn register_route(
"Guest user \"{user_id}\" with no device display name registered on this server from IP \ "Guest user \"{user_id}\" with no device display name registered on this server from IP \
{client}", {client}",
))) )))
.await; .await
.ok();
} }
} }
} }
@ -365,10 +378,15 @@ pub(crate) async fn register_route(
// If this is the first real user, grant them admin privileges except for guest // If this is the first real user, grant them admin privileges except for guest
// users Note: the server user, @conduit:servername, is generated first // users Note: the server user, @conduit:servername, is generated first
if !is_guest { if !is_guest {
if let Some(admin_room) = services.admin.get_admin_room()? { if let Ok(admin_room) = services.admin.get_admin_room().await {
if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { if services
.rooms
.state_cache
.room_joined_count(&admin_room)
.await
.is_ok_and(is_equal_to!(1))
{
services.admin.make_user_admin(&user_id).await?; services.admin.make_user_admin(&user_id).await?;
warn!("Granting {user_id} admin privileges as the first user"); warn!("Granting {user_id} admin privileges as the first user");
} }
} }
@ -382,7 +400,8 @@ pub(crate) async fn register_route(
if !services if !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services.globals.server_name(), room)? .server_in_room(services.globals.server_name(), room)
.await
{ {
warn!("Skipping room {room} to automatically join as we have never joined before."); warn!("Skipping room {room} to automatically join as we have never joined before.");
continue; continue;
@ -398,6 +417,7 @@ pub(crate) async fn register_route(
None, None,
&body.appservice_info, &body.appservice_info,
) )
.boxed()
.await .await
{ {
// don't return this error so we don't fail registrations // don't return this error so we don't fail registrations
@ -461,16 +481,20 @@ pub(crate) async fn change_password_route(
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -482,14 +506,12 @@ pub(crate) async fn change_password_route(
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
for id in services services
.users .users
.all_device_ids(sender_user) .all_device_ids(sender_user)
.filter_map(Result::ok) .ready_filter(|id| id != sender_device)
.filter(|id| id != sender_device) .for_each(|id| services.users.remove_device(sender_user, id))
{ .await;
services.users.remove_device(sender_user, &id)?;
}
} }
info!("User {sender_user} changed their password."); info!("User {sender_user} changed their password.");
@ -500,7 +522,8 @@ pub(crate) async fn change_password_route(
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} changed their password." "User {sender_user} changed their password."
))) )))
.await; .await
.ok();
} }
Ok(change_password::v3::Response {}) Ok(change_password::v3::Response {})
@ -520,7 +543,7 @@ pub(crate) async fn whoami_route(
Ok(whoami::v3::Response { Ok(whoami::v3::Response {
user_id: sender_user.clone(), user_id: sender_user.clone(),
device_id, device_id,
is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(),
}) })
} }
@ -561,7 +584,9 @@ pub(crate) async fn deactivate_route(
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
@ -570,7 +595,8 @@ pub(crate) async fn deactivate_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -581,10 +607,14 @@ pub(crate) async fn deactivate_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
full_user_deactivate(&services, sender_user, all_joined_rooms).await?; super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?;
super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?;
full_user_deactivate(&services, sender_user, &all_joined_rooms).await?;
info!("User {sender_user} deactivated their account."); info!("User {sender_user} deactivated their account.");
@ -594,7 +624,8 @@ pub(crate) async fn deactivate_route(
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} deactivated their account." "User {sender_user} deactivated their account."
))) )))
.await; .await
.ok();
} }
Ok(deactivate::v3::Response { Ok(deactivate::v3::Response {
@ -674,34 +705,27 @@ pub(crate) async fn check_registration_token_validity(
/// - Removing all profile data /// - Removing all profile data
/// - Leaving all rooms (and forgets all of them) /// - Leaving all rooms (and forgets all of them)
pub async fn full_user_deactivate( pub async fn full_user_deactivate(
services: &Services, user_id: &UserId, all_joined_rooms: Vec<OwnedRoomId>, services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId],
) -> Result<()> { ) -> Result<()> {
services.users.deactivate_account(user_id)?; services.users.deactivate_account(user_id).await?;
super::update_displayname(services, user_id, None, all_joined_rooms).await?;
super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?;
super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?; services
super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?;
let all_profile_keys = services
.users .users
.all_profile_keys(user_id) .all_profile_keys(user_id)
.filter_map(Result::ok); .ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None))
.await;
for (profile_key, _profile_value) in all_profile_keys {
if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) {
warn!("Failed removing {user_id} profile key {profile_key}: {e}");
}
}
for room_id in all_joined_rooms { for room_id in all_joined_rooms {
let state_lock = services.rooms.state.mutex.lock(&room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
let room_power_levels = services let room_power_levels = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "")
.as_ref() .await
.and_then(|event| serde_json::from_str(event.content.get()).ok()?) .ok();
.and_then(|content: RoomPowerLevelsEventContent| content.into());
let user_can_demote_self = room_power_levels let user_can_demote_self = room_power_levels
.as_ref() .as_ref()
@ -710,9 +734,9 @@ pub async fn full_user_deactivate(
}) || services }) || services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")
.as_ref() .await
.is_some_and(|event| event.sender == user_id); .is_ok_and(|event| event.sender == user_id);
if user_can_demote_self { if user_can_demote_self {
let mut power_levels_content = room_power_levels.unwrap_or_default(); let mut power_levels_content = room_power_levels.unwrap_or_default();
@ -732,7 +756,7 @@ pub async fn full_user_deactivate(
timestamp: None, timestamp: None,
}, },
user_id, user_id,
&room_id, room_id,
&state_lock, &state_lock,
) )
.await .await

View File

@ -1,11 +1,9 @@
use axum::extract::State; use axum::extract::State;
use conduit::{debug, Error, Result}; use conduit::{debug, Err, Result};
use futures::StreamExt;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use ruma::{ use ruma::{
api::client::{ api::client::alias::{create_alias, delete_alias, get_alias},
alias::{create_alias, delete_alias, get_alias},
error::ErrorKind,
},
OwnedServerName, RoomAliasId, RoomId, OwnedServerName, RoomAliasId, RoomId,
}; };
use service::Services; use service::Services;
@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route(
.forbidden_alias_names() .forbidden_alias_names()
.is_match(body.room_alias.alias()) .is_match(body.room_alias.alias())
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); return Err!(Request(Forbidden("Room alias is forbidden.")));
} }
if services if services
.rooms .rooms
.alias .alias
.resolve_local_alias(&body.room_alias)? .resolve_local_alias(&body.room_alias)
.is_some() .await
.is_ok()
{ {
return Err(Error::Conflict("Alias already exists.")); return Err!(Conflict("Alias already exists."));
} }
services services
@ -95,16 +94,16 @@ pub(crate) async fn get_alias_route(
.resolve_alias(&room_alias, servers.as_ref()) .resolve_alias(&room_alias, servers.as_ref())
.await .await
else { else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); return Err!(Request(NotFound("Room with alias not found.")));
}; };
let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers).await;
debug!(?room_alias, ?room_id, "available servers: {servers:?}"); debug!(?room_alias, ?room_id, "available servers: {servers:?}");
Ok(get_alias::v3::Response::new(room_id, servers)) Ok(get_alias::v3::Response::new(room_id, servers))
} }
fn room_available_servers( async fn room_available_servers(
services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>, services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>,
) -> Vec<OwnedServerName> { ) -> Vec<OwnedServerName> {
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
@ -112,8 +111,9 @@ fn room_available_servers(
.rooms .rooms
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
// push any servers we want in the list already (e.g. responded remote alias // push any servers we want in the list already (e.g. responded remote alias
// servers, room alias server itself) // servers, room alias server itself)

View File

@ -1,18 +1,16 @@
use axum::extract::State; use axum::extract::State;
use conduit::{err, Err};
use ruma::{ use ruma::{
api::client::{ api::client::backup::{
backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
get_latest_backup_info, update_backup_version, get_latest_backup_info, update_backup_version,
}, },
error::ErrorKind,
},
UInt, UInt,
}; };
use crate::{Error, Result, Ruma}; use crate::{Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version` /// # `POST /_matrix/client/r0/room_keys/version`
/// ///
@ -40,7 +38,8 @@ pub(crate) async fn update_backup_version_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services services
.key_backups .key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?; .update_backup(sender_user, &body.version, &body.algorithm)
.await?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
@ -55,14 +54,15 @@ pub(crate) async fn get_latest_backup_info_route(
let (version, algorithm) = services let (version, algorithm) = services
.key_backups .key_backups
.get_latest_backup(sender_user)? .get_latest_backup(sender_user)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; .await
.map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?;
Ok(get_latest_backup_info::v3::Response { Ok(get_latest_backup_info::v3::Response {
algorithm, algorithm,
count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version).await)
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services.key_backups.get_etag(sender_user, &version)?, etag: services.key_backups.get_etag(sender_user, &version).await,
version, version,
}) })
} }
@ -76,18 +76,21 @@ pub(crate) async fn get_backup_info_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let algorithm = services let algorithm = services
.key_backups .key_backups
.get_backup(sender_user, &body.version)? .get_backup(sender_user, &body.version)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; .await
.map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?;
Ok(get_backup_info::v3::Response { Ok(get_backup_info::v3::Response {
algorithm, algorithm,
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
version: body.version.clone(), version: body.version.clone(),
}) })
} }
@ -105,7 +108,8 @@ pub(crate) async fn delete_backup_version_route(
services services
.key_backups .key_backups
.delete_backup(sender_user, &body.version)?; .delete_backup(sender_user, &body.version)
.await;
Ok(delete_backup_version::v3::Response {}) Ok(delete_backup_version::v3::Response {})
} }
@ -123,34 +127,36 @@ pub(crate) async fn add_backup_keys_route(
) -> Result<add_backup_keys::v3::Response> { ) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if services
!= services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)
.as_ref() .await
.is_ok_and(|version| version != body.version)
{ {
return Err(Error::BadRequest( return Err!(Request(InvalidParam(
ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup."
"You may only manipulate the most recently created version of the backup.", )));
));
} }
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
services services
.key_backups .key_backups
.add_key(sender_user, &body.version, room_id, session_id, key_data)?; .add_key(sender_user, &body.version, room_id, session_id, key_data)
.await?;
} }
} }
Ok(add_backup_keys::v3::Response { Ok(add_backup_keys::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }
@ -167,32 +173,34 @@ pub(crate) async fn add_backup_keys_for_room_route(
) -> Result<add_backup_keys_for_room::v3::Response> { ) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if services
!= services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)
.as_ref() .await
.is_ok_and(|version| version != body.version)
{ {
return Err(Error::BadRequest( return Err!(Request(InvalidParam(
ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup."
"You may only manipulate the most recently created version of the backup.", )));
));
} }
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
services services
.key_backups .key_backups
.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)
.await?;
} }
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }
@ -209,30 +217,32 @@ pub(crate) async fn add_backup_keys_for_session_route(
) -> Result<add_backup_keys_for_session::v3::Response> { ) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if services
!= services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)
.as_ref() .await
.is_ok_and(|version| version != body.version)
{ {
return Err(Error::BadRequest( return Err!(Request(InvalidParam(
ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup."
"You may only manipulate the most recently created version of the backup.", )));
));
} }
services services
.key_backups .key_backups
.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)
.await?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }
@ -244,7 +254,10 @@ pub(crate) async fn get_backup_keys_route(
) -> Result<get_backup_keys::v3::Response> { ) -> Result<get_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let rooms = services.key_backups.get_all(sender_user, &body.version)?; let rooms = services
.key_backups
.get_all(sender_user, &body.version)
.await;
Ok(get_backup_keys::v3::Response { Ok(get_backup_keys::v3::Response {
rooms, rooms,
@ -261,7 +274,8 @@ pub(crate) async fn get_backup_keys_for_room_route(
let sessions = services let sessions = services
.key_backups .key_backups
.get_room(sender_user, &body.version, &body.room_id)?; .get_room(sender_user, &body.version, &body.room_id)
.await;
Ok(get_backup_keys_for_room::v3::Response { Ok(get_backup_keys_for_room::v3::Response {
sessions, sessions,
@ -278,8 +292,9 @@ pub(crate) async fn get_backup_keys_for_session_route(
let key_data = services let key_data = services
.key_backups .key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .get_session(sender_user, &body.version, &body.room_id, &body.session_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; .await
.map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?;
Ok(get_backup_keys_for_session::v3::Response { Ok(get_backup_keys_for_session::v3::Response {
key_data, key_data,
@ -296,16 +311,19 @@ pub(crate) async fn delete_backup_keys_route(
services services
.key_backups .key_backups
.delete_all_keys(sender_user, &body.version)?; .delete_all_keys(sender_user, &body.version)
.await;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }
@ -319,16 +337,19 @@ pub(crate) async fn delete_backup_keys_for_room_route(
services services
.key_backups .key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?; .delete_room_keys(sender_user, &body.version, &body.room_id)
.await;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }
@ -342,15 +363,18 @@ pub(crate) async fn delete_backup_keys_for_session_route(
services services
.key_backups .key_backups
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)
.await;
Ok(delete_backup_keys_for_session::v3::Response { Ok(delete_backup_keys_for_session::v3::Response {
count: (UInt::try_from( count: services
services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)
) .await
.expect("user backup keys count should not be that high")), .try_into()?,
etag: services.key_backups.get_etag(sender_user, &body.version)?, etag: services
.key_backups
.get_etag(sender_user, &body.version)
.await,
}) })
} }

View File

@ -1,4 +1,5 @@
use axum::extract::State; use axum::extract::State;
use conduit::err;
use ruma::{ use ruma::{
api::client::{ api::client::{
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
@ -25,7 +26,8 @@ pub(crate) async fn set_global_account_data_route(
&body.sender_user, &body.sender_user,
&body.event_type.to_string(), &body.event_type.to_string(),
body.data.json(), body.data.json(),
)?; )
.await?;
Ok(set_global_account_data::v3::Response {}) Ok(set_global_account_data::v3::Response {})
} }
@ -42,7 +44,8 @@ pub(crate) async fn set_room_account_data_route(
&body.sender_user, &body.sender_user,
&body.event_type.to_string(), &body.event_type.to_string(),
body.data.json(), body.data.json(),
)?; )
.await?;
Ok(set_room_account_data::v3::Response {}) Ok(set_room_account_data::v3::Response {})
} }
@ -57,8 +60,9 @@ pub(crate) async fn get_global_account_data_route(
let event: Box<RawJsonValue> = services let event: Box<RawJsonValue> = services
.account_data .account_data
.get(None, sender_user, body.event_type.to_string().into())? .get(None, sender_user, body.event_type.to_string().into())
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .await
.map_err(|_| err!(Request(NotFound("Data not found."))))?;
let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get()) let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -79,8 +83,9 @@ pub(crate) async fn get_room_account_data_route(
let event: Box<RawJsonValue> = services let event: Box<RawJsonValue> = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, body.event_type.clone())? .get(Some(&body.room_id), sender_user, body.event_type.clone())
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .await
.map_err(|_| err!(Request(NotFound("Data not found."))))?;
let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get()) let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -91,7 +96,7 @@ pub(crate) async fn get_room_account_data_route(
}) })
} }
fn set_account_data( async fn set_account_data(
services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str, services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str,
data: &RawJsonValue, data: &RawJsonValue,
) -> Result<()> { ) -> Result<()> {
@ -100,7 +105,9 @@ fn set_account_data(
let data: serde_json::Value = let data: serde_json::Value =
serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
services.account_data.update( services
.account_data
.update(
room_id, room_id,
sender_user, sender_user,
event_type.into(), event_type.into(),
@ -108,7 +115,8 @@ fn set_account_data(
"type": event_type, "type": event_type,
"content": data, "content": data,
}), }),
)?; )
.await?;
Ok(()) Ok(())
} }

View File

@ -1,13 +1,14 @@
use std::collections::HashSet; use std::collections::HashSet;
use axum::extract::State; use axum::extract::State;
use conduit::{err, error, Err};
use futures::StreamExt;
use ruma::{ use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, api::client::{context::get_context, filter::LazyLoadOptions},
events::StateEventType, events::StateEventType,
}; };
use tracing::error;
use crate::{Error, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// # `GET /_matrix/client/r0/rooms/{roomId}/context`
/// ///
@ -35,34 +36,33 @@ pub(crate) async fn get_context_route(
let base_token = services let base_token = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; .await
.map_err(|_| err!(Request(NotFound("Base event id not found."))))?;
let base_event = services let base_event = services
.rooms .rooms
.timeline .timeline
.get_pdu(&body.event_id)? .get_pdu(&body.event_id)
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; .await
.map_err(|_| err!(Request(NotFound("Base event not found."))))?;
let room_id = base_event.room_id.clone(); let room_id = &base_event.room_id;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &body.event_id)? .user_can_see_event(sender_user, room_id, &body.event_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden("You don't have permission to view this event.")));
ErrorKind::forbidden(),
"You don't have permission to view this event.",
));
} }
if !services.rooms.lazy_loading.lazy_load_was_sent_before( if !services
sender_user, .rooms
sender_device, .lazy_loading
&room_id, .lazy_load_was_sent_before(sender_user, sender_device, room_id, &base_event.sender)
&base_event.sender, .await || lazy_load_send_redundant
)? || lazy_load_send_redundant
{ {
lazy_loaded.insert(base_event.sender.as_str().to_owned()); lazy_loaded.insert(base_event.sender.as_str().to_owned());
} }
@ -75,25 +75,26 @@ pub(crate) async fn get_context_route(
let events_before: Vec<_> = services let events_before: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &room_id, base_token)? .pdus_until(sender_user, room_id, base_token)
.await?
.take(limit / 2) .take(limit / 2)
.filter_map(Result::ok) // Remove buggy events .filter_map(|(count, pdu)| async move {
.filter(|(_, pdu)| {
services services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .await
.then_some((count, pdu))
}) })
.collect(); .collect()
.await;
for (_, event) in &events_before { for (_, event) in &events_before {
if !services.rooms.lazy_loading.lazy_load_was_sent_before( if !services
sender_user, .rooms
sender_device, .lazy_loading
&room_id, .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
&event.sender, .await || lazy_load_send_redundant
)? || lazy_load_send_redundant
{ {
lazy_loaded.insert(event.sender.as_str().to_owned()); lazy_loaded.insert(event.sender.as_str().to_owned());
} }
@ -111,25 +112,26 @@ pub(crate) async fn get_context_route(
let events_after: Vec<_> = services let events_after: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &room_id, base_token)? .pdus_after(sender_user, room_id, base_token)
.await?
.take(limit / 2) .take(limit / 2)
.filter_map(Result::ok) // Remove buggy events .filter_map(|(count, pdu)| async move {
.filter(|(_, pdu)| {
services services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .await
.then_some((count, pdu))
}) })
.collect(); .collect()
.await;
for (_, event) in &events_after { for (_, event) in &events_after {
if !services.rooms.lazy_loading.lazy_load_was_sent_before( if !services
sender_user, .rooms
sender_device, .lazy_loading
&room_id, .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
&event.sender, .await || lazy_load_send_redundant
)? || lazy_load_send_redundant
{ {
lazy_loaded.insert(event.sender.as_str().to_owned()); lazy_loaded.insert(event.sender.as_str().to_owned());
} }
@ -142,12 +144,14 @@ pub(crate) async fn get_context_route(
events_after events_after
.last() .last()
.map_or(&*body.event_id, |(_, e)| &*e.event_id), .map_or(&*body.event_id, |(_, e)| &*e.event_id),
)? )
.await
.map_or( .map_or(
services services
.rooms .rooms
.state .state
.get_room_shortstatehash(&room_id)? .get_room_shortstatehash(room_id)
.await
.expect("All rooms have state"), .expect("All rooms have state"),
|hash| hash, |hash| hash,
); );
@ -156,7 +160,8 @@ pub(crate) async fn get_context_route(
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await?; .await
.map_err(|e| err!(Database("State not found: {e}")))?;
let end_token = events_after let end_token = events_after
.last() .last()
@ -173,18 +178,19 @@ pub(crate) async fn get_context_route(
let (event_type, state_key) = services let (event_type, state_key) = services
.rooms .rooms
.short .short
.get_statekey_from_short(shortstatekey)?; .get_statekey_from_short(shortstatekey)
.await?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {id}");
continue; continue;
}; };
state.push(pdu.to_state_event()); state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {id}");
continue; continue;
}; };

View File

@ -1,4 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduit::{err, Err};
use futures::StreamExt;
use ruma::api::client::{ use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
error::ErrorKind, error::ErrorKind,
@ -19,8 +21,8 @@ pub(crate) async fn get_devices_route(
let devices: Vec<device::Device> = services let devices: Vec<device::Device> = services
.users .users
.all_devices_metadata(sender_user) .all_devices_metadata(sender_user)
.filter_map(Result::ok) // Filter out buggy devices .collect()
.collect(); .await;
Ok(get_devices::v3::Response { Ok(get_devices::v3::Response {
devices, devices,
@ -37,8 +39,9 @@ pub(crate) async fn get_device_route(
let device = services let device = services
.users .users
.get_device_metadata(sender_user, &body.body.device_id)? .get_device_metadata(sender_user, &body.body.device_id)
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .await
.map_err(|_| err!(Request(NotFound("Device not found."))))?;
Ok(get_device::v3::Response { Ok(get_device::v3::Response {
device, device,
@ -55,14 +58,16 @@ pub(crate) async fn update_device_route(
let mut device = services let mut device = services
.users .users
.get_device_metadata(sender_user, &body.device_id)? .get_device_metadata(sender_user, &body.device_id)
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .await
.map_err(|_| err!(Request(NotFound("Device not found."))))?;
device.display_name.clone_from(&body.display_name); device.display_name.clone_from(&body.display_name);
services services
.users .users
.update_device_metadata(sender_user, &body.device_id, &device)?; .update_device_metadata(sender_user, &body.device_id, &device)
.await?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
@ -97,22 +102,28 @@ pub(crate) async fn delete_device_route(
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err!(Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo));
return Err!(Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err!(Request(NotJson("Not json.")));
} }
services.users.remove_device(sender_user, &body.device_id)?; services
.users
.remove_device(sender_user, &body.device_id)
.await;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -149,7 +160,9 @@ pub(crate) async fn delete_devices_route(
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
@ -158,14 +171,15 @@ pub(crate) async fn delete_devices_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
for device_id in &body.devices { for device_id in &body.devices {
services.users.remove_device(sender_user, device_id)?; services.users.remove_device(sender_user, device_id).await;
} }
Ok(delete_devices::v3::Response {}) Ok(delete_devices::v3::Response {})

View File

@ -1,6 +1,7 @@
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{err, info, warn, Err, Error, Result}; use conduit::{info, warn, Err, Error, Result};
use futures::{StreamExt, TryFutureExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -18,7 +19,7 @@ use ruma::{
}, },
StateEventType, StateEventType,
}, },
uint, RoomId, ServerName, UInt, UserId, uint, OwnedRoomId, RoomId, ServerName, UInt, UserId,
}; };
use service::Services; use service::Services;
@ -119,16 +120,22 @@ pub(crate) async fn set_room_visibility_route(
) -> Result<set_room_visibility::v3::Response> { ) -> Result<set_room_visibility::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services.rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id).await {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() { if services
.users
.is_deactivated(sender_user)
.await
.unwrap_or(false)
&& body.appservice_info.is_none()
{
return Err!(Request(Forbidden("Guests cannot publish to room directories"))); return Err!(Request(Forbidden("Guests cannot publish to room directories")));
} }
if !user_can_publish_room(&services, sender_user, &body.room_id)? { if !user_can_publish_room(&services, sender_user, &body.room_id).await? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"User is not allowed to publish this room", "User is not allowed to publish this room",
@ -138,7 +145,7 @@ pub(crate) async fn set_room_visibility_route(
match &body.visibility { match &body.visibility {
room::Visibility::Public => { room::Visibility::Public => {
if services.globals.config.lockdown_public_room_directory if services.globals.config.lockdown_public_room_directory
&& !services.users.is_admin(sender_user)? && !services.users.is_admin(sender_user).await
&& body.appservice_info.is_none() && body.appservice_info.is_none()
{ {
info!( info!(
@ -164,7 +171,7 @@ pub(crate) async fn set_room_visibility_route(
)); ));
} }
services.rooms.directory.set_public(&body.room_id)?; services.rooms.directory.set_public(&body.room_id);
if services.globals.config.admin_room_notices { if services.globals.config.admin_room_notices {
services services
@ -174,7 +181,7 @@ pub(crate) async fn set_room_visibility_route(
} }
info!("{sender_user} made {0} public to the room directory", body.room_id); info!("{sender_user} made {0} public to the room directory", body.room_id);
}, },
room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id),
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -192,13 +199,13 @@ pub(crate) async fn set_room_visibility_route(
pub(crate) async fn get_room_visibility_route( pub(crate) async fn get_room_visibility_route(
State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>, State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>,
) -> Result<get_room_visibility::v3::Response> { ) -> Result<get_room_visibility::v3::Response> {
if !services.rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id).await {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
Ok(get_room_visibility::v3::Response { Ok(get_room_visibility::v3::Response {
visibility: if services.rooms.directory.is_public_room(&body.room_id)? { visibility: if services.rooms.directory.is_public_room(&body.room_id).await {
room::Visibility::Public room::Visibility::Public
} else { } else {
room::Visibility::Private room::Visibility::Private
@ -257,101 +264,41 @@ pub(crate) async fn get_public_rooms_filtered_helper(
} }
} }
let mut all_rooms: Vec<_> = services let mut all_rooms: Vec<PublicRoomsChunk> = services
.rooms .rooms
.directory .directory
.public_rooms() .public_rooms()
.map(|room_id| { .map(ToOwned::to_owned)
let room_id = room_id?; .then(|room_id| public_rooms_chunk(services, room_id))
.filter_map(|chunk| async move {
let chunk = PublicRoomsChunk {
canonical_alias: services
.rooms
.state_accessor
.get_canonical_alias(&room_id)?,
name: services.rooms.state_accessor.get_name(&room_id)?,
num_joined_members: services
.rooms
.state_cache
.room_joined_count(&room_id)?
.unwrap_or_else(|| {
warn!("Room {} has no member count", room_id);
0
})
.try_into()
.expect("user count should not be that big"),
topic: services
.rooms
.state_accessor
.get_room_topic(&room_id)
.unwrap_or(None),
world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?,
guest_can_join: services
.rooms
.state_accessor
.guest_can_join(&room_id)?,
avatar_url: services
.rooms
.state_accessor
.get_avatar(&room_id)?
.into_option()
.unwrap_or_default()
.url,
join_rule: services
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => Some(PublicRoomJoinRule::Public),
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
_ => None,
})
.map_err(|e| {
err!(Database(error!("Invalid room join rule event in database: {e}")))
})
})
.transpose()?
.flatten()
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)?,
room_id,
};
Ok(chunk)
})
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
.filter(|chunk| {
if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
if let Some(name) = &chunk.name { if let Some(name) = &chunk.name {
if name.as_str().to_lowercase().contains(&query) { if name.as_str().to_lowercase().contains(&query) {
return true; return Some(chunk);
} }
} }
if let Some(topic) = &chunk.topic { if let Some(topic) = &chunk.topic {
if topic.to_lowercase().contains(&query) { if topic.to_lowercase().contains(&query) {
return true; return Some(chunk);
} }
} }
if let Some(canonical_alias) = &chunk.canonical_alias { if let Some(canonical_alias) = &chunk.canonical_alias {
if canonical_alias.as_str().to_lowercase().contains(&query) { if canonical_alias.as_str().to_lowercase().contains(&query) {
return true; return Some(chunk);
} }
} }
false return None;
} else {
// No search term
true
} }
// No search term
Some(chunk)
}) })
// We need to collect all, so we can sort by member count // We need to collect all, so we can sort by member count
.collect(); .collect()
.await;
all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members));
@ -394,22 +341,23 @@ pub(crate) async fn get_public_rooms_filtered_helper(
/// Check whether the user can publish to the room directory via power levels of /// Check whether the user can publish to the room directory via power levels of
/// room history visibility event or room creator /// room history visibility event or room creator
fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> { async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
if let Some(event) = services if let Ok(event) = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")
.await
{ {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels"))
.map(|content: RoomPowerLevelsEventContent| { .map(|content: RoomPowerLevelsEventContent| {
RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility)
}) })
} else if let Some(event) = } else if let Ok(event) = services
services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
{ {
Ok(event.sender == user_id) Ok(event.sender == user_id)
} else { } else {
@ -419,3 +367,61 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId
)); ));
} }
} }
async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk {
PublicRoomsChunk {
canonical_alias: services
.rooms
.state_accessor
.get_canonical_alias(&room_id)
.await
.ok(),
name: services.rooms.state_accessor.get_name(&room_id).await.ok(),
num_joined_members: services
.rooms
.state_cache
.room_joined_count(&room_id)
.await
.unwrap_or(0)
.try_into()
.expect("joined count overflows ruma UInt"),
topic: services
.rooms
.state_accessor
.get_room_topic(&room_id)
.await
.ok(),
world_readable: services
.rooms
.state_accessor
.is_world_readable(&room_id)
.await,
guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await,
avatar_url: services
.rooms
.state_accessor
.get_avatar(&room_id)
.await
.into_option()
.unwrap_or_default()
.url,
join_rule: services
.rooms
.state_accessor
.room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "")
.map_ok(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => PublicRoomJoinRule::Public,
JoinRule::Knock => PublicRoomJoinRule::Knock,
_ => "invite".into(),
})
.await
.unwrap_or_default(),
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)
.await
.ok(),
room_id,
}
}

View File

@ -1,10 +1,8 @@
use axum::extract::State; use axum::extract::State;
use ruma::api::client::{ use conduit::err;
error::ErrorKind, use ruma::api::client::filter::{create_filter, get_filter};
filter::{create_filter, get_filter},
};
use crate::{Error, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
/// ///
@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route(
State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>, State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>,
) -> Result<get_filter::v3::Response> { ) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found."));
};
Ok(get_filter::v3::Response::new(filter)) services
.users
.get_filter(sender_user, &body.filter_id)
.await
.map(get_filter::v3::Response::new)
.map_err(|_| err!(Request(NotFound("Filter not found."))))
} }
/// # `PUT /_matrix/client/r0/user/{userId}/filter` /// # `PUT /_matrix/client/r0/user/{userId}/filter`
@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route(
State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>, State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>,
) -> Result<create_filter::v3::Response> { ) -> Result<create_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(create_filter::v3::Response::new(
services.users.create_filter(sender_user, &body.filter)?, let filter_id = services.users.create_filter(sender_user, &body.filter);
))
Ok(create_filter::v3::Response::new(filter_id))
} }

View File

@ -4,8 +4,8 @@ use std::{
}; };
use axum::extract::State; use axum::extract::State;
use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -21,7 +21,10 @@ use ruma::{
use serde_json::json; use serde_json::json;
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{service::Services, Ruma}; use crate::{
service::{users::parse_master_key, Services},
Ruma,
};
/// # `POST /_matrix/client/r0/keys/upload` /// # `POST /_matrix/client/r0/keys/upload`
/// ///
@ -39,7 +42,8 @@ pub(crate) async fn upload_keys_route(
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services services
.users .users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?; .add_one_time_key(sender_user, sender_device, key_key, key_value)
.await?;
} }
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
@ -47,19 +51,22 @@ pub(crate) async fn upload_keys_route(
// This check is needed to assure that signatures are kept // This check is needed to assure that signatures are kept
if services if services
.users .users
.get_device_keys(sender_user, sender_device)? .get_device_keys(sender_user, sender_device)
.is_none() .await
.is_err()
{ {
services services
.users .users
.add_device_keys(sender_user, sender_device, device_keys)?; .add_device_keys(sender_user, sender_device, device_keys)
.await;
} }
} }
Ok(upload_keys::v3::Response { Ok(upload_keys::v3::Response {
one_time_key_counts: services one_time_key_counts: services
.users .users
.count_one_time_keys(sender_user, sender_device)?, .count_one_time_keys(sender_user, sender_device)
.await,
}) })
} }
@ -120,7 +127,9 @@ pub(crate) async fn upload_signing_keys_route(
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
@ -129,20 +138,24 @@ pub(crate) async fn upload_signing_keys_route(
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json);
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
if let Some(master_key) = &body.master_key { if let Some(master_key) = &body.master_key {
services.users.add_cross_signing_keys( services
.users
.add_cross_signing_keys(
sender_user, sender_user,
master_key, master_key,
&body.self_signing_key, &body.self_signing_key,
&body.user_signing_key, &body.user_signing_key,
true, // notify so that other users see the new keys true, // notify so that other users see the new keys
)?; )
.await?;
} }
Ok(upload_signing_keys::v3::Response {}) Ok(upload_signing_keys::v3::Response {})
@ -179,9 +192,11 @@ pub(crate) async fn upload_signatures_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
.to_owned(), .to_owned(),
); );
services services
.users .users
.sign_key(user_id, key_id, signature, sender_user)?; .sign_key(user_id, key_id, signature, sender_user)
.await?;
} }
} }
} }
@ -204,56 +219,51 @@ pub(crate) async fn get_key_changes_route(
let mut device_list_updates = HashSet::new(); let mut device_list_updates = HashSet::new();
let from = body
.from
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?;
let to = body
.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?;
device_list_updates.extend( device_list_updates.extend(
services services
.users .users
.keys_changed( .keys_changed(sender_user.as_str(), from, Some(to))
sender_user.as_str(), .map(ToOwned::to_owned)
body.from .collect::<Vec<_>>()
.parse() .await,
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
)
.filter_map(Result::ok),
); );
for room_id in services let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed();
.rooms
.state_cache while let Some(room_id) = rooms_joined.next().await {
.rooms_joined(sender_user)
.filter_map(Result::ok)
{
device_list_updates.extend( device_list_updates.extend(
services services
.users .users
.keys_changed( .keys_changed(room_id.as_ref(), from, Some(to))
room_id.as_ref(), .map(ToOwned::to_owned)
body.from .collect::<Vec<_>>()
.parse() .await,
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
)
.filter_map(Result::ok),
); );
} }
Ok(get_key_changes::v3::Response { Ok(get_key_changes::v3::Response {
changed: device_list_updates.into_iter().collect(), changed: device_list_updates.into_iter().collect(),
left: Vec::new(), // TODO left: Vec::new(), // TODO
}) })
} }
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( pub(crate) async fn get_keys_helper<F>(
services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
allowed_signatures: F, include_display_names: bool, allowed_signatures: F, include_display_names: bool,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response>
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let mut master_keys = BTreeMap::new(); let mut master_keys = BTreeMap::new();
let mut self_signing_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new();
let mut user_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new();
@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
if device_ids.is_empty() { if device_ids.is_empty() {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for device_id in services.users.all_device_ids(user_id) { let mut devices = services.users.all_device_ids(user_id).boxed();
let device_id = device_id?;
if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { while let Some(device_id) = devices.next().await {
if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await {
let metadata = services let metadata = services
.users .users
.get_device_metadata(user_id, &device_id)? .get_device_metadata(user_id, device_id)
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; .await
.map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names) add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| err!(Database("invalid device keys in database")))?;
container.insert(device_id, keys); container.insert(device_id.to_owned(), keys);
} }
} }
device_keys.insert(user_id.to_owned(), container); device_keys.insert(user_id.to_owned(), container);
} else { } else {
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await {
let metadata = services let metadata = services
.users .users
.get_device_metadata(user_id, device_id)? .get_device_metadata(user_id, device_id)
.ok_or(Error::BadRequest( .await
ErrorKind::InvalidParam, .map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?;
"Tried to get keys for nonexistent device.",
))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names) add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| err!(Database("invalid device keys in database")))?;
container.insert(device_id.to_owned(), keys); container.insert(device_id.to_owned(), keys);
} }
device_keys.insert(user_id.to_owned(), container); device_keys.insert(user_id.to_owned(), container);
} }
} }
if let Some(master_key) = services if let Ok(master_key) = services
.users .users
.get_master_key(sender_user, user_id, &allowed_signatures)? .get_master_key(sender_user, user_id, &allowed_signatures)
.await
{ {
master_keys.insert(user_id.to_owned(), master_key); master_keys.insert(user_id.to_owned(), master_key);
} }
if let Some(self_signing_key) = if let Ok(self_signing_key) = services
services
.users .users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)? .get_self_signing_key(sender_user, user_id, &allowed_signatures)
.await
{ {
self_signing_keys.insert(user_id.to_owned(), self_signing_key); self_signing_keys.insert(user_id.to_owned(), self_signing_key);
} }
if Some(user_id) == sender_user { if Some(user_id) == sender_user {
if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await {
user_signing_keys.insert(user_id.to_owned(), user_signing_key); user_signing_keys.insert(user_id.to_owned(), user_signing_key);
} }
} }
@ -386,23 +400,26 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
while let Some((server, response)) = futures.next().await { while let Some((server, response)) = futures.next().await {
if let Ok(Ok(response)) = response { if let Ok(Ok(response)) = response {
for (user, masterkey) in response.master_keys { for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; let (master_key_id, mut master_key) = parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) = if let Ok(our_master_key) = services
services
.users .users
.get_key(&master_key_id, sender_user, &user, &allowed_signatures)? .get_key(&master_key_id, sender_user, &user, &allowed_signatures)
.await
{ {
let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; let (_, our_master_key) = parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures); master_key.signatures.extend(our_master_key.signatures);
} }
let json = serde_json::to_value(master_key).expect("to_value always works"); let json = serde_json::to_value(master_key).expect("to_value always works");
let raw = serde_json::from_value(json).expect("Raw::from_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works");
services.users.add_cross_signing_keys( services
.users
.add_cross_signing_keys(
&user, &raw, &None, &None, &user, &raw, &None, &None,
false, /* Dont notify. A notification would trigger another key request resulting in an false, /* Dont notify. A notification would trigger another key request resulting in an
* endless loop */ * endless loop */
)?; )
.await?;
master_keys.insert(user.clone(), raw); master_keys.insert(user.clone(), raw);
} }
@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = services if let Ok(one_time_keys) = services
.users .users
.take_one_time_key(user_id, device_id, key_algorithm)? .take_one_time_key(user_id, device_id, key_algorithm)
.await
{ {
let mut c = BTreeMap::new(); let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1); c.insert(one_time_keys.0, one_time_keys.1);

View File

@ -11,9 +11,10 @@ use conduit::{
debug, debug_error, debug_warn, err, error, info, debug, debug_error, debug_warn, err, error, info,
pdu::{gen_event_id_canonical_json, PduBuilder}, pdu::{gen_event_id_canonical_json, PduBuilder},
trace, utils, trace, utils,
utils::math::continue_exponential_backoff_secs, utils::{math::continue_exponential_backoff_secs, IterStream, ReadyExt},
warn, Err, Error, PduEvent, Result, warn, Err, Error, PduEvent, Result,
}; };
use futures::{FutureExt, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -55,9 +56,9 @@ async fn banned_room_check(
services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>,
client_ip: IpAddr, client_ip: IpAddr,
) -> Result<()> { ) -> Result<()> {
if !services.users.is_admin(user_id)? { if !services.users.is_admin(user_id).await {
if let Some(room_id) = room_id { if let Some(room_id) = room_id {
if services.rooms.metadata.is_banned(room_id)? if services.rooms.metadata.is_banned(room_id).await
|| services || services
.globals .globals
.config .config
@ -79,23 +80,22 @@ async fn banned_room_check(
"Automatically deactivating user {user_id} due to attempted banned room join from IP \ "Automatically deactivating user {user_id} due to attempted banned room join from IP \
{client_ip}" {client_ip}"
))) )))
.await; .await
.ok();
} }
let all_joined_rooms: Vec<OwnedRoomId> = services let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
full_user_deactivate(services, user_id, all_joined_rooms).await?; full_user_deactivate(services, user_id, &all_joined_rooms).await?;
} }
return Err(Error::BadRequest( return Err!(Request(Forbidden("This room is banned on this homeserver.")));
ErrorKind::forbidden(),
"This room is banned on this homeserver.",
));
} }
} else if let Some(server_name) = server_name { } else if let Some(server_name) = server_name {
if services if services
@ -119,23 +119,22 @@ async fn banned_room_check(
"Automatically deactivating user {user_id} due to attempted banned room join from IP \ "Automatically deactivating user {user_id} due to attempted banned room join from IP \
{client_ip}" {client_ip}"
))) )))
.await; .await
.ok();
} }
let all_joined_rooms: Vec<OwnedRoomId> = services let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
full_user_deactivate(services, user_id, all_joined_rooms).await?; full_user_deactivate(services, user_id, &all_joined_rooms).await?;
} }
return Err(Error::BadRequest( return Err!(Request(Forbidden("This remote server is banned on this homeserver.")));
ErrorKind::forbidden(),
"This remote server is banned on this homeserver.",
));
} }
} }
} }
@ -172,14 +171,16 @@ pub(crate) async fn join_room_by_id_route(
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&body.room_id) .servers_invite_via(&body.room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect::<Vec<_>>(); .collect::<Vec<_>>()
.await;
servers.extend( servers.extend(
services services
.rooms .rooms
.state_cache .state_cache
.invite_state(sender_user, &body.room_id)? .invite_state(sender_user, &body.room_id)
.await
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok()) .filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -202,6 +203,7 @@ pub(crate) async fn join_room_by_id_route(
body.third_party_signed.as_ref(), body.third_party_signed.as_ref(),
&body.appservice_info, &body.appservice_info,
) )
.boxed()
.await .await
} }
@ -233,14 +235,17 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&room_id) .servers_invite_via(&room_id)
.filter_map(Result::ok), .map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
); );
servers.extend( servers.extend(
services services
.rooms .rooms
.state_cache .state_cache
.invite_state(sender_user, &room_id)? .invite_state(sender_user, &room_id)
.await
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok()) .filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -270,19 +275,23 @@ pub(crate) async fn join_room_by_id_or_alias_route(
if let Some(pre_servers) = &mut pre_servers { if let Some(pre_servers) = &mut pre_servers {
servers.append(pre_servers); servers.append(pre_servers);
} }
servers.extend( servers.extend(
services services
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&room_id) .servers_invite_via(&room_id)
.filter_map(Result::ok), .map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await,
); );
servers.extend( servers.extend(
services services
.rooms .rooms
.state_cache .state_cache
.invite_state(sender_user, &room_id)? .invite_state(sender_user, &room_id)
.await
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok()) .filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -305,6 +314,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
body.third_party_signed.as_ref(), body.third_party_signed.as_ref(),
appservice_info, appservice_info,
) )
.boxed()
.await?; .await?;
Ok(join_room_by_id_or_alias::v3::Response { Ok(join_room_by_id_or_alias::v3::Response {
@ -337,7 +347,7 @@ pub(crate) async fn invite_user_route(
) -> Result<invite_user::v3::Response> { ) -> Result<invite_user::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() {
info!( info!(
"User {sender_user} is not an admin and attempted to send an invite to room {}", "User {sender_user} is not an admin and attempted to send an invite to room {}",
&body.room_id &body.room_id
@ -375,15 +385,13 @@ pub(crate) async fn kick_user_route(
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.ok_or(Error::BadRequest( .await
ErrorKind::BadState, .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))?
"Cannot kick member that's not in the room.",
))?
.content .content
.get(), .get(),
) )
.map_err(|_| Error::bad_database("Invalid member event in database."))?; .map_err(|_| err!(Database("Invalid member event in database.")))?;
event.membership = MembershipState::Leave; event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason); event.reason.clone_from(&body.reason);
@ -421,10 +429,13 @@ pub(crate) async fn ban_user_route(
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let blurhash = services.users.blurhash(&body.user_id).await.ok();
let event = services let event = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.await
.map_or( .map_or(
Ok(RoomMemberEventContent { Ok(RoomMemberEventContent {
membership: MembershipState::Ban, membership: MembershipState::Ban,
@ -432,7 +443,7 @@ pub(crate) async fn ban_user_route(
avatar_url: None, avatar_url: None,
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), blurhash: blurhash.clone(),
reason: body.reason.clone(), reason: body.reason.clone(),
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}), }),
@ -442,12 +453,12 @@ pub(crate) async fn ban_user_route(
membership: MembershipState::Ban, membership: MembershipState::Ban,
displayname: None, displayname: None,
avatar_url: None, avatar_url: None,
blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), blurhash: blurhash.clone(),
reason: body.reason.clone(), reason: body.reason.clone(),
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
..event ..event
}) })
.map_err(|_| Error::bad_database("Invalid member event in database.")) .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))
}, },
)?; )?;
@ -488,12 +499,13 @@ pub(crate) async fn unban_user_route(
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())
.ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? .await
.map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))?
.content .content
.get(), .get(),
) )
.map_err(|_| Error::bad_database("Invalid member event in database."))?; .map_err(|e| err!(Database("Invalid member event in database: {e:?}")))?;
event.membership = MembershipState::Leave; event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason); event.reason.clone_from(&body.reason);
@ -539,18 +551,16 @@ pub(crate) async fn forget_room_route(
if services if services
.rooms .rooms
.state_cache .state_cache
.is_joined(sender_user, &body.room_id)? .is_joined(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Unknown("You must leave the room before forgetting it")));
ErrorKind::Unknown,
"You must leave the room before forgetting it",
));
} }
services services
.rooms .rooms
.state_cache .state_cache
.forget(&body.room_id, sender_user)?; .forget(&body.room_id, sender_user);
Ok(forget_room::v3::Response::new()) Ok(forget_room::v3::Response::new())
} }
@ -568,8 +578,9 @@ pub(crate) async fn joined_rooms_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(), .collect()
.await,
}) })
} }
@ -587,12 +598,10 @@ pub(crate) async fn get_member_events_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden("You don't have permission to view this room.")));
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
} }
Ok(get_member_events::v3::Response { Ok(get_member_events::v3::Response {
@ -622,30 +631,27 @@ pub(crate) async fn joined_members_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden("You don't have permission to view this room.")));
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
} }
let joined: BTreeMap<OwnedUserId, RoomMember> = services let joined: BTreeMap<OwnedUserId, RoomMember> = services
.rooms .rooms
.state_cache .state_cache
.room_members(&body.room_id) .room_members(&body.room_id)
.filter_map(|user| { .then(|user| async move {
let user = user.ok()?; (
user.to_owned(),
Some((
user.clone(),
RoomMember { RoomMember {
display_name: services.users.displayname(&user).unwrap_or_default(), display_name: services.users.displayname(user).await.ok(),
avatar_url: services.users.avatar_url(&user).unwrap_or_default(), avatar_url: services.users.avatar_url(user).await.ok(),
}, },
)) )
}) })
.collect(); .collect()
.await;
Ok(joined_members::v3::Response { Ok(joined_members::v3::Response {
joined, joined,
@ -658,13 +664,23 @@ pub async fn join_room_by_id_helper(
) -> Result<join_room_by_id::v3::Response> { ) -> Result<join_room_by_id::v3::Response> {
let state_lock = services.rooms.state.mutex.lock(room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none(); let user_is_guest = services
.users
.is_deactivated(sender_user)
.await
.unwrap_or(false)
&& appservice_info.is_none();
if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest { if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await {
return Err!(Request(Forbidden("Guests are not allowed to join this room"))); return Err!(Request(Forbidden("Guests are not allowed to join this room")));
} }
if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { if services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
debug_warn!("{sender_user} is already joined in {room_id}"); debug_warn!("{sender_user} is already joined in {room_id}");
return Ok(join_room_by_id::v3::Response { return Ok(join_room_by_id::v3::Response {
room_id: room_id.into(), room_id: room_id.into(),
@ -674,15 +690,17 @@ pub async fn join_room_by_id_helper(
if services if services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services.globals.server_name(), room_id)? .server_in_room(services.globals.server_name(), room_id)
|| servers.is_empty() .await || servers.is_empty()
|| (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) || (servers.len() == 1 && services.globals.server_is_ours(&servers[0]))
{ {
join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock)
.boxed()
.await .await
} else { } else {
// Ask a remote server if we are not participating in this room // Ask a remote server if we are not participating in this room
join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock)
.boxed()
.await .await
} }
} }
@ -739,11 +757,11 @@ async fn join_room_by_id_helper_remote(
"content".to_owned(), "content".to_owned(),
to_canonical_value(RoomMemberEventContent { to_canonical_value(RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?, displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user).await.ok(),
reason, reason,
join_authorized_via_users_server: join_authorized_via_users_server.clone(), join_authorized_via_users_server: join_authorized_via_users_server.clone(),
}) })
@ -791,10 +809,11 @@ async fn join_room_by_id_helper_remote(
federation::membership::create_join_event::v2::Request { federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id: event_id.to_owned(), event_id: event_id.to_owned(),
omit_members: false,
pdu: services pdu: services
.sending .sending
.convert_to_outgoing_federation_event(join_event.clone()), .convert_to_outgoing_federation_event(join_event.clone())
omit_members: false, .await,
}, },
) )
.await?; .await?;
@ -864,7 +883,11 @@ async fn join_room_by_id_helper_remote(
} }
} }
services.rooms.short.get_or_create_shortroomid(room_id)?; services
.rooms
.short
.get_or_create_shortroomid(room_id)
.await;
info!("Parsing join event"); info!("Parsing join event");
let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone())
@ -895,12 +918,13 @@ async fn join_room_by_id_helper_remote(
err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) err!(BadServerResponse("Invalid PDU in send_join response: {e:?}"))
})?; })?;
services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; services.rooms.outlier.add_pdu_outlier(&event_id, &value);
if let Some(state_key) = &pdu.state_key { if let Some(state_key) = &pdu.state_key {
let shortstatekey = services let shortstatekey = services
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)
.await;
state.insert(shortstatekey, pdu.event_id.clone()); state.insert(shortstatekey, pdu.event_id.clone());
} }
} }
@ -916,50 +940,53 @@ async fn join_room_by_id_helper_remote(
continue; continue;
}; };
services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; services.rooms.outlier.add_pdu_outlier(&event_id, &value);
} }
debug!("Running send_join auth check"); debug!("Running send_join auth check");
let fetch_state = &state;
let state_fetch = |k: &'static StateEventType, s: String| async move {
let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?;
let event_id = fetch_state.get(&shortstatekey)?;
services.rooms.timeline.get_pdu(event_id).await.ok()
};
let auth_check = state_res::event_auth::auth_check( let auth_check = state_res::event_auth::auth_check(
&state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"),
&parsed_join_pdu, &parsed_join_pdu,
None::<PduEvent>, // TODO: third party invite None, // TODO: third party invite
|k, s| { |k, s| state_fetch(k, s.to_owned()),
services
.rooms
.timeline
.get_pdu(
state.get(
&services
.rooms
.short
.get_or_create_shortstatekey(&k.to_string().into(), s)
.ok()?,
)?,
) )
.ok()? .await
}, .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
)
.map_err(|e| {
warn!("Auth check failed: {e}");
Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")
})?;
if !auth_check { if !auth_check {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); return Err!(Request(Forbidden("Auth check failed")));
} }
info!("Saving state from send_join"); info!("Saving state from send_join");
let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( let (statehash_before_join, new, removed) = services
.rooms
.state_compressor
.save_state(
room_id, room_id,
Arc::new( Arc::new(
state state
.into_iter() .into_iter()
.map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) .stream()
.collect::<Result<_>>()?, .then(|(k, id)| async move {
services
.rooms
.state_compressor
.compress_state_event(k, &id)
.await
})
.collect()
.await,
), ),
)?; )
.await?;
services services
.rooms .rooms
@ -968,12 +995,20 @@ async fn join_room_by_id_helper_remote(
.await?; .await?;
info!("Updating joined counts for new room"); info!("Updating joined counts for new room");
services.rooms.state_cache.update_joined_count(room_id)?; services
.rooms
.state_cache
.update_joined_count(room_id)
.await;
// We append to state before appending the pdu, so we don't have a moment in // We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't // time with the pdu without it's state. This is okay because append_pdu can't
// fail. // fail.
let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; let statehash_after_join = services
.rooms
.state
.append_to_state(&parsed_join_pdu)
.await?;
info!("Appending new room join event"); info!("Appending new room join event");
services services
@ -993,7 +1028,7 @@ async fn join_room_by_id_helper_remote(
services services
.rooms .rooms
.state .state
.set_room_state(room_id, statehash_after_join, &state_lock)?; .set_room_state(room_id, statehash_after_join, &state_lock);
Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) Ok(join_room_by_id::v3::Response::new(room_id.to_owned()))
} }
@ -1005,23 +1040,15 @@ async fn join_room_by_id_helper_local(
) -> Result<join_room_by_id::v3::Response> { ) -> Result<join_room_by_id::v3::Response> {
debug!("We can join locally"); debug!("We can join locally");
let join_rules_event = services let join_rules_event_content = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event .map(|content: RoomJoinRulesEventContent| content);
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
})
})
.transpose()?;
let restriction_rooms = match join_rules_event_content { let restriction_rooms = match join_rules_event_content {
Some(RoomJoinRulesEventContent { Ok(RoomJoinRulesEventContent {
join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted),
}) => restricted }) => restricted
.allow .allow
@ -1034,29 +1061,34 @@ async fn join_room_by_id_helper_local(
_ => Vec::new(), _ => Vec::new(),
}; };
let local_members = services let local_members: Vec<_> = services
.rooms .rooms
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
.filter_map(Result::ok) .ready_filter(|user| services.globals.user_is_local(user))
.filter(|user| services.globals.user_is_local(user)) .map(ToOwned::to_owned)
.collect::<Vec<OwnedUserId>>(); .collect()
.await;
let mut join_authorized_via_users_server: Option<OwnedUserId> = None; let mut join_authorized_via_users_server: Option<OwnedUserId> = None;
if restriction_rooms.iter().any(|restriction_room_id| { if restriction_rooms
.iter()
.stream()
.any(|restriction_room_id| {
services services
.rooms .rooms
.state_cache .state_cache
.is_joined(sender_user, restriction_room_id) .is_joined(sender_user, restriction_room_id)
.unwrap_or(false) })
}) { .await
{
for user in local_members { for user in local_members {
if services if services
.rooms .rooms
.state_accessor .state_accessor
.user_can_invite(room_id, &user, sender_user, &state_lock) .user_can_invite(room_id, &user, sender_user, &state_lock)
.unwrap_or(false) .await
{ {
join_authorized_via_users_server = Some(user); join_authorized_via_users_server = Some(user);
break; break;
@ -1066,11 +1098,11 @@ async fn join_room_by_id_helper_local(
let event = RoomMemberEventContent { let event = RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?, displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user).await.ok(),
reason: reason.clone(), reason: reason.clone(),
join_authorized_via_users_server, join_authorized_via_users_server,
}; };
@ -1144,11 +1176,11 @@ async fn join_room_by_id_helper_local(
"content".to_owned(), "content".to_owned(),
to_canonical_value(RoomMemberEventContent { to_canonical_value(RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?, displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user).await.ok(),
reason, reason,
join_authorized_via_users_server, join_authorized_via_users_server,
}) })
@ -1195,10 +1227,11 @@ async fn join_room_by_id_helper_local(
federation::membership::create_join_event::v2::Request { federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
event_id: event_id.to_owned(), event_id: event_id.to_owned(),
omit_members: false,
pdu: services pdu: services
.sending .sending
.convert_to_outgoing_federation_event(join_event.clone()), .convert_to_outgoing_federation_event(join_event.clone())
omit_members: false, .await,
}, },
) )
.await?; .await?;
@ -1369,7 +1402,7 @@ pub(crate) async fn invite_helper(
services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>, services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>,
is_direct: bool, is_direct: bool,
) -> Result<()> { ) -> Result<()> {
if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { if !services.users.is_admin(user_id).await && services.globals.block_non_admin_invites() {
info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}");
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
@ -1381,7 +1414,7 @@ pub(crate) async fn invite_helper(
let (pdu, pdu_json, invite_room_state) = { let (pdu, pdu_json, invite_room_state) = {
let state_lock = services.rooms.state.mutex.lock(room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
let content = to_raw_value(&RoomMemberEventContent { let content = to_raw_value(&RoomMemberEventContent {
avatar_url: services.users.avatar_url(user_id)?, avatar_url: services.users.avatar_url(user_id).await.ok(),
displayname: None, displayname: None,
is_direct: Some(is_direct), is_direct: Some(is_direct),
membership: MembershipState::Invite, membership: MembershipState::Invite,
@ -1392,7 +1425,10 @@ pub(crate) async fn invite_helper(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( let (pdu, pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content, content,
@ -1404,16 +1440,17 @@ pub(crate) async fn invite_helper(
sender_user, sender_user,
room_id, room_id,
&state_lock, &state_lock,
)?; )
.await?;
let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; let invite_room_state = services.rooms.state.calculate_invite_state(&pdu).await?;
drop(state_lock); drop(state_lock);
(pdu, pdu_json, invite_room_state) (pdu, pdu_json, invite_room_state)
}; };
let room_version_id = services.rooms.state.get_room_version(room_id)?; let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let response = services let response = services
.sending .sending
@ -1425,9 +1462,15 @@ pub(crate) async fn invite_helper(
room_version: room_version_id.clone(), room_version: room_version_id.clone(),
event: services event: services
.sending .sending
.convert_to_outgoing_federation_event(pdu_json.clone()), .convert_to_outgoing_federation_event(pdu_json.clone())
.await,
invite_room_state, invite_room_state,
via: services.rooms.state_cache.servers_route_via(room_id).ok(), via: services
.rooms
.state_cache
.servers_route_via(room_id)
.await
.ok(),
}, },
) )
.await?; .await?;
@ -1478,11 +1521,16 @@ pub(crate) async fn invite_helper(
"Could not accept incoming PDU as timeline event.", "Could not accept incoming PDU as timeline event.",
))?; ))?;
services.sending.send_pdu_room(room_id, &pdu_id)?; services.sending.send_pdu_room(room_id, &pdu_id).await?;
return Ok(()); return Ok(());
} }
if !services.rooms.state_cache.is_joined(sender_user, room_id)? { if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -1499,11 +1547,11 @@ pub(crate) async fn invite_helper(
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite, membership: MembershipState::Invite,
displayname: services.users.displayname(user_id)?, displayname: services.users.displayname(user_id).await.ok(),
avatar_url: services.users.avatar_url(user_id)?, avatar_url: services.users.avatar_url(user_id).await.ok(),
is_direct: Some(is_direct), is_direct: Some(is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(user_id)?, blurhash: services.users.blurhash(user_id).await.ok(),
reason, reason,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -1531,36 +1579,37 @@ pub async fn leave_all_rooms(services: &Services, user_id: &UserId) {
.rooms .rooms
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.map(ToOwned::to_owned)
.chain( .chain(
services services
.rooms .rooms
.state_cache .state_cache
.rooms_invited(user_id) .rooms_invited(user_id)
.map(|t| t.map(|(r, _)| r)), .map(|(r, _)| r),
) )
.collect::<Vec<_>>(); .collect::<Vec<_>>()
.await;
for room_id in all_rooms { for room_id in all_rooms {
let Ok(room_id) = room_id else {
continue;
};
// ignore errors // ignore errors
if let Err(e) = leave_room(services, user_id, &room_id, None).await { if let Err(e) = leave_room(services, user_id, &room_id, None).await {
warn!(%room_id, %user_id, %e, "Failed to leave room"); warn!(%room_id, %user_id, %e, "Failed to leave room");
} }
if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) {
warn!(%room_id, %user_id, %e, "Failed to forget room"); services.rooms.state_cache.forget(&room_id, user_id);
}
} }
} }
pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> { pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> {
//use conduit::utils::stream::OptionStream;
use futures::TryFutureExt;
// Ask a remote server if we don't have this room // Ask a remote server if we don't have this room
if !services if !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services.globals.server_name(), room_id)? .server_in_room(services.globals.server_name(), room_id)
.await
{ {
if let Err(e) = remote_leave_room(services, user_id, room_id).await { if let Err(e) = remote_leave_room(services, user_id, room_id).await {
warn!("Failed to leave room {} remotely: {}", user_id, e); warn!("Failed to leave room {} remotely: {}", user_id, e);
@ -1570,11 +1619,16 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId,
let last_state = services let last_state = services
.rooms .rooms
.state_cache .state_cache
.invite_state(user_id, room_id)? .invite_state(user_id, room_id)
.map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id))
.await
.ok();
// We always drop the invite, we can't rely on other servers // We always drop the invite, we can't rely on other servers
services.rooms.state_cache.update_membership( services
.rooms
.state_cache
.update_membership(
room_id, room_id,
user_id, user_id,
RoomMemberEventContent::new(MembershipState::Leave), RoomMemberEventContent::new(MembershipState::Leave),
@ -1582,22 +1636,25 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId,
last_state, last_state,
None, None,
true, true,
)?; )
.await?;
} else { } else {
let state_lock = services.rooms.state.mutex.lock(room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
let member_event = let member_event = services
services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.await;
// Fix for broken rooms // Fix for broken rooms
let member_event = match member_event { let Ok(member_event) = member_event else {
None => {
error!("Trying to leave a room you are not a member of."); error!("Trying to leave a room you are not a member of.");
services.rooms.state_cache.update_membership( services
.rooms
.state_cache
.update_membership(
room_id, room_id,
user_id, user_id,
RoomMemberEventContent::new(MembershipState::Leave), RoomMemberEventContent::new(MembershipState::Leave),
@ -1605,16 +1662,14 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId,
None, None,
None, None,
true, true,
)?; )
.await?;
return Ok(()); return Ok(());
},
Some(e) => e,
}; };
let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get())
error!("Invalid room member event in database: {}", e); .map_err(|e| err!(Database(error!("Invalid room member event in database: {e}"))))?;
Error::bad_database("Invalid member event in database.")
})?;
event.membership = MembershipState::Leave; event.membership = MembershipState::Leave;
event.reason = reason; event.reason = reason;
@ -1647,15 +1702,17 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
let invite_state = services let invite_state = services
.rooms .rooms
.state_cache .state_cache
.invite_state(user_id, room_id)? .invite_state(user_id, room_id)
.ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; .await
.map_err(|_| err!(Request(BadState("User is not invited."))))?;
let mut servers: HashSet<OwnedServerName> = services let mut servers: HashSet<OwnedServerName> = services
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(room_id) .servers_invite_via(room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
servers.extend( servers.extend(
invite_state invite_state
@ -1760,7 +1817,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
event_id, event_id,
pdu: services pdu: services
.sending .sending
.convert_to_outgoing_federation_event(leave_event.clone()), .convert_to_outgoing_federation_event(leave_event.clone())
.await,
}, },
) )
.await?; .await?;

View File

@ -1,7 +1,8 @@
use std::collections::{BTreeMap, HashSet}; use std::collections::{BTreeMap, HashSet};
use axum::extract::State; use axum::extract::State;
use conduit::PduCount; use conduit::{err, utils::ReadyExt, Err, PduCount};
use futures::{FutureExt, StreamExt};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -9,13 +10,14 @@ use ruma::{
message::{get_message_events, send_message_event}, message::{get_message_events, send_message_event},
}, },
events::{MessageLikeEventType, StateEventType}, events::{MessageLikeEventType, StateEventType},
RoomId, UserId, UserId,
}; };
use serde_json::{from_str, Value}; use serde_json::{from_str, Value};
use service::rooms::timeline::PdusIterItem;
use crate::{ use crate::{
service::{pdu::PduBuilder, Services}, service::{pdu::PduBuilder, Services},
utils, Error, PduEvent, Result, Ruma, utils, Error, Result, Ruma,
}; };
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
@ -30,79 +32,78 @@ use crate::{
pub(crate) async fn send_message_event_route( pub(crate) async fn send_message_event_route(
State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>, State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> { ) -> Result<send_message_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
let appservice_info = body.appservice_info.as_ref();
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
// Forbid m.room.encrypted if encryption is disabled // Forbid m.room.encrypted if encryption is disabled
if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); return Err!(Request(Forbidden("Encryption has been disabled")));
} }
if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
return Err(Error::BadRequest(
ErrorKind::forbidden(), if body.event_type == MessageLikeEventType::CallInvite
"Room call invites are not allowed in public rooms", && services.rooms.directory.is_public_room(&body.room_id).await
)); {
return Err!(Request(Forbidden("Room call invites are not allowed in public rooms")));
} }
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = services if let Ok(response) = services
.transaction_ids .transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)
.await
{ {
// The client might have sent a txnid of the /sendToDevice endpoint // The client might have sent a txnid of the /sendToDevice endpoint
// This txnid has no response associated with it // This txnid has no response associated with it
if response.is_empty() { if response.is_empty() {
return Err(Error::BadRequest( return Err!(Request(InvalidParam(
ErrorKind::InvalidParam, "Tried to use txn id already used for an incompatible endpoint."
"Tried to use txn id already used for an incompatible endpoint.", )));
));
} }
let event_id = utils::string_from_bytes(&response)
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
return Ok(send_message_event::v3::Response { return Ok(send_message_event::v3::Response {
event_id, event_id: utils::string_from_bytes(&response)
.map(TryInto::try_into)
.map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??,
}); });
} }
let mut unsigned = BTreeMap::new(); let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let content = from_str(body.body.body.json().get())
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?;
let event_id = services let event_id = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: body.event_type.to_string().into(), event_type: body.event_type.to_string().into(),
content: from_str(body.body.body.json().get()) content,
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?,
unsigned: Some(unsigned), unsigned: Some(unsigned),
state_key: None, state_key: None,
redacts: None, redacts: None,
timestamp: if body.appservice_info.is_some() { timestamp: appservice_info.and(body.timestamp),
body.timestamp
} else {
None
},
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&state_lock, &state_lock,
) )
.await?; .await
.map(|event_id| (*event_id).to_owned())?;
services services
.transaction_ids .transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes());
drop(state_lock); drop(state_lock);
Ok(send_message_event::v3::Response::new((*event_id).to_owned())) Ok(send_message_event::v3::Response {
event_id,
})
} }
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
@ -117,8 +118,12 @@ pub(crate) async fn get_message_events_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let from = match body.from.clone() { let room_id = &body.room_id;
Some(from) => PduCount::try_from_string(&from)?, let filter = &body.filter;
let limit = usize::try_from(body.limit).unwrap_or(10).min(100);
let from = match body.from.as_ref() {
Some(from) => PduCount::try_from_string(from)?,
None => match body.dir { None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(), ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(), ruma::api::Direction::Backward => PduCount::max(),
@ -133,30 +138,25 @@ pub(crate) async fn get_message_events_route(
services services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from);
.await?;
let limit = usize::try_from(body.limit).unwrap_or(10).min(100);
let next_token;
let mut resp = get_message_events::v3::Response::new(); let mut resp = get_message_events::v3::Response::new();
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let next_token;
match body.dir { match body.dir {
ruma::api::Direction::Forward => { ruma::api::Direction::Forward => {
let events_after: Vec<_> = services let events_after: Vec<PdusIterItem> = services
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &body.room_id, from)? .pdus_after(sender_user, room_id, from)
.filter_map(Result::ok) // Filter out buggy events .await?
.filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) .ready_filter_map(|item| contains_url_filter(item, filter))
.filter_map(|item| visibility_filter(&services, item, sender_user))
}) .ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to`
.take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.take(limit) .take(limit)
.collect(); .collect()
.boxed()
.await;
for (_, event) in &events_after { for (_, event) in &events_after {
/* TODO: Remove the not "element_hacks" check when these are resolved: /* TODO: Remove the not "element_hacks" check when these are resolved:
@ -164,17 +164,19 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
*/ */
if !cfg!(feature = "element_hacks") if !cfg!(feature = "element_hacks")
&& !services.rooms.lazy_loading.lazy_load_was_sent_before( && !services
sender_user, .rooms
sender_device, .lazy_loading
&body.room_id, .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
&event.sender, .await
)? { {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
if cfg!(features = "element_hacks") {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
}
next_token = events_after.last().map(|(count, _)| count).copied(); next_token = events_after.last().map(|(count, _)| count).copied();
@ -191,17 +193,22 @@ pub(crate) async fn get_message_events_route(
services services
.rooms .rooms
.timeline .timeline
.backfill_if_required(&body.room_id, from) .backfill_if_required(room_id, from)
.boxed()
.await?; .await?;
let events_before: Vec<_> = services
let events_before: Vec<PdusIterItem> = services
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &body.room_id, from)? .pdus_until(sender_user, room_id, from)
.filter_map(Result::ok) // Filter out buggy events .await?
.filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) .ready_filter_map(|item| contains_url_filter(item, filter))
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .filter_map(|item| visibility_filter(&services, item, sender_user))
.ready_take_while(|(count, _)| Some(*count) != to) // Stop at `to`
.take(limit) .take(limit)
.collect(); .collect()
.boxed()
.await;
for (_, event) in &events_before { for (_, event) in &events_before {
/* TODO: Remove the not "element_hacks" check when these are resolved: /* TODO: Remove the not "element_hacks" check when these are resolved:
@ -209,17 +216,19 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
*/ */
if !cfg!(feature = "element_hacks") if !cfg!(feature = "element_hacks")
&& !services.rooms.lazy_loading.lazy_load_was_sent_before( && !services
sender_user, .rooms
sender_device, .lazy_loading
&body.room_id, .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender)
&event.sender, .await
)? { {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
if cfg!(features = "element_hacks") {
lazy_loaded.insert(event.sender.clone()); lazy_loaded.insert(event.sender.clone());
} }
}
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
@ -236,11 +245,11 @@ pub(crate) async fn get_message_events_route(
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = if let Ok(member_event) = services
services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? .room_state_get(room_id, &StateEventType::RoomMember, ll_id.as_str())
.await
{ {
resp.state.push(member_event.to_state_event()); resp.state.push(member_event.to_state_event());
} }
@ -249,34 +258,43 @@ pub(crate) async fn get_message_events_route(
// remove the feature check when we are sure clients like element can handle it // remove the feature check when we are sure clients like element can handle it
if !cfg!(feature = "element_hacks") { if !cfg!(feature = "element_hacks") {
if let Some(next_token) = next_token { if let Some(next_token) = next_token {
services services.rooms.lazy_loading.lazy_load_mark_sent(
.rooms sender_user,
.lazy_loading sender_device,
.lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) room_id,
.await; lazy_loaded,
next_token,
);
} }
} }
Ok(resp) Ok(resp)
} }
fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { async fn visibility_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option<PdusIterItem> {
let (_, pdu) = &item;
services services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(user_id, room_id, &pdu.event_id) .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id)
.unwrap_or(false) .await
.then_some(item)
} }
fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { fn contains_url_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option<PdusIterItem> {
let (_, pdu) = &item;
if filter.url_filter.is_none() { if filter.url_filter.is_none() {
return true; return Some(item);
} }
let content: Value = from_str(pdu.content.get()).unwrap(); let content: Value = from_str(pdu.content.get()).unwrap();
match filter.url_filter { let res = match filter.url_filter {
Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(),
Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), Some(UrlFilter::EventsWithUrl) => content["url"].is_string(),
None => true, None => true,
} };
res.then_some(item)
} }

View File

@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route(
services services
.presence .presence
.set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())
.await?;
Ok(set_presence::v3::Response {}) Ok(set_presence::v3::Response {})
} }
@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route(
let mut presence_event = None; let mut presence_event = None;
for _room_id in services let has_shared_rooms = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? .has_shared_rooms(sender_user, &body.user_id)
{ .await;
if let Some(presence) = services.presence.get_presence(&body.user_id)? {
if has_shared_rooms {
if let Ok(presence) = services.presence.get_presence(&body.user_id).await {
presence_event = Some(presence); presence_event = Some(presence);
break;
} }
} }

View File

@ -1,5 +1,10 @@
use axum::extract::State; use axum::extract::State;
use conduit::{pdu::PduBuilder, warn, Err, Error, Result}; use conduit::{
pdu::PduBuilder,
utils::{stream::TryIgnore, IterStream},
warn, Err, Error, Result,
};
use futures::{StreamExt, TryStreamExt};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -35,16 +40,18 @@ pub(crate) async fn set_displayname_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?; update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?;
if services.globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await?;
} }
Ok(set_display_name::v3::Response {}) Ok(set_display_name::v3::Response {})
@ -72,22 +79,19 @@ pub(crate) async fn get_displayname_route(
) )
.await .await
{ {
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone());
.await?;
services services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone());
.await?;
services services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone());
.await?;
return Ok(get_display_name::v3::Response { return Ok(get_display_name::v3::Response {
displayname: response.displayname, displayname: response.displayname,
@ -95,14 +99,14 @@ pub(crate) async fn get_displayname_route(
} }
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_display_name::v3::Response { Ok(get_display_name::v3::Response {
displayname: services.users.displayname(&body.user_id)?, displayname: services.users.displayname(&body.user_id).await.ok(),
}) })
} }
@ -124,15 +128,16 @@ pub(crate) async fn set_avatar_url_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
update_avatar_url( update_avatar_url(
&services, &services,
&body.user_id, &body.user_id,
body.avatar_url.clone(), body.avatar_url.clone(),
body.blurhash.clone(), body.blurhash.clone(),
all_joined_rooms, &all_joined_rooms,
) )
.await?; .await?;
@ -140,7 +145,9 @@ pub(crate) async fn set_avatar_url_route(
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await
.ok();
} }
Ok(set_avatar_url::v3::Response {}) Ok(set_avatar_url::v3::Response {})
@ -168,22 +175,21 @@ pub(crate) async fn get_avatar_url_route(
) )
.await .await
{ {
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone());
.await?;
services services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone());
.await?;
services services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone());
.await?;
return Ok(get_avatar_url::v3::Response { return Ok(get_avatar_url::v3::Response {
avatar_url: response.avatar_url, avatar_url: response.avatar_url,
@ -192,15 +198,15 @@ pub(crate) async fn get_avatar_url_route(
} }
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_avatar_url::v3::Response { Ok(get_avatar_url::v3::Response {
avatar_url: services.users.avatar_url(&body.user_id)?, avatar_url: services.users.avatar_url(&body.user_id).await.ok(),
blurhash: services.users.blurhash(&body.user_id)?, blurhash: services.users.blurhash(&body.user_id).await.ok(),
}) })
} }
@ -226,31 +232,30 @@ pub(crate) async fn get_profile_route(
) )
.await .await
{ {
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone());
.await?;
services services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone());
.await?;
services services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone());
.await?;
services services
.users .users
.set_timezone(&body.user_id, response.tz.clone()) .set_timezone(&body.user_id, response.tz.clone());
.await?;
for (profile_key, profile_key_value) in &response.custom_profile_fields { for (profile_key, profile_key_value) in &response.custom_profile_fields {
services services
.users .users
.set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?; .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()));
} }
return Ok(get_profile::v3::Response { return Ok(get_profile::v3::Response {
@ -263,104 +268,93 @@ pub(crate) async fn get_profile_route(
} }
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_profile::v3::Response { Ok(get_profile::v3::Response {
avatar_url: services.users.avatar_url(&body.user_id)?, avatar_url: services.users.avatar_url(&body.user_id).await.ok(),
blurhash: services.users.blurhash(&body.user_id)?, blurhash: services.users.blurhash(&body.user_id).await.ok(),
displayname: services.users.displayname(&body.user_id)?, displayname: services.users.displayname(&body.user_id).await.ok(),
tz: services.users.timezone(&body.user_id)?, tz: services.users.timezone(&body.user_id).await.ok(),
custom_profile_fields: services custom_profile_fields: services
.users .users
.all_profile_keys(&body.user_id) .all_profile_keys(&body.user_id)
.filter_map(Result::ok) .collect()
.collect(), .await,
}) })
} }
pub async fn update_displayname( pub async fn update_displayname(
services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: &[OwnedRoomId],
) -> Result<()> { ) -> Result<()> {
let current_display_name = services.users.displayname(user_id).unwrap_or_default(); let current_display_name = services.users.displayname(user_id).await.ok();
if displayname == current_display_name { if displayname == current_display_name {
return Ok(()); return Ok(());
} }
services services.users.set_displayname(user_id, displayname.clone());
.users
.set_displayname(user_id, displayname.clone())
.await?;
// Send a new join membership event into all joined rooms // Send a new join membership event into all joined rooms
let all_joined_rooms: Vec<_> = all_joined_rooms let mut joined_rooms = Vec::new();
.iter() for room_id in all_joined_rooms {
.map(|room_id| { let Ok(event) = services
Ok::<_, Error>(( .rooms
PduBuilder { .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
else {
continue;
};
let pdu = PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
displayname: displayname.clone(), displayname: displayname.clone(),
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
..serde_json::from_str( ..serde_json::from_str(event.content.get()).expect("Database contains invalid PDU.")
services
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.ok_or_else(|| {
Error::bad_database("Tried to send display name update for user not in the room.")
})?
.content
.get(),
)
.map_err(|_| Error::bad_database("Database contains invalid PDU."))?
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
state_key: Some(user_id.to_string()), state_key: Some(user_id.to_string()),
redacts: None, redacts: None,
timestamp: None, timestamp: None,
}, };
room_id,
))
})
.filter_map(Result::ok)
.collect();
update_all_rooms(services, all_joined_rooms, user_id).await; joined_rooms.push((pdu, room_id));
}
update_all_rooms(services, joined_rooms, user_id).await;
Ok(()) Ok(())
} }
pub async fn update_avatar_url( pub async fn update_avatar_url(
services: &Services, user_id: &UserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>, services: &Services, user_id: &UserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>,
all_joined_rooms: Vec<OwnedRoomId>, all_joined_rooms: &[OwnedRoomId],
) -> Result<()> { ) -> Result<()> {
let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default(); let current_avatar_url = services.users.avatar_url(user_id).await.ok();
let current_blurhash = services.users.blurhash(user_id).unwrap_or_default(); let current_blurhash = services.users.blurhash(user_id).await.ok();
if current_avatar_url == avatar_url && current_blurhash == blurhash { if current_avatar_url == avatar_url && current_blurhash == blurhash {
return Ok(()); return Ok(());
} }
services services.users.set_avatar_url(user_id, avatar_url.clone());
.users
.set_avatar_url(user_id, avatar_url.clone()) services.users.set_blurhash(user_id, blurhash.clone());
.await?;
services
.users
.set_blurhash(user_id, blurhash.clone())
.await?;
// Send a new join membership event into all joined rooms // Send a new join membership event into all joined rooms
let avatar_url = &avatar_url;
let blurhash = &blurhash;
let all_joined_rooms: Vec<_> = all_joined_rooms let all_joined_rooms: Vec<_> = all_joined_rooms
.iter() .iter()
.map(|room_id| { .try_stream()
Ok::<_, Error>(( .and_then(|room_id: &OwnedRoomId| async move {
Ok((
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
@ -371,8 +365,9 @@ pub async fn update_avatar_url(
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())
.ok_or_else(|| { .await
.map_err(|_| {
Error::bad_database("Tried to send avatar URL update for user not in the room.") Error::bad_database("Tried to send avatar URL update for user not in the room.")
})? })?
.content .content
@ -389,8 +384,9 @@ pub async fn update_avatar_url(
room_id, room_id,
)) ))
}) })
.filter_map(Result::ok) .ignore_err()
.collect(); .collect()
.await;
update_all_rooms(services, all_joined_rooms, user_id).await; update_all_rooms(services, all_joined_rooms, user_id).await;

View File

@ -29,40 +29,36 @@ pub(crate) async fn get_pushrules_all_route(
let global_ruleset: Ruleset; let global_ruleset: Ruleset;
let Ok(event) = let event = services
services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
else { .await;
// push rules event doesn't exist, create it and return default
return recreate_push_rules_and_return(&services, sender_user); let Ok(event) = event else {
// user somehow has non-existent push rule event. recreate it and return server
// default silently
return recreate_push_rules_and_return(&services, sender_user).await;
}; };
if let Some(event) = event {
let value = serde_json::from_str::<CanonicalJsonObject>(event.get()) let value = serde_json::from_str::<CanonicalJsonObject>(event.get())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
let Some(content_value) = value.get("content") else { let Some(content_value) = value.get("content") else {
// user somehow has a push rule event with no content key, recreate it and // user somehow has a push rule event with no content key, recreate it and
// return server default silently // return server default silently
return recreate_push_rules_and_return(&services, sender_user); return recreate_push_rules_and_return(&services, sender_user).await;
}; };
if content_value.to_string().is_empty() { if content_value.to_string().is_empty() {
// user somehow has a push rule event with empty content, recreate it and return // user somehow has a push rule event with empty content, recreate it and return
// server default silently // server default silently
return recreate_push_rules_and_return(&services, sender_user); return recreate_push_rules_and_return(&services, sender_user).await;
} }
let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.clone().into()) let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.clone().into())
.map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?;
global_ruleset = account_data_content.global; global_ruleset = account_data_content.global;
} else {
// user somehow has non-existent push rule event. recreate it and return server
// default silently
return recreate_push_rules_and_return(&services, sender_user);
}
Ok(get_pushrules_all::v3::Response { Ok(get_pushrules_all::v3::Response {
global: global_ruleset, global: global_ruleset,
@ -79,8 +75,9 @@ pub(crate) async fn get_pushrule_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -118,8 +115,9 @@ pub(crate) async fn set_pushrule_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -155,12 +153,15 @@ pub(crate) async fn set_pushrule_route(
return Err(err); return Err(err);
} }
services.account_data.update( services
.account_data
.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )
.await?;
Ok(set_pushrule::v3::Response {}) Ok(set_pushrule::v3::Response {})
} }
@ -182,8 +183,9 @@ pub(crate) async fn get_pushrule_actions_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
@ -217,8 +219,9 @@ pub(crate) async fn set_pushrule_actions_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -232,12 +235,15 @@ pub(crate) async fn set_pushrule_actions_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
services.account_data.update( services
.account_data
.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )
.await?;
Ok(set_pushrule_actions::v3::Response {}) Ok(set_pushrule_actions::v3::Response {})
} }
@ -259,8 +265,9 @@ pub(crate) async fn get_pushrule_enabled_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -293,8 +300,9 @@ pub(crate) async fn set_pushrule_enabled_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -308,12 +316,15 @@ pub(crate) async fn set_pushrule_enabled_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
services.account_data.update( services
.account_data
.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )
.await?;
Ok(set_pushrule_enabled::v3::Response {}) Ok(set_pushrule_enabled::v3::Response {})
} }
@ -335,8 +346,9 @@ pub(crate) async fn delete_pushrule_route(
let event = services let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
@ -357,12 +369,15 @@ pub(crate) async fn delete_pushrule_route(
return Err(err); return Err(err);
} }
services.account_data.update( services
.account_data
.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(account_data).expect("to json value always works"), &serde_json::to_value(account_data).expect("to json value always works"),
)?; )
.await?;
Ok(delete_pushrule::v3::Response {}) Ok(delete_pushrule::v3::Response {})
} }
@ -376,7 +391,7 @@ pub(crate) async fn get_pushers_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_pushers::v3::Response { Ok(get_pushers::v3::Response {
pushers: services.pusher.get_pushers(sender_user)?, pushers: services.pusher.get_pushers(sender_user).await,
}) })
} }
@ -390,17 +405,19 @@ pub(crate) async fn set_pushers_route(
) -> Result<set_pusher::v3::Response> { ) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services.pusher.set_pusher(sender_user, &body.action)?; services.pusher.set_pusher(sender_user, &body.action);
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }
/// user somehow has bad push rules, these must always exist per spec. /// user somehow has bad push rules, these must always exist per spec.
/// so recreate it and return server default silently /// so recreate it and return server default silently
fn recreate_push_rules_and_return( async fn recreate_push_rules_and_return(
services: &Services, sender_user: &ruma::UserId, services: &Services, sender_user: &ruma::UserId,
) -> Result<get_pushrules_all::v3::Response> { ) -> Result<get_pushrules_all::v3::Response> {
services.account_data.update( services
.account_data
.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -410,7 +427,8 @@ fn recreate_push_rules_and_return(
}, },
}) })
.expect("to json always works"), .expect("to json always works"),
)?; )
.await?;
Ok(get_pushrules_all::v3::Response { Ok(get_pushrules_all::v3::Response {
global: Ruleset::server_default(sender_user), global: Ruleset::server_default(sender_user),

View File

@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route(
event_id: fully_read.clone(), event_id: fully_read.clone(),
}, },
}; };
services.account_data.update( services
.account_data
.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event).expect("to json value always works"),
)?; )
.await?;
} }
if body.private_read_receipt.is_some() || body.read_receipt.is_some() { if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services services
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id);
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
let count = services let count = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(event)? .get_pdu_count(event)
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
let count = match count { let count = match count {
PduCount::Backfilled(_) => { PduCount::Backfilled(_) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route(
services services
.rooms .rooms
.read_receipt .read_receipt
.private_read_set(&body.room_id, sender_user, count)?; .private_read_set(&body.room_id, sender_user, count);
} }
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts); receipt_content.insert(event.to_owned(), receipts);
services.rooms.read_receipt.readreceipt_update( services
.rooms
.read_receipt
.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
&ruma::events::receipt::ReceiptEvent { &ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )
.await;
} }
Ok(set_read_marker::v3::Response {}) Ok(set_read_marker::v3::Response {})
@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route(
services services
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id);
} }
match body.receipt_type { match body.receipt_type {
@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route(
event_id: body.event_id.clone(), event_id: body.event_id.clone(),
}, },
}; };
services.account_data.update( services
.account_data
.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&serde_json::to_value(fully_read_event).expect("to json value always works"), &serde_json::to_value(fully_read_event).expect("to json value always works"),
)?; )
.await?;
}, },
create_receipt::v3::ReceiptType::Read => { create_receipt::v3::ReceiptType::Read => {
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.clone(), receipts); receipt_content.insert(body.event_id.clone(), receipts);
services.rooms.read_receipt.readreceipt_update( services
.rooms
.read_receipt
.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
&ruma::events::receipt::ReceiptEvent { &ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
)?; )
.await;
}, },
create_receipt::v3::ReceiptType::ReadPrivate => { create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services let count = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
let count = match count { let count = match count {
PduCount::Backfilled(_) => { PduCount::Backfilled(_) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route(
services services
.rooms .rooms
.read_receipt .read_receipt
.private_read_set(&body.room_id, sender_user, count)?; .private_read_set(&body.room_id, sender_user, count);
}, },
_ => return Err(Error::bad_database("Unsupported receipt type")), _ => return Err(Error::bad_database("Unsupported receipt type")),
} }

View File

@ -9,20 +9,24 @@ use crate::{Result, Ruma};
pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let res = services.rooms.pdu_metadata.paginate_relations_with_filter( let res = services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
&Some(body.event_type.clone()), body.event_type.clone().into(),
&Some(body.rel_type.clone()), body.rel_type.clone().into(),
&body.from, body.from.as_ref(),
&body.to, body.to.as_ref(),
&body.limit, body.limit,
body.recurse, body.recurse,
body.dir, body.dir,
)?; )
.await?;
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
@ -36,20 +40,24 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_with_rel_type_route(
State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_deref().expect("user is authenticated");
let res = services.rooms.pdu_metadata.paginate_relations_with_filter( let res = services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
&None, None,
&Some(body.rel_type.clone()), body.rel_type.clone().into(),
&body.from, body.from.as_ref(),
&body.to, body.to.as_ref(),
&body.limit, body.limit,
body.recurse, body.recurse,
body.dir, body.dir,
)?; )
.await?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
@ -63,18 +71,22 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
pub(crate) async fn get_relating_events_route( pub(crate) async fn get_relating_events_route(
State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_deref().expect("user is authenticated");
services.rooms.pdu_metadata.paginate_relations_with_filter( services
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
&None, None,
&None, None,
&body.from, body.from.as_ref(),
&body.to, body.to.as_ref(),
&body.limit, body.limit,
body.recurse, body.recurse,
body.dir, body.dir,
) )
.await
} }

View File

@ -1,6 +1,7 @@
use std::time::Duration; use std::time::Duration;
use axum::extract::State; use axum::extract::State;
use conduit::{utils::ReadyExt, Err};
use rand::Rng; use rand::Rng;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, room::report_content}, api::client::{error::ErrorKind, room::report_content},
@ -34,11 +35,8 @@ pub(crate) async fn report_event_route(
delay_response().await; delay_response().await;
// check if we know about the reported event ID or if it's invalid // check if we know about the reported event ID or if it's invalid
let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else {
return Err(Error::BadRequest( return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid")));
ErrorKind::NotFound,
"Event ID is not known to us or Event ID is invalid",
));
}; };
is_report_valid( is_report_valid(
@ -49,7 +47,8 @@ pub(crate) async fn report_event_route(
&body.reason, &body.reason,
body.score, body.score,
&pdu, &pdu,
)?; )
.await?;
// send admin room message that we received the report with an @room ping for // send admin room message that we received the report with an @room ping for
// urgency // urgency
@ -81,7 +80,8 @@ pub(crate) async fn report_event_route(
HtmlEscape(body.reason.as_deref().unwrap_or("")) HtmlEscape(body.reason.as_deref().unwrap_or(""))
), ),
)) ))
.await; .await
.ok();
Ok(report_content::v3::Response {}) Ok(report_content::v3::Response {})
} }
@ -92,7 +92,7 @@ pub(crate) async fn report_event_route(
/// check if score is in valid range /// check if score is in valid range
/// check if report reasoning is less than or equal to 750 characters /// check if report reasoning is less than or equal to 750 characters
/// check if reporting user is in the reporting room /// check if reporting user is in the reporting room
fn is_report_valid( async fn is_report_valid(
services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>, services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>,
score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>, score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>,
) -> Result<()> { ) -> Result<()> {
@ -123,8 +123,8 @@ fn is_report_valid(
.rooms .rooms
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
.filter_map(Result::ok) .ready_any(|user_id| user_id == sender_user)
.any(|user_id| user_id == *sender_user) .await
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,

View File

@ -2,6 +2,7 @@ use std::{cmp::max, collections::BTreeMap};
use axum::extract::State; use axum::extract::State;
use conduit::{debug_info, debug_warn, err, Err}; use conduit::{debug_info, debug_warn, err, Err};
use futures::{FutureExt, StreamExt};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -74,7 +75,7 @@ pub(crate) async fn create_room_route(
if !services.globals.allow_room_creation() if !services.globals.allow_room_creation()
&& body.appservice_info.is_none() && body.appservice_info.is_none()
&& !services.users.is_admin(sender_user)? && !services.users.is_admin(sender_user).await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled."));
} }
@ -86,7 +87,7 @@ pub(crate) async fn create_room_route(
}; };
// check if room ID doesn't already exist instead of erroring on auth check // check if room ID doesn't already exist instead of erroring on auth check
if services.rooms.short.get_shortroomid(&room_id)?.is_some() { if services.rooms.short.get_shortroomid(&room_id).await.is_ok() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::RoomInUse, ErrorKind::RoomInUse,
"Room with that custom room ID already exists", "Room with that custom room ID already exists",
@ -95,7 +96,7 @@ pub(crate) async fn create_room_route(
if body.visibility == room::Visibility::Public if body.visibility == room::Visibility::Public
&& services.globals.config.lockdown_public_room_directory && services.globals.config.lockdown_public_room_directory
&& !services.users.is_admin(sender_user)? && !services.users.is_admin(sender_user).await
&& body.appservice_info.is_none() && body.appservice_info.is_none()
{ {
info!( info!(
@ -118,7 +119,11 @@ pub(crate) async fn create_room_route(
return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed"))); return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed")));
} }
let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; let _short_id = services
.rooms
.short
.get_or_create_shortroomid(&room_id)
.await;
let state_lock = services.rooms.state.mutex.lock(&room_id).await; let state_lock = services.rooms.state.mutex.lock(&room_id).await;
let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name {
@ -218,6 +223,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 2. Let the room creator join // 2. Let the room creator join
@ -229,11 +235,11 @@ pub(crate) async fn create_room_route(
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?, displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: Some(body.is_direct), is_direct: Some(body.is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user).await.ok(),
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -247,6 +253,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 3. Power levels // 3. Power levels
@ -284,6 +291,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 4. Canonical room alias // 4. Canonical room alias
@ -308,6 +316,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
} }
@ -335,6 +344,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 5.2 History Visibility // 5.2 History Visibility
@ -355,6 +365,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 5.3 Guest Access // 5.3 Guest Access
@ -378,6 +389,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
// 6. Events listed in initial_state // 6. Events listed in initial_state
@ -410,6 +422,7 @@ pub(crate) async fn create_room_route(
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.boxed()
.await?; .await?;
} }
@ -432,6 +445,7 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
} }
@ -455,13 +469,17 @@ pub(crate) async fn create_room_route(
&room_id, &room_id,
&state_lock, &state_lock,
) )
.boxed()
.await?; .await?;
} }
// 8. Events implied by invite (and TODO: invite_3pid) // 8. Events implied by invite (and TODO: invite_3pid)
drop(state_lock); drop(state_lock);
for user_id in &body.invite { for user_id in &body.invite {
if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct)
.boxed()
.await
{
warn!(%e, "Failed to send invite"); warn!(%e, "Failed to send invite");
} }
} }
@ -475,7 +493,7 @@ pub(crate) async fn create_room_route(
} }
if body.visibility == room::Visibility::Public { if body.visibility == room::Visibility::Public {
services.rooms.directory.set_public(&room_id)?; services.rooms.directory.set_public(&room_id);
if services.globals.config.admin_room_notices { if services.globals.config.admin_room_notices {
services services
@ -505,13 +523,15 @@ pub(crate) async fn get_room_event_route(
let event = services let event = services
.rooms .rooms
.timeline .timeline
.get_pdu(&body.event_id)? .get_pdu(&body.event_id)
.ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; .await
.map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &event.room_id, &body.event_id)? .user_can_see_event(sender_user, &event.room_id, &body.event_id)
.await
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
@ -541,7 +561,8 @@ pub(crate) async fn get_room_aliases_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
@ -554,8 +575,9 @@ pub(crate) async fn get_room_aliases_route(
.rooms .rooms
.alias .alias
.local_aliases_for_room(&body.room_id) .local_aliases_for_room(&body.room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(), .collect()
.await,
}) })
} }
@ -591,7 +613,8 @@ pub(crate) async fn upgrade_room_route(
let _short_id = services let _short_id = services
.rooms .rooms
.short .short
.get_or_create_shortroomid(&replacement_room)?; .get_or_create_shortroomid(&replacement_room)
.await;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
@ -629,12 +652,12 @@ pub(crate) async fn upgrade_room_route(
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .await
.map_err(|_| err!(Database("Found room without m.room.create event.")))?
.content .content
.get(), .get(),
) )?;
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
// Use the m.room.tombstone event as the predecessor // Use the m.room.tombstone event as the predecessor
let predecessor = Some(ruma::events::room::create::PreviousRoom::new( let predecessor = Some(ruma::events::room::create::PreviousRoom::new(
@ -714,11 +737,11 @@ pub(crate) async fn upgrade_room_route(
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services.users.displayname(sender_user)?, displayname: services.users.displayname(sender_user).await.ok(),
avatar_url: services.users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user).await.ok(),
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services.users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user).await.ok(),
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -739,10 +762,11 @@ pub(crate) async fn upgrade_room_route(
let event_content = match services let event_content = match services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, event_type, "")? .room_state_get(&body.room_id, event_type, "")
.await
{ {
Some(v) => v.content.clone(), Ok(v) => v.content.clone(),
None => continue, // Skipping missing events. Err(_) => continue, // Skipping missing events.
}; };
services services
@ -765,21 +789,23 @@ pub(crate) async fn upgrade_room_route(
} }
// Moves any local aliases to the new room // Moves any local aliases to the new room
for alias in services let mut local_aliases = services
.rooms .rooms
.alias .alias
.local_aliases_for_room(&body.room_id) .local_aliases_for_room(&body.room_id)
.filter_map(Result::ok) .boxed();
{
while let Some(alias) = local_aliases.next().await {
services services
.rooms .rooms
.alias .alias
.remove_alias(&alias, sender_user) .remove_alias(alias, sender_user)
.await?; .await?;
services services
.rooms .rooms
.alias .alias
.set_alias(&alias, &replacement_room, sender_user)?; .set_alias(alias, &replacement_room, sender_user)?;
} }
// Get the old room power levels // Get the old room power levels
@ -787,12 +813,12 @@ pub(crate) async fn upgrade_room_route(
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .await
.map_err(|_| err!(Database("Found room without m.room.create event.")))?
.content .content
.get(), .get(),
) )?;
.map_err(|_| Error::bad_database("Invalid room event in database."))?;
// Setting events_default and invite to the greater of 50 and users_default + 1 // Setting events_default and invite to the greater of 50 and users_default + 1
let new_level = max( let new_level = max(
@ -800,9 +826,7 @@ pub(crate) async fn upgrade_room_route(
power_levels_event_content power_levels_event_content
.users_default .users_default
.checked_add(int!(1)) .checked_add(int!(1))
.ok_or_else(|| { .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?,
Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid")
})?,
); );
power_levels_event_content.events_default = new_level; power_levels_event_content.events_default = new_level;
power_levels_event_content.invite = new_level; power_levels_event_content.invite = new_level;
@ -921,8 +945,9 @@ async fn room_alias_check(
if services if services
.rooms .rooms
.alias .alias
.resolve_local_alias(&full_room_alias)? .resolve_local_alias(&full_room_alias)
.is_some() .await
.is_ok()
{ {
return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists."));
} }

View File

@ -1,6 +1,12 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduit::{
debug,
utils::{IterStream, ReadyExt},
Err,
};
use futures::{FutureExt, StreamExt};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -13,7 +19,6 @@ use ruma::{
serde::Raw, serde::Raw,
uint, OwnedRoomId, uint, OwnedRoomId,
}; };
use tracing::debug;
use crate::{Error, Result, Ruma}; use crate::{Error, Result, Ruma};
@ -32,14 +37,17 @@ pub(crate) async fn search_events_route(
let filter = &search_criteria.filter; let filter = &search_criteria.filter;
let include_state = &search_criteria.include_state; let include_state = &search_criteria.include_state;
let room_ids = filter.rooms.clone().unwrap_or_else(|| { let room_ids = if let Some(room_ids) = &filter.rooms {
room_ids.clone()
} else {
services services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect() .collect()
}); .await
};
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit: usize = filter let limit: usize = filter
@ -53,18 +61,21 @@ pub(crate) async fn search_events_route(
if include_state.is_some_and(|include_state| include_state) { if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids { for room_id in &room_ids {
if !services.rooms.state_cache.is_joined(sender_user, room_id)? { if !services
return Err(Error::BadRequest( .rooms
ErrorKind::forbidden(), .state_cache
"You don't have permission to view this room.", .is_joined(sender_user, room_id)
)); .await
{
return Err!(Request(Forbidden("You don't have permission to view this room.")));
} }
// check if sender_user can see state events // check if sender_user can see state events
if services if services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, room_id)? .user_can_see_state_events(sender_user, room_id)
.await
{ {
let room_state = services let room_state = services
.rooms .rooms
@ -87,10 +98,15 @@ pub(crate) async fn search_events_route(
} }
} }
let mut searches = Vec::new(); let mut search_vecs = Vec::new();
for room_id in &room_ids { for room_id in &room_ids {
if !services.rooms.state_cache.is_joined(sender_user, room_id)? { if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -100,12 +116,18 @@ pub(crate) async fn search_events_route(
if let Some(search) = services if let Some(search) = services
.rooms .rooms
.search .search
.search_pdus(room_id, &search_criteria.search_term)? .search_pdus(room_id, &search_criteria.search_term)
.await
{ {
searches.push(search.0.peekable()); search_vecs.push(search.0);
} }
} }
let mut searches: Vec<_> = search_vecs
.iter()
.map(|vec| vec.iter().peekable())
.collect();
let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) {
Some(Ok(s)) => s, Some(Ok(s)) => s,
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
@ -118,8 +140,8 @@ pub(crate) async fn search_events_route(
for _ in 0..next_batch { for _ in 0..next_batch {
if let Some(s) = searches if let Some(s) = searches
.iter_mut() .iter_mut()
.map(|s| (s.peek().cloned(), s)) .map(|s| (s.peek().copied(), s))
.max_by_key(|(peek, _)| peek.clone()) .max_by_key(|(peek, _)| *peek)
.and_then(|(_, i)| i.next()) .and_then(|(_, i)| i.next())
{ {
results.push(s); results.push(s);
@ -127,26 +149,22 @@ pub(crate) async fn search_events_route(
} }
let results: Vec<_> = results let results: Vec<_> = results
.iter() .into_iter()
.skip(skip) .skip(skip)
.filter_map(|result| { .stream()
.filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok))
.ready_filter(|pdu| !pdu.is_redacted())
.filter_map(|pdu| async move {
services services
.rooms
.timeline
.get_pdu_from_id(result)
.ok()?
.filter(|pdu| {
!pdu.is_redacted()
&& services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.unwrap_or(false) .await
.then_some(pdu)
}) })
.take(limit)
.map(|pdu| pdu.to_room_event()) .map(|pdu| pdu.to_room_event())
}) .map(|result| SearchResult {
.map(|result| {
Ok::<_, Error>(SearchResult {
context: EventContextResult { context: EventContextResult {
end: None, end: None,
events_after: Vec::new(), events_after: Vec::new(),
@ -157,12 +175,12 @@ pub(crate) async fn search_events_route(
rank: None, rank: None,
result: Some(result), result: Some(result),
}) })
}) .collect()
.filter_map(Result::ok) .boxed()
.take(limit) .await;
.collect();
let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some());
let next_batch = more_unloaded_results.then(|| next_batch.to_string()); let next_batch = more_unloaded_results.then(|| next_batch.to_string());
Ok(search_events::v3::Response::new(ResultCategories { Ok(search_events::v3::Response::new(ResultCategories {

View File

@ -1,5 +1,7 @@
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{debug, err, info, utils::ReadyExt, warn, Err};
use futures::StreamExt;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -19,7 +21,6 @@ use ruma::{
UserId, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
use tracing::{debug, info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{utils, utils::hash, Error, Result, Ruma}; use crate::{utils, utils::hash, Error, Result, Ruma};
@ -79,21 +80,22 @@ pub(crate) async fn login_route(
UserId::parse(user) UserId::parse(user)
} else { } else {
warn!("Bad login type: {:?}", &body.login_info); warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type.")); return Err!(Request(Forbidden("Bad login type.")));
} }
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
let hash = services let hash = services
.users .users
.password_hash(&user_id)? .password_hash(&user_id)
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; .await
.map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?;
if hash.is_empty() { if hash.is_empty() {
return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); return Err!(Request(UserDeactivated("The user has been deactivated")));
} }
if hash::verify_password(password, &hash).is_err() { if hash::verify_password(password, &hash).is_err() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password.")); return Err!(Request(Forbidden("Wrong username or password.")));
} }
user_id user_id
@ -112,15 +114,12 @@ pub(crate) async fn login_route(
let username = token.claims.sub.to_lowercase(); let username = token.claims.sub.to_lowercase();
UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { UserId::parse_with_server_name(username, services.globals.server_name())
warn!("Failed to parse username from user logging in: {e}"); .map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))?
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?
} else { } else {
return Err(Error::BadRequest( return Err!(Request(Unknown(
ErrorKind::Unknown, "Token login is not supported (server has no jwt decoding key)."
"Token login is not supported (server has no jwt decoding key).", )));
));
} }
}, },
#[allow(deprecated)] #[allow(deprecated)]
@ -169,23 +168,32 @@ pub(crate) async fn login_route(
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| { let device_exists = if body.device_id.is_some() {
services services
.users .users
.all_device_ids(&user_id) .all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id)) .ready_any(|v| v == device_id)
}); .await
} else {
false
};
if device_exists { if device_exists {
services.users.set_token(&user_id, &device_id, &token)?; services
.users
.set_token(&user_id, &device_id, &token)
.await?;
} else { } else {
services.users.create_device( services
.users
.create_device(
&user_id, &user_id,
&device_id, &device_id,
&token, &token,
body.initial_device_display_name.clone(), body.initial_device_display_name.clone(),
Some(client.to_string()), Some(client.to_string()),
)?; )
.await?;
} }
// send client well-known if specified so the client knows to reconfigure itself // send client well-known if specified so the client knows to reconfigure itself
@ -228,10 +236,13 @@ pub(crate) async fn logout_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
services.users.remove_device(sender_user, sender_device)?; services
.users
.remove_device(sender_user, sender_device)
.await;
// send device list update for user after logout // send device list update for user after logout
services.users.mark_device_key_update(sender_user)?; services.users.mark_device_key_update(sender_user).await;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())
} }
@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route(
) -> Result<logout_all::v3::Response> { ) -> Result<logout_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for device_id in services.users.all_device_ids(sender_user).flatten() { services
services.users.remove_device(sender_user, &device_id)?; .users
} .all_device_ids(sender_user)
.for_each(|device_id| services.users.remove_device(sender_user, device_id))
.await;
// send device list update for user after logout // send device list update for user after logout
services.users.mark_device_key_update(sender_user)?; services.users.mark_device_key_update(sender_user).await;
Ok(logout_all::v3::Response::new()) Ok(logout_all::v3::Response::new())
} }

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; use conduit::{err, error, pdu::PduBuilder, Err, Error, Result};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden("You don't have permission to view the room state.")));
ErrorKind::forbidden(),
"You don't have permission to view the room state.",
));
} }
Ok(get_state_events::v3::Response { Ok(get_state_events::v3::Response {
@ -120,22 +118,25 @@ pub(crate) async fn get_state_events_for_key_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden("You don't have permission to view the room state.")));
ErrorKind::forbidden(),
"You don't have permission to view the room state.",
));
} }
let event = services let event = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &body.event_type, &body.state_key)? .room_state_get(&body.room_id, &body.event_type, &body.state_key)
.ok_or_else(|| { .await
debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); .map_err(|_| {
Error::BadRequest(ErrorKind::NotFound, "State event not found.") err!(Request(NotFound(error!(
room_id = ?body.room_id,
event_type = ?body.event_type,
"State event not found in room.",
))))
})?; })?;
if body if body
.format .format
.as_ref() .as_ref()
@ -204,7 +205,7 @@ async fn send_state_event_for_key_helper(
async fn allowed_to_send_state_event( async fn allowed_to_send_state_event(
services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>,
) -> Result<()> { ) -> Result {
match event_type { match event_type {
// Forbid m.room.encryption if encryption is disabled // Forbid m.room.encryption if encryption is disabled
StateEventType::RoomEncryption => { StateEventType::RoomEncryption => {
@ -214,7 +215,7 @@ async fn allowed_to_send_state_event(
}, },
// admin room is a sensitive room, it should not ever be made public // admin room is a sensitive room, it should not ever be made public
StateEventType::RoomJoinRules => { StateEventType::RoomJoinRules => {
if let Some(admin_room_id) = services.admin.get_admin_room()? { if let Ok(admin_room_id) = services.admin.get_admin_room().await {
if admin_room_id == room_id { if admin_room_id == room_id {
if let Ok(join_rule) = serde_json::from_str::<RoomJoinRulesEventContent>(json.json().get()) { if let Ok(join_rule) = serde_json::from_str::<RoomJoinRulesEventContent>(json.json().get()) {
if join_rule.join_rule == JoinRule::Public { if join_rule.join_rule == JoinRule::Public {
@ -229,7 +230,7 @@ async fn allowed_to_send_state_event(
}, },
// admin room is a sensitive room, it should not ever be made world readable // admin room is a sensitive room, it should not ever be made world readable
StateEventType::RoomHistoryVisibility => { StateEventType::RoomHistoryVisibility => {
if let Some(admin_room_id) = services.admin.get_admin_room()? { if let Ok(admin_room_id) = services.admin.get_admin_room().await {
if admin_room_id == room_id { if admin_room_id == room_id {
if let Ok(visibility_content) = if let Ok(visibility_content) =
serde_json::from_str::<RoomHistoryVisibilityEventContent>(json.json().get()) serde_json::from_str::<RoomHistoryVisibilityEventContent>(json.json().get())
@ -254,23 +255,27 @@ async fn allowed_to_send_state_event(
} }
for alias in aliases { for alias in aliases {
if !services.globals.server_is_ours(alias.server_name()) if !services.globals.server_is_ours(alias.server_name()) {
|| services return Err!(Request(Forbidden("canonical_alias must be for this server")));
}
if !services
.rooms .rooms
.alias .alias
.resolve_local_alias(&alias)? .resolve_local_alias(&alias)
.filter(|room| room == room_id) // Make sure it's the right room .await
.is_none() .is_ok_and(|room| room == room_id)
// Make sure it's the right room
{ {
return Err(Error::BadRequest( return Err!(Request(Forbidden(
ErrorKind::forbidden(), "You are only allowed to send canonical_alias events when its aliases already exist"
"You are only allowed to send canonical_alias events when its aliases already exist", )));
));
} }
} }
} }
}, },
_ => (), _ => (),
} }
Ok(()) Ok(())
} }

File diff suppressed because it is too large Load Diff

View File

@ -23,10 +23,11 @@ pub(crate) async fn update_tag_route(
let event = services let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| { |_| {
Ok(TagEvent { Ok(TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),
@ -41,12 +42,15 @@ pub(crate) async fn update_tag_route(
.tags .tags
.insert(body.tag.clone().into(), body.tag_info.clone()); .insert(body.tag.clone().into(), body.tag_info.clone());
services.account_data.update( services
.account_data
.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )
.await?;
Ok(create_tag::v3::Response {}) Ok(create_tag::v3::Response {})
} }
@ -63,10 +67,11 @@ pub(crate) async fn delete_tag_route(
let event = services let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| { |_| {
Ok(TagEvent { Ok(TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),
@ -78,12 +83,15 @@ pub(crate) async fn delete_tag_route(
tags_event.content.tags.remove(&body.tag.clone().into()); tags_event.content.tags.remove(&body.tag.clone().into());
services.account_data.update( services
.account_data
.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&serde_json::to_value(tags_event).expect("to json value always works"), &serde_json::to_value(tags_event).expect("to json value always works"),
)?; )
.await?;
Ok(delete_tag::v3::Response {}) Ok(delete_tag::v3::Response {})
} }
@ -100,10 +108,11 @@ pub(crate) async fn get_tags_route(
let event = services let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)
.await;
let tags_event = event.map_or_else( let tags_event = event.map_or_else(
|| { |_| {
Ok(TagEvent { Ok(TagEvent {
content: TagEventContent { content: TagEventContent {
tags: BTreeMap::new(), tags: BTreeMap::new(),

View File

@ -1,4 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduit::PduEvent;
use futures::StreamExt;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, threads::get_threads}, api::client::{error::ErrorKind, threads::get_threads},
uint, uint,
@ -27,20 +29,23 @@ pub(crate) async fn get_threads_route(
u64::MAX u64::MAX
}; };
let threads = services let room_id = &body.room_id;
let threads: Vec<(u64, PduEvent)> = services
.rooms .rooms
.threads .threads
.threads_until(sender_user, &body.room_id, from, &body.include)? .threads_until(sender_user, &body.room_id, from, &body.include)
.await?
.take(limit) .take(limit)
.filter_map(Result::ok) .filter_map(|(count, pdu)| async move {
.filter(|(_, pdu)| {
services services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .await
.then_some((count, pdu))
}) })
.collect::<Vec<_>>(); .collect()
.await;
let next_batch = threads.last().map(|(count, _)| count.to_string()); let next_batch = threads.last().map(|(count, _)| count.to_string());

View File

@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use futures::StreamExt;
use ruma::{ use ruma::{
api::{ api::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route(
// Check if this is a new transaction id // Check if this is a new transaction id
if services if services
.transaction_ids .transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)
.is_some() .await
.is_ok()
{ {
return Ok(send_event_to_device::v3::Response {}); return Ok(send_event_to_device::v3::Response {});
} }
@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route(
continue; continue;
} }
let event_type = &body.event_type.to_string();
let event = event
.deserialize_as()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?;
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services
.users
.add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event)
.await;
},
DeviceIdOrAllDevices::AllDevices => {
let (event_type, event) = (&event_type, &event);
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event( services.users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
target_device_id, target_device_id,
&body.event_type.to_string(), event_type,
event event.clone(),
.deserialize_as() )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, })
)?; .await;
},
DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services.users.all_device_ids(target_user_id) {
services.users.add_to_device_event(
sender_user,
target_user_id,
&target_device_id?,
&body.event_type.to_string(),
event
.deserialize_as()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?;
}
}, },
} }
} }
@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route(
// Save transaction id with empty data // Save transaction id with empty data
services services
.transaction_ids .transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; .add_txnid(sender_user, sender_device, &body.txn_id, &[]);
Ok(send_event_to_device::v3::Response {}) Ok(send_event_to_device::v3::Response {})
} }

View File

@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route(
if !services if !services
.rooms .rooms
.state_cache .state_cache
.is_joined(sender_user, &body.room_id)? .is_joined(sender_user, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room."));
} }

View File

@ -2,7 +2,8 @@ use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{warn, Err}; use conduit::Err;
use futures::StreamExt;
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route(
)); ));
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
return Ok(mutual_rooms::unstable::Response { return Ok(mutual_rooms::unstable::Response {
joined: vec![], joined: vec![],
next_batch_token: None, next_batch_token: None,
@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route(
let mutual_rooms: Vec<OwnedRoomId> = services let mutual_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? .get_shared_rooms(sender_user, &body.user_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
Ok(mutual_rooms::unstable::Response { Ok(mutual_rooms::unstable::Response {
joined: mutual_rooms, joined: mutual_rooms,
@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary(
let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?;
if !services.rooms.metadata.exists(&room_id)? { if !services.rooms.metadata.exists(&room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"));
} }
@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary(
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&room_id) .is_world_readable(&room_id)
.unwrap_or(false) .await
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary(
.rooms .rooms
.state_accessor .state_accessor
.get_canonical_alias(&room_id) .get_canonical_alias(&room_id)
.unwrap_or(None), .await
.ok(),
avatar_url: services avatar_url: services
.rooms .rooms
.state_accessor .state_accessor
.get_avatar(&room_id)? .get_avatar(&room_id)
.await
.into_option() .into_option()
.unwrap_or_default() .unwrap_or_default()
.url, .url,
guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await,
name: services name: services.rooms.state_accessor.get_name(&room_id).await.ok(),
.rooms
.state_accessor
.get_name(&room_id)
.unwrap_or(None),
num_joined_members: services num_joined_members: services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(&room_id) .room_joined_count(&room_id)
.unwrap_or_default() .await
.unwrap_or_else(|| { .unwrap_or(0)
warn!("Room {room_id} has no member count"); .try_into()?,
0
})
.try_into()
.expect("user count should not be that big"),
topic: services topic: services
.rooms .rooms
.state_accessor .state_accessor
.get_room_topic(&room_id) .get_room_topic(&room_id)
.unwrap_or(None), .await
.ok(),
world_readable: services world_readable: services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&room_id) .is_world_readable(&room_id)
.unwrap_or(false), .await,
join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, join_rule: services
room_type: services.rooms.state_accessor.get_room_type(&room_id)?, .rooms
room_version: Some(services.rooms.state.get_room_version(&room_id)?), .state_accessor
.get_join_rule(&room_id)
.await
.unwrap_or_default()
.0,
room_type: services
.rooms
.state_accessor
.get_room_type(&room_id)
.await
.ok(),
room_version: services.rooms.state.get_room_version(&room_id).await.ok(),
membership: if let Some(sender_user) = sender_user { membership: if let Some(sender_user) = sender_user {
services services
.rooms .rooms
.state_accessor .state_accessor
.get_member(&room_id, sender_user)? .get_member(&room_id, sender_user)
.map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership)) .await
.map_or_else(|_| MembershipState::Leave, |content| content.membership)
.into()
} else { } else {
None None
}, },
@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary(
.rooms .rooms
.state_accessor .state_accessor
.get_room_encryption(&room_id) .get_room_encryption(&room_id)
.unwrap_or_else(|_e| None), .await
.ok(),
}) })
} }
@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route(
return Err!(Request(Forbidden("You cannot update the profile of another user"))); return Err!(Request(Forbidden("You cannot update the profile of another user")));
} }
services.users.set_timezone(&body.user_id, None).await?; services.users.set_timezone(&body.user_id, None);
if services.globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await?;
} }
Ok(delete_timezone_key::unstable::Response {}) Ok(delete_timezone_key::unstable::Response {})
@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route(
return Err!(Request(Forbidden("You cannot update the profile of another user"))); return Err!(Request(Forbidden("You cannot update the profile of another user")));
} }
services services.users.set_timezone(&body.user_id, body.tz.clone());
.users
.set_timezone(&body.user_id, body.tz.clone())
.await?;
if services.globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await?;
} }
Ok(set_timezone_key::unstable::Response {}) Ok(set_timezone_key::unstable::Response {})
@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?; update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?;
} else if body.key == "avatar_url" { } else if body.key == "avatar_url" {
let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string()); let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string());
@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?; update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?;
} else { } else {
services services
.users .users
.set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?; .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()));
} }
if services.globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await?;
} }
Ok(set_profile_key::unstable::Response {}) Ok(set_profile_key::unstable::Response {})
@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route(
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
update_displayname(&services, &body.user_id, None, all_joined_rooms).await?; update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?;
} else if body.key == "avatar_url" { } else if body.key == "avatar_url" {
let all_joined_rooms: Vec<OwnedRoomId> = services let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&body.user_id) .rooms_joined(&body.user_id)
.filter_map(Result::ok) .map(Into::into)
.collect(); .collect()
.await;
update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?; update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?;
} else { } else {
services services
.users .users
.set_profile_key(&body.user_id, &body.key, None)?; .set_profile_key(&body.user_id, &body.key, None);
} }
if services.globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services services
.presence .presence
.ping_presence(&body.user_id, &PresenceState::Online)?; .ping_presence(&body.user_id, &PresenceState::Online)
.await?;
} }
Ok(delete_profile_key::unstable::Response {}) Ok(delete_profile_key::unstable::Response {})
@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route(
) )
.await .await
{ {
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone());
.await?;
services services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone());
.await?;
services services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone());
.await?;
services services
.users .users
.set_timezone(&body.user_id, response.tz.clone()) .set_timezone(&body.user_id, response.tz.clone());
.await?;
return Ok(get_timezone_key::unstable::Response { return Ok(get_timezone_key::unstable::Response {
tz: response.tz, tz: response.tz,
@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route(
} }
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_timezone_key::unstable::Response { Ok(get_timezone_key::unstable::Response {
tz: services.users.timezone(&body.user_id)?, tz: services.users.timezone(&body.user_id).await.ok(),
}) })
} }
@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route(
) )
.await .await
{ {
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
services.users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone());
.await?;
services services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone());
.await?;
services services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone());
.await?;
services services
.users .users
.set_timezone(&body.user_id, response.tz.clone()) .set_timezone(&body.user_id, response.tz.clone());
.await?;
if let Some(value) = response.custom_profile_fields.get(&body.key) { if let Some(value) = response.custom_profile_fields.get(&body.key) {
profile_key_value.insert(body.key.clone(), value.clone()); profile_key_value.insert(body.key.clone(), value.clone());
services services
.users .users
.set_profile_key(&body.user_id, &body.key, Some(value.clone()))?; .set_profile_key(&body.user_id, &body.key, Some(value.clone()));
} else { } else {
return Err!(Request(NotFound("The requested profile key does not exist."))); return Err!(Request(NotFound("The requested profile key does not exist.")));
} }
@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route(
} }
} }
if !services.users.exists(&body.user_id)? { if !services.users.exists(&body.user_id).await {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err!(Request(NotFound("Profile was not found.")));
} }
if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? { if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await {
profile_key_value.insert(body.key.clone(), value); profile_key_value.insert(body.key.clone(), value);
} else { } else {
return Err!(Request(NotFound("The requested profile key does not exist."))); return Err!(Request(NotFound("The requested profile key does not exist.")));

View File

@ -1,6 +1,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use futures::StreamExt;
use ruma::api::client::{ use ruma::api::client::{
discovery::{ discovery::{
discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo},
@ -173,7 +174,7 @@ pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> {
/// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// homeserver. Endpoint is disabled if federation is disabled for privacy. This
/// only includes active users (not deactivated, no guests, etc) /// only includes active users (not deactivated, no guests, etc)
pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> { pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> {
let user_count = services.users.list_local_users()?.len(); let user_count = services.users.list_local_users().count().await;
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"count": user_count "count": user_count

View File

@ -1,4 +1,5 @@
use axum::extract::State; use axum::extract::State;
use futures::{pin_mut, StreamExt};
use ruma::{ use ruma::{
api::client::user_directory::search_users, api::client::user_directory::search_users,
events::{ events::{
@ -21,14 +22,12 @@ pub(crate) async fn search_users_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10
let mut users = services.users.iter().filter_map(|user_id| { let users = services.users.stream().filter_map(|user_id| async {
// Filter out buggy users (they should not exist, but you never know...) // Filter out buggy users (they should not exist, but you never know...)
let user_id = user_id.ok()?;
let user = search_users::v3::User { let user = search_users::v3::User {
user_id: user_id.clone(), user_id: user_id.to_owned(),
display_name: services.users.displayname(&user_id).ok()?, display_name: services.users.displayname(user_id).await.ok(),
avatar_url: services.users.avatar_url(&user_id).ok()?, avatar_url: services.users.avatar_url(user_id).await.ok(),
}; };
let user_id_matches = user let user_id_matches = user
@ -56,20 +55,19 @@ pub(crate) async fn search_users_route(
let user_is_in_public_rooms = services let user_is_in_public_rooms = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user.user_id)
.filter_map(Result::ok) .any(|room| async move {
.any(|room| {
services services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room, &StateEventType::RoomJoinRules, "") .room_state_get(room, &StateEventType::RoomJoinRules, "")
.await
.map_or(false, |event| { .map_or(false, |event| {
event.map_or(false, |event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
}) })
}) })
}); .await;
if user_is_in_public_rooms { if user_is_in_public_rooms {
user_visible = true; user_visible = true;
@ -77,25 +75,22 @@ pub(crate) async fn search_users_route(
let user_is_in_shared_rooms = services let user_is_in_shared_rooms = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), user_id]) .has_shared_rooms(sender_user, &user.user_id)
.ok()? .await;
.next()
.is_some();
if user_is_in_shared_rooms { if user_is_in_shared_rooms {
user_visible = true; user_visible = true;
} }
} }
if !user_visible { user_visible.then_some(user)
return None;
}
Some(user)
}); });
let results = users.by_ref().take(limit).collect(); pin_mut!(users);
let limited = users.next().is_some();
let limited = users.by_ref().next().await.is_some();
let results = users.take(limit).collect().await;
Ok(search_users::v3::Response { Ok(search_users::v3::Response {
results, results,

View File

@ -22,101 +22,101 @@ use crate::{client, server};
pub fn build(router: Router<State>, server: &Server) -> Router<State> { pub fn build(router: Router<State>, server: &Server) -> Router<State> {
let config = &server.config; let config = &server.config;
let mut router = router let mut router = router
.ruma_route(client::get_timezone_key_route) .ruma_route(&client::get_timezone_key_route)
.ruma_route(client::get_profile_key_route) .ruma_route(&client::get_profile_key_route)
.ruma_route(client::set_profile_key_route) .ruma_route(&client::set_profile_key_route)
.ruma_route(client::delete_profile_key_route) .ruma_route(&client::delete_profile_key_route)
.ruma_route(client::set_timezone_key_route) .ruma_route(&client::set_timezone_key_route)
.ruma_route(client::delete_timezone_key_route) .ruma_route(&client::delete_timezone_key_route)
.ruma_route(client::appservice_ping) .ruma_route(&client::appservice_ping)
.ruma_route(client::get_supported_versions_route) .ruma_route(&client::get_supported_versions_route)
.ruma_route(client::get_register_available_route) .ruma_route(&client::get_register_available_route)
.ruma_route(client::register_route) .ruma_route(&client::register_route)
.ruma_route(client::get_login_types_route) .ruma_route(&client::get_login_types_route)
.ruma_route(client::login_route) .ruma_route(&client::login_route)
.ruma_route(client::whoami_route) .ruma_route(&client::whoami_route)
.ruma_route(client::logout_route) .ruma_route(&client::logout_route)
.ruma_route(client::logout_all_route) .ruma_route(&client::logout_all_route)
.ruma_route(client::change_password_route) .ruma_route(&client::change_password_route)
.ruma_route(client::deactivate_route) .ruma_route(&client::deactivate_route)
.ruma_route(client::third_party_route) .ruma_route(&client::third_party_route)
.ruma_route(client::request_3pid_management_token_via_email_route) .ruma_route(&client::request_3pid_management_token_via_email_route)
.ruma_route(client::request_3pid_management_token_via_msisdn_route) .ruma_route(&client::request_3pid_management_token_via_msisdn_route)
.ruma_route(client::check_registration_token_validity) .ruma_route(&client::check_registration_token_validity)
.ruma_route(client::get_capabilities_route) .ruma_route(&client::get_capabilities_route)
.ruma_route(client::get_pushrules_all_route) .ruma_route(&client::get_pushrules_all_route)
.ruma_route(client::set_pushrule_route) .ruma_route(&client::set_pushrule_route)
.ruma_route(client::get_pushrule_route) .ruma_route(&client::get_pushrule_route)
.ruma_route(client::set_pushrule_enabled_route) .ruma_route(&client::set_pushrule_enabled_route)
.ruma_route(client::get_pushrule_enabled_route) .ruma_route(&client::get_pushrule_enabled_route)
.ruma_route(client::get_pushrule_actions_route) .ruma_route(&client::get_pushrule_actions_route)
.ruma_route(client::set_pushrule_actions_route) .ruma_route(&client::set_pushrule_actions_route)
.ruma_route(client::delete_pushrule_route) .ruma_route(&client::delete_pushrule_route)
.ruma_route(client::get_room_event_route) .ruma_route(&client::get_room_event_route)
.ruma_route(client::get_room_aliases_route) .ruma_route(&client::get_room_aliases_route)
.ruma_route(client::get_filter_route) .ruma_route(&client::get_filter_route)
.ruma_route(client::create_filter_route) .ruma_route(&client::create_filter_route)
.ruma_route(client::create_openid_token_route) .ruma_route(&client::create_openid_token_route)
.ruma_route(client::set_global_account_data_route) .ruma_route(&client::set_global_account_data_route)
.ruma_route(client::set_room_account_data_route) .ruma_route(&client::set_room_account_data_route)
.ruma_route(client::get_global_account_data_route) .ruma_route(&client::get_global_account_data_route)
.ruma_route(client::get_room_account_data_route) .ruma_route(&client::get_room_account_data_route)
.ruma_route(client::set_displayname_route) .ruma_route(&client::set_displayname_route)
.ruma_route(client::get_displayname_route) .ruma_route(&client::get_displayname_route)
.ruma_route(client::set_avatar_url_route) .ruma_route(&client::set_avatar_url_route)
.ruma_route(client::get_avatar_url_route) .ruma_route(&client::get_avatar_url_route)
.ruma_route(client::get_profile_route) .ruma_route(&client::get_profile_route)
.ruma_route(client::set_presence_route) .ruma_route(&client::set_presence_route)
.ruma_route(client::get_presence_route) .ruma_route(&client::get_presence_route)
.ruma_route(client::upload_keys_route) .ruma_route(&client::upload_keys_route)
.ruma_route(client::get_keys_route) .ruma_route(&client::get_keys_route)
.ruma_route(client::claim_keys_route) .ruma_route(&client::claim_keys_route)
.ruma_route(client::create_backup_version_route) .ruma_route(&client::create_backup_version_route)
.ruma_route(client::update_backup_version_route) .ruma_route(&client::update_backup_version_route)
.ruma_route(client::delete_backup_version_route) .ruma_route(&client::delete_backup_version_route)
.ruma_route(client::get_latest_backup_info_route) .ruma_route(&client::get_latest_backup_info_route)
.ruma_route(client::get_backup_info_route) .ruma_route(&client::get_backup_info_route)
.ruma_route(client::add_backup_keys_route) .ruma_route(&client::add_backup_keys_route)
.ruma_route(client::add_backup_keys_for_room_route) .ruma_route(&client::add_backup_keys_for_room_route)
.ruma_route(client::add_backup_keys_for_session_route) .ruma_route(&client::add_backup_keys_for_session_route)
.ruma_route(client::delete_backup_keys_for_room_route) .ruma_route(&client::delete_backup_keys_for_room_route)
.ruma_route(client::delete_backup_keys_for_session_route) .ruma_route(&client::delete_backup_keys_for_session_route)
.ruma_route(client::delete_backup_keys_route) .ruma_route(&client::delete_backup_keys_route)
.ruma_route(client::get_backup_keys_for_room_route) .ruma_route(&client::get_backup_keys_for_room_route)
.ruma_route(client::get_backup_keys_for_session_route) .ruma_route(&client::get_backup_keys_for_session_route)
.ruma_route(client::get_backup_keys_route) .ruma_route(&client::get_backup_keys_route)
.ruma_route(client::set_read_marker_route) .ruma_route(&client::set_read_marker_route)
.ruma_route(client::create_receipt_route) .ruma_route(&client::create_receipt_route)
.ruma_route(client::create_typing_event_route) .ruma_route(&client::create_typing_event_route)
.ruma_route(client::create_room_route) .ruma_route(&client::create_room_route)
.ruma_route(client::redact_event_route) .ruma_route(&client::redact_event_route)
.ruma_route(client::report_event_route) .ruma_route(&client::report_event_route)
.ruma_route(client::create_alias_route) .ruma_route(&client::create_alias_route)
.ruma_route(client::delete_alias_route) .ruma_route(&client::delete_alias_route)
.ruma_route(client::get_alias_route) .ruma_route(&client::get_alias_route)
.ruma_route(client::join_room_by_id_route) .ruma_route(&client::join_room_by_id_route)
.ruma_route(client::join_room_by_id_or_alias_route) .ruma_route(&client::join_room_by_id_or_alias_route)
.ruma_route(client::joined_members_route) .ruma_route(&client::joined_members_route)
.ruma_route(client::leave_room_route) .ruma_route(&client::leave_room_route)
.ruma_route(client::forget_room_route) .ruma_route(&client::forget_room_route)
.ruma_route(client::joined_rooms_route) .ruma_route(&client::joined_rooms_route)
.ruma_route(client::kick_user_route) .ruma_route(&client::kick_user_route)
.ruma_route(client::ban_user_route) .ruma_route(&client::ban_user_route)
.ruma_route(client::unban_user_route) .ruma_route(&client::unban_user_route)
.ruma_route(client::invite_user_route) .ruma_route(&client::invite_user_route)
.ruma_route(client::set_room_visibility_route) .ruma_route(&client::set_room_visibility_route)
.ruma_route(client::get_room_visibility_route) .ruma_route(&client::get_room_visibility_route)
.ruma_route(client::get_public_rooms_route) .ruma_route(&client::get_public_rooms_route)
.ruma_route(client::get_public_rooms_filtered_route) .ruma_route(&client::get_public_rooms_filtered_route)
.ruma_route(client::search_users_route) .ruma_route(&client::search_users_route)
.ruma_route(client::get_member_events_route) .ruma_route(&client::get_member_events_route)
.ruma_route(client::get_protocols_route) .ruma_route(&client::get_protocols_route)
.route("/_matrix/client/unstable/thirdparty/protocols", .route("/_matrix/client/unstable/thirdparty/protocols",
get(client::get_protocols_route_unstable)) get(client::get_protocols_route_unstable))
.ruma_route(client::send_message_event_route) .ruma_route(&client::send_message_event_route)
.ruma_route(client::send_state_event_for_key_route) .ruma_route(&client::send_state_event_for_key_route)
.ruma_route(client::get_state_events_route) .ruma_route(&client::get_state_events_route)
.ruma_route(client::get_state_events_for_key_route) .ruma_route(&client::get_state_events_for_key_route)
// Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes
// share one Ruma request / response type pair with {get,send}_state_event_for_key_route // share one Ruma request / response type pair with {get,send}_state_event_for_key_route
.route( .route(
@ -140,46 +140,46 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
get(client::get_state_events_for_empty_key_route) get(client::get_state_events_for_empty_key_route)
.put(client::send_state_event_for_empty_key_route), .put(client::send_state_event_for_empty_key_route),
) )
.ruma_route(client::sync_events_route) .ruma_route(&client::sync_events_route)
.ruma_route(client::sync_events_v4_route) .ruma_route(&client::sync_events_v4_route)
.ruma_route(client::get_context_route) .ruma_route(&client::get_context_route)
.ruma_route(client::get_message_events_route) .ruma_route(&client::get_message_events_route)
.ruma_route(client::search_events_route) .ruma_route(&client::search_events_route)
.ruma_route(client::turn_server_route) .ruma_route(&client::turn_server_route)
.ruma_route(client::send_event_to_device_route) .ruma_route(&client::send_event_to_device_route)
.ruma_route(client::create_content_route) .ruma_route(&client::create_content_route)
.ruma_route(client::get_content_thumbnail_route) .ruma_route(&client::get_content_thumbnail_route)
.ruma_route(client::get_content_route) .ruma_route(&client::get_content_route)
.ruma_route(client::get_content_as_filename_route) .ruma_route(&client::get_content_as_filename_route)
.ruma_route(client::get_media_preview_route) .ruma_route(&client::get_media_preview_route)
.ruma_route(client::get_media_config_route) .ruma_route(&client::get_media_config_route)
.ruma_route(client::get_devices_route) .ruma_route(&client::get_devices_route)
.ruma_route(client::get_device_route) .ruma_route(&client::get_device_route)
.ruma_route(client::update_device_route) .ruma_route(&client::update_device_route)
.ruma_route(client::delete_device_route) .ruma_route(&client::delete_device_route)
.ruma_route(client::delete_devices_route) .ruma_route(&client::delete_devices_route)
.ruma_route(client::get_tags_route) .ruma_route(&client::get_tags_route)
.ruma_route(client::update_tag_route) .ruma_route(&client::update_tag_route)
.ruma_route(client::delete_tag_route) .ruma_route(&client::delete_tag_route)
.ruma_route(client::upload_signing_keys_route) .ruma_route(&client::upload_signing_keys_route)
.ruma_route(client::upload_signatures_route) .ruma_route(&client::upload_signatures_route)
.ruma_route(client::get_key_changes_route) .ruma_route(&client::get_key_changes_route)
.ruma_route(client::get_pushers_route) .ruma_route(&client::get_pushers_route)
.ruma_route(client::set_pushers_route) .ruma_route(&client::set_pushers_route)
.ruma_route(client::upgrade_room_route) .ruma_route(&client::upgrade_room_route)
.ruma_route(client::get_threads_route) .ruma_route(&client::get_threads_route)
.ruma_route(client::get_relating_events_with_rel_type_and_event_type_route) .ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route)
.ruma_route(client::get_relating_events_with_rel_type_route) .ruma_route(&client::get_relating_events_with_rel_type_route)
.ruma_route(client::get_relating_events_route) .ruma_route(&client::get_relating_events_route)
.ruma_route(client::get_hierarchy_route) .ruma_route(&client::get_hierarchy_route)
.ruma_route(client::get_mutual_rooms_route) .ruma_route(&client::get_mutual_rooms_route)
.ruma_route(client::get_room_summary) .ruma_route(&client::get_room_summary)
.route( .route(
"/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary",
get(client::get_room_summary_legacy) get(client::get_room_summary_legacy)
) )
.ruma_route(client::well_known_support) .ruma_route(&client::well_known_support)
.ruma_route(client::well_known_client) .ruma_route(&client::well_known_client)
.route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_conduwuit/server_version", get(client::conduwuit_server_version))
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
@ -187,35 +187,35 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
if config.allow_federation { if config.allow_federation {
router = router router = router
.ruma_route(server::get_server_version_route) .ruma_route(&server::get_server_version_route)
.route("/_matrix/key/v2/server", get(server::get_server_keys_route)) .route("/_matrix/key/v2/server", get(server::get_server_keys_route))
.route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route)) .route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route))
.ruma_route(server::get_public_rooms_route) .ruma_route(&server::get_public_rooms_route)
.ruma_route(server::get_public_rooms_filtered_route) .ruma_route(&server::get_public_rooms_filtered_route)
.ruma_route(server::send_transaction_message_route) .ruma_route(&server::send_transaction_message_route)
.ruma_route(server::get_event_route) .ruma_route(&server::get_event_route)
.ruma_route(server::get_backfill_route) .ruma_route(&server::get_backfill_route)
.ruma_route(server::get_missing_events_route) .ruma_route(&server::get_missing_events_route)
.ruma_route(server::get_event_authorization_route) .ruma_route(&server::get_event_authorization_route)
.ruma_route(server::get_room_state_route) .ruma_route(&server::get_room_state_route)
.ruma_route(server::get_room_state_ids_route) .ruma_route(&server::get_room_state_ids_route)
.ruma_route(server::create_leave_event_template_route) .ruma_route(&server::create_leave_event_template_route)
.ruma_route(server::create_leave_event_v1_route) .ruma_route(&server::create_leave_event_v1_route)
.ruma_route(server::create_leave_event_v2_route) .ruma_route(&server::create_leave_event_v2_route)
.ruma_route(server::create_join_event_template_route) .ruma_route(&server::create_join_event_template_route)
.ruma_route(server::create_join_event_v1_route) .ruma_route(&server::create_join_event_v1_route)
.ruma_route(server::create_join_event_v2_route) .ruma_route(&server::create_join_event_v2_route)
.ruma_route(server::create_invite_route) .ruma_route(&server::create_invite_route)
.ruma_route(server::get_devices_route) .ruma_route(&server::get_devices_route)
.ruma_route(server::get_room_information_route) .ruma_route(&server::get_room_information_route)
.ruma_route(server::get_profile_information_route) .ruma_route(&server::get_profile_information_route)
.ruma_route(server::get_keys_route) .ruma_route(&server::get_keys_route)
.ruma_route(server::claim_keys_route) .ruma_route(&server::claim_keys_route)
.ruma_route(server::get_openid_userinfo_route) .ruma_route(&server::get_openid_userinfo_route)
.ruma_route(server::get_hierarchy_route) .ruma_route(&server::get_hierarchy_route)
.ruma_route(server::well_known_server) .ruma_route(&server::well_known_server)
.ruma_route(server::get_content_route) .ruma_route(&server::get_content_route)
.ruma_route(server::get_content_thumbnail_route) .ruma_route(&server::get_content_thumbnail_route)
.route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count)); .route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count));
} else { } else {
router = router router = router
@ -227,11 +227,11 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
if config.allow_legacy_media { if config.allow_legacy_media {
router = router router = router
.ruma_route(client::get_media_config_legacy_route) .ruma_route(&client::get_media_config_legacy_route)
.ruma_route(client::get_media_preview_legacy_route) .ruma_route(&client::get_media_preview_legacy_route)
.ruma_route(client::get_content_legacy_route) .ruma_route(&client::get_content_legacy_route)
.ruma_route(client::get_content_as_filename_legacy_route) .ruma_route(&client::get_content_as_filename_legacy_route)
.ruma_route(client::get_content_thumbnail_legacy_route) .ruma_route(&client::get_content_thumbnail_legacy_route)
.route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route)) .route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route))
.route("/_matrix/media/v1/upload", post(client::create_content_legacy_route)) .route("/_matrix/media/v1/upload", post(client::create_content_legacy_route))
.route( .route(

View File

@ -10,7 +10,10 @@ use super::{auth, auth::Auth, request, request::Request};
use crate::{service::appservice::RegistrationInfo, State}; use crate::{service::appservice::RegistrationInfo, State};
/// Extractor for Ruma request structs /// Extractor for Ruma request structs
pub(crate) struct Args<T> { pub(crate) struct Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
/// Request struct body /// Request struct body
pub(crate) body: T, pub(crate) body: T,
@ -38,7 +41,7 @@ pub(crate) struct Args<T> {
#[async_trait] #[async_trait]
impl<T> FromRequest<State, Body> for Args<T> impl<T> FromRequest<State, Body> for Args<T>
where where
T: IncomingRequest, T: IncomingRequest + Send + Sync + 'static,
{ {
type Rejection = Error; type Rejection = Error;
@ -57,7 +60,10 @@ where
} }
} }
impl<T> Deref for Args<T> { impl<T> Deref for Args<T>
where
T: IncomingRequest + Send + Sync + 'static,
{
type Target = T; type Target = T;
fn deref(&self) -> &Self::Target { &self.body } fn deref(&self) -> &Self::Target { &self.body }
@ -67,7 +73,7 @@ fn make_body<T>(
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth, services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
) -> Result<T> ) -> Result<T>
where where
T: IncomingRequest, T: IncomingRequest + Send + Sync + 'static,
{ {
let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body {
let user_id = auth.sender_user.clone().unwrap_or_else(|| { let user_id = auth.sender_user.clone().unwrap_or_else(|| {
@ -77,15 +83,13 @@ where
let uiaa_request = json_body let uiaa_request = json_body
.get("auth") .get("auth")
.and_then(|auth| auth.as_object()) .and_then(CanonicalJsonValue::as_object)
.and_then(|auth| auth.get("session")) .and_then(|auth| auth.get("session"))
.and_then(|session| session.as_str()) .and_then(CanonicalJsonValue::as_str)
.and_then(|session| { .and_then(|session| {
services.uiaa.get_uiaa_request( services
&user_id, .uiaa
&auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session)
session,
)
}); });
if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request {

View File

@ -44,8 +44,8 @@ pub(super) async fn auth(
let token = if let Some(token) = token { let token = if let Some(token) = token {
if let Some(reg_info) = services.appservice.find_from_token(token).await { if let Some(reg_info) = services.appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info)) Token::Appservice(Box::new(reg_info))
} else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { } else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await {
Token::User((user_id, OwnedDeviceId::from(device_id))) Token::User((user_id, device_id))
} else { } else {
Token::Invalid Token::Invalid
} }
@ -98,7 +98,7 @@ pub(super) async fn auth(
)) ))
} }
}, },
(AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?),
(AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => {
Ok(Auth { Ok(Auth {
origin: None, origin: None,
@ -150,7 +150,7 @@ pub(super) async fn auth(
} }
} }
fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { async fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> {
let user_id = request let user_id = request
.query .query
.user_id .user_id
@ -170,7 +170,7 @@ fn auth_appservice(services: &Services, request: &Request, info: Box<Registratio
return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace."));
} }
if !services.users.exists(&user_id)? { if !services.users.exists(&user_id).await {
return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist."));
} }

View File

@ -1,5 +1,3 @@
use std::future::Future;
use axum::{ use axum::{
extract::FromRequestParts, extract::FromRequestParts,
response::IntoResponse, response::IntoResponse,
@ -7,19 +5,25 @@ use axum::{
Router, Router,
}; };
use conduit::Result; use conduit::Result;
use futures::{Future, TryFutureExt};
use http::Method; use http::Method;
use ruma::api::IncomingRequest; use ruma::api::IncomingRequest;
use super::{Ruma, RumaResponse, State}; use super::{Ruma, RumaResponse, State};
pub(in super::super) trait RumaHandler<T> {
fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State>;
fn add_routes(&'static self, router: Router<State>) -> Router<State>;
}
pub(in super::super) trait RouterExt { pub(in super::super) trait RouterExt {
fn ruma_route<H, T>(self, handler: H) -> Self fn ruma_route<H, T>(self, handler: &'static H) -> Self
where where
H: RumaHandler<T>; H: RumaHandler<T>;
} }
impl RouterExt for Router<State> { impl RouterExt for Router<State> {
fn ruma_route<H, T>(self, handler: H) -> Self fn ruma_route<H, T>(self, handler: &'static H) -> Self
where where
H: RumaHandler<T>, H: RumaHandler<T>,
{ {
@ -27,34 +31,28 @@ impl RouterExt for Router<State> {
} }
} }
pub(in super::super) trait RumaHandler<T> {
fn add_routes(&self, router: Router<State>) -> Router<State>;
fn add_route(&self, router: Router<State>, path: &str) -> Router<State>;
}
macro_rules! ruma_handler { macro_rules! ruma_handler {
( $($tx:ident),* $(,)? ) => { ( $($tx:ident),* $(,)? ) => {
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<Req, Ret, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun impl<Err, Req, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun
where where
Req: IncomingRequest + Send + 'static, Fun: Fn($($tx,)* Ruma<Req>,) -> Fut + Send + Sync + 'static,
Ret: IntoResponse, Fut: Future<Output = Result<Req::OutgoingResponse, Err>> + Send,
Fut: Future<Output = Result<Req::OutgoingResponse, Ret>> + Send, Req: IncomingRequest + Send + Sync,
Fun: FnOnce($($tx,)* Ruma<Req>,) -> Fut + Clone + Send + Sync + 'static, Err: IntoResponse + Send,
$( $tx: FromRequestParts<State> + Send + 'static, )* <Req as IncomingRequest>::OutgoingResponse: Send,
$( $tx: FromRequestParts<State> + Send + Sync + 'static, )*
{ {
fn add_routes(&self, router: Router<State>) -> Router<State> { fn add_routes(&'static self, router: Router<State>) -> Router<State> {
Req::METADATA Req::METADATA
.history .history
.all_paths() .all_paths()
.fold(router, |router, path| self.add_route(router, path)) .fold(router, |router, path| self.add_route(router, path))
} }
fn add_route(&self, router: Router<State>, path: &str) -> Router<State> { fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State> {
let handle = self.clone(); let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse);
let method = method_to_filter(&Req::METADATA.method); let method = method_to_filter(&Req::METADATA.method);
let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) };
router.route(path, on(method, action)) router.route(path, on(method, action))
} }
} }

View File

@ -5,13 +5,18 @@ use http::StatusCode;
use http_body_util::Full; use http_body_util::Full;
use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse};
pub(crate) struct RumaResponse<T>(pub(crate) T); pub(crate) struct RumaResponse<T>(pub(crate) T)
where
T: OutgoingResponse;
impl From<Error> for RumaResponse<UiaaResponse> { impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { Self(t.into()) } fn from(t: Error) -> Self { Self(t.into()) }
} }
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { impl<T> IntoResponse for RumaResponse<T>
where
T: OutgoingResponse,
{
fn into_response(self) -> Response { fn into_response(self) -> Response {
self.0 self.0
.try_into_http_response::<BytesMut>() .try_into_http_response::<BytesMut>()

View File

@ -1,9 +1,13 @@
use std::cmp;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{
use ruma::{ is_equal_to,
api::{client::error::ErrorKind, federation::backfill::get_backfill}, utils::{IterStream, ReadyExt},
uint, user_id, MilliSecondsSinceUnixEpoch, Err, PduCount, Result,
}; };
use futures::{FutureExt, StreamExt};
use ruma::{api::federation::backfill::get_backfill, uint, user_id, MilliSecondsSinceUnixEpoch};
use crate::Ruma; use crate::Ruma;
@ -19,27 +23,35 @@ pub(crate) async fn get_backfill_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)
&& !services .await && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err!(Request(Forbidden("Server is not in room.")));
} }
let until = body let until = body
.v .v
.iter() .iter()
.map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) .stream()
.filter_map(|r| r.ok().flatten()) .filter_map(|event_id| {
.max() services
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; .rooms
.timeline
.get_pdu_count(event_id)
.map(Result::ok)
})
.ready_fold(PduCount::Backfilled(0), cmp::max)
.await;
let limit = body let limit = body
.limit .limit
@ -47,31 +59,37 @@ pub(crate) async fn get_backfill_route(
.try_into() .try_into()
.expect("UInt could not be converted to usize"); .expect("UInt could not be converted to usize");
let all_events = services let pdus = services
.rooms .rooms
.timeline .timeline
.pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)
.take(limit); .await?
.take(limit)
let events = all_events .filter_map(|(_, pdu)| async move {
.filter_map(Result::ok) if !services
.filter(|(_, e)| {
matches!(
services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, &e.room_id, &e.event_id,), .server_can_see_event(origin, &pdu.room_id, &pdu.event_id)
Ok(true), .await
) .is_ok_and(is_equal_to!(true))
{
return None;
}
services
.rooms
.timeline
.get_pdu_json(&pdu.event_id)
.await
.ok()
}) })
.map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.filter_map(|r| r.ok().flatten()) .collect()
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .await;
.collect();
Ok(get_backfill::v1::Response { Ok(get_backfill::v1::Response {
origin: services.globals.server_name().to_owned(), origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus: events, pdus,
}) })
} }

View File

@ -1,9 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{err, Err, Result};
use ruma::{ use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId};
api::{client::error::ErrorKind, federation::event::get_event},
MilliSecondsSinceUnixEpoch, RoomId,
};
use crate::Ruma; use crate::Ruma;
@ -21,34 +18,46 @@ pub(crate) async fn get_event_route(
let event = services let event = services
.rooms .rooms
.timeline .timeline
.get_pdu_json(&body.event_id)? .get_pdu_json(&body.event_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; .await
.map_err(|_| err!(Request(NotFound("Event not found."))))?;
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
.ok_or_else(|| Error::bad_database("Invalid event in database."))?; .ok_or_else(|| err!(Database("Invalid event in database.")))?;
let room_id = let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; <&RoomId>::try_from(room_id_str).map_err(|_| err!(Database("Invalid room_id in event in database.")))?;
if !services.rooms.state_accessor.is_world_readable(room_id)? if !services
&& !services.rooms.state_cache.server_in_room(origin, room_id)? .rooms
.state_accessor
.is_world_readable(room_id)
.await && !services
.rooms
.state_cache
.server_in_room(origin, room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err!(Request(Forbidden("Server is not in room.")));
} }
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, room_id, &body.event_id)? .server_can_see_event(origin, room_id, &body.event_id)
.await?
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event.")); return Err!(Request(Forbidden("Server is not allowed to see event.")));
} }
Ok(get_event::v1::Response { Ok(get_event::v1::Response {
origin: services.globals.server_name().to_owned(), origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdu: services.sending.convert_to_outgoing_federation_event(event), pdu: services
.sending
.convert_to_outgoing_federation_event(event)
.await,
}) })
} }

View File

@ -2,6 +2,7 @@ use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use futures::StreamExt;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
RoomId, RoomId,
@ -22,16 +23,18 @@ pub(crate) async fn get_event_authorization_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)
&& !services .await && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
} }
@ -39,8 +42,9 @@ pub(crate) async fn get_event_authorization_route(
let event = services let event = services
.rooms .rooms
.timeline .timeline
.get_pdu_json(&body.event_id)? .get_pdu_json(&body.event_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?;
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
@ -50,16 +54,17 @@ pub(crate) async fn get_event_authorization_route(
let room_id = let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
let auth_chain_ids = services let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
.await?; .await?
.filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() })
.then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect()
.await;
Ok(get_event_authorization::v1::Response { Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids auth_chain,
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
}) })
} }

View File

@ -18,16 +18,18 @@ pub(crate) async fn get_missing_events_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)
&& !services .await && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room"));
} }
@ -43,7 +45,12 @@ pub(crate) async fn get_missing_events_route(
let mut i: usize = 0; let mut i: usize = 0;
while i < queued_events.len() && events.len() < limit { while i < queued_events.len() && events.len() < limit {
if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { if let Ok(pdu) = services
.rooms
.timeline
.get_pdu_json(&queued_events[i])
.await
{
let room_id_str = pdu let room_id_str = pdu
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
@ -64,7 +71,8 @@ pub(crate) async fn get_missing_events_route(
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, &body.room_id, &queued_events[i])? .server_can_see_event(origin, &body.room_id, &queued_events[i])
.await?
{ {
i = i.saturating_add(1); i = i.saturating_add(1);
continue; continue;
@ -81,7 +89,12 @@ pub(crate) async fn get_missing_events_route(
) )
.map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?,
); );
events.push(services.sending.convert_to_outgoing_federation_event(pdu)); events.push(
services
.sending
.convert_to_outgoing_federation_event(pdu)
.await,
);
} }
i = i.saturating_add(1); i = i.saturating_add(1);
} }

View File

@ -12,7 +12,7 @@ pub(crate) async fn get_hierarchy_route(
) -> Result<get_hierarchy::v1::Response> { ) -> Result<get_hierarchy::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
if services.rooms.metadata.exists(&body.room_id)? { if services.rooms.metadata.exists(&body.room_id).await {
services services
.rooms .rooms
.spaces .spaces

View File

@ -24,7 +24,8 @@ pub(crate) async fn create_invite_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.globals .globals
@ -98,7 +99,8 @@ pub(crate) async fn create_invite_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(invited_user.server_name(), &body.room_id)?; .acl_check(invited_user.server_name(), &body.room_id)
.await?;
ruma::signatures::hash_and_sign_event( ruma::signatures::hash_and_sign_event(
services.globals.server_name().as_str(), services.globals.server_name().as_str(),
@ -128,14 +130,14 @@ pub(crate) async fn create_invite_route(
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?;
if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"This room is banned on this homeserver.", "This room is banned on this homeserver.",
)); ));
} }
if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"This server does not allow room invites.", "This server does not allow room invites.",
@ -159,9 +161,13 @@ pub(crate) async fn create_invite_route(
if !services if !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services.globals.server_name(), &body.room_id)? .server_in_room(services.globals.server_name(), &body.room_id)
.await
{ {
services.rooms.state_cache.update_membership( services
.rooms
.state_cache
.update_membership(
&body.room_id, &body.room_id,
&invited_user, &invited_user,
RoomMemberEventContent::new(MembershipState::Invite), RoomMemberEventContent::new(MembershipState::Invite),
@ -169,12 +175,14 @@ pub(crate) async fn create_invite_route(
Some(invite_state), Some(invite_state),
body.via.clone(), body.via.clone(),
true, true,
)?; )
.await?;
} }
Ok(create_invite::v2::Response { Ok(create_invite::v2::Response {
event: services event: services
.sending .sending
.convert_to_outgoing_federation_event(signed_event), .convert_to_outgoing_federation_event(signed_event)
.await,
}) })
} }

View File

@ -1,4 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduit::utils::{IterStream, ReadyExt};
use futures::StreamExt;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::prepare_join_event}, api::{client::error::ErrorKind, federation::membership::prepare_join_event},
events::{ events::{
@ -24,7 +26,7 @@ use crate::{
pub(crate) async fn create_join_event_template_route( pub(crate) async fn create_join_event_template_route(
State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>, State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>,
) -> Result<prepare_join_event::v1::Response> { ) -> Result<prepare_join_event::v1::Response> {
if !services.rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
@ -40,7 +42,8 @@ pub(crate) async fn create_join_event_template_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if services if services
.globals .globals
@ -73,7 +76,7 @@ pub(crate) async fn create_join_event_template_route(
} }
} }
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
@ -81,22 +84,24 @@ pub(crate) async fn create_join_event_template_route(
.rooms .rooms
.state_cache .state_cache
.is_left(&body.user_id, &body.room_id) .is_left(&body.user_id, &body.room_id)
.unwrap_or(true)) .await)
&& user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await?
{ {
let auth_user = services let auth_user = services
.rooms .rooms
.state_cache .state_cache
.room_members(&body.room_id) .room_members(&body.room_id)
.filter_map(Result::ok) .ready_filter(|user| user.server_name() == services.globals.server_name())
.filter(|user| user.server_name() == services.globals.server_name()) .filter(|user| {
.find(|user| {
services services
.rooms .rooms
.state_accessor .state_accessor
.user_can_invite(&body.room_id, user, &body.user_id, &state_lock) .user_can_invite(&body.room_id, user, &body.user_id, &state_lock)
.unwrap_or(false) })
}); .boxed()
.next()
.await
.map(ToOwned::to_owned);
if auth_user.is_some() { if auth_user.is_some() {
auth_user auth_user
@ -110,7 +115,7 @@ pub(crate) async fn create_join_event_template_route(
None None
}; };
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
if !body.ver.contains(&room_version_id) { if !body.ver.contains(&room_version_id) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion { ErrorKind::IncompatibleRoomVersion {
@ -132,7 +137,10 @@ pub(crate) async fn create_join_event_template_route(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( let (_pdu, mut pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content, content,
@ -144,7 +152,8 @@ pub(crate) async fn create_join_event_template_route(
&body.user_id, &body.user_id,
&body.room_id, &body.room_id,
&state_lock, &state_lock,
)?; )
.await?;
drop(state_lock); drop(state_lock);
@ -161,7 +170,7 @@ pub(crate) async fn create_join_event_template_route(
/// This doesn't check the current user's membership. This should be done /// This doesn't check the current user's membership. This should be done
/// externally, either by using the state cache or attempting to authorize the /// externally, either by using the state cache or attempting to authorize the
/// event. /// event.
pub(crate) fn user_can_perform_restricted_join( pub(crate) async fn user_can_perform_restricted_join(
services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId,
) -> Result<bool> { ) -> Result<bool> {
use RoomVersionId::*; use RoomVersionId::*;
@ -169,18 +178,15 @@ pub(crate) fn user_can_perform_restricted_join(
let join_rules_event = services let join_rules_event = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; .room_state_get(room_id, &StateEventType::RoomJoinRules, "")
.await;
let Some(join_rules_event_content) = join_rules_event let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| {
.as_ref()
.map(|join_rules_event| {
serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| { serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event in database: {e}"); warn!("Invalid join rules event in database: {e}");
Error::bad_database("Invalid join rules event in database") Error::bad_database("Invalid join rules event in database")
}) })
}) }) else {
.transpose()?
else {
return Ok(false); return Ok(false);
}; };
@ -201,13 +207,10 @@ pub(crate) fn user_can_perform_restricted_join(
None None
} }
}) })
.any(|m| { .stream()
services .any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id))
.rooms .await
.state_cache {
.is_joined(user_id, &m.room_id)
.unwrap_or(false)
}) {
Ok(true) Ok(true)
} else { } else {
Err(Error::BadRequest( Err(Error::BadRequest(

View File

@ -18,7 +18,7 @@ use crate::{service::pdu::PduBuilder, Ruma};
pub(crate) async fn create_leave_event_template_route( pub(crate) async fn create_leave_event_template_route(
State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>, State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>,
) -> Result<prepare_leave_event::v1::Response> { ) -> Result<prepare_leave_event::v1::Response> {
if !services.rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
@ -34,9 +34,10 @@ pub(crate) async fn create_leave_event_template_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?;
let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let content = to_raw_value(&RoomMemberEventContent { let content = to_raw_value(&RoomMemberEventContent {
avatar_url: None, avatar_url: None,
@ -50,7 +51,10 @@ pub(crate) async fn create_leave_event_template_route(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( let (_pdu, mut pdu_json) = services
.rooms
.timeline
.create_hash_and_sign_event(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content, content,
@ -62,7 +66,8 @@ pub(crate) async fn create_leave_event_template_route(
&body.user_id, &body.user_id,
&body.room_id, &body.room_id,
&state_lock, &state_lock,
)?; )
.await?;
drop(state_lock); drop(state_lock);

View File

@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route(
State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>, State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>,
) -> Result<get_openid_userinfo::v1::Response> { ) -> Result<get_openid_userinfo::v1::Response> {
Ok(get_openid_userinfo::v1::Response::new( Ok(get_openid_userinfo::v1::Response::new(
services.users.find_from_openid_token(&body.access_token)?, services
.users
.find_from_openid_token(&body.access_token)
.await?,
)) ))
} }

View File

@ -1,7 +1,8 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{err, Error, Result};
use futures::StreamExt;
use get_profile_information::v1::ProfileField; use get_profile_information::v1::ProfileField;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use ruma::{ use ruma::{
@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route(
let room_id = services let room_id = services
.rooms .rooms
.alias .alias
.resolve_local_alias(&body.room_alias)? .resolve_local_alias(&body.room_alias)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; .await
.map_err(|_| err!(Request(NotFound("Room alias not found."))))?;
let mut servers: Vec<OwnedServerName> = services let mut servers: Vec<OwnedServerName> = services
.rooms .rooms
.state_cache .state_cache
.room_servers(&room_id) .room_servers(&room_id)
.filter_map(Result::ok) .map(ToOwned::to_owned)
.collect(); .collect()
.await;
servers.sort_unstable(); servers.sort_unstable();
servers.dedup(); servers.dedup();
@ -82,30 +85,31 @@ pub(crate) async fn get_profile_information_route(
match &body.field { match &body.field {
Some(ProfileField::DisplayName) => { Some(ProfileField::DisplayName) => {
displayname = services.users.displayname(&body.user_id)?; displayname = services.users.displayname(&body.user_id).await.ok();
}, },
Some(ProfileField::AvatarUrl) => { Some(ProfileField::AvatarUrl) => {
avatar_url = services.users.avatar_url(&body.user_id)?; avatar_url = services.users.avatar_url(&body.user_id).await.ok();
blurhash = services.users.blurhash(&body.user_id)?; blurhash = services.users.blurhash(&body.user_id).await.ok();
}, },
Some(custom_field) => { Some(custom_field) => {
if let Some(value) = services if let Ok(value) = services
.users .users
.profile_key(&body.user_id, custom_field.as_str())? .profile_key(&body.user_id, custom_field.as_str())
.await
{ {
custom_profile_fields.insert(custom_field.to_string(), value); custom_profile_fields.insert(custom_field.to_string(), value);
} }
}, },
None => { None => {
displayname = services.users.displayname(&body.user_id)?; displayname = services.users.displayname(&body.user_id).await.ok();
avatar_url = services.users.avatar_url(&body.user_id)?; avatar_url = services.users.avatar_url(&body.user_id).await.ok();
blurhash = services.users.blurhash(&body.user_id)?; blurhash = services.users.blurhash(&body.user_id).await.ok();
tz = services.users.timezone(&body.user_id)?; tz = services.users.timezone(&body.user_id).await.ok();
custom_profile_fields = services custom_profile_fields = services
.users .users
.all_profile_keys(&body.user_id) .all_profile_keys(&body.user_id)
.filter_map(Result::ok) .collect()
.collect(); .await;
}, },
} }

View File

@ -2,7 +2,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use axum::extract::State; use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{debug, debug_warn, err, trace, warn, Err}; use conduit::{debug, debug_warn, err, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result};
use futures::StreamExt;
use ruma::{ use ruma::{
api::{ api::{
client::error::ErrorKind, client::error::ErrorKind,
@ -23,10 +24,13 @@ use tokio::sync::RwLock;
use crate::{ use crate::{
services::Services, services::Services,
utils::{self}, utils::{self},
Error, Result, Ruma, Ruma,
}; };
type ResolvedMap = BTreeMap<OwnedEventId, Result<(), Error>>; const PDU_LIMIT: usize = 50;
const EDU_LIMIT: usize = 100;
type ResolvedMap = BTreeMap<OwnedEventId, Result<()>>;
/// # `PUT /_matrix/federation/v1/send/{txnId}` /// # `PUT /_matrix/federation/v1/send/{txnId}`
/// ///
@ -44,12 +48,16 @@ pub(crate) async fn send_transaction_message_route(
))); )));
} }
if body.pdus.len() > 50_usize { if body.pdus.len() > PDU_LIMIT {
return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); return Err!(Request(Forbidden(
"Not allowed to send more than {PDU_LIMIT} PDUs in one transaction"
)));
} }
if body.edus.len() > 100_usize { if body.edus.len() > EDU_LIMIT {
return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); return Err!(Request(Forbidden(
"Not allowed to send more than {EDU_LIMIT} EDUs in one transaction"
)));
} }
let txn_start_time = Instant::now(); let txn_start_time = Instant::now();
@ -62,8 +70,8 @@ pub(crate) async fn send_transaction_message_route(
"Starting txn", "Starting txn",
); );
let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await;
handle_edus(&services, &client, &body, origin).await?; handle_edus(&services, &client, &body, origin).await;
debug!( debug!(
pdus = ?body.pdus.len(), pdus = ?body.pdus.len(),
@ -85,10 +93,10 @@ pub(crate) async fn send_transaction_message_route(
async fn handle_pdus( async fn handle_pdus(
services: &Services, _client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName, services: &Services, _client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName,
txn_start_time: &Instant, txn_start_time: &Instant,
) -> Result<ResolvedMap> { ) -> ResolvedMap {
let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); let mut parsed_pdus = Vec::with_capacity(body.pdus.len());
for pdu in &body.pdus { for pdu in &body.pdus {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => {
debug_warn!("Could not parse PDU: {e}"); debug_warn!("Could not parse PDU: {e}");
@ -151,38 +159,34 @@ async fn handle_pdus(
} }
} }
Ok(resolved_map) resolved_map
} }
async fn handle_edus( async fn handle_edus(
services: &Services, client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName, services: &Services, client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName,
) -> Result<()> { ) {
for edu in body for edu in body
.edus .edus
.iter() .iter()
.filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) .filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok())
{ {
match edu { match edu {
Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await,
Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await,
Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await,
Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await,
Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await,
Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await,
Edu::_Custom(ref _custom) => { Edu::_Custom(ref _custom) => {
debug_warn!(?body.edus, "received custom/unknown EDU"); debug_warn!(?body.edus, "received custom/unknown EDU");
}, },
} }
} }
Ok(())
} }
async fn handle_edu_presence( async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) {
services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent,
) -> Result<()> {
if !services.globals.allow_incoming_presence() { if !services.globals.allow_incoming_presence() {
return Ok(()); return;
} }
for update in presence.push { for update in presence.push {
@ -194,23 +198,24 @@ async fn handle_edu_presence(
continue; continue;
} }
services.presence.set_presence( services
.presence
.set_presence(
&update.user_id, &update.user_id,
&update.presence, &update.presence,
Some(update.currently_active), Some(update.currently_active),
Some(update.last_active_ago), Some(update.last_active_ago),
update.status_msg.clone(), update.status_msg.clone(),
)?; )
.await
.log_err()
.ok();
} }
Ok(())
} }
async fn handle_edu_receipt( async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) {
services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent,
) -> Result<()> {
if !services.globals.allow_incoming_read_receipts() { if !services.globals.allow_incoming_read_receipts() {
return Ok(()); return;
} }
for (room_id, room_updates) in receipt.receipts { for (room_id, room_updates) in receipt.receipts {
@ -218,6 +223,7 @@ async fn handle_edu_receipt(
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &room_id) .acl_check(origin, &room_id)
.await
.is_err() .is_err()
{ {
debug_warn!( debug_warn!(
@ -240,8 +246,8 @@ async fn handle_edu_receipt(
.rooms .rooms
.state_cache .state_cache
.room_members(&room_id) .room_members(&room_id)
.filter_map(Result::ok) .ready_any(|member| member.server_name() == user_id.server_name())
.any(|member| member.server_name() == user_id.server_name()) .await
{ {
for event_id in &user_updates.event_ids { for event_id in &user_updates.event_ids {
let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]);
@ -255,7 +261,8 @@ async fn handle_edu_receipt(
services services
.rooms .rooms
.read_receipt .read_receipt
.readreceipt_update(&user_id, &room_id, &event)?; .readreceipt_update(&user_id, &room_id, &event)
.await;
} }
} else { } else {
debug_warn!( debug_warn!(
@ -266,15 +273,11 @@ async fn handle_edu_receipt(
} }
} }
} }
Ok(())
} }
async fn handle_edu_typing( async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) {
services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent,
) -> Result<()> {
if !services.globals.config.allow_incoming_typing { if !services.globals.config.allow_incoming_typing {
return Ok(()); return;
} }
if typing.user_id.server_name() != origin { if typing.user_id.server_name() != origin {
@ -282,26 +285,28 @@ async fn handle_edu_typing(
%typing.user_id, %origin, %typing.user_id, %origin,
"received typing EDU for user not belonging to origin" "received typing EDU for user not belonging to origin"
); );
return Ok(()); return;
} }
if services if services
.rooms .rooms
.event_handler .event_handler
.acl_check(typing.user_id.server_name(), &typing.room_id) .acl_check(typing.user_id.server_name(), &typing.room_id)
.await
.is_err() .is_err()
{ {
debug_warn!( debug_warn!(
%typing.user_id, %typing.room_id, %origin, %typing.user_id, %typing.room_id, %origin,
"received typing EDU for ACL'd user's server" "received typing EDU for ACL'd user's server"
); );
return Ok(()); return;
} }
if services if services
.rooms .rooms
.state_cache .state_cache
.is_joined(&typing.user_id, &typing.room_id)? .is_joined(&typing.user_id, &typing.room_id)
.await
{ {
if typing.typing { if typing.typing {
let timeout = utils::millis_since_unix_epoch().saturating_add( let timeout = utils::millis_since_unix_epoch().saturating_add(
@ -315,28 +320,29 @@ async fn handle_edu_typing(
.rooms .rooms
.typing .typing
.typing_add(&typing.user_id, &typing.room_id, timeout) .typing_add(&typing.user_id, &typing.room_id, timeout)
.await?; .await
.log_err()
.ok();
} else { } else {
services services
.rooms .rooms
.typing .typing
.typing_remove(&typing.user_id, &typing.room_id) .typing_remove(&typing.user_id, &typing.room_id)
.await?; .await
.log_err()
.ok();
} }
} else { } else {
debug_warn!( debug_warn!(
%typing.user_id, %typing.room_id, %origin, %typing.user_id, %typing.room_id, %origin,
"received typing EDU for user not in room" "received typing EDU for user not in room"
); );
return Ok(());
} }
Ok(())
} }
async fn handle_edu_device_list_update( async fn handle_edu_device_list_update(
services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent,
) -> Result<()> { ) {
let DeviceListUpdateContent { let DeviceListUpdateContent {
user_id, user_id,
.. ..
@ -347,17 +353,15 @@ async fn handle_edu_device_list_update(
%user_id, %origin, %user_id, %origin,
"received device list update EDU for user not belonging to origin" "received device list update EDU for user not belonging to origin"
); );
return Ok(()); return;
} }
services.users.mark_device_key_update(&user_id)?; services.users.mark_device_key_update(&user_id).await;
Ok(())
} }
async fn handle_edu_direct_to_device( async fn handle_edu_direct_to_device(
services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent,
) -> Result<()> { ) {
let DirectDeviceContent { let DirectDeviceContent {
sender, sender,
ev_type, ev_type,
@ -370,45 +374,52 @@ async fn handle_edu_direct_to_device(
%sender, %origin, %sender, %origin,
"received direct to device EDU for user not belonging to origin" "received direct to device EDU for user not belonging to origin"
); );
return Ok(()); return;
} }
// Check if this is a new transaction id // Check if this is a new transaction id
if services if services
.transaction_ids .transaction_ids
.existing_txnid(&sender, None, &message_id)? .existing_txnid(&sender, None, &message_id)
.is_some() .await
.is_ok()
{ {
return Ok(()); return;
} }
for (target_user_id, map) in &messages { for (target_user_id, map) in &messages {
for (target_device_id_maybe, event) in map { for (target_device_id_maybe, event) in map {
let Ok(event) = event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))
else {
continue;
};
let ev_type = ev_type.to_string();
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services.users.add_to_device_event( services
&sender, .users
target_user_id, .add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event)
target_device_id, .await;
&ev_type.to_string(),
event
.deserialize_as()
.map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?,
)?;
}, },
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services.users.all_device_ids(target_user_id) { let (sender, ev_type, event) = (&sender, &ev_type, &event);
services
.users
.all_device_ids(target_user_id)
.for_each(|target_device_id| {
services.users.add_to_device_event( services.users.add_to_device_event(
&sender, sender,
target_user_id, target_user_id,
&target_device_id?, target_device_id,
&ev_type.to_string(), ev_type,
event event.clone(),
.deserialize_as() )
.map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, })
)?; .await;
}
}, },
} }
} }
@ -417,14 +428,12 @@ async fn handle_edu_direct_to_device(
// Save transaction id with empty data // Save transaction id with empty data
services services
.transaction_ids .transaction_ids
.add_txnid(&sender, None, &message_id, &[])?; .add_txnid(&sender, None, &message_id, &[]);
Ok(())
} }
async fn handle_edu_signing_key_update( async fn handle_edu_signing_key_update(
services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent,
) -> Result<()> { ) {
let SigningKeyUpdateContent { let SigningKeyUpdateContent {
user_id, user_id,
master_key, master_key,
@ -436,14 +445,15 @@ async fn handle_edu_signing_key_update(
%user_id, %origin, %user_id, %origin,
"received signing key update EDU from server that does not belong to user's server" "received signing key update EDU from server that does not belong to user's server"
); );
return Ok(()); return;
} }
if let Some(master_key) = master_key { if let Some(master_key) = master_key {
services services
.users .users
.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)
.await
.log_err()
.ok();
} }
Ok(())
} }

View File

@ -3,7 +3,8 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result};
use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::create_join_event}, api::{client::error::ErrorKind, federation::membership::create_join_event},
events::{ events::{
@ -22,27 +23,32 @@ use crate::Ruma;
async fn create_join_event( async fn create_join_event(
services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<create_join_event::v1::RoomState> { ) -> Result<create_join_event::v1::RoomState> {
if !services.rooms.metadata.exists(room_id)? { if !services.rooms.metadata.exists(room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
// ACL check origin server // ACL check origin server
services.rooms.event_handler.acl_check(origin, room_id)?; services
.rooms
.event_handler
.acl_check(origin, room_id)
.await?;
// We need to return the state prior to joining, let's keep a reference to that // We need to return the state prior to joining, let's keep a reference to that
// here // here
let shortstatehash = services let shortstatehash = services
.rooms .rooms
.state .state
.get_room_shortstatehash(room_id)? .get_room_shortstatehash(room_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?; .await
.map_err(|_| err!(Request(NotFound("Event state not found."))))?;
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
// let mut auth_cache = EventMap::new(); // let mut auth_cache = EventMap::new();
// We do not add the event_id field to the pdu here because of signature and // We do not add the event_id field to the pdu here because of signature and
// hashes checks // hashes checks
let room_version_id = services.rooms.state.get_room_version(room_id)?; let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json // Event could not be converted to canonical json
@ -97,7 +103,8 @@ async fn create_join_event(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(sender.server_name(), room_id)?; .acl_check(sender.server_name(), room_id)
.await?;
// check if origin server is trying to send for another server // check if origin server is trying to send for another server
if sender.server_name() != origin { if sender.server_name() != origin {
@ -126,7 +133,9 @@ async fn create_join_event(
if content if content
.join_authorized_via_users_server .join_authorized_via_users_server
.is_some_and(|user| services.globals.user_is_local(&user)) .is_some_and(|user| services.globals.user_is_local(&user))
&& super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id)
.await
.unwrap_or_default()
{ {
ruma::signatures::hash_and_sign_event( ruma::signatures::hash_and_sign_event(
services.globals.server_name().as_str(), services.globals.server_name().as_str(),
@ -158,12 +167,14 @@ async fn create_join_event(
.mutex_federation .mutex_federation
.lock(room_id) .lock(room_id)
.await; .await;
let pdu_id: Vec<u8> = services let pdu_id: Vec<u8> = services
.rooms .rooms
.event_handler .event_handler
.handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map)
.await? .await?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?;
drop(mutex_lock); drop(mutex_lock);
let state_ids = services let state_ids = services
@ -171,29 +182,43 @@ async fn create_join_event(
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await?; .await?;
let auth_chain_ids = services
let state = state_ids
.iter()
.try_stream()
.and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id))
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?;
let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, state_ids.values().cloned().collect()) .event_ids_iter(room_id, state_ids.values().cloned().collect())
.await?
.map(Ok)
.and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await })
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?; .await?;
services.sending.send_pdu_room(room_id, &pdu_id)?; services.sending.send_pdu_room(room_id, &pdu_id).await?;
Ok(create_join_event::v1::RoomState { Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids auth_chain,
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) state,
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
state: state_ids
.iter()
.filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten())
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
// Event field is required if the room version supports restricted join rules. // Event field is required if the room version supports restricted join rules.
event: Some( event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(),
to_raw_value(&CanonicalJsonValue::Object(value))
.expect("To raw json should not fail since only change was adding signature"),
),
}) })
} }

View File

@ -3,7 +3,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{utils::ReadyExt, Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::create_leave_event}, api::{client::error::ErrorKind, federation::membership::create_leave_event},
events::{ events::{
@ -49,18 +49,22 @@ pub(crate) async fn create_leave_event_v2_route(
async fn create_leave_event( async fn create_leave_event(
services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<()> { ) -> Result<()> {
if !services.rooms.metadata.exists(room_id)? { if !services.rooms.metadata.exists(room_id).await {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
// ACL check origin // ACL check origin
services.rooms.event_handler.acl_check(origin, room_id)?; services
.rooms
.event_handler
.acl_check(origin, room_id)
.await?;
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
// We do not add the event_id field to the pdu here because of signature and // We do not add the event_id field to the pdu here because of signature and
// hashes checks // hashes checks
let room_version_id = services.rooms.state.get_room_version(room_id)?; let room_version_id = services.rooms.state.get_room_version(room_id).await?;
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json // Event could not be converted to canonical json
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -114,7 +118,8 @@ async fn create_leave_event(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(sender.server_name(), room_id)?; .acl_check(sender.server_name(), room_id)
.await?;
if sender.server_name() != origin { if sender.server_name() != origin {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -173,10 +178,9 @@ async fn create_leave_event(
.rooms .rooms
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .ready_filter(|server| !services.globals.server_is_ours(server));
.filter(|server| !services.globals.server_is_ours(server));
services.sending.send_pdu_servers(servers, &pdu_id)?; services.sending.send_pdu_servers(servers, &pdu_id).await?;
Ok(()) Ok(())
} }

View File

@ -1,8 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{err, result::LogErr, utils::IterStream, Err, Result};
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; use futures::{FutureExt, StreamExt, TryStreamExt};
use ruma::api::federation::event::get_room_state;
use crate::Ruma; use crate::Ruma;
@ -17,56 +18,66 @@ pub(crate) async fn get_room_state_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)
&& !services .await && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err!(Request(Forbidden("Server is not in room.")));
} }
let shortstatehash = services let shortstatehash = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash(&body.event_id)? .pdu_shortstatehash(&body.event_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; .await
.map_err(|_| err!(Request(NotFound("PDU state not found."))))?;
let pdus = services let pdus = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await? .await
.into_values() .log_err()
.map(|id| { .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))?
.values()
.try_stream()
.and_then(|id| services.rooms.timeline.get_pdu_json(id))
.and_then(|pdu| {
services services
.sending .sending
.convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) .convert_to_outgoing_federation_event(pdu)
.map(Ok)
}) })
.collect(); .try_collect()
.await?;
let auth_chain_ids = services let auth_chain = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?
.map(Ok)
.and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await })
.and_then(|pdu| {
services
.sending
.convert_to_outgoing_federation_event(pdu)
.map(Ok)
})
.try_collect()
.await?; .await?;
Ok(get_room_state::v1::Response { Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids auth_chain,
.filter_map(|id| {
services
.rooms
.timeline
.get_pdu_json(&id)
.ok()?
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
})
.collect(),
pdus, pdus,
}) })
} }

View File

@ -1,9 +1,11 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; use conduit::{err, Err};
use futures::StreamExt;
use ruma::api::federation::event::get_room_state_ids;
use crate::{Error, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// # `GET /_matrix/federation/v1/state_ids/{roomId}`
/// ///
@ -17,31 +19,35 @@ pub(crate) async fn get_room_state_ids_route(
services services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)
.await?;
if !services if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)
&& !services .await && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)
.await
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err!(Request(Forbidden("Server is not in room.")));
} }
let shortstatehash = services let shortstatehash = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash(&body.event_id)? .pdu_shortstatehash(&body.event_id)
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; .await
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
let pdu_ids = services let pdu_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await? .await
.map_err(|_| err!(Request(NotFound("State ids not found"))))?
.into_values() .into_values()
.map(|id| (*id).to_owned()) .map(|id| (*id).to_owned())
.collect(); .collect();
@ -50,10 +56,13 @@ pub(crate) async fn get_room_state_ids_route(
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?; .await?
.map(|id| (*id).to_owned())
.collect()
.await;
Ok(get_room_state_ids::v1::Response { Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), auth_chain_ids,
pdu_ids, pdu_ids,
}) })
} }

View File

@ -1,5 +1,6 @@
use axum::extract::State; use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use futures::{FutureExt, StreamExt, TryFutureExt};
use ruma::api::{ use ruma::api::{
client::error::ErrorKind, client::error::ErrorKind,
federation::{ federation::{
@ -28,41 +29,51 @@ pub(crate) async fn get_devices_route(
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
let user_id = &body.user_id;
Ok(get_devices::v1::Response { Ok(get_devices::v1::Response {
user_id: body.user_id.clone(), user_id: user_id.clone(),
stream_id: services stream_id: services
.users .users
.get_devicelist_version(&body.user_id)? .get_devicelist_version(user_id)
.await
.unwrap_or(0) .unwrap_or(0)
.try_into() .try_into()?,
.expect("version will not grow that large"),
devices: services devices: services
.users .users
.all_devices_metadata(&body.user_id) .all_devices_metadata(user_id)
.filter_map(Result::ok) .filter_map(|metadata| async move {
.filter_map(|metadata| { let device_id = metadata.device_id.clone();
let device_id_string = metadata.device_id.as_str().to_owned(); let device_id_clone = device_id.clone();
let device_id_string = device_id.as_str().to_owned();
let device_display_name = if services.globals.allow_device_name_federation() { let device_display_name = if services.globals.allow_device_name_federation() {
metadata.display_name metadata.display_name.clone()
} else { } else {
Some(device_id_string) Some(device_id_string)
}; };
Some(UserDevice {
keys: services services
.users .users
.get_device_keys(&body.user_id, &metadata.device_id) .get_device_keys(user_id, &device_id_clone)
.ok()??, .map_ok(|keys| UserDevice {
device_id: metadata.device_id, device_id,
keys,
device_display_name, device_display_name,
}) })
.map(Result::ok)
.await
}) })
.collect(), .collect()
.await,
master_key: services master_key: services
.users .users
.get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)
.await
.ok(),
self_signing_key: services self_signing_key: services
.users .users
.get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)
.await
.ok(),
}) })
} }

View File

@ -67,6 +67,7 @@ ctor.workspace = true
cyborgtime.workspace = true cyborgtime.workspace = true
either.workspace = true either.workspace = true
figment.workspace = true figment.workspace = true
futures.workspace = true
http-body-util.workspace = true http-body-util.workspace = true
http.workspace = true http.workspace = true
image.workspace = true image.workspace = true

View File

@ -86,7 +86,7 @@ pub enum Error {
#[error("There was a problem with the '{0}' directive in your configuration: {1}")] #[error("There was a problem with the '{0}' directive in your configuration: {1}")]
Config(&'static str, Cow<'static, str>), Config(&'static str, Cow<'static, str>),
#[error("{0}")] #[error("{0}")]
Conflict(&'static str), // This is only needed for when a room alias already exists Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists
#[error(transparent)] #[error(transparent)]
ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError), ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError),
#[error("{0}")] #[error("{0}")]
@ -107,6 +107,8 @@ pub enum Error {
Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode),
#[error(transparent)] #[error(transparent)]
Ruma(#[from] ruma::api::client::error::Error), Ruma(#[from] ruma::api::client::error::Error),
#[error(transparent)]
StateRes(#[from] ruma::state_res::Error),
#[error("uiaa")] #[error("uiaa")]
Uiaa(ruma::api::client::uiaa::UiaaInfo), Uiaa(ruma::api::client::uiaa::UiaaInfo),

View File

@ -3,8 +3,6 @@ mod count;
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
pub use builder::PduBuilder;
pub use count::PduCount;
use ruma::{ use ruma::{
canonical_json::redact_content_in_place, canonical_json::redact_content_in_place,
events::{ events::{
@ -23,7 +21,8 @@ use serde_json::{
value::{to_raw_value, RawValue as RawJsonValue}, value::{to_raw_value, RawValue as RawJsonValue},
}; };
use crate::{err, warn, Error}; pub use self::{builder::PduBuilder, count::PduCount};
use crate::{err, warn, Error, Result};
#[derive(Deserialize)] #[derive(Deserialize)]
struct ExtractRedactedBecause { struct ExtractRedactedBecause {
@ -65,11 +64,12 @@ pub struct PduEvent {
impl PduEvent { impl PduEvent {
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result<()> {
self.unsigned = None; self.unsigned = None;
let mut content = serde_json::from_str(self.content.get()) let mut content = serde_json::from_str(self.content.get())
.map_err(|_| Error::bad_database("PDU in db has invalid content."))?; .map_err(|_| Error::bad_database("PDU in db has invalid content."))?;
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
.map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?;
@ -98,31 +98,38 @@ impl PduEvent {
unsigned.redacted_because.is_some() unsigned.redacted_because.is_some()
} }
pub fn remove_transaction_id(&mut self) -> crate::Result<()> { pub fn remove_transaction_id(&mut self) -> Result<()> {
if let Some(unsigned) = &self.unsigned { let Some(unsigned) = &self.unsigned else {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get()) return Ok(());
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; };
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?;
unsigned.remove("transaction_id"); unsigned.remove("transaction_id");
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); self.unsigned = to_raw_value(&unsigned)
} .map(Some)
.expect("unsigned is valid");
Ok(()) Ok(())
} }
pub fn add_age(&mut self) -> crate::Result<()> { pub fn add_age(&mut self) -> Result<()> {
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self
.unsigned .unsigned
.as_ref() .as_ref()
.map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get()))
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?;
// deliberately allowing for the possibility of negative age // deliberately allowing for the possibility of negative age
let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into();
let then: i128 = self.origin_server_ts.into(); let then: i128 = self.origin_server_ts.into();
let this_age = now.saturating_sub(then); let this_age = now.saturating_sub(then);
unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid"));
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); self.unsigned = to_raw_value(&unsigned)
.map(Some)
.expect("unsigned is valid");
Ok(()) Ok(())
} }
@ -369,9 +376,9 @@ impl state_res::Event for PduEvent {
fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) } fn prev_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.prev_events.iter() }
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) } fn auth_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.auth_events.iter() }
fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
} }
@ -395,7 +402,7 @@ impl Ord for PduEvent {
/// CanonicalJsonValue>`. /// CanonicalJsonValue>`.
pub fn gen_event_id_canonical_json( pub fn gen_event_id_canonical_json(
pdu: &RawJsonValue, room_version_id: &RoomVersionId, pdu: &RawJsonValue, room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { ) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
.map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?;

View File

@ -1,18 +1,14 @@
use std::fmt; use std::fmt::Debug;
use tracing::Level; use tracing::Level;
use super::{DebugInspect, Result}; use super::{DebugInspect, Result};
use crate::error; use crate::error;
pub trait LogDebugErr<T, E> pub trait LogDebugErr<T, E: Debug> {
where
E: fmt::Debug,
{
#[must_use] #[must_use]
fn err_debug_log(self, level: Level) -> Self; fn err_debug_log(self, level: Level) -> Self;
#[inline]
#[must_use] #[must_use]
fn log_debug_err(self) -> Self fn log_debug_err(self) -> Self
where where
@ -22,15 +18,9 @@ where
} }
} }
impl<T, E> LogDebugErr<T, E> for Result<T, E> impl<T, E: Debug> LogDebugErr<T, E> for Result<T, E> {
where
E: fmt::Debug,
{
#[inline] #[inline]
fn err_debug_log(self, level: Level) -> Self fn err_debug_log(self, level: Level) -> Self {
where
Self: Sized,
{
self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level))
} }
} }

View File

@ -1,18 +1,14 @@
use std::fmt; use std::fmt::Display;
use tracing::Level; use tracing::Level;
use super::Result; use super::Result;
use crate::error; use crate::error;
pub trait LogErr<T, E> pub trait LogErr<T, E: Display> {
where
E: fmt::Display,
{
#[must_use] #[must_use]
fn err_log(self, level: Level) -> Self; fn err_log(self, level: Level) -> Self;
#[inline]
#[must_use] #[must_use]
fn log_err(self) -> Self fn log_err(self) -> Self
where where
@ -22,15 +18,7 @@ where
} }
} }
impl<T, E> LogErr<T, E> for Result<T, E> impl<T, E: Display> LogErr<T, E> for Result<T, E> {
where
E: fmt::Display,
{
#[inline] #[inline]
fn err_log(self, level: Level) -> Self fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) }
where
Self: Sized,
{
self.inspect_err(|error| error::inspect_log_level(&error, level))
}
} }

View File

@ -1,25 +0,0 @@
use std::cmp::Ordering;
#[allow(clippy::impl_trait_in_params)]
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering,
) -> Option<impl Iterator<Item = Vec<u8>>> {
let first_iterator = iterators.next()?;
let mut other_iterators = iterators.map(Iterator::peekable).collect::<Vec<_>>();
Some(first_iterator.filter(move |target| {
other_iterators.iter_mut().all(|it| {
while let Some(element) = it.peek() {
match check_order(element, target) {
Ordering::Greater => return false, // We went too far
Ordering::Equal => return true, // Element is in both iters
Ordering::Less => {
// Keep searching
it.next();
},
}
}
false
})
}))
}

View File

@ -1,4 +1,3 @@
pub mod algorithm;
pub mod bytes; pub mod bytes;
pub mod content_disposition; pub mod content_disposition;
pub mod debug; pub mod debug;
@ -9,25 +8,30 @@ pub mod json;
pub mod math; pub mod math;
pub mod mutex_map; pub mod mutex_map;
pub mod rand; pub mod rand;
pub mod set;
pub mod stream;
pub mod string; pub mod string;
pub mod sys; pub mod sys;
mod tests; mod tests;
pub mod time; pub mod time;
pub use ::conduit_macros::implement;
pub use ::ctor::{ctor, dtor}; pub use ::ctor::{ctor, dtor};
pub use algorithm::common_elements;
pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; pub use self::{
pub use conduit_macros::implement; bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8},
pub use debug::slice_truncated as debug_slice_truncated; debug::slice_truncated as debug_slice_truncated,
pub use hash::calculate_hash; hash::calculate_hash,
pub use html::Escape as HtmlEscape; html::Escape as HtmlEscape,
pub use json::{deserialize_from_str, to_canonical_object}; json::{deserialize_from_str, to_canonical_object},
pub use math::clamp; math::clamp,
pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; mutex_map::{Guard as MutexMapGuard, MutexMap},
pub use rand::string as random_string; rand::string as random_string,
pub use string::{str_from_bytes, string_from_bytes}; stream::{IterStream, ReadyExt, TryReadyExt},
pub use sys::available_parallelism; string::{str_from_bytes, string_from_bytes},
pub use time::now_millis as millis_since_unix_epoch; sys::available_parallelism,
time::now_millis as millis_since_unix_epoch,
};
#[inline] #[inline]
pub fn exchange<T>(state: &mut T, source: T) -> T { std::mem::replace(state, source) } pub fn exchange<T>(state: &mut T, source: T) -> T { std::mem::replace(state, source) }

47
src/core/utils/set.rs Normal file
View File

@ -0,0 +1,47 @@
use std::cmp::{Eq, Ord};
use crate::{is_equal_to, is_less_than};
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs do not have to
/// be sorted. If inputs are sorted a more optimized function is available in
/// this suite and should be used.
pub fn intersection<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq + Send,
{
input.next().into_iter().flat_map(move |first| {
let input = input.clone();
first.filter(move |targ| {
input
.clone()
.all(|mut other| other.any(is_equal_to!(*targ)))
})
})
}
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs must be sorted.
pub fn intersection_sorted<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq + Ord + Send,
{
input.next().into_iter().flat_map(move |first| {
let mut input = input.clone().collect::<Vec<_>>();
first.filter(move |targ| {
input.iter_mut().all(|it| {
it.by_ref()
.skip_while(is_less_than!(targ))
.peekable()
.peek()
.is_some_and(is_equal_to!(targ))
})
})
})
}

View File

@ -0,0 +1,20 @@
use std::clone::Clone;
use futures::{stream::Map, Stream, StreamExt};
pub trait Cloned<'a, T, S>
where
S: Stream<Item = &'a T>,
T: Clone + 'a,
{
fn cloned(self) -> Map<S, fn(&T) -> T>;
}
impl<'a, T, S> Cloned<'a, T, S> for S
where
S: Stream<Item = &'a T>,
T: Clone + 'a,
{
#[inline]
fn cloned(self) -> Map<S, fn(&T) -> T> { self.map(Clone::clone) }
}

View File

@ -0,0 +1,17 @@
use futures::{Stream, StreamExt, TryStream};
use crate::Result;
pub trait TryExpect<'a, Item> {
fn expect_ok(self) -> impl Stream<Item = Item> + Send + 'a;
}
impl<'a, T, Item> TryExpect<'a, Item> for T
where
T: Stream<Item = Result<Item>> + TryStream + Send + 'a,
{
#[inline]
fn expect_ok(self: T) -> impl Stream<Item = Item> + Send + 'a {
self.map(|res| res.expect("stream expectation failure"))
}
}

View File

@ -0,0 +1,21 @@
use futures::{future::ready, Stream, StreamExt, TryStream};
use crate::{Error, Result};
pub trait TryIgnore<'a, Item> {
fn ignore_err(self) -> impl Stream<Item = Item> + Send + 'a;
fn ignore_ok(self) -> impl Stream<Item = Error> + Send + 'a;
}
impl<'a, T, Item> TryIgnore<'a, Item> for T
where
T: Stream<Item = Result<Item>> + TryStream + Send + 'a,
Item: Send + 'a,
{
#[inline]
fn ignore_err(self: T) -> impl Stream<Item = Item> + Send + 'a { self.filter_map(|res| ready(res.ok())) }
#[inline]
fn ignore_ok(self: T) -> impl Stream<Item = Error> + Send + 'a { self.filter_map(|res| ready(res.err())) }
}

View File

@ -0,0 +1,27 @@
use futures::{
stream,
stream::{Stream, TryStream},
StreamExt,
};
pub trait IterStream<I: IntoIterator + Send> {
/// Convert an Iterator into a Stream
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send;
/// Convert an Iterator into a TryStream
fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send;
}
impl<I> IterStream<I> for I
where
I: IntoIterator + Send,
<I as IntoIterator>::IntoIter: Send,
{
#[inline]
fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send { stream::iter(self) }
#[inline]
fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send {
self.stream().map(Ok)
}
}

View File

@ -0,0 +1,13 @@
mod cloned;
mod expect;
mod ignore;
mod iter_stream;
mod ready;
mod try_ready;
pub use cloned::Cloned;
pub use expect::TryExpect;
pub use ignore::TryIgnore;
pub use iter_stream::IterStream;
pub use ready::ReadyExt;
pub use try_ready::TryReadyExt;

View File

@ -0,0 +1,109 @@
//! Synchronous combinator extensions to futures::Stream
use futures::{
future::{ready, Ready},
stream::{Any, Filter, FilterMap, Fold, ForEach, SkipWhile, Stream, StreamExt, TakeWhile},
};
/// Synchronous combinators to augment futures::StreamExt. Most Stream
/// combinators take asynchronous arguments, but often only simple predicates
/// are required to steer a Stream like an Iterator. This suite provides a
/// convenience to reduce boilerplate by de-cluttering non-async predicates.
///
/// This interface is not necessarily complete; feel free to add as-needed.
pub trait ReadyExt<Item, S>
where
S: Stream<Item = Item> + Send + ?Sized,
Self: Stream + Send + Sized,
{
fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(S::Item) -> Ready<bool>>
where
F: Fn(S::Item) -> bool;
fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(S::Item) -> Ready<Option<U>>>
where
F: Fn(S::Item) -> Option<U>;
fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, S::Item) -> Ready<T>>
where
F: Fn(T, S::Item) -> T;
fn ready_for_each<F>(self, f: F) -> ForEach<Self, Ready<()>, impl FnMut(S::Item) -> Ready<()>>
where
F: FnMut(S::Item);
fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a;
}
impl<Item, S> ReadyExt<Item, S> for S
where
S: Stream<Item = Item> + Send + ?Sized,
Self: Stream + Send + Sized,
{
#[inline]
fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(S::Item) -> Ready<bool>>
where
F: Fn(S::Item) -> bool,
{
self.any(move |t| ready(f(t)))
}
#[inline]
fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.filter(move |t| ready(f(t)))
}
#[inline]
fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(S::Item) -> Ready<Option<U>>>
where
F: Fn(S::Item) -> Option<U>,
{
self.filter_map(move |t| ready(f(t)))
}
#[inline]
fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, S::Item) -> Ready<T>>
where
F: Fn(T, S::Item) -> T,
{
self.fold(init, move |a, t| ready(f(a, t)))
}
#[inline]
#[allow(clippy::unit_arg)]
fn ready_for_each<F>(self, mut f: F) -> ForEach<Self, Ready<()>, impl FnMut(S::Item) -> Ready<()>>
where
F: FnMut(S::Item),
{
self.for_each(move |t| ready(f(t)))
}
#[inline]
fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.take_while(move |t| ready(f(t)))
}
#[inline]
fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&S::Item) -> Ready<bool> + 'a>
where
F: Fn(&S::Item) -> bool + 'a,
{
self.skip_while(move |t| ready(f(t)))
}
}

View File

@ -0,0 +1,35 @@
//! Synchronous combinator extensions to futures::TryStream
use futures::{
future::{ready, Ready},
stream::{AndThen, TryStream, TryStreamExt},
};
use crate::Result;
/// Synchronous combinators to augment futures::TryStreamExt.
///
/// This interface is not necessarily complete; feel free to add as-needed.
pub trait TryReadyExt<T, E, S>
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>>
where
F: Fn(S::Ok) -> Result<U, E>;
}
impl<T, E, S> TryReadyExt<T, E, S> for S
where
S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized,
Self: TryStream + Send + Sized,
{
#[inline]
fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>>
where
F: Fn(S::Ok) -> Result<U, E>,
{
self.and_then(move |t| ready(f(t)))
}
}

View File

@ -107,3 +107,133 @@ async fn mutex_map_contend() {
tokio::try_join!(join_b, join_a).expect("joined"); tokio::try_join!(join_b, join_a).expect("joined");
assert!(map.is_empty(), "Must be empty"); assert!(map.is_empty(), "Must be empty");
} }
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_none() {
use utils::set::intersection;
let a: [&str; 0] = [];
let b: [&str; 0] = [];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let a: [&str; 0] = [];
let b = ["abc", "def"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let i = [b.iter(), a.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let i = [a.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
let a = ["foo", "bar", "baz"];
let b = ["def", "hij", "klm", "nop"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert_eq!(r.count(), 0);
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_all() {
use utils::set::intersection;
let a = ["foo"];
let b = ["foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["foo", "bar"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "bar"].iter()));
let i = [b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let a = ["foo", "bar", "baz"];
let b = ["baz", "foo", "bar"];
let c = ["bar", "baz", "foo"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "bar", "baz"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_some() {
use utils::set::intersection;
let a = ["foo"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let i = [b.iter(), a.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["abcdef", "foo", "hijkl", "abc"];
let b = ["hij", "bar", "baz", "abc", "foo"];
let c = ["abc", "xyz", "foo", "ghi"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection(i.into_iter());
assert!(r.eq(["foo", "abc"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_sorted_some() {
use utils::set::intersection_sorted;
let a = ["bar"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar"].iter()));
let i = [b.iter(), a.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar"].iter()));
let a = ["aaa", "ccc", "eee", "ggg"];
let b = ["aaa", "bbb", "ccc", "ddd", "eee"];
let c = ["bbb", "ccc", "eee", "fff"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["ccc", "eee"].iter()));
}
#[test]
#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)]
fn set_intersection_sorted_all() {
use utils::set::intersection_sorted;
let a = ["foo"];
let b = ["foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["foo"].iter()));
let a = ["bar", "foo"];
let b = ["bar", "foo"];
let i = [a.iter(), b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let i = [b.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "foo"].iter()));
let a = ["bar", "baz", "foo"];
let b = ["bar", "baz", "foo"];
let c = ["bar", "baz", "foo"];
let i = [a.iter(), b.iter(), c.iter()];
let r = intersection_sorted(i.into_iter());
assert!(r.eq(["bar", "baz", "foo"].iter()));
}

View File

@ -37,8 +37,11 @@ zstd_compression = [
[dependencies] [dependencies]
conduit-core.workspace = true conduit-core.workspace = true
const-str.workspace = true const-str.workspace = true
futures.workspace = true
log.workspace = true log.workspace = true
rust-rocksdb.workspace = true rust-rocksdb.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true

View File

@ -37,7 +37,7 @@ impl Database {
pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) }
#[inline] #[inline]
pub fn iter_maps(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + '_ { self.map.iter() } pub fn iter_maps(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + Send + '_ { self.map.iter() }
} }
impl Index<&str> for Database { impl Index<&str> for Database {

261
src/database/de.rs Normal file
View File

@ -0,0 +1,261 @@
use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result};
use serde::{
de,
de::{DeserializeSeed, Visitor},
Deserialize,
};
pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result<T>
where
T: Deserialize<'a>,
{
let mut deserializer = Deserializer {
buf,
pos: 0,
};
T::deserialize(&mut deserializer).debug_inspect(|_| {
deserializer
.finished()
.expect("deserialization failed to consume trailing bytes");
})
}
pub(crate) struct Deserializer<'de> {
buf: &'de [u8],
pos: usize,
}
/// Directive to ignore a record. This type can be used to skip deserialization
/// until the next separator is found.
#[derive(Debug, Deserialize)]
pub struct Ignore;
impl<'de> Deserializer<'de> {
const SEP: u8 = b'\xFF';
fn finished(&self) -> Result<()> {
let pos = self.pos;
let len = self.buf.len();
let parsed = &self.buf[0..pos];
let unparsed = &self.buf[pos..];
let remain = checked!(len - pos)?;
let trailing_sep = remain == 1 && unparsed[0] == Self::SEP;
(remain == 0 || trailing_sep)
.then_some(())
.ok_or(err!(SerdeDe(
"{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}",
)))
}
#[inline]
fn record_next(&mut self) -> &'de [u8] {
self.buf[self.pos..]
.split(|b| *b == Deserializer::SEP)
.inspect(|record| self.inc_pos(record.len()))
.next()
.expect("remainder of buf even if SEP was not found")
}
#[inline]
fn record_trail(&mut self) -> &'de [u8] {
let record = &self.buf[self.pos..];
self.inc_pos(record.len());
record
}
#[inline]
fn record_start(&mut self) {
let started = self.pos != 0;
debug_assert!(
!started || self.buf[self.pos] == Self::SEP,
"Missing expected record separator at current position"
);
self.inc_pos(started.into());
}
#[inline]
fn inc_pos(&mut self, n: usize) {
self.pos = self.pos.saturating_add(n);
debug_assert!(self.pos <= self.buf.len(), "pos out of range");
}
}
impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Map not implemented")
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_struct<V>(
self, _name: &'static str, _fields: &'static [&'static str], _visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Struct not implemented")
}
fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match name {
"Ignore" => self.record_next(),
_ => unimplemented!("Unrecognized deserialization Directive {name:?}"),
};
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Newtype Struct not implemented")
}
fn deserialize_enum<V>(
self, _name: &'static str, _variants: &'static [&'static str], _visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!("deserialize Enum not implemented")
}
fn deserialize_option<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Option not implemented")
}
fn deserialize_bool<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize bool not implemented")
}
fn deserialize_i8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i8 not implemented")
}
fn deserialize_i16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i16 not implemented")
}
fn deserialize_i32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize i32 not implemented")
}
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let bytes: [u8; size_of::<i64>()] = self.buf[self.pos..].try_into()?;
self.pos = self.pos.saturating_add(size_of::<i64>());
visitor.visit_i64(i64::from_be_bytes(bytes))
}
fn deserialize_u8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u8 not implemented")
}
fn deserialize_u16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u16 not implemented")
}
fn deserialize_u32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize u32 not implemented")
}
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let bytes: [u8; size_of::<u64>()] = self.buf[self.pos..].try_into()?;
self.pos = self.pos.saturating_add(size_of::<u64>());
visitor.visit_u64(u64::from_be_bytes(bytes))
}
fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize f32 not implemented")
}
fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize f64 not implemented")
}
fn deserialize_char<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize char not implemented")
}
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = string::str_from_bytes(input)?;
visitor.visit_borrowed_str(out)
}
fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_next();
let out = string::string_from_bytes(input)?;
visitor.visit_string(out)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
let input = self.record_trail();
visitor.visit_borrowed_bytes(input)
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Byte Buf not implemented")
}
fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Unit Struct not implemented")
}
fn deserialize_identifier<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Identifier not implemented")
}
fn deserialize_ignored_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize Ignored Any not implemented")
}
fn deserialize_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
unimplemented!("deserialize any not implemented")
}
}
impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.pos >= self.buf.len() {
return Ok(None);
}
self.record_start();
seed.deserialize(&mut **self).map(Some)
}
}

View File

@ -0,0 +1,34 @@
use std::convert::identity;
use conduit::Result;
use serde::Deserialize;
pub trait Deserialized {
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>;
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>;
#[inline]
fn deserialized<T>(self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
Self: Sized,
{
self.map_de(identity::<T>)
}
#[inline]
fn deserialized_json<T>(self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
Self: Sized,
{
self.map_json(identity::<T>)
}
}

View File

@ -106,7 +106,7 @@ impl Engine {
})) }))
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self), level = "trace")]
pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> { pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> {
let mut cfs = self.cfs.lock().expect("locked"); let mut cfs = self.cfs.lock().expect("locked");
if !cfs.contains(name) { if !cfs.contains(name) {

View File

@ -1,6 +1,10 @@
use std::ops::Deref; use std::{fmt, fmt::Debug, ops::Deref};
use conduit::Result;
use rocksdb::DBPinnableSlice; use rocksdb::DBPinnableSlice;
use serde::{Deserialize, Serialize, Serializer};
use crate::{keyval::deserialize_val, Deserialized, Slice};
pub struct Handle<'a> { pub struct Handle<'a> {
val: DBPinnableSlice<'a>, val: DBPinnableSlice<'a>,
@ -14,14 +18,91 @@ impl<'a> From<DBPinnableSlice<'a>> for Handle<'a> {
} }
} }
impl Debug for Handle<'_> {
fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result {
let val: &Slice = self;
let ptr = val.as_ptr();
let len = val.len();
write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}")
}
}
impl Serialize for Handle<'_> {
#[inline]
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes: &Slice = self;
serializer.serialize_bytes(bytes)
}
}
impl Deref for Handle<'_> { impl Deref for Handle<'_> {
type Target = [u8]; type Target = Slice;
#[inline] #[inline]
fn deref(&self) -> &Self::Target { &self.val } fn deref(&self) -> &Self::Target { &self.val }
} }
impl AsRef<[u8]> for Handle<'_> { impl AsRef<Slice> for Handle<'_> {
#[inline] #[inline]
fn as_ref(&self) -> &[u8] { &self.val } fn as_ref(&self) -> &Slice { &self.val }
}
impl Deserialized for Result<Handle<'_>> {
#[inline]
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self?.map_json(f)
}
#[inline]
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self?.map_de(f)
}
}
impl<'a> Deserialized for Result<&'a Handle<'a>> {
#[inline]
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self.and_then(|handle| handle.map_json(f))
}
#[inline]
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
self.and_then(|handle| handle.map_de(f))
}
}
impl<'a> Deserialized for &'a Handle<'a> {
fn map_json<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
serde_json::from_slice::<T>(self.as_ref())
.map_err(Into::into)
.map(f)
}
fn map_de<T, U, F>(self, f: F) -> Result<U>
where
F: FnOnce(T) -> U,
T: for<'de> Deserialize<'de>,
{
deserialize_val(self.as_ref()).map(f)
}
} }

View File

@ -1,110 +0,0 @@
use std::{iter::FusedIterator, sync::Arc};
use conduit::Result;
use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions};
use crate::{
engine::Db,
result,
slice::{OwnedKeyVal, OwnedKeyValPair},
Engine,
};
type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>;
struct State<'cursor> {
cursor: Cursor<'cursor>,
direction: Direction,
valid: bool,
init: bool,
}
impl<'cursor> State<'cursor> {
pub(crate) fn new(
db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>,
) -> Self {
let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts);
let direction = into_direction(mode);
let valid = seek_init(&mut cursor, mode);
Self {
cursor,
direction,
valid,
init: true,
}
}
}
pub struct Iter<'cursor> {
state: State<'cursor>,
}
impl<'cursor> Iter<'cursor> {
pub(crate) fn new(
db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>,
) -> Self {
Self {
state: State::new(db, cf, opts, mode),
}
}
}
impl Iterator for Iter<'_> {
type Item = OwnedKeyValPair;
fn next(&mut self) -> Option<Self::Item> {
if !self.state.init && self.state.valid {
seek_next(&mut self.state.cursor, self.state.direction);
} else if self.state.init {
self.state.init = false;
}
self.state
.cursor
.item()
.map(OwnedKeyVal::from)
.map(OwnedKeyVal::to_tuple)
.or_else(|| {
when_invalid(&mut self.state).expect("iterator invalidated due to error");
None
})
}
}
impl FusedIterator for Iter<'_> {}
fn when_invalid(state: &mut State<'_>) -> Result<()> {
state.valid = false;
result(state.cursor.status())
}
fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) {
match direction {
Direction::Forward => cursor.next(),
Direction::Reverse => cursor.prev(),
}
}
fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool {
use Direction::{Forward, Reverse};
use IteratorMode::{End, From, Start};
match mode {
Start => cursor.seek_to_first(),
End => cursor.seek_to_last(),
From(key, Forward) => cursor.seek(key),
From(key, Reverse) => cursor.seek_for_prev(key),
};
cursor.valid()
}
fn into_direction(mode: &IteratorMode<'_>) -> Direction {
use Direction::{Forward, Reverse};
use IteratorMode::{End, From, Start};
match mode {
Start | From(_, Forward) => Forward,
End | From(_, Reverse) => Reverse,
}
}

Some files were not shown because too many files have changed in this diff Show More