From 10aa23a9bf383419aa131723922be1c82c0f2e8c Mon Sep 17 00:00:00 2001 From: strawberry Date: Sat, 16 Nov 2024 11:46:05 -0500 Subject: [PATCH] implement knocking on rooms Signed-off-by: strawberry --- Cargo.lock | 28 +- Cargo.toml | 2 +- src/api/client/membership.rs | 591 +++++++++++++++++++++++++-- src/api/client/sync/v3.rs | 41 +- src/api/client/sync/v4.rs | 9 + src/api/client/user_directory.rs | 2 +- src/api/router.rs | 3 + src/api/server/invite.rs | 4 +- src/api/server/make_join.rs | 12 +- src/api/server/make_knock.rs | 18 +- src/api/server/mod.rs | 4 + src/api/server/send_join.rs | 39 +- src/api/server/send_knock.rs | 4 +- src/api/server/send_leave.rs | 22 +- src/database/maps.rs | 2 + src/service/migrations.rs | 85 +++- src/service/rooms/state_cache/mod.rs | 121 +++++- src/service/rooms/timeline/mod.rs | 13 +- 18 files changed, 869 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3a95f83a..aea1dff3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2862,7 +2862,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.87", @@ -3128,7 +3128,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "assign", "js_int", @@ -3150,7 +3150,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "js_int", "ruma-common", @@ -3162,7 +3162,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "as_variant", "assign", @@ -3185,7 +3185,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "as_variant", "base64 0.22.1", @@ -3215,7 +3215,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3239,7 +3239,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "bytes", "http", @@ -3257,7 +3257,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "js_int", "thiserror 2.0.3", @@ -3266,7 +3266,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "js_int", "ruma-common", @@ -3276,7 +3276,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "cfg-if", "once_cell", @@ -3292,7 +3292,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "js_int", "ruma-common", @@ -3304,7 +3304,7 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "headers", "http", @@ -3317,7 +3317,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3333,7 +3333,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +source = "git+https://github.com/girlbossceo/ruwuma?rev=97e2fb6df13f65532d33fc2f0f097ad5a449dd70#97e2fb6df13f65532d33fc2f0f097ad5a449dd70" dependencies = [ "futures-util", "itertools 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 68c87c57..7132fd63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -322,7 +322,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "2ab432fba19eb8862c594d24af39d8f9f6b4eac6" +rev = "97e2fb6df13f65532d33fc2f0f097ad5a449dd70" features = [ "compat", "rand", diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 9478e383..b28880ea 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -19,6 +19,7 @@ use ruma::{ api::{ client::{ error::ErrorKind, + knock::knock_room, membership::{ ban_user, forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, joined_members::{self, v3::RoomMember}, @@ -151,21 +152,14 @@ async fn banned_room_check( /// rules locally /// - If the server does not know about the room: asks other servers over /// federation -#[tracing::instrument(skip_all, fields(%client_ip), name = "join")] +#[tracing::instrument(skip_all, fields(%client), name = "join")] pub(crate) async fn join_room_by_id_route( - State(services): State, InsecureClientIp(client_ip): InsecureClientIp, + State(services): State, InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - banned_room_check( - &services, - sender_user, - Some(&body.room_id), - body.room_id.server_name(), - client_ip, - ) - .await?; + banned_room_check(&services, sender_user, Some(&body.room_id), body.room_id.server_name(), client).await?; // There is no body.server_name for /roomId/join let mut servers: Vec<_> = services @@ -324,6 +318,101 @@ pub(crate) async fn join_room_by_id_or_alias_route( }) } +/// # `POST /_matrix/client/*/knock/{roomIdOrAlias}` +/// +/// Tries to knock the room to ask permission to join for the sender user. +#[tracing::instrument(skip_all, fields(%client), name = "knock")] +pub(crate) async fn knock_room_route( + State(services): State, InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; + + let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + Ok(room_id) => { + banned_room_check(&services, sender_user, Some(&room_id), room_id.server_name(), client).await?; + + let mut servers = body.via.clone(); + servers.extend( + services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned) + .collect::>() + .await, + ); + + servers.extend( + services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default() + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + if let Some(server) = room_id.server_name() { + servers.push(server.to_owned()); + } + + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + + (servers, room_id) + }, + Err(room_alias) => { + let (room_id, mut servers) = services + .rooms + .alias + .resolve_alias(&room_alias, Some(body.via.clone())) + .await?; + + banned_room_check(&services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; + + let addl_via_servers = services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned); + + let addl_state_servers = services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default(); + + let mut addl_servers: Vec<_> = addl_state_servers + .iter() + .map(|event| event.get_field("sender")) + .filter_map(FlatOk::flat_ok) + .map(|user: &UserId| user.server_name().to_owned()) + .stream() + .chain(addl_via_servers) + .collect() + .await; + + addl_servers.sort_unstable(); + addl_servers.dedup(); + shuffle(&mut addl_servers); + servers.append(&mut addl_servers); + + (servers, room_id) + }, + }; + + knock_room_by_id_helper(&services, sender_user, &room_id, body.reason.clone(), &servers) + .boxed() + .await +} + /// # `POST /_matrix/client/v3/rooms/{roomId}/leave` /// /// Tries to leave the sender user from a room. @@ -668,6 +757,18 @@ pub async fn join_room_by_id_helper( }); } + if let Ok(membership) = services + .rooms + .state_accessor + .get_member(room_id, sender_user) + .await + { + if membership.membership == MembershipState::Ban { + debug_warn!("{sender_user} is banned from {room_id} but attempted to join"); + return Err!(Request(Forbidden("You are banned from the room."))); + } + } + let server_in_room = services .rooms .state_cache @@ -1027,7 +1128,7 @@ async fn join_room_by_id_helper_local( services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result { - debug!("We can join locally"); + debug_info!("We can join locally"); let join_rules_event_content = services .rooms @@ -1053,8 +1154,7 @@ async fn join_room_by_id_helper_local( let local_members: Vec<_> = services .rooms .state_cache - .room_members(room_id) - .ready_filter(|user| services.globals.user_is_local(user)) + .local_users_in_room(room_id) .map(ToOwned::to_owned) .collect() .await; @@ -1142,7 +1242,7 @@ async fn join_room_by_id_helper_local( .as_str() }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - // TODO: Is origin needed? + join_event_stub.insert( "origin".to_owned(), CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), @@ -1494,7 +1594,8 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, .rooms .state_cache .invite_state(user_id, room_id) - .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .or_else(|_| services.rooms.state_cache.knock_state(user_id, room_id)) + .or_else(|_| services.rooms.state_cache.left_state(user_id, room_id)) .await .ok(); @@ -1566,13 +1667,6 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut make_leave_response_and_server = Err!(BadServerResponse("No server available to assist in leaving.")); - let invite_state = services - .rooms - .state_cache - .invite_state(user_id, room_id) - .await - .map_err(|_| err!(Request(BadState("User is not invited."))))?; - let mut servers: HashSet = services .rooms .state_cache @@ -1581,15 +1675,45 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room .collect() .await; - servers.extend( - invite_state - .iter() - .filter_map(|event| event.get_field("sender").ok().flatten()) - .filter_map(|sender: &str| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + if let Ok(invite_state) = services + .rooms + .state_cache + .invite_state(user_id, room_id) + .await + { + servers.extend( + invite_state + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + } else if let Ok(knock_state) = services + .rooms + .state_cache + .knock_state(user_id, room_id) + .await + { + servers.extend( + knock_state + .iter() + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) + .filter_map(|sender| { + if !services.globals.user_is_local(&sender) { + Some(sender.server_name().to_owned()) + } else { + None + } + }), + ); - debug!("servers in remote_leave_room: {servers:?}"); + if let Some(room_id_server_name) = room_id.server_name() { + servers.insert(room_id_server_name.to_owned()); + } + } + + debug_info!("servers in remote_leave_room: {servers:?}"); for remote_server in servers { let make_leave_response = services @@ -1683,3 +1807,410 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room Ok(()) } + +async fn knock_room_by_id_helper( + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], +) -> Result { + let state_lock = services.rooms.state.mutex.lock(room_id).await; + + if services + .rooms + .state_cache + .is_invited(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already invited in {room_id} but attempted to knock"); + return Err!(Request(Forbidden( + "You cannot knock on a room you are already invited/accepted to." + ))); + } + + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already joined in {room_id} but attempted to knock"); + return Err!(Request(Forbidden("You cannot knock on a room you are already joined in."))); + } + + if services + .rooms + .state_cache + .is_knocked(sender_user, room_id) + .await + { + debug_warn!("{sender_user} is already knocked in {room_id}"); + return Ok(knock_room::v3::Response { + room_id: room_id.into(), + }); + } + + if let Ok(membership) = services + .rooms + .state_accessor + .get_member(room_id, sender_user) + .await + { + if membership.membership == MembershipState::Ban { + debug_warn!("{sender_user} is banned from {room_id} but attempted to knock"); + return Err!(Request(Forbidden("You cannot knock on a room you are banned from."))); + } + } + + let server_in_room = services + .rooms + .state_cache + .server_in_room(services.globals.server_name(), room_id) + .await; + + let local_knock = + server_in_room || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); + + if local_knock { + knock_room_helper_local(services, sender_user, room_id, reason, servers, state_lock) + .boxed() + .await?; + } else { + knock_room_helper_remote(services, sender_user, room_id, reason, servers, state_lock) + .boxed() + .await?; + } + + Ok(knock_room::v3::Response::new(room_id.to_owned())) +} + +async fn knock_room_helper_local( + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], + state_lock: RoomMutexGuard, +) -> Result { + debug_info!("We can knock locally"); + + let room_version_id = services.rooms.state.get_room_version(room_id).await?; + + if matches!( + room_version_id, + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + ) { + return Err!(Request(Forbidden("This room does not support knocking."))); + } + + let content = RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason: reason.clone(), + ..RoomMemberEventContent::new(MembershipState::Knock) + }; + + // Try normal knock first + let error = match services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state(sender_user.to_string(), &content), + sender_user, + room_id, + &state_lock, + ) + .await + { + Ok(_) => return Ok(()), + Err(e) => e, + }; + + if servers + .iter() + .any(|server_name| !services.globals.server_is_ours(server_name)) + { + warn!("We couldn't do the knock locally, maybe federation can help to satisfy the knock"); + + let (make_knock_response, remote_server) = make_knock_request(services, sender_user, room_id, servers).await?; + + info!("make_knock finished"); + + let room_version_id = make_knock_response.room_version; + + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse("Room version is not supported")); + } + + let mut knock_event_stub: CanonicalJsonObject = serde_json::from_str(make_knock_response.event.get()) + .map_err(|e| err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")))?; + + knock_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), + ); + knock_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + knock_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason, + ..RoomMemberEventContent::new(MembershipState::Knock) + }) + .expect("event is valid, we just created it"), + ); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + services + .server_keys + .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; + + // Generate event id + let event_id = pdu::gen_event_id(&knock_event_stub, &room_version_id)?; + + // Add event_id + knock_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); + + // It has enough fields to be called a proper event now + let knock_event = knock_event_stub; + + info!("Asking {remote_server} for send_knock in room {room_id}"); + let send_knock_request = federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + pdu: services + .sending + .convert_to_outgoing_federation_event(knock_event.clone()) + .await, + }; + + let send_knock_response = services + .sending + .send_federation_request(&remote_server, send_knock_request) + .await?; + + info!("send_knock finished"); + + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; + + info!("Parsing knock event"); + + let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) + .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; + + info!("Updating membership locally to knock state with provided stripped state events"); + services + .rooms + .state_cache + .update_membership( + room_id, + sender_user, + parsed_knock_pdu + .get_content::() + .expect("we just created this"), + sender_user, + Some(send_knock_response.knock_room_state), + None, + false, + ) + .await?; + + info!("Appending room knock event locally"); + services + .rooms + .timeline + .append_pdu( + &parsed_knock_pdu, + knock_event, + vec![(*parsed_knock_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + } else { + return Err(error); + } + + Ok(()) +} + +async fn knock_room_helper_remote( + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], + state_lock: RoomMutexGuard, +) -> Result { + info!("Knocking {room_id} over federation."); + + let (make_knock_response, remote_server) = make_knock_request(services, sender_user, room_id, servers).await?; + + info!("make_knock finished"); + + let room_version_id = make_knock_response.room_version; + + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse("Room version is not supported")); + } + + let mut knock_event_stub: CanonicalJsonObject = serde_json::from_str(make_knock_response.event.get()) + .map_err(|e| err!(BadServerResponse("Invalid make_knock event json received from server: {e:?}")))?; + + knock_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), + ); + knock_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + knock_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), + reason, + ..RoomMemberEventContent::new(MembershipState::Knock) + }) + .expect("event is valid, we just created it"), + ); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + services + .server_keys + .hash_and_sign_event(&mut knock_event_stub, &room_version_id)?; + + // Generate event id + let event_id = pdu::gen_event_id(&knock_event_stub, &room_version_id)?; + + // Add event_id + knock_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); + + // It has enough fields to be called a proper event now + let knock_event = knock_event_stub; + + info!("Asking {remote_server} for send_knock in room {room_id}"); + let send_knock_request = federation::knock::send_knock::v1::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + pdu: services + .sending + .convert_to_outgoing_federation_event(knock_event.clone()) + .await, + }; + + let send_knock_response = services + .sending + .send_federation_request(&remote_server, send_knock_request) + .await?; + + info!("send_knock finished"); + + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; + + info!("Parsing knock event"); + + let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) + .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; + + info!("Updating membership locally to knock state with provided stripped state events"); + services + .rooms + .state_cache + .update_membership( + room_id, + sender_user, + parsed_knock_pdu + .get_content::() + .expect("we just created this"), + sender_user, + Some(send_knock_response.knock_room_state), + None, + false, + ) + .await?; + + info!("Appending room knock event locally"); + services + .rooms + .timeline + .append_pdu( + &parsed_knock_pdu, + knock_event, + vec![(*parsed_knock_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + Ok(()) +} + +async fn make_knock_request( + services: &Services, sender_user: &UserId, room_id: &RoomId, servers: &[OwnedServerName], +) -> Result<(federation::knock::create_knock_event_template::v1::Response, OwnedServerName)> { + let mut make_knock_response_and_server = Err!(BadServerResponse("No server available to assist in knocking.")); + + let mut make_knock_counter: usize = 0; + + for remote_server in servers { + if services.globals.server_is_ours(remote_server) { + continue; + } + info!("Asking {remote_server} for make_knock ({make_knock_counter})"); + let make_knock_response = services + .sending + .send_federation_request( + remote_server, + federation::knock::create_knock_event_template::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services.globals.supported_room_versions(), + }, + ) + .await; + + trace!("make_knock response: {make_knock_response:?}"); + make_knock_counter = make_knock_counter.saturating_add(1); + + make_knock_response_and_server = make_knock_response.map(|r| (r, remote_server.clone())); + + if make_knock_response_and_server.is_ok() { + break; + } + + if make_knock_counter > 50 { + warn!("50 servers failed to provide valid make_knock response, assuming no server can assist in knocking."); + make_knock_response_and_server = Err!(BadServerResponse("No server available to assist in knocking.")); + return make_knock_response_and_server; + } + } + + make_knock_response_and_server +} diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 7a78ea74..bb9a4afb 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -18,8 +18,8 @@ use ruma::{ sync::sync_events::{ self, v3::{ - Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, - RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, + Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, KnockedRoom, LeftRoom, + Presence, RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, }, DeviceLists, UnreadNotificationsCount, }, @@ -241,6 +241,41 @@ pub(crate) async fn sync_events_route( ); } + let mut knocked_rooms = BTreeMap::new(); + let all_knocked_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_knocked(&sender_user) + .collect() + .await; + + for (room_id, knock_state_events) in all_knocked_rooms { + // Get and drop the lock to wait for remaining operations to finish + let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; + drop(insert_lock); + + let knock_count = services + .rooms + .state_cache + .get_knock_count(&room_id, &sender_user) + .await + .ok(); + + // Knocked before last sync + if Some(since) >= knock_count { + continue; + } + + knocked_rooms.insert( + room_id.clone(), + KnockedRoom { + knock_state: sync_events::v3::KnockState { + events: knock_state_events, + }, + }, + ); + } + for user_id in left_encrypted_users { let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; @@ -263,7 +298,7 @@ pub(crate) async fn sync_events_route( leave: left_rooms, join: joined_rooms, invite: invited_rooms, - knock: BTreeMap::new(), // TODO + knock: knocked_rooms, }, presence: Presence { events: presence_updates diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 57edc953..365fee65 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -107,9 +107,18 @@ pub(crate) async fn sync_events_v4_route( .collect() .await; + let all_knocked_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_knocked(sender_user) + .map(|r| r.0) + .collect() + .await; + let all_rooms = all_joined_rooms .iter() .chain(all_invited_rooms.iter()) + .chain(all_knocked_rooms.iter()) .map(Clone::clone) .collect(); diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 868811a3..b903909c 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -21,7 +21,7 @@ pub(crate) async fn search_users_route( State(services): State, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 + let limit = usize::try_from(body.limit).map_or(10, usize::from).min(100); // default limit is 10 let users = services.users.stream().filter_map(|user_id| async { // Filter out buggy users (they should not exist, but you never know...) diff --git a/src/api/router.rs b/src/api/router.rs index 4bdd692d..2ff6bb67 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -99,6 +99,7 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::join_room_by_id_route) .ruma_route(&client::join_room_by_id_or_alias_route) .ruma_route(&client::joined_members_route) + .ruma_route(&client::knock_room_route) .ruma_route(&client::leave_room_route) .ruma_route(&client::forget_room_route) .ruma_route(&client::joined_rooms_route) @@ -200,9 +201,11 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&server::get_event_authorization_route) .ruma_route(&server::get_room_state_route) .ruma_route(&server::get_room_state_ids_route) + .ruma_route(&server::create_knock_event_template_route) .ruma_route(&server::create_leave_event_template_route) .ruma_route(&server::create_leave_event_v1_route) .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_knock_event_v1_route) .ruma_route(&server::create_join_event_template_route) .ruma_route(&server::create_join_event_v1_route) .ruma_route(&server::create_join_event_v2_route) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 0ceb914f..361c64b5 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -117,12 +117,12 @@ pub(crate) async fn create_invite_route( let mut invite_state = body.invite_room_state.clone(); let mut event: JsonObject = serde_json::from_str(body.event.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event PDU."))?; event.insert("event_id".to_owned(), "$placeholder".into()); let pdu: PduEvent = serde_json::from_value(event.into()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event."))?; + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event PDU."))?; invite_state.push(pdu.to_stripped_state_event()); diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index d5ea675e..57ee9315 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,7 +1,7 @@ use axum::extract::State; use conduit::{ utils::{IterStream, ReadyExt}, - warn, + warn, Err, }; use futures::StreamExt; use ruma::{ @@ -59,10 +59,7 @@ pub(crate) async fn create_join_event_template_route( &body.user_id, &body.room_id, ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } if let Some(server) = body.room_id.server_name() { @@ -72,10 +69,7 @@ pub(crate) async fn create_join_event_template_route( .forbidden_remote_server_names .contains(&server.to_owned()) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } diff --git a/src/api/server/make_knock.rs b/src/api/server/make_knock.rs index c1875a1f..6d998f77 100644 --- a/src/api/server/make_knock.rs +++ b/src/api/server/make_knock.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use conduit::Err; +use conduit::{debug_warn, Err}; use ruma::{ api::{client::error::ErrorKind, federation::knock::create_knock_event_template}, events::room::member::{MembershipState, RoomMemberEventContent}, @@ -82,6 +82,22 @@ pub(crate) async fn create_knock_event_template_route( )); } + if let Ok(membership) = services + .rooms + .state_accessor + .get_member(&body.room_id, &body.user_id) + .await + { + if membership.membership == MembershipState::Ban { + debug_warn!( + "Remote user {} is banned from {} but attempted to knock", + &body.user_id, + &body.room_id + ); + return Err!(Request(Forbidden("You cannot knock on a room you are banned from."))); + } + } + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let (_pdu, mut pdu_json) = services diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 9b7d91cb..5c1ff3f7 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -6,6 +6,7 @@ pub(super) mod hierarchy; pub(super) mod invite; pub(super) mod key; pub(super) mod make_join; +pub(super) mod make_knock; pub(super) mod make_leave; pub(super) mod media; pub(super) mod openid; @@ -13,6 +14,7 @@ pub(super) mod publicrooms; pub(super) mod query; pub(super) mod send; pub(super) mod send_join; +pub(super) mod send_knock; pub(super) mod send_leave; pub(super) mod state; pub(super) mod state_ids; @@ -28,6 +30,7 @@ pub(super) use hierarchy::*; pub(super) use invite::*; pub(super) use key::*; pub(super) use make_join::*; +pub(super) use make_knock::*; pub(super) use make_leave::*; pub(super) use media::*; pub(super) use openid::*; @@ -35,6 +38,7 @@ pub(super) use publicrooms::*; pub(super) use query::*; pub(super) use send::*; pub(super) use send_join::*; +pub(super) use send_knock::*; pub(super) use send_leave::*; pub(super) use state::*; pub(super) use state_ids::*; diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 60ec8c1f..a823d6ab 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use axum::extract::State; -use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; +use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Err, Error, Result}; use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, @@ -126,18 +126,19 @@ async fn create_join_event( )); }; - if content + let should_sign_join_event = content .join_authorized_via_users_server .is_some_and(|user| services.globals.user_is_local(&user)) && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id) .await - .unwrap_or_default() - { + .unwrap_or_default(); + + if should_sign_join_event { services .server_keys .hash_and_sign_event(&mut value, &room_version_id) .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; - } + }; let origin: OwnedServerName = serde_json::from_value( serde_json::to_value( @@ -206,8 +207,12 @@ async fn create_join_event( Ok(create_join_event::v1::RoomState { auth_chain, state, - // Event field is required if the room version supports restricted join rules. - event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), + // Event field is required if the room is using restricted join rules and we sign the event + event: if should_sign_join_event { + to_raw_value(&CanonicalJsonValue::Object(value)).ok() + } else { + None + }, }) } @@ -228,10 +233,7 @@ pub(crate) async fn create_join_event_v1_route( body.origin(), &body.room_id, ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } if let Some(server) = body.room_id.server_name() { @@ -246,10 +248,7 @@ pub(crate) async fn create_join_event_v1_route( body.origin(), &body.room_id, ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -274,10 +273,7 @@ pub(crate) async fn create_join_event_v2_route( .forbidden_remote_server_names .contains(body.origin()) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } if let Some(server) = body.room_id.server_name() { @@ -287,10 +283,7 @@ pub(crate) async fn create_join_event_v2_route( .forbidden_remote_server_names .contains(&server.to_owned()) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs index c57998ae..1a1519d4 100644 --- a/src/api/server/send_knock.rs +++ b/src/api/server/send_knock.rs @@ -177,13 +177,13 @@ pub(crate) async fn create_knock_event_v1_route( drop(mutex_lock); - let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; - services .sending .send_pdu_room(&body.room_id, &pdu_id) .await?; + let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; + Ok(send_knock::v1::Response { knock_room_state, }) diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index e4f41833..7451d5bf 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,7 +1,7 @@ #![allow(deprecated)] use axum::extract::State; -use conduit::{err, utils::ReadyExt, Error, Result}; +use conduit::{err, Err, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -74,10 +74,9 @@ async fn create_leave_event( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event content is empty or invalid"))?; if content.membership != MembershipState::Leave { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Not allowed to send a non-leave membership event to leave endpoint.", - )); + return Err!(Request(InvalidParam( + "Not allowed to send a non-leave membership event to leave endpoint." + ))); } let event_type: StateEventType = serde_json::from_value( @@ -90,10 +89,9 @@ async fn create_leave_event( .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Event does not have a valid state event type."))?; if event_type != StateEventType::RoomMember { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, + return Err!(Request(InvalidParam( "Not allowed to send non-membership state event to leave endpoint.", - )); + ))); } // ACL check sender server name @@ -151,11 +149,5 @@ async fn create_leave_event( drop(mutex_lock); - let servers = services - .rooms - .state_cache - .room_servers(room_id) - .ready_filter(|server| !services.globals.server_is_ours(server)); - - services.sending.send_pdu_servers(servers, &pdu_id).await + services.sending.send_pdu_room(room_id, &pdu_id).await } diff --git a/src/database/maps.rs b/src/database/maps.rs index 0e835abf..d3924590 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -56,6 +56,7 @@ pub const MAPS: &[&str] = &[ "roomuserid_joined", "roomuserid_lastprivatereadupdate", "roomuserid_leftcount", + "roomuserid_knockedcount", "roomuserid_privateread", "roomuseroncejoinedids", "roomusertype_roomuserdataid", @@ -100,5 +101,6 @@ pub const MAPS: &[&str] = &[ "userroomid_invitestate", "userroomid_joined", "userroomid_leftstate", + "userroomid_knockedstate", "userroomid_notificationcount", ]; diff --git a/src/service/migrations.rs b/src/service/migrations.rs index 126d3c7e..60729abf 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -1,4 +1,4 @@ -use std::cmp; +use std::{cmp, collections::HashSet}; use conduit::{ debug, debug_info, debug_warn, error, info, @@ -14,7 +14,7 @@ use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, push::Ruleset, - OwnedUserId, UserId, + UserId, }; use crate::{media, Services}; @@ -69,6 +69,7 @@ async fn fresh(services: &Services) -> Result<()> { db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); db["global"].insert(b"fix_referencedevents_missing_sep", []); + db["global"].insert(b"update_knocked_user_memberships_locally", []); // Create the admin room and server user on first run crate::admin::create_admin_room(services).boxed().await?; @@ -130,6 +131,14 @@ async fn migrate(services: &Services) -> Result<()> { fix_referencedevents_missing_sep(services).await?; } + if db["global"] + .get(b"update_knocked_user_memberships_locally") + .await + .is_not_found() + { + update_knocked_user_memberships_locally(services).await?; + } + let version_match = services.globals.db.database_version().await == DATABASE_VERSION || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; @@ -371,24 +380,24 @@ async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result< Ok(()) } -async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> { +async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result { warn!("Retroactively fixing bad data from broken roomuserid_joined"); let db = &services.db; - let _cork = db.cork_and_sync(); + let cork = db.cork_and_sync(); let room_ids = services .rooms .metadata .iter_ids() .map(ToOwned::to_owned) - .collect::>() + .collect::>() .await; for room_id in &room_ids { debug_info!("Fixing room {room_id}"); - let users_in_room: Vec = services + let users_in_room: HashSet<_> = services .rooms .state_cache .room_members(room_id) @@ -406,7 +415,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .get_member(room_id, user_id) .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect::>() + .collect::>() .await; let non_joined_members = users_in_room @@ -419,7 +428,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .get_member(room_id, user_id) .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) }) - .collect::>() + .collect::>() .await; for user_id in &joined_members { @@ -445,11 +454,11 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) .await; } - db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); - + drop(cork); info!("Finished fixing"); - Ok(()) + + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); + db.db.cleanup() } async fn fix_referencedevents_missing_sep(services: &Services) -> Result { @@ -493,3 +502,55 @@ async fn fix_referencedevents_missing_sep(services: &Services) -> Result { db["global"].insert(b"fix_referencedevents_missing_sep", []); db.db.cleanup() } + +async fn update_knocked_user_memberships_locally(services: &Services) -> Result { + info!("Updating database of knocked users locally"); + + let db = &services.db; + let cork = db.cork_and_sync(); + + let room_ids = services + .rooms + .metadata + .iter_ids() + .collect::>() + .await; + + for room_id in room_ids { + debug_info!("Updating {room_id}"); + + let users_in_room: HashSet<_> = services + .rooms + .state_cache + .room_members(room_id) + .collect() + .await; + + let knocked_members = users_in_room + .iter() + .stream() + .filter(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Knock)) + }) + .collect::>() + .await; + + for user_id in knocked_members { + debug_info!("Making {user_id} as knocked"); + services + .rooms + .state_cache + .mark_as_knocked(user_id, room_id, None); + } + } + + drop(cork); + info!("Finished updating knocked user memberships locally in database"); + + db["global"].insert(b"update_knocked_user_memberships_locally", []); + db.db.cleanup() +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 6e330fdc..6587b69f 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -4,13 +4,13 @@ use std::{ }; use conduit::{ - err, is_not_empty, + debug_warn, is_not_empty, result::LogErr, utils::{stream::TryIgnore, ReadyExt, StreamTools}, warn, Result, }; use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json, Map}; -use futures::{future::join4, stream::iter, Stream, StreamExt}; +use futures::{future::join5, stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ @@ -50,11 +50,13 @@ struct Data { roomuserid_invitecount: Arc, roomuserid_joined: Arc, roomuserid_leftcount: Arc, + roomuserid_knockedcount: Arc, roomuseroncejoinedids: Arc, serverroomids: Arc, userroomid_invitestate: Arc, userroomid_joined: Arc, userroomid_leftstate: Arc, + userroomid_knockedstate: Arc, } type AppServiceInRoomCache = RwLock>>; @@ -79,11 +81,13 @@ impl crate::Service for Service { roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(), roomuserid_joined: args.db["roomuserid_joined"].clone(), roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(), + roomuserid_knockedcount: args.db["roomuserid_knockedcount"].clone(), roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(), serverroomids: args.db["serverroomids"].clone(), userroomid_invitestate: args.db["userroomid_invitestate"].clone(), userroomid_joined: args.db["userroomid_joined"].clone(), userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + userroomid_knockedstate: args.db["userroomid_knockedstate"].clone(), }, })) } @@ -235,7 +239,12 @@ impl Service { MembershipState::Leave | MembershipState::Ban => { self.mark_as_left(user_id, room_id); }, - _ => {}, + MembershipState::Knock => { + self.mark_as_knocked(user_id, room_id, last_state); + }, + _ => { + debug_warn!("unknown membership state received: {membership:?}"); + }, } if update_joined_count { @@ -303,6 +312,9 @@ impl Service { self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + self.db.roomid_inviteviaservers.remove(room_id); } @@ -332,6 +344,41 @@ impl Service { self.db.userroomid_invitestate.remove(&userroom_id); self.db.roomuserid_invitecount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); + } + + /// Direct DB function to directly mark a user as knocked. It is not + /// recommended to use this directly. You most likely should use + /// `update_membership` instead + #[tracing::instrument(skip(self), level = "debug")] + pub fn mark_as_knocked( + &self, user_id: &UserId, room_id: &RoomId, knocked_state: Option>>, + ) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db + .userroomid_knockedstate + .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); + self.db + .roomuserid_knockedcount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.roomid_inviteviaservers.remove(room_id); } @@ -472,6 +519,16 @@ impl Service { .deserialized() } + #[tracing::instrument(skip(self), level = "debug")] + pub async fn get_knock_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + let key = (room_id, user_id); + self.db + .roomuserid_knockedcount + .qry(&key) + .await + .deserialized() + } + #[tracing::instrument(skip(self), level = "debug")] pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { let key = (room_id, user_id); @@ -504,6 +561,22 @@ impl Service { .ignore_err() } + /// Returns an iterator over all rooms a user is currently knocking. + #[tracing::instrument(skip(self), level = "debug")] + pub fn rooms_knocked<'a>(&'a self, user_id: &'a UserId) -> impl Stream + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_knockedstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() + } + #[tracing::instrument(skip(self), level = "debug")] pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { let key = (user_id, room_id); @@ -515,6 +588,17 @@ impl Service { .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) } + #[tracing::instrument(skip(self), level = "debug")] + pub async fn knock_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { + let key = (user_id, room_id); + self.db + .userroomid_knockedstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw>| val.deserialize_as().map_err(Into::into)) + } + #[tracing::instrument(skip(self), level = "debug")] pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>> { let key = (user_id, room_id); @@ -554,6 +638,12 @@ impl Service { self.db.userroomid_joined.qry(&key).await.is_ok() } + #[tracing::instrument(skip(self), level = "debug")] + pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_knockedstate.qry(&key).await.is_ok() + } + #[tracing::instrument(skip(self), level = "debug")] pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { let key = (user_id, room_id); @@ -567,9 +657,10 @@ impl Service { } pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option { - let states = join4( + let states = join5( self.is_joined(user_id, room_id), self.is_left(user_id, room_id), + self.is_knocked(user_id, room_id), self.is_invited(user_id, room_id), self.once_joined(user_id, room_id), ) @@ -578,8 +669,9 @@ impl Service { match states { (true, ..) => Some(MembershipState::Join), (_, true, ..) => Some(MembershipState::Leave), - (_, _, true, ..) => Some(MembershipState::Invite), - (false, false, false, true) => Some(MembershipState::Ban), + (_, _, true, ..) => Some(MembershipState::Knock), + (_, _, _, true, ..) => Some(MembershipState::Invite), + (false, false, false, false, true) => Some(MembershipState::Ban), _ => None, } } @@ -595,10 +687,10 @@ impl Service { .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) } - /// Gets up to three servers that are likely to be in the room in the + /// Gets up to five servers that are likely to be in the room in the /// distant future. /// - /// See + /// See #[tracing::instrument(skip(self))] pub async fn servers_route_via(&self, room_id: &RoomId) -> Result> { let most_powerful_user_server = self @@ -613,23 +705,23 @@ impl Service { .max_by_key(|(_, power)| *power) .and_then(|x| (x.1 >= &int!(50)).then_some(x)) .map(|(user, _power)| user.server_name().to_owned()) - }) - .map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?; + }); let mut servers: Vec = self .room_members(room_id) .counts_by(|user| user.server_name().to_owned()) .await .into_iter() + .filter(|(server, _)| !server.is_ip_literal()) .sorted_by_key(|(_, users)| *users) .map(|(server, _)| server) .rev() - .take(3) + .take(5) .collect(); - if let Some(server) = most_powerful_user_server { + if let Ok(Some(server)) = most_powerful_user_server { servers.insert(0, server); - servers.truncate(3); + servers.truncate(5); } Ok(servers) @@ -730,6 +822,9 @@ impl Service { self.db.userroomid_leftstate.remove(&userroom_id); self.db.roomuserid_leftcount.remove(&roomuser_id); + self.db.userroomid_knockedstate.remove(&userroom_id); + self.db.roomuserid_knockedcount.remove(&roomuser_id); + if let Some(servers) = invite_via.filter(is_not_empty!()) { self.add_servers_invite_via(room_id, servers).await; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index a3fc6a0b..8a675913 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -511,13 +511,16 @@ impl Service { UserId::parse(state_key.clone()).expect("This state_key was previously validated"); let content: RoomMemberEventContent = pdu.get_content()?; - let invite_state = match content.membership { - MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(), + let stripped_state = match content.membership { + MembershipState::Invite | MembershipState::Knock => { + self.services.state.summary_stripped(pdu).await.into() + }, _ => None, }; - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth + // Update our membership info, we do this here incase a user is invited or + // knocked and immediately leaves we need the DB to record the invite or + // knock event for auth self.services .state_cache .update_membership( @@ -525,7 +528,7 @@ impl Service { &target_user_id, content, &pdu.sender, - invite_state, + stripped_state, None, true, )