From 0857fe7907e3d70575e8e2bf4fd43759e831adf4 Mon Sep 17 00:00:00 2001 From: strawberry Date: Wed, 5 Jun 2024 01:52:53 -0400 Subject: [PATCH] abstract+add more "users in room" accessors, check membership state on `active_local_joined_users_in_room` `roomuserid_joined` cf seems unreliable, so in the mean time we need to check membership state (or maybe this is a more reliable check anyways) Signed-off-by: strawberry --- src/service/rooms/state_cache/data.rs | 66 +++++++++++++++++++++------ src/service/rooms/state_cache/mod.rs | 23 ++++++++-- src/service/rooms/timeline/mod.rs | 4 +- 3 files changed, 76 insertions(+), 17 deletions(-) diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 5694b22c..6f491d89 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use itertools::Itertools; use ruma::{ - events::{AnyStrippedStateEvent, AnySyncStateEvent}, + events::{room::member::MembershipState, AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; @@ -43,12 +43,23 @@ pub trait Data: Send + Sync { /// know). fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a>; - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - /// Returns a vec of all the users joined in a room who are active - /// (not guests, not deactivated users) - fn active_local_users_in_room(&self, room_id: &RoomId) -> Vec; + /// Returns an iterator of all our local users + /// in the room, even if they're deactivated/guests + fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a>; + + /// Returns an iterator of all our local users in a room who are active (not + /// deactivated, not guest) + fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a>; + + /// Returns an iterator of all our local users joined in a room who are + /// active (not deactivated, not guest) and have a joined membership state + /// in the room + fn active_local_joined_users_in_room<'a>( + &'a self, room_id: &'a RoomId, + ) -> Box + 'a>; fn room_joined_count(&self, room_id: &RoomId) -> Result>; @@ -381,7 +392,7 @@ impl Data for KeyValueDatabase { })) } - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self))] fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); @@ -400,14 +411,43 @@ impl Data for KeyValueDatabase { })) } - /// Returns a vec of all our local users joined in a room who are active - /// (not guests / not deactivated users) + /// Returns an iterator of all our local users in the room, even if they're + /// deactivated/guests + fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { + Box::new( + self.room_members(room_id) + .filter_map(Result::ok) + .filter(|user| user_is_local(user)), + ) + } + + /// Returns an iterator of all our local users in a room who are active (not + /// deactivated, not guest) #[tracing::instrument(skip(self))] - fn active_local_users_in_room(&self, room_id: &RoomId) -> Vec { - self.room_members(room_id) - .filter_map(Result::ok) - .filter(|user| user_is_local(user) && !services().users.is_deactivated(user).unwrap_or(true)) - .collect_vec() + fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { + Box::new( + self.local_users_in_room(room_id) + .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), + ) + } + + /// Returns an iterator of all our local users joined in a room who are + /// active (not deactivated, not guest) and have a joined membership state + /// in the room + /// + /// TODO: why is `roomuserid_joined` not reliable? + #[tracing::instrument(skip(self))] + fn active_local_joined_users_in_room<'a>( + &'a self, room_id: &'a RoomId, + ) -> Box + 'a> { + Box::new(self.active_local_users_in_room(room_id).filter(|user_id| { + services() + .rooms + .state_accessor + .get_member(room_id, user_id) + .unwrap_or(None) + .map_or(false, |membership| membership.membership == MembershipState::Join) + })) } /// Returns the number of users which are currently in a room diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index ab59c0ad..3f47ff4c 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -274,12 +274,29 @@ impl Service { pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } #[tracing::instrument(skip(self))] - /// Returns a vec of all the users joined in a room who are active - /// (not guests, not deactivated users) - pub fn active_local_users_in_room(&self, room_id: &RoomId) -> Vec { + /// Returns an iterator of all our local users in the room, even if they're + /// deactivated/guests + pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { + self.db.local_users_in_room(room_id) + } + + #[tracing::instrument(skip(self))] + /// Returns an iterator of all our local users in a room who are active (not + /// deactivated, not guest) + pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { self.db.active_local_users_in_room(room_id) } + #[tracing::instrument(skip(self))] + /// Returns an iterator of all our local users joined in a room who are + /// active (not deactivated, not guest) and have a joined membership state + /// in the room + pub fn active_local_joined_users_in_room<'a>( + &'a self, room_id: &'a RoomId, + ) -> impl Iterator + 'a { + self.db.active_local_joined_users_in_room(room_id) + } + #[tracing::instrument(skip(self))] pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index a9482853..34a60507 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -6,6 +6,7 @@ use std::{ }; use data::Data; +use itertools::Itertools; use rand::prelude::SliceRandom; use ruma::{ api::{client::error::ErrorKind, federation}, @@ -309,7 +310,8 @@ impl Service { let mut push_target = services() .rooms .state_cache - .active_local_users_in_room(&pdu.room_id); + .active_local_joined_users_in_room(&pdu.room_id) + .collect_vec(); if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key {