diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 5c56faa0..2ac4c723 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -7,9 +7,8 @@ use std::{ use axum_client_ip::InsecureClientIp; use conduit::{ - debug, debug_warn, error, info, trace, utils, - utils::{math::continue_exponential_backoff_secs, mutex_map}, - warn, Error, PduEvent, Result, + debug, debug_warn, error, info, trace, utils, utils::math::continue_exponential_backoff_secs, warn, Error, + PduEvent, Result, }; use ruma::{ api::{ @@ -36,13 +35,14 @@ use ruma::{ OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::sending::convert_to_outgoing_federation_event; use tokio::sync::RwLock; use crate::{ client::{update_avatar_url, update_displayname}, service::{ + globals::RoomMutexGuard, pdu::{gen_event_id_canonical_json, PduBuilder}, + sending::convert_to_outgoing_federation_event, server_is_ours, user_is_local, }, services, Ruma, @@ -682,7 +682,7 @@ pub async fn join_room_by_id_helper( #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result { info!("Joining {room_id} over federation."); @@ -1018,7 +1018,7 @@ async fn join_room_by_id_helper_remote( #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result { debug!("We can join locally"); diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 2b79c3c4..4e5ba480 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -20,7 +20,7 @@ pub use debug::slice_truncated as debug_slice_truncated; pub use hash::calculate_hash; pub use html::Escape as HtmlEscape; pub use json::{deserialize_from_str, to_canonical_object}; -pub use mutex_map::MutexMap; +pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; pub use rand::string as random_string; pub use string::{str_from_bytes, string_from_bytes}; pub use sys::available_parallelism; diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index f102487c..c3c51798 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,20 +1,22 @@ -use std::{hash::Hash, sync::Arc}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; -type Value = tokio::sync::Mutex; -type ArcMutex = Arc>; -type HashMap = std::collections::HashMap>; -type MapMutex = std::sync::Mutex>; -type Map = MapMutex; +use tokio::sync::OwnedMutexGuard as Omg; /// Map of Mutexes pub struct MutexMap { map: Map, } -pub struct Guard { - _guard: tokio::sync::OwnedMutexGuard, +pub struct Guard { + map: Map, + val: Omg, } +type Map = Arc>; +type MapMutex = std::sync::Mutex>; +type HashMap = std::collections::HashMap>; +type Value = Arc>; + impl MutexMap where Key: Send + Hash + Eq + Clone, @@ -23,28 +25,38 @@ where #[must_use] pub fn new() -> Self { Self { - map: Map::::new(HashMap::::new()), + map: Map::new(MapMutex::new(HashMap::new())), } } - pub async fn lock(&self, k: &K) -> Guard + #[tracing::instrument(skip(self), level = "debug")] + pub async fn lock(&self, k: &K) -> Guard where - K: ?Sized + Send + Sync, + K: ?Sized + Send + Sync + Debug, Key: for<'a> From<&'a K>, { let val = self .map .lock() - .expect("map mutex locked") + .expect("locked") .entry(k.into()) .or_default() .clone(); - let guard = val.lock_owned().await; - Guard:: { - _guard: guard, + Guard:: { + map: Arc::clone(&self.map), + val: val.lock_owned().await, } } + + #[must_use] + pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } + + #[must_use] + pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } + + #[must_use] + pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap @@ -54,3 +66,14 @@ where { fn default() -> Self { Self::new() } } + +impl Drop for Guard { + fn drop(&mut self) { + if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { + self.map + .lock() + .expect("locked") + .retain(|_, val| !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2); + } + } +} diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index add15861..43968947 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -81,3 +81,56 @@ fn checked_add_overflow() { let res = checked!(a + 1).expect("overflow"); assert_eq!(res, 0); } + +#[tokio::test] +async fn mutex_map_cleanup() { + use crate::utils::MutexMap; + + let map = MutexMap::::new(); + + let lock = map.lock("foo").await; + assert!(!map.is_empty(), "map must not be empty"); + + drop(lock); + assert!(map.is_empty(), "map must be empty"); +} + +#[tokio::test] +async fn mutex_map_contend() { + use std::sync::Arc; + + use tokio::sync::Barrier; + + use crate::utils::MutexMap; + + let map = Arc::new(MutexMap::::new()); + let seq = Arc::new([Barrier::new(2), Barrier::new(2)]); + let str = "foo".to_owned(); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_a = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "A0 must not be empty"); + seq_[0].wait().await; + assert!(map_.contains(&str_), "A1 must contain key"); + }); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_b = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "B0 must not be empty"); + seq_[1].wait().await; + assert!(map_.contains(&str_), "B1 must contain key"); + }); + + seq[0].wait().await; + assert!(map.contains(&str), "Must contain key"); + seq[1].wait().await; + + tokio::try_join!(join_b, join_a).expect("joined"); + assert!(map.is_empty(), "Must be empty"); +} diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index ca0e551b..e0dd1760 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -9,7 +9,7 @@ use std::{ }; use async_trait::async_trait; -use conduit::{error, utils::mutex_map, Error, Result}; +use conduit::{error, Error, Result}; pub use create::create_admin_room; pub use grant::make_user_admin; use loole::{Receiver, Sender}; @@ -26,7 +26,7 @@ use tokio::{ task::JoinHandle, }; -use crate::{pdu::PduBuilder, services, user_is_local, PduEvent}; +use crate::{globals::RoomMutexGuard, pdu::PduBuilder, services, user_is_local, PduEvent}; const COMMAND_QUEUE_LIMIT: usize = 512; @@ -270,7 +270,7 @@ async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, use } async fn handle_response_error( - e: &Error, room_id: &RoomId, user_id: &UserId, state_lock: &mutex_map::Guard<()>, + e: &Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, ) -> Result<()> { error!("Failed to build and append admin room response PDU: \"{e}\""); let error_room_message = RoomMessageEventContent::text_plain(format!( diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 11bfc88c..a5b70835 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -12,7 +12,11 @@ use std::{ time::Instant, }; -use conduit::{error, trace, utils::MutexMap, Config, Result}; +use conduit::{ + error, trace, + utils::{MutexMap, MutexMapGuard}, + Config, Result, +}; use data::Data; use ipaddress::IPAddress; use regex::RegexSet; @@ -27,8 +31,6 @@ use url::Url; use crate::services; -type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries - pub struct Service { pub db: Data, @@ -43,9 +45,9 @@ pub struct Service { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub bad_query_ratelimiter: Arc>>, - pub roomid_mutex_insert: MutexMap, - pub roomid_mutex_state: MutexMap, - pub roomid_mutex_federation: MutexMap, + pub roomid_mutex_insert: RoomMutexMap, + pub roomid_mutex_state: RoomMutexMap, + pub roomid_mutex_federation: RoomMutexMap, pub roomid_federationhandletime: RwLock>, pub updates_handle: Mutex>>, pub stateres_mutex: Arc>, @@ -53,6 +55,10 @@ pub struct Service { pub admin_alias: OwnedRoomAliasId, } +pub type RoomMutexMap = MutexMap; +pub type RoomMutexGuard = MutexMapGuard; +type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index aad3bede..b62adf60 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -3,7 +3,8 @@ use std::{collections::HashSet, sync::Arc}; use conduit::{utils, Error, Result}; use database::{Database, Map}; use ruma::{EventId, OwnedEventId, RoomId}; -use utils::mutex_map; + +use crate::globals::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc, @@ -35,7 +36,7 @@ impl Data { &self, room_id: &RoomId, new_shortstatehash: u64, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; @@ -68,7 +69,7 @@ impl Data { &self, room_id: &RoomId, event_ids: Vec, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 52ee89d1..b46a9d04 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -5,10 +5,7 @@ use std::{ sync::Arc, }; -use conduit::{ - utils::{calculate_hash, mutex_map}, - warn, Error, Result, -}; +use conduit::{utils::calculate_hash, warn, Error, Result}; use data::Data; use ruma::{ api::client::error::ErrorKind, @@ -22,7 +19,7 @@ use ruma::{ }; use super::state_compressor::CompressedStateEvent; -use crate::{services, PduEvent}; +use crate::{globals::RoomMutexGuard, services, PduEvent}; pub struct Service { db: Data, @@ -46,7 +43,7 @@ impl Service { shortstatehash: u64, statediffnew: Arc>, _statediffremoved: Arc>, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { services() @@ -318,7 +315,7 @@ impl Service { &self, room_id: &RoomId, shortstatehash: u64, - mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db.set_room_state(room_id, shortstatehash, mutex_lock) } @@ -358,7 +355,7 @@ impl Service { &self, room_id: &RoomId, event_ids: Vec, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db .set_forward_extremities(room_id, event_ids, state_lock) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index a3567857..35719c15 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,11 +6,7 @@ use std::{ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{ - error, - utils::{math::usize_from_f64, mutex_map}, - warn, Error, Result, -}; +use conduit::{error, utils::math::usize_from_f64, warn, Error, Result}; use data::Data; use lru_cache::LruCache; use ruma::{ @@ -37,7 +33,7 @@ use ruma::{ }; use serde_json::value::to_raw_value; -use crate::{pdu::PduBuilder, services, PduEvent}; +use crate::{globals::RoomMutexGuard, pdu::PduBuilder, services, PduEvent}; pub struct Service { db: Data, @@ -333,7 +329,7 @@ impl Service { } pub fn user_can_invite( - &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &mutex_map::Guard<()>, + &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, ) -> Result { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index c82098ba..9bfc2715 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -6,7 +6,7 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, info, utils, utils::mutex_map, validated, warn, Error, Result}; +use conduit::{debug, error, info, utils, validated, warn, Error, Result}; use data::Data; use itertools::Itertools; use ruma::{ @@ -36,6 +36,7 @@ use tokio::sync::RwLock; use crate::{ admin, appservice::NamespaceRegex, + globals::RoomMutexGuard, pdu::{EventHash, PduBuilder}, rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent}, server_is_ours, services, PduCount, PduEvent, @@ -203,7 +204,7 @@ impl Service { pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, leaves: Vec, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { // Coalesce database writes for the remainder of this scope. let _cork = services().db.cork_and_flush(); @@ -593,7 +594,7 @@ impl Service { pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<(PduEvent, CanonicalJsonObject)> { let PduBuilder { event_type, @@ -780,7 +781,7 @@ impl Service { pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; if let Some(admin_room) = admin::Service::get_admin_room()? { @@ -963,7 +964,7 @@ impl Service { new_room_leaves: Vec, state_ids_compressed: Arc>, soft_fail: bool, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't