From c111d2e39536672224e4a2efe19d674e5ac7c6a0 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 11 Jul 2024 21:00:30 +0000 Subject: [PATCH] abstract service worker pattern; restart on panic. Signed-off-by: Jason Volk --- src/router/run.rs | 6 +- src/service/admin/mod.rs | 64 +++-------- src/service/mod.rs | 6 +- src/service/presence/mod.rs | 57 +++------ src/service/sending/mod.rs | 50 +------- src/service/sending/sender.rs | 39 ++++--- src/service/service.rs | 17 +-- src/service/services.rs | 210 +++++++++++++++++++++++++--------- 8 files changed, 233 insertions(+), 216 deletions(-) diff --git a/src/router/run.rs b/src/router/run.rs index 3e09823a..02cec781 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -11,7 +11,6 @@ extern crate conduit_service as service; use std::sync::atomic::Ordering; use conduit::{debug_info, trace, Error, Result, Server}; -use service::services; use crate::{layers, serve}; @@ -50,7 +49,6 @@ pub(crate) async fn start(server: Arc) -> Result<(), Error> { debug!("Starting..."); service::init(&server).await?; - services().start().await?; #[cfg(feature = "systemd")] sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).expect("failed to notify systemd of ready state"); @@ -66,9 +64,7 @@ pub(crate) async fn stop(_server: Arc) -> Result<(), Error> { // Wait for all completions before dropping or we'll lose them to the module // unload and explode. - services().stop().await; - // Deactivate services(). Any further use will panic the caller. - service::fini(); + service::fini().await; debug!("Cleaning up..."); diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index e9729d2d..41019cd1 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -21,17 +21,13 @@ use ruma::{ OwnedEventId, OwnedRoomId, RoomId, UserId, }; use serde_json::value::to_raw_value; -use tokio::{ - sync::{Mutex, RwLock}, - task::JoinHandle, -}; +use tokio::sync::{Mutex, RwLock}; use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, user_is_local, PduEvent}; pub struct Service { sender: Sender, receiver: Mutex>, - handler_join: Mutex>>, pub handle: RwLock>, pub complete: StdRwLock>, #[cfg(feature = "console")] @@ -59,7 +55,6 @@ impl crate::Service for Service { Ok(Arc::new(Self { sender, receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), handle: RwLock::new(None), complete: StdRwLock::new(None), #[cfg(feature = "console")] @@ -67,16 +62,25 @@ impl crate::Service for Service { })) } - async fn start(self: Arc) -> Result<()> { - let self_ = Arc::clone(&self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to initialize admin room handler"); - }); + async fn worker(self: Arc) -> Result<()> { + let receiver = self.receiver.lock().await; + let mut signals = services().server.signal.subscribe(); + loop { + tokio::select! { + command = receiver.recv_async() => match command { + Ok(command) => self.handle_command(command).await, + Err(_) => break, + }, + sig = signals.recv() => match sig { + Ok(sig) => self.handle_signal(sig).await, + Err(_) => continue, + }, + } + } - _ = self.handler_join.lock().await.insert(handle); + //TODO: not unwind safe + #[cfg(feature = "console")] + self.console.close().await; Ok(()) } @@ -90,19 +94,6 @@ impl crate::Service for Service { } } - async fn stop(&self) { - self.interrupt(); - - #[cfg(feature = "console")] - self.console.close().await; - - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -149,23 +140,6 @@ impl Service { self.sender.send_async(message).await.expect("message sent"); } - async fn handler(self: &Arc) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut signals = services().server.signal.subscribe(); - loop { - tokio::select! { - command = receiver.recv_async() => match command { - Ok(command) => self.handle_command(command).await, - Err(_) => return Ok(()), - }, - sig = signals.recv() => match sig { - Ok(sig) => self.handle_signal(sig).await, - Err(_) => continue, - }, - } - } - } - async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { #[cfg(feature = "console")] self.console.handle_signal(sig).await; diff --git a/src/service/mod.rs b/src/service/mod.rs index 4b19073d..15c4cc35 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -42,10 +42,12 @@ pub async fn init(server: &Arc) -> Result<()> { let s = Box::new(Services::build(server.clone(), d)?); _ = SERVICES.write().expect("write locked").insert(Box::leak(s)); - Ok(()) + services().start().await } -pub fn fini() { +pub async fn fini() { + services().stop().await; + // Deactivate services(). Any further use will panic the caller. let s = SERVICES .write() diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index f5400379..254304ba 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -12,7 +12,7 @@ use ruma::{ OwnedUserId, UInt, UserId, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinHandle, time::sleep}; +use tokio::{sync::Mutex, time::sleep}; use crate::{services, user_is_local}; @@ -77,7 +77,6 @@ pub struct Service { pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex>, - handler_join: Mutex>>, timeout_remote_users: bool, idle_timeout: u64, offline_timeout: u64, @@ -94,34 +93,26 @@ impl crate::Service for Service { db: Data::new(args.db), timer_sender, timer_receiver: Mutex::new(timer_receiver), - handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, idle_timeout: checked!(idle_timeout_s * 1_000)?, offline_timeout: checked!(offline_timeout_s * 1_000)?, })) } - async fn start(self: Arc) -> Result<()> { - //TODO: if self.globals.config.allow_local_presence { return; } - - let self_ = Arc::clone(&self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start presence handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); - - Ok(()) - } - - async fn stop(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); + async fn worker(self: Arc) -> Result<()> { + let mut presence_timers = FuturesUnordered::new(); + let receiver = self.timer_receiver.lock().await; + loop { + debug_assert!(!receiver.is_closed(), "channel error"); + tokio::select! { + Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + event = receiver.recv_async() => match event { + Err(_e) => return Ok(()), + Ok((user_id, timeout)) => { + debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); + presence_timers.push(presence_timer(user_id, timeout)); + }, + }, } } } @@ -219,24 +210,6 @@ impl Service { self.db.presence_since(since) } - async fn handler(&self) -> Result<()> { - let mut presence_timers = FuturesUnordered::new(); - let receiver = self.timer_receiver.lock().await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); - tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, - event = receiver.recv_async() => match event { - Err(_e) => return Ok(()), - Ok((user_id, timeout)) => { - debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); - presence_timers.push(presence_timer(user_id, timeout)); - }, - }, - } - } - } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index d7a9c0fc..eb708fcf 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -4,29 +4,26 @@ mod resolve; mod send; mod sender; -use std::{fmt::Debug, sync::Arc}; +use std::fmt::Debug; -use async_trait::async_trait; use conduit::{Error, Result}; -use data::Data; pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; pub use sender::convert_to_outgoing_federation_event; -use tokio::{sync::Mutex, task::JoinHandle}; -use tracing::{error, warn}; +use tokio::sync::Mutex; +use tracing::warn; use crate::{server_is_ours, services}; pub struct Service { - pub db: Data, + pub db: data::Data, /// The state for a given state hash. sender: loole::Sender, receiver: Mutex>, - handler_join: Mutex>>, startup_netburst: bool, startup_netburst_keep: i64, } @@ -53,45 +50,6 @@ pub enum SendingEvent { Flush, // none } -#[async_trait] -impl crate::Service for Service { - fn build(args: crate::Args<'_>) -> Result> { - let config = &args.server.config; - let (sender, receiver) = loole::unbounded(); - Ok(Arc::new(Self { - db: Data::new(args.db.clone()), - sender, - receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), - startup_netburst: config.startup_netburst, - startup_netburst_keep: config.startup_netburst_keep, - })) - } - - async fn start(self: Arc) -> Result<()> { - self.start_handler().await; - - Ok(()) - } - - async fn stop(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } - - fn interrupt(&self) { - if !self.sender.is_closed() { - self.sender.close(); - } - } - - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } -} - impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 0fb0d9dc..cfd5b4bc 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -6,6 +6,7 @@ use std::{ time::Instant, }; +use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; use conduit::{debug, error, utils::math::continue_exponential_backoff_secs, warn}; use federation::transactions::send_transaction_message; @@ -23,8 +24,9 @@ use ruma::{ ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::sync::Mutex; -use super::{appservice, send, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::Data, send, Destination, Msg, SendingEvent, Service}; use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; #[derive(Debug)] @@ -43,21 +45,22 @@ type CurTransactionStatus = HashMap; const DEQUEUE_LIMIT: usize = 48; const SELECT_EDU_LIMIT: usize = 16; -impl Service { - pub async fn start_handler(self: &Arc) { - let self_ = Arc::clone(self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start sending handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let (sender, receiver) = loole::unbounded(); + Ok(Arc::new(Self { + db: Data::new(args.db.clone()), + sender, + receiver: Mutex::new(receiver), + startup_netburst: config.startup_netburst, + startup_netburst_keep: config.startup_netburst_keep, + })) } #[tracing::instrument(skip_all, name = "sender")] - async fn handler(&self) -> Result<()> { + async fn worker(self: Arc) -> Result<()> { let receiver = self.receiver.lock().await; let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); @@ -77,6 +80,16 @@ impl Service { } } + fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { fn handle_response( &self, response: SendingResult, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { diff --git a/src/service/service.rs b/src/service/service.rs index ef60f359..3b8f4231 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -15,19 +15,12 @@ pub(crate) trait Service: Send + Sync { where Self: Sized; - /// Start the service. Implement the spawning of any service workers. This - /// is called after all other services have been constructed. Failure will - /// shutdown the server with an error. - async fn start(self: Arc) -> Result<()> { Ok(()) } + /// Implement the service's worker loop. The service manager spawns a + /// task and calls this function after all services have been built. + async fn worker(self: Arc) -> Result<()> { Ok(()) } - /// Stop the service. Implement the joining of any service workers and - /// cleanup of any other state. This function is asynchronous to allow that - /// gracefully, but errors cannot propagate. - async fn stop(&self) {} - - /// Interrupt the service. This may be sent prior to `stop()` as a - /// notification to improve the shutdown sequence. Implementations must be - /// robust to this being called multiple times. + /// Interrupt the service. This is sent to initiate a graceful shutdown. + /// The service worker should return from its work loop. fn interrupt(&self) {} /// Clear any caches or similar runtime state. diff --git a/src/service/services.rs b/src/service/services.rs index aeed8204..13689008 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,7 +1,13 @@ -use std::{collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{collections::BTreeMap, fmt::Write, panic::AssertUnwindSafe, sync::Arc, time::Duration}; -use conduit::{debug, debug_info, info, trace, Result, Server}; +use conduit::{debug, debug_info, error, info, trace, utils::time, warn, Error, Result, Server}; use database::Database; +use futures_util::FutureExt; +use tokio::{ + sync::{Mutex, MutexGuard}, + task::{JoinHandle, JoinSet}, + time::sleep, +}; use crate::{ account_data, admin, appservice, globals, key_backups, media, presence, pusher, rooms, sending, @@ -24,11 +30,19 @@ pub struct Services { pub media: Arc, pub sending: Arc, + workers: Mutex, + manager: Mutex>>>, pub(crate) service: Map, pub server: Arc, pub db: Arc, } +type Workers = JoinSet; +type WorkerResult = (Arc, Result<()>); +type WorkersLocked<'a> = MutexGuard<'a, Workers>; + +const RESTART_DELAY_MS: u64 = 2500; + impl Services { pub fn build(server: Arc, db: Arc) -> Result { let mut service: Map = BTreeMap::new(); @@ -79,12 +93,74 @@ impl Services { media: build!(media::Service), sending: build!(sending::Service), globals: build!(globals::Service), + workers: Mutex::new(JoinSet::new()), + manager: Mutex::new(None), service, server, db, }) } + pub async fn start(&self) -> Result<()> { + debug_info!("Starting services..."); + + self.media.create_media_dir().await?; + globals::migrations::migrations(&self.db, &self.globals.config).await?; + globals::emerg_access::init_emergency_access(); + + let mut workers = self.workers.lock().await; + for service in self.service.values() { + self.start_worker(&mut workers, service).await?; + } + + debug!("Starting service manager..."); + let manager = async move { crate::services().manager().await }; + let manager = self.server.runtime().spawn(manager); + _ = self.manager.lock().await.insert(manager); + + if self.globals.allow_check_for_updates() { + let handle = globals::updates::start_check_for_updates_task(); + _ = self.globals.updates_handle.lock().await.insert(handle); + } + + debug_info!("Services startup complete."); + Ok(()) + } + + pub async fn stop(&self) { + info!("Shutting down services..."); + self.interrupt(); + + debug!("Waiting for update worker..."); + if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { + updates_handle.abort(); + _ = updates_handle.await; + } + + debug!("Stopping service manager..."); + if let Some(manager) = self.manager.lock().await.take() { + if let Err(e) = manager.await { + error!("Manager shutdown error: {e:?}"); + } + } + + debug_info!("Services shutdown complete."); + } + + pub async fn clear_cache(&self) { + for service in self.service.values() { + service.clear_cache(); + } + + //TODO + self.rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .clear(); + } + pub async fn memory_usage(&self) -> Result { let mut out = String::new(); for service in self.service.values() { @@ -104,65 +180,97 @@ impl Services { Ok(out) } - pub async fn clear_cache(&self) { - for service in self.service.values() { - service.clear_cache(); - } - - //TODO - self.rooms - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .clear(); - } - - pub async fn start(&self) -> Result<()> { - debug_info!("Starting services"); - - self.media.create_media_dir().await?; - globals::migrations::migrations(&self.db, &self.globals.config).await?; - globals::emerg_access::init_emergency_access(); - - for (name, service) in &self.service { - debug!("Starting {name}"); - service.clone().start().await?; - } - - if self.globals.allow_check_for_updates() { - let handle = globals::updates::start_check_for_updates_task(); - _ = self.globals.updates_handle.lock().await.insert(handle); - } - - debug_info!("Services startup complete."); - Ok(()) - } - - pub fn interrupt(&self) { + fn interrupt(&self) { debug!("Interrupting services..."); - for (name, service) in &self.service { trace!("Interrupting {name}"); service.interrupt(); } } - pub async fn stop(&self) { - info!("Shutting down services"); - self.interrupt(); - - debug!("Waiting for update worker..."); - if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { - updates_handle.abort(); - _ = updates_handle.await; + async fn manager(&self) -> Result<()> { + loop { + let mut workers = self.workers.lock().await; + tokio::select! { + result = workers.join_next() => match result { + Some(Ok(result)) => self.handle_result(&mut workers, result).await?, + Some(Err(error)) => self.handle_abort(&mut workers, Error::from(error)).await?, + None => break, + } + } } - for (name, service) in &self.service { - debug!("Waiting for {name} ..."); - service.stop().await; + debug!("Worker manager finished"); + Ok(()) + } + + async fn handle_abort(&self, _workers: &mut WorkersLocked<'_>, error: Error) -> Result<()> { + // not supported until service can be associated with abort + unimplemented!("unexpected worker task abort {error:?}"); + } + + async fn handle_result(&self, workers: &mut WorkersLocked<'_>, result: WorkerResult) -> Result<()> { + let (service, result) = result; + match result { + Ok(()) => self.handle_finished(workers, &service).await, + Err(error) => self.handle_error(workers, &service, error).await, + } + } + + async fn handle_finished(&self, _workers: &mut WorkersLocked<'_>, service: &Arc) -> Result<()> { + debug!("service {:?} worker finished", service.name()); + Ok(()) + } + + async fn handle_error( + &self, workers: &mut WorkersLocked<'_>, service: &Arc, error: Error, + ) -> Result<()> { + let name = service.name(); + error!("service {name:?} worker error: {error}"); + + if !error.is_panic() { + return Ok(()); } - debug_info!("Services shutdown complete."); + if !self.server.running() { + return Ok(()); + } + + let delay = Duration::from_millis(RESTART_DELAY_MS); + warn!("service {name:?} worker restarting after {} delay", time::pretty(delay)); + sleep(delay).await; + + self.start_worker(workers, service).await + } + + /// Start the worker in a task for the service. + async fn start_worker(&self, workers: &mut WorkersLocked<'_>, service: &Arc) -> Result<()> { + if !self.server.running() { + return Err(Error::Err(format!( + "Service {:?} worker not starting during server shutdown.", + service.name() + ))); + } + + debug!("Service {:?} worker starting...", service.name()); + workers.spawn_on(worker(service.clone()), self.server.runtime()); + + Ok(()) } } + +/// Base frame for service worker. This runs in a tokio::task. All errors and +/// panics from the worker are caught and returned cleanly. The JoinHandle +/// should never error with a panic, and if so it should propagate, but it may +/// error with an Abort which the manager should handle along with results to +/// determine if the worker should be restarted. +async fn worker(service: Arc) -> WorkerResult { + let service_ = Arc::clone(&service); + let result = AssertUnwindSafe(service_.worker()) + .catch_unwind() + .await + .map_err(Error::from_panic); + + // flattens JoinError for panic into worker's Error + (service, result.unwrap_or_else(Err)) +}