diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 4b97ef4e..02e89e7a 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -26,11 +26,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> appservice_id, } => { let timer = tokio::time::Instant::now(); - let results = services - .appservice - .db - .get_registration(appservice_id.as_ref()) - .await; + let results = services.appservice.get_registration(&appservice_id).await; let query_time = timer.elapsed(); diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs deleted file mode 100644 index 8fb7d958..00000000 --- a/src/service/appservice/data.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::sync::Arc; - -use conduit::{err, utils::stream::TryIgnore, Result}; -use database::{Database, Map}; -use futures::Stream; -use ruma::api::appservice::Registration; - -pub struct Data { - id_appserviceregistrations: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - id_appserviceregistrations: db["id_appserviceregistrations"].clone(), - } - } - - /// Registers an appservice and returns the ID to the caller - pub(super) fn register_appservice(&self, yaml: &Registration) -> Result { - let id = yaml.id.as_str(); - self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes()); - - Ok(id.to_owned()) - } - - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes()); - Ok(()) - } - - pub async fn get_registration(&self, id: &str) -> Result { - self.id_appserviceregistrations - .get(id) - .await - .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) - .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) - } - - pub(super) fn iter_ids(&self) -> impl Stream + Send + '_ { - self.id_appserviceregistrations.keys().ignore_err() - } -} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 7e2dc738..1617e6e6 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,147 +1,49 @@ -mod data; +mod namespace_regex; +mod registration_info; use std::{collections::BTreeMap, sync::Arc}; use async_trait::async_trait; -use conduit::{err, Result}; -use data::Data; +use conduit::{err, utils::stream::TryIgnore, Result}; +use database::Map; use futures::{Future, StreamExt, TryStreamExt}; -use regex::RegexSet; -use ruma::{ - api::appservice::{Namespace, Registration}, - RoomAliasId, RoomId, UserId, -}; +use ruma::{api::appservice::Registration, RoomAliasId, RoomId, UserId}; use tokio::sync::RwLock; +pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; use crate::{sending, Dep}; -/// Compiled regular expressions for a namespace -#[derive(Clone, Debug)] -pub struct NamespaceRegex { - pub exclusive: Option, - pub non_exclusive: Option, -} - -impl NamespaceRegex { - /// Checks if this namespace has rights to a namespace - #[inline] - #[must_use] - pub fn is_match(&self, heystack: &str) -> bool { - if self.is_exclusive_match(heystack) { - return true; - } - - if let Some(non_exclusive) = &self.non_exclusive { - if non_exclusive.is_match(heystack) { - return true; - } - } - false - } - - /// Checks if this namespace has exlusive rights to a namespace - #[inline] - #[must_use] - pub fn is_exclusive_match(&self, heystack: &str) -> bool { - if let Some(exclusive) = &self.exclusive { - if exclusive.is_match(heystack) { - return true; - } - } - false - } -} - -impl RegistrationInfo { - #[must_use] - pub fn is_user_match(&self, user_id: &UserId) -> bool { - self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } - - #[inline] - #[must_use] - pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { - self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } -} - -impl TryFrom> for NamespaceRegex { - type Error = regex::Error; - - fn try_from(value: Vec) -> Result { - let mut exclusive = Vec::with_capacity(value.len()); - let mut non_exclusive = Vec::with_capacity(value.len()); - - for namespace in value { - if namespace.exclusive { - exclusive.push(namespace.regex); - } else { - non_exclusive.push(namespace.regex); - } - } - - Ok(Self { - exclusive: if exclusive.is_empty() { - None - } else { - Some(RegexSet::new(exclusive)?) - }, - non_exclusive: if non_exclusive.is_empty() { - None - } else { - Some(RegexSet::new(non_exclusive)?) - }, - }) - } -} - -/// Appservice registration combined with its compiled regular expressions. -#[derive(Clone, Debug)] -pub struct RegistrationInfo { - pub registration: Registration, - pub users: NamespaceRegex, - pub aliases: NamespaceRegex, - pub rooms: NamespaceRegex, -} - -impl TryFrom for RegistrationInfo { - type Error = regex::Error; - - fn try_from(value: Registration) -> Result { - Ok(Self { - users: value.namespaces.users.clone().try_into()?, - aliases: value.namespaces.aliases.clone().try_into()?, - rooms: value.namespaces.rooms.clone().try_into()?, - registration: value, - }) - } -} - pub struct Service { - pub db: Data, - services: Services, registration_info: RwLock>, + services: Services, + db: Data, } struct Services { sending: Dep, } +struct Data { + id_appserviceregistrations: Arc, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - db: Data::new(args.db), + registration_info: RwLock::new(BTreeMap::new()), services: Services { sending: args.depend::("sending"), }, - registration_info: RwLock::new(BTreeMap::new()), + db: Data { + id_appserviceregistrations: args.db["id_appserviceregistrations"].clone(), + }, })) } async fn worker(self: Arc) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&self.db).await? { + for appservice in self.iter_db_ids().await? { self.registration_info.write().await.insert( appservice.0, appservice @@ -158,9 +60,6 @@ impl crate::Service for Service { } impl Service { - #[inline] - pub async fn all(&self) -> Result> { iter_ids(&self.db).await } - /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { //TODO: Check for collisions between exclusive appservice namespaces @@ -169,7 +68,11 @@ impl Service { .await .insert(yaml.id.clone(), yaml.clone().try_into()?); - self.db.register_appservice(&yaml) + let id = yaml.id.as_str(); + let yaml = serde_yaml::to_string(&yaml)?; + self.db.id_appserviceregistrations.insert(id, yaml); + + Ok(id.to_owned()) } /// Remove an appservice registration @@ -186,7 +89,7 @@ impl Service { .ok_or(err!("Appservice not found"))?; // remove the appservice from the database - self.db.unregister_appservice(service_name)?; + self.db.id_appserviceregistrations.remove(service_name); // deletes all active requests for the appservice if there are any so we stop // sending to the URL @@ -254,11 +157,29 @@ impl Service { pub fn read(&self) -> impl Future>> { self.registration_info.read() } -} -async fn iter_ids(db: &Data) -> Result> { - db.iter_ids() - .then(|id| async move { Ok((id.clone(), db.get_registration(&id).await?)) }) - .try_collect() - .await + #[inline] + pub async fn all(&self) -> Result> { self.iter_db_ids().await } + + pub async fn get_db_registration(&self, id: &str) -> Result { + self.db + .id_appserviceregistrations + .get(id) + .await + .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) + } + + async fn iter_db_ids(&self) -> Result> { + self.db + .id_appserviceregistrations + .keys() + .ignore_err() + .then(|id: String| async move { + let reg = self.get_db_registration(&id).await?; + Ok((id, reg)) + }) + .try_collect() + .await + } } diff --git a/src/service/appservice/namespace_regex.rs b/src/service/appservice/namespace_regex.rs new file mode 100644 index 00000000..3529fc0e --- /dev/null +++ b/src/service/appservice/namespace_regex.rs @@ -0,0 +1,70 @@ +use conduit::Result; +use regex::RegexSet; +use ruma::api::appservice::Namespace; + +/// Compiled regular expressions for a namespace +#[derive(Clone, Debug)] +pub struct NamespaceRegex { + pub exclusive: Option, + pub non_exclusive: Option, +} + +impl NamespaceRegex { + /// Checks if this namespace has rights to a namespace + #[inline] + #[must_use] + pub fn is_match(&self, heystack: &str) -> bool { + if self.is_exclusive_match(heystack) { + return true; + } + + if let Some(non_exclusive) = &self.non_exclusive { + if non_exclusive.is_match(heystack) { + return true; + } + } + false + } + + /// Checks if this namespace has exlusive rights to a namespace + #[inline] + #[must_use] + pub fn is_exclusive_match(&self, heystack: &str) -> bool { + if let Some(exclusive) = &self.exclusive { + if exclusive.is_match(heystack) { + return true; + } + } + false + } +} + +impl TryFrom> for NamespaceRegex { + type Error = regex::Error; + + fn try_from(value: Vec) -> Result { + let mut exclusive = Vec::with_capacity(value.len()); + let mut non_exclusive = Vec::with_capacity(value.len()); + + for namespace in value { + if namespace.exclusive { + exclusive.push(namespace.regex); + } else { + non_exclusive.push(namespace.regex); + } + } + + Ok(Self { + exclusive: if exclusive.is_empty() { + None + } else { + Some(RegexSet::new(exclusive)?) + }, + non_exclusive: if non_exclusive.is_empty() { + None + } else { + Some(RegexSet::new(non_exclusive)?) + }, + }) + } +} diff --git a/src/service/appservice/registration_info.rs b/src/service/appservice/registration_info.rs new file mode 100644 index 00000000..2c8595b1 --- /dev/null +++ b/src/service/appservice/registration_info.rs @@ -0,0 +1,39 @@ +use conduit::Result; +use ruma::{api::appservice::Registration, UserId}; + +use super::NamespaceRegex; + +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, +} + +impl RegistrationInfo { + #[must_use] + pub fn is_user_match(&self, user_id: &UserId) -> bool { + self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } + + #[inline] + #[must_use] + pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { + self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } +} + +impl TryFrom for RegistrationInfo { + type Error = regex::Error; + + fn try_from(value: Registration) -> Result { + Ok(Self { + users: value.namespaces.users.clone().try_into()?, + aliases: value.namespaces.aliases.clone().try_into()?, + rooms: value.namespaces.rooms.clone().try_into()?, + registration: value, + }) + } +}