diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 2e22d688..9bdd38fc 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -26,7 +26,7 @@ pub(super) async fn globals(subcommand: Globals) -> Result { let timer = tokio::time::Instant::now(); - let results = services().globals.db.last_check_for_updates_id(); + let results = services().updates.last_check_for_updates_id(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 254a3d9c..281c2a94 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -14,9 +14,6 @@ use ruma::{ use crate::services; -const COUNTER: &[u8] = b"c"; -const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; - pub struct Data { global: Arc, todeviceid_events: Arc, @@ -35,6 +32,8 @@ pub struct Data { counter: RwLock, } +const COUNTER: &[u8] = b"c"; + impl Data { pub(super) fn new(db: &Arc) -> Self { Self { @@ -93,23 +92,6 @@ impl Data { .map_or(Ok(0_u64), utils::u64_from_bytes) } - pub fn last_check_for_updates_id(&self) -> Result { - self.global - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) - } - - #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) - } - #[tracing::instrument(skip(self), level = "debug")] pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 16830d87..0a0d0d8e 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,8 @@ mod client; mod data; -pub(super) mod emerg_access; +mod emerg_access; pub(super) mod migrations; pub(crate) mod resolver; -pub(super) mod updates; use std::{ collections::{BTreeMap, HashMap}, @@ -12,6 +11,7 @@ use std::{ time::Instant, }; +use async_trait::async_trait; use conduit::{error, trace, Config, Result}; use data::Data; use ipaddress::IPAddress; @@ -22,7 +22,7 @@ use ruma::{ DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, }; -use tokio::{sync::Mutex, task::JoinHandle}; +use tokio::sync::Mutex; use url::Url; use crate::services; @@ -41,7 +41,6 @@ pub struct Service { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub bad_query_ratelimiter: Arc>>, - pub updates_handle: Mutex>>, pub stateres_mutex: Arc>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, @@ -49,6 +48,7 @@ pub struct Service { type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; @@ -103,7 +103,6 @@ impl crate::Service for Service { bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - updates_handle: Mutex::new(None), stateres_mutex: Arc::new(Mutex::new(())), admin_alias: RoomAliasId::parse(format!("#admins:{}", &config.server_name)) .expect("#admins:server_name is valid alias name"), @@ -122,6 +121,12 @@ impl crate::Service for Service { Ok(Arc::new(s)) } + async fn worker(self: Arc) -> Result<()> { + emerg_access::init_emergency_access(); + + Ok(()) + } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { self.resolver.memory_usage(out)?; @@ -181,12 +186,6 @@ impl Service { #[inline] pub fn current_count(&self) -> Result { Ok(self.db.current_count()) } - #[tracing::instrument(skip(self), level = "debug")] - pub fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { self.db.watch(user_id, device_id).await } diff --git a/src/service/globals/updates.rs b/src/service/globals/updates.rs deleted file mode 100644 index c6ac9fff..00000000 --- a/src/service/globals/updates.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::time::Duration; - -use ruma::events::room::message::RoomMessageEventContent; -use serde::Deserialize; -use tokio::{task::JoinHandle, time::interval}; -use tracing::{error, warn}; - -use crate::{ - conduit::{Error, Result}, - services, -}; - -const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; -const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours - -#[derive(Deserialize)] -struct CheckForUpdatesResponseEntry { - id: u64, - date: String, - message: String, -} -#[derive(Deserialize)] -struct CheckForUpdatesResponse { - updates: Vec, -} - -#[tracing::instrument] -pub fn start_check_for_updates_task() -> JoinHandle<()> { - let timer_interval = Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL); - - services().server.runtime().spawn(async move { - let mut i = interval(timer_interval); - - loop { - i.tick().await; - - if let Err(e) = try_handle_updates().await { - warn!(%e, "Failed to check for updates"); - } - } - }) -} - -#[tracing::instrument(skip_all)] -async fn try_handle_updates() -> Result<()> { - let response = services() - .globals - .client - .default - .get(CHECK_FOR_UPDATES_URL) - .send() - .await?; - - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| Error::Err(format!("Bad check for updates response: {e}")))?; - - let mut last_update_id = services().globals.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > services().globals.last_check_for_updates_id()? { - error!("{}", update.message); - services() - .admin - .send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))) - .await; - } - } - services() - .globals - .update_check_for_updates_id(last_update_id)?; - - Ok(()) -} diff --git a/src/service/mod.rs b/src/service/mod.rs index ba68fae2..6f2f4ee5 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -16,6 +16,7 @@ pub mod rooms; pub mod sending; pub mod transaction_ids; pub mod uiaa; +pub mod updates; pub mod users; extern crate conduit_core as conduit; diff --git a/src/service/services.rs b/src/service/services.rs index 88b58299..cc9ec290 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -9,7 +9,7 @@ use crate::{ manager::Manager, media, presence, pusher, rooms, sending, service::{Args, Map, Service}, - transaction_ids, uiaa, users, + transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -22,10 +22,11 @@ pub struct Services { pub account_data: Arc, pub presence: Arc, pub admin: Arc, - pub globals: Arc, pub key_backups: Arc, pub media: Arc, pub sending: Arc, + pub updates: Arc, + pub globals: Arc, manager: Mutex>>, pub(crate) service: Map, @@ -82,6 +83,7 @@ impl Services { key_backups: build!(key_backups::Service), media: build!(media::Service), sending: build!(sending::Service), + updates: build!(updates::Service), globals: build!(globals::Service), manager: Mutex::new(None), service, @@ -93,9 +95,7 @@ impl Services { pub(super) async fn start(&self) -> Result<()> { debug_info!("Starting services..."); - globals::migrations::migrations(&self.db, &self.globals.config).await?; - globals::emerg_access::init_emergency_access(); - + globals::migrations::migrations(&self.db, &self.server.config).await?; self.manager .lock() .await @@ -104,25 +104,14 @@ impl Services { .start() .await?; - if self.globals.allow_check_for_updates() { - let handle = globals::updates::start_check_for_updates_task(); - _ = self.globals.updates_handle.lock().await.insert(handle); - } - debug_info!("Services startup complete."); Ok(()) } pub(super) async fn stop(&self) { info!("Shutting down services..."); + self.interrupt(); - - debug!("Waiting for update worker..."); - if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { - updates_handle.abort(); - _ = updates_handle.await; - } - if let Some(manager) = self.manager.lock().await.as_ref() { manager.stop().await; } @@ -173,6 +162,7 @@ impl Services { fn interrupt(&self) { debug!("Interrupting services..."); + for (name, service) in &self.service { trace!("Interrupting {name}"); service.interrupt(); diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs new file mode 100644 index 00000000..a7088aba --- /dev/null +++ b/src/service/updates/mod.rs @@ -0,0 +1,112 @@ +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use conduit::{err, info, utils, warn, Error, Result}; +use database::Map; +use ruma::events::room::message::RoomMessageEventContent; +use serde::Deserialize; +use tokio::{sync::Notify, time::interval}; + +use crate::services; + +pub struct Service { + db: Arc, + interrupt: Notify, + interval: Duration, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponse { + updates: Vec, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, +} + +const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; +const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours +const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: args.db["global"].clone(), + interrupt: Notify::new(), + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + })) + } + + async fn worker(self: Arc) -> Result<()> { + let mut i = interval(self.interval); + loop { + tokio::select! { + () = self.interrupt.notified() => return Ok(()), + _ = i.tick() => (), + } + + if let Err(e) = self.handle_updates().await { + warn!(%e, "Failed to check for updates"); + } + } + } + + fn interrupt(&self) { self.interrupt.notify_waiters(); } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip_all)] + async fn handle_updates(&self) -> Result<()> { + let response = services() + .globals + .client + .default + .get(CHECK_FOR_UPDATES_URL) + .send() + .await?; + + let response = serde_json::from_str::(&response.text().await?) + .map_err(|e| Error::Err(format!("Bad check for updates response: {e}")))?; + + let mut last_update_id = self.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > self.last_check_for_updates_id()? { + info!("{:#}", update.message); + services() + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await; + } + } + self.update_check_for_updates_id(last_update_id)?; + + Ok(()) + } + + #[inline] + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + self.db + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + + Ok(()) + } + + pub fn last_check_for_updates_id(&self) -> Result { + self.db + .get(LAST_CHECK_FOR_UPDATES_COUNT)? + .map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) + } +}