From 9d4fa9a2201399a652e52cc1d76f0e2a3a5608d2 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 14 Jul 2021 07:07:08 +0000 Subject: [PATCH] Sqlite --- .gitignore | 1 + Cargo.lock | 143 ++++++++- Cargo.toml | 8 +- DEPLOY.md | 4 +- conduit-example.toml | 4 +- debian/postinst | 4 +- docker-compose.yml | 2 +- src/client_server/account.rs | 14 +- src/client_server/alias.rs | 11 +- src/client_server/backup.rs | 33 +-- src/client_server/config.rs | 13 +- src/client_server/context.rs | 7 +- src/client_server/device.rs | 15 +- src/client_server/directory.rs | 13 +- src/client_server/keys.rs | 21 +- src/client_server/media.rs | 15 +- src/client_server/membership.rs | 32 +- src/client_server/message.rs | 8 +- src/client_server/mod.rs | 4 +- src/client_server/presence.rs | 9 +- src/client_server/profile.rs | 15 +- src/client_server/push.rs | 25 +- src/client_server/read_marker.rs | 9 +- src/client_server/redact.rs | 6 +- src/client_server/room.rs | 14 +- src/client_server/search.rs | 6 +- src/client_server/session.rs | 12 +- src/client_server/state.rs | 17 +- src/client_server/sync.rs | 19 +- src/client_server/tag.rs | 11 +- src/client_server/to_device.rs | 7 +- src/client_server/typing.rs | 7 +- src/client_server/user_directory.rs | 7 +- src/database.rs | 382 ++++++++++++++++++------ src/database/abstraction.rs | 297 +------------------ src/database/abstraction/rocksdb.rs | 176 +++++++++++ src/database/abstraction/sled.rs | 119 ++++++++ src/database/abstraction/sqlite.rs | 444 ++++++++++++++++++++++++++++ src/database/account_data.rs | 2 +- src/database/admin.rs | 68 +++-- src/database/appservice.rs | 4 +- src/database/globals.rs | 30 +- src/database/pusher.rs | 2 +- src/database/rooms.rs | 12 +- src/database/sending.rs | 57 +++- src/error.rs | 6 + src/main.rs | 31 +- src/ruma_wrapper.rs | 9 +- src/server_server.rs | 51 ++-- 49 files changed, 1525 insertions(+), 681 deletions(-) create mode 100644 src/database/abstraction/rocksdb.rs create mode 100644 src/database/abstraction/sled.rs create mode 100644 src/database/abstraction/sqlite.rs diff --git a/.gitignore b/.gitignore index e2f4e882..1f5f395f 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ $RECYCLE.BIN/ # Conduit Rocket.toml conduit.toml +conduit.db # Etc. **/*.rs.bk diff --git a/Cargo.lock b/Cargo.lock index 7efeeac9..a0d7a700 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,6 +6,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" +[[package]] +name = "ahash" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +dependencies = [ + "getrandom 0.2.2", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.15" @@ -238,14 +249,17 @@ version = "0.1.0" dependencies = [ "base64 0.13.0", "bytes", + "crossbeam", "directories", "http", "image", "jsonwebtoken", "log", "lru-cache", + "num_cpus", "opentelemetry", "opentelemetry-jaeger", + "parking_lot", "pretty_env_logger", "rand 0.8.3", "regex", @@ -254,6 +268,7 @@ dependencies = [ "rocket", "rocksdb", "ruma", + "rusqlite", "rust-argon2", "rustls", "rustls-native-certs", @@ -340,10 +355,45 @@ dependencies = [ ] [[package]] -name = "crossbeam-epoch" -version = "0.9.3" +name = "crossbeam" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2584f639eb95fea8c798496315b297cf81b9b58b6d30ab066a75455333cf4b12" +checksum = "4ae5588f6b3c3cb05239e90bd110f257254aecd01e4635400391aeae07497845" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" dependencies = [ "cfg-if 1.0.0", "crossbeam-utils", @@ -353,12 +403,21 @@ dependencies = [ ] [[package]] -name = "crossbeam-utils" -version = "0.8.3" +name = "crossbeam-queue" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7e9d99fa91428effe99c5c6d4634cdeba32b8cf784fc428a2a687f61a952c49" +checksum = "9b10ddc024425c88c2ad148c1b0fd53f4c6d38db9697c9f1588381212fa657c9" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" dependencies = [ - "autocfg", "cfg-if 1.0.0", "lazy_static", ] @@ -547,6 +606,18 @@ dependencies = [ "termcolor", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "figment" version = "0.10.5" @@ -774,6 +845,24 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +dependencies = [ + "hashbrown 0.11.2", +] + [[package]] name = "heck" version = "0.3.2" @@ -920,7 +1009,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.9.1", "serde", ] @@ -1083,6 +1172,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -1484,6 +1584,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "pkg-config" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" + [[package]] name = "png" version = "0.16.8" @@ -2136,6 +2242,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "rusqlite" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57adcf67c8faaf96f3248c2a7b419a0dbc52ebe36ba83dd57fe83827c1ea4eb3" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "memchr", + "smallvec", +] + [[package]] name = "rust-argon2" version = "0.8.3" @@ -3007,6 +3128,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "vcpkg" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "025ce40a007e1907e58d5bc1a594def78e5573bb0b1160bc389634e8f12e4faa" + [[package]] name = "version_check" version = "0.9.3" diff --git a/Cargo.toml b/Cargo.toml index 426d242c..896140cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,11 +73,17 @@ tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" lru-cache = "0.1.2" +rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } +parking_lot = { version = "0.11.1", optional = true } +crossbeam = { version = "0.8.1", optional = true } +num_cpus = { version = "1.13.0", optional = true } [features] -default = ["conduit_bin", "backend_sled"] +default = ["conduit_bin", "backend_sqlite"] backend_sled = ["sled"] backend_rocksdb = ["rocksdb"] +backend_sqlite = ["sqlite"] +sqlite = ["rusqlite", "parking_lot", "crossbeam", "num_cpus", "tokio/signal"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/DEPLOY.md b/DEPLOY.md index 778d0e08..8e16c19a 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -114,11 +114,13 @@ allow_federation = true trusted_servers = ["matrix.org"] -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #workers = 4 # default: cpu core count * 2 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy + +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 200 ``` ## Setting the correct file permissions diff --git a/conduit-example.toml b/conduit-example.toml index db0bbb77..d184991a 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -35,7 +35,6 @@ max_request_size = 20_000_000 # in bytes trusted_servers = ["matrix.org"] -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 @@ -43,3 +42,6 @@ trusted_servers = ["matrix.org"] address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy proxy = "none" # more examples can be found at src/database/proxy.rs:6 + +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 200 \ No newline at end of file diff --git a/debian/postinst b/debian/postinst index 6a4cdb8a..824fd64e 100644 --- a/debian/postinst +++ b/debian/postinst @@ -73,10 +73,12 @@ max_request_size = 20_000_000 # in bytes # Enable jaeger to support monitoring and troubleshooting through jaeger. #allow_jaeger = false -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 + +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 200 EOF fi ;; diff --git a/docker-compose.yml b/docker-compose.yml index cf0d2c1d..d6437094 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -56,4 +56,4 @@ services: # - homeserver volumes: - db: + db: diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 5326a798..7f38eb18 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,7 +1,7 @@ -use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; +use std::{collections::BTreeMap, convert::TryInto}; -use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10; )] #[tracing::instrument(skip(db, body))] pub async fn get_register_available_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { // Validate user id @@ -85,7 +85,7 @@ pub async fn get_register_available_route( )] #[tracing::instrument(skip(db, body))] pub async fn register_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_registration() && !body.from_appservice { @@ -500,7 +500,7 @@ pub async fn register_route( )] #[tracing::instrument(skip(db, body))] pub async fn change_password_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -592,7 +592,7 @@ pub async fn whoami_route(body: Ruma) -> ConduitResult>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index a54bd36d..f5d9f64b 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Ruma}; use regex::Regex; use ruma::{ api::{ @@ -24,7 +21,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_alias_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if db.rooms.id_from_alias(&body.room_alias)?.is_some() { @@ -45,7 +42,7 @@ pub async fn create_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_alias_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; @@ -61,7 +58,7 @@ pub async fn delete_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_alias_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { get_alias_helper(&db, &body.room_alias).await diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index fcca676f..ccb17faa 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::backup::{ @@ -21,7 +18,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_backup_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +37,7 @@ pub async fn create_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_backup_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -58,7 +55,7 @@ pub async fn update_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_latest_backup_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +83,7 @@ pub async fn get_latest_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -113,7 +110,7 @@ pub async fn get_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -132,7 +129,7 @@ pub async fn delete_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -166,7 +163,7 @@ pub async fn add_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_sessions_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -198,7 +195,7 @@ pub async fn add_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_session_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -227,7 +224,7 @@ pub async fn add_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -243,7 +240,7 @@ pub async fn get_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_sessions_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -261,7 +258,7 @@ pub async fn get_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_session_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -283,7 +280,7 @@ pub async fn get_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -306,7 +303,7 @@ pub async fn delete_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_sessions_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -329,7 +326,7 @@ pub async fn delete_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_session_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/config.rs b/src/client_server/config.rs index 829bf94a..4f33689a 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -25,7 +22,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_global_account_data_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -60,7 +57,7 @@ pub async fn set_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_account_data_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -92,7 +89,7 @@ pub async fn set_room_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_global_account_data_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -119,7 +116,7 @@ pub async fn get_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_account_data_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/context.rs b/src/client_server/context.rs index b86fd0bf..dbc121e3 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,7 +1,6 @@ -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; -use std::{convert::TryFrom, sync::Arc}; +use std::convert::TryFrom; #[cfg(feature = "conduit_bin")] use rocket::get; @@ -12,7 +11,7 @@ use rocket::get; )] #[tracing::instrument(skip(db, body))] pub async fn get_context_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 2c4b527c..a10d7887 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::{ @@ -20,7 +17,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_devices_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +37,7 @@ pub async fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_device_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -59,7 +56,7 @@ pub async fn get_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_device_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +82,7 @@ pub async fn update_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_device_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -139,7 +136,7 @@ pub async fn delete_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_devices_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 1b6b1d7b..4a440fd9 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ api::{ @@ -35,7 +32,7 @@ use rocket::{get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { get_public_rooms_filtered_helper( @@ -55,7 +52,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let response = get_public_rooms_filtered_helper( @@ -84,7 +81,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_visibility_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -114,7 +111,7 @@ pub async fn set_room_visibility_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_visibility_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_room_visibility::Response { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 60269813..621e5ddd 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,5 +1,5 @@ -use super::{State, SESSION_ID_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Result, Ruma}; +use super::SESSION_ID_LENGTH; +use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -14,10 +14,7 @@ use ruma::{ encryption::UnsignedDeviceInfo, DeviceId, DeviceKeyAlgorithm, UserId, }; -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, -}; +use std::collections::{BTreeMap, HashSet}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -28,7 +25,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn upload_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -77,7 +74,7 @@ pub async fn upload_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -98,7 +95,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db)?; @@ -114,7 +111,7 @@ pub async fn claim_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signing_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +174,7 @@ pub async fn upload_signing_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signatures_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -238,7 +235,7 @@ pub async fn upload_signatures_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_key_changes_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 0b1fbd7b..eaaf9399 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,20 +1,21 @@ -use super::State; -use crate::{database::media::FileMeta, utils, ConduitResult, Database, Error, Ruma}; +use crate::{ + database::media::FileMeta, database::DatabaseGuard, utils, ConduitResult, Error, Ruma, +}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, }; +use std::convert::TryInto; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; -use std::{convert::TryInto, sync::Arc}; const MXC_LENGTH: usize = 32; #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[tracing::instrument(skip(db))] pub async fn get_media_config_route( - db: State<'_, Arc>, + db: DatabaseGuard, ) -> ConduitResult { Ok(get_media_config::Response { upload_size: db.globals.max_request_size().into(), @@ -28,7 +29,7 @@ pub async fn get_media_config_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_content_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!( @@ -66,7 +67,7 @@ pub async fn create_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -119,7 +120,7 @@ pub async fn get_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_thumbnail_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 5c57b68a..4667f25d 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -1,6 +1,6 @@ -use super::State; use crate::{ client_server, + database::DatabaseGuard, pdu::{PduBuilder, PduEvent}, server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; @@ -44,7 +44,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -65,14 +65,18 @@ pub async fn join_room_by_id_route( servers.insert(body.room_id.server_name().to_owned()); - join_room_by_id_helper( + let ret = join_room_by_id_helper( &db, body.sender_user.as_ref(), &body.room_id, &servers, body.third_party_signed.as_ref(), ) - .await + .await; + + db.flush().await?; + + ret } #[cfg_attr( @@ -81,7 +85,7 @@ pub async fn join_room_by_id_route( )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_or_alias_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -135,7 +139,7 @@ pub async fn join_room_by_id_or_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn leave_room_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -153,7 +157,7 @@ pub async fn leave_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn invite_user_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -173,7 +177,7 @@ pub async fn invite_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn kick_user_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -223,7 +227,7 @@ pub async fn kick_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn ban_user_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -281,7 +285,7 @@ pub async fn ban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn unban_user_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -330,7 +334,7 @@ pub async fn unban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn forget_room_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -348,7 +352,7 @@ pub async fn forget_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_rooms_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -369,7 +373,7 @@ pub async fn joined_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_member_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -399,7 +403,7 @@ pub async fn get_member_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_members_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 0d19f347..7e898b11 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -11,7 +10,6 @@ use ruma::{ use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, - sync::Arc, }; #[cfg(feature = "conduit_bin")] @@ -23,7 +21,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_message_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +84,7 @@ pub async fn send_message_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_message_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/mod.rs b/src/client_server/mod.rs index 825dbbb9..f211a571 100644 --- a/src/client_server/mod.rs +++ b/src/client_server/mod.rs @@ -64,9 +64,7 @@ pub use voip::*; use super::State; #[cfg(feature = "conduit_bin")] use { - crate::ConduitResult, - rocket::{options, State}, - ruma::api::client::r0::to_device::send_event_to_device, + crate::ConduitResult, rocket::options, ruma::api::client::r0::to_device::send_event_to_device, }; pub const DEVICE_ID_LENGTH: usize = 10; diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index ce80dfd7..bfe638fb 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,7 +1,6 @@ -use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; -use std::{convert::TryInto, sync::Arc, time::Duration}; +use std::{convert::TryInto, time::Duration}; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -12,7 +11,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_presence_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +52,7 @@ pub async fn set_presence_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_presence_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 4e9a37b6..5281a4a2 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -10,10 +9,10 @@ use ruma::{ events::EventType, serde::Raw, }; +use std::convert::TryInto; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; -use std::{convert::TryInto, sync::Arc}; #[cfg_attr( feature = "conduit_bin", @@ -21,7 +20,7 @@ use std::{convert::TryInto, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_displayname_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -108,7 +107,7 @@ pub async fn set_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_displayname_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_display_name::Response { @@ -123,7 +122,7 @@ pub async fn get_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_avatar_url_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -210,7 +209,7 @@ pub async fn set_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_avatar_url_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_avatar_url::Response { @@ -225,7 +224,7 @@ pub async fn get_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_profile_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.users.exists(&body.user_id)? { diff --git a/src/client_server/push.rs b/src/client_server/push.rs index d6f62126..794cbce4 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,7 +21,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrules_all_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -49,7 +46,7 @@ pub async fn get_pushrules_all_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -103,7 +100,7 @@ pub async fn get_pushrule_route( )] #[tracing::instrument(skip(db, req))] pub async fn set_pushrule_route( - db: State<'_, Arc>, + db: DatabaseGuard, req: Ruma>, ) -> ConduitResult { let sender_user = req.sender_user.as_ref().expect("user is authenticated"); @@ -206,7 +203,7 @@ pub async fn set_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_actions_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -265,7 +262,7 @@ pub async fn get_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_actions_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -339,7 +336,7 @@ pub async fn set_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_enabled_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -400,7 +397,7 @@ pub async fn get_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_enabled_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -479,7 +476,7 @@ pub async fn set_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_pushrule_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -548,7 +545,7 @@ pub async fn delete_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushers_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -565,7 +562,7 @@ pub async fn get_pushers_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushers_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 837170ff..fe49af9d 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -9,10 +8,10 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, }; +use std::collections::BTreeMap; #[cfg(feature = "conduit_bin")] use rocket::post; -use std::{collections::BTreeMap, sync::Arc}; #[cfg_attr( feature = "conduit_bin", @@ -20,7 +19,7 @@ use std::{collections::BTreeMap, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_read_marker_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -87,7 +86,7 @@ pub async fn set_read_marker_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_receipt_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index e1930823..3db27716 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,10 +1,8 @@ -use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Ruma}; use ruma::{ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, }; -use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::put; @@ -15,7 +13,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn redact_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/room.rs b/src/client_server/room.rs index b33b5500..43625fe5 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -1,5 +1,7 @@ -use super::State; -use crate::{client_server::invite_helper, pdu::PduBuilder, ConduitResult, Database, Error, Ruma}; +use crate::{ + client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Error, + Ruma, +}; use log::info; use ruma::{ api::client::{ @@ -13,7 +15,7 @@ use ruma::{ serde::Raw, RoomAliasId, RoomId, RoomVersionId, }; -use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc}; +use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -24,7 +26,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn create_room_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -294,7 +296,7 @@ pub async fn create_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -322,7 +324,7 @@ pub async fn get_room_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, _room_id: String, ) -> ConduitResult { diff --git a/src/client_server/search.rs b/src/client_server/search.rs index 5fc64d09..ec23dd40 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,7 +1,5 @@ -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; -use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::post; @@ -14,7 +12,7 @@ use std::collections::BTreeMap; )] #[tracing::instrument(skip(db, body))] pub async fn search_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/session.rs b/src/client_server/session.rs index dd504f18..7ad792b2 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,7 +1,5 @@ -use std::sync::Arc; - -use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -52,7 +50,7 @@ pub async fn get_login_types_route() -> ConduitResult )] #[tracing::instrument(skip(db, body))] pub async fn login_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { // Validate login method @@ -169,7 +167,7 @@ pub async fn login_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -197,7 +195,7 @@ pub async fn logout_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_all_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/state.rs b/src/client_server/state.rs index be52834a..68246d54 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,7 +1,6 @@ -use std::sync::Arc; - -use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; +use crate::{ + database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma, +}; use ruma::{ api::client::{ error::ErrorKind, @@ -27,7 +26,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_key_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +52,7 @@ pub async fn send_state_event_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_empty_key_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -79,7 +78,7 @@ pub async fn send_state_event_for_empty_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -126,7 +125,7 @@ pub async fn get_state_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_key_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +176,7 @@ pub async fn get_state_events_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_empty_key_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 69511fa1..c57f1da1 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma, RumaResponse}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, @@ -35,13 +34,15 @@ use rocket::{get, tokio}; )] #[tracing::instrument(skip(db, body))] pub async fn sync_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let mut rx = match db + let arc_db = Arc::new(db); + + let mut rx = match arc_db .globals .sync_receivers .write() @@ -52,7 +53,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -68,7 +69,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -104,7 +105,7 @@ pub async fn sync_events_route( } pub async fn sync_helper_wrapper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, @@ -142,11 +143,13 @@ pub async fn sync_helper_wrapper( } } + drop(db); + let _ = tx.send(Some(r.map(|(r, _)| r.into()))); } async fn sync_helper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 2382fe0a..17df2c2e 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -1,10 +1,9 @@ -use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Ruma}; use ruma::{ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, }; -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; #[cfg(feature = "conduit_bin")] use rocket::{delete, get, put}; @@ -15,7 +14,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn update_tag_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -52,7 +51,7 @@ pub async fn update_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_tag_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +85,7 @@ pub async fn delete_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_tags_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index ada0c9a1..3bb135e7 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, to_device::DeviceIdOrAllDevices, @@ -16,7 +13,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn send_event_to_device_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index a0a5d430..7a590af9 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Ruma}; use create_typing_event::Typing; use ruma::api::client::r0::typing::create_typing_event; @@ -14,7 +11,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub fn create_typing_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index d7c16d7c..14b85a65 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Ruma}; use ruma::api::client::r0::user_directory::search_users; #[cfg(feature = "conduit_bin")] @@ -13,7 +10,7 @@ use rocket::post; )] #[tracing::instrument(skip(db, body))] pub async fn search_users_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let limit = u64::from(body.limit) as usize; diff --git a/src/database.rs b/src/database.rs index ec4052cb..ac173720 100644 --- a/src/database.rs +++ b/src/database.rs @@ -19,16 +19,23 @@ use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; use lru_cache::LruCache; -use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; +use rocket::{ + futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}, + outcome::IntoOutcome, + request::{FromRequest, Request}, + try_outcome, State, +}; use ruma::{DeviceId, ServerName, UserId}; -use serde::Deserialize; +use serde::{de::IgnoredAny, Deserialize}; use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, fs::{self, remove_dir_all}, io::Write, + ops::Deref, + path::Path, sync::{Arc, RwLock}, }; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use self::proxy::ProxyConfig; @@ -36,8 +43,16 @@ use self::proxy::ProxyConfig; pub struct Config { server_name: Box, database_path: String, - #[serde(default = "default_cache_capacity")] - cache_capacity: u32, + #[serde(default = "default_db_cache_capacity_mb")] + db_cache_capacity_mb: f64, + #[serde(default = "default_sqlite_read_pool_size")] + sqlite_read_pool_size: usize, + #[serde(default = "true_fn")] + sqlite_wal_clean_timer: bool, + #[serde(default = "default_sqlite_wal_clean_second_interval")] + sqlite_wal_clean_second_interval: u32, + #[serde(default = "default_sqlite_wal_clean_second_timeout")] + sqlite_wal_clean_second_timeout: u32, #[serde(default = "default_max_request_size")] max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] @@ -57,6 +72,29 @@ pub struct Config { trusted_servers: Vec>, #[serde(default = "default_log")] pub log: String, + + #[serde(flatten)] + catchall: BTreeMap, +} + +const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; + +impl Config { + pub fn warn_deprecated(&self) { + let mut was_deprecated = false; + for key in self + .catchall + .keys() + .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) + { + log::warn!("Config parameter {} is deprecated", key); + was_deprecated = true; + } + + if was_deprecated { + log::warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted"); + } + } } fn false_fn() -> bool { @@ -67,8 +105,20 @@ fn true_fn() -> bool { true } -fn default_cache_capacity() -> u32 { - 1024 * 1024 * 1024 +fn default_db_cache_capacity_mb() -> f64 { + 200.0 +} + +fn default_sqlite_read_pool_size() -> usize { + num_cpus::get().max(1) +} + +fn default_sqlite_wal_clean_second_interval() -> u32 { + 60 * 60 +} + +fn default_sqlite_wal_clean_second_timeout() -> u32 { + 2 } fn default_max_request_size() -> u32 { @@ -84,12 +134,16 @@ fn default_log() -> String { } #[cfg(feature = "sled")] -pub type Engine = abstraction::SledEngine; +pub type Engine = abstraction::sled::Engine; #[cfg(feature = "rocksdb")] -pub type Engine = abstraction::RocksDbEngine; +pub type Engine = abstraction::rocksdb::Engine; + +#[cfg(feature = "sqlite")] +pub type Engine = abstraction::sqlite::Engine; pub struct Database { + _db: Arc, pub globals: globals::Globals, pub users: users::Users, pub uiaa: uiaa::Uiaa, @@ -117,8 +171,37 @@ impl Database { Ok(()) } + fn check_sled_or_sqlite_db(config: &Config) -> Result<()> { + let path = Path::new(&config.database_path); + + #[cfg(feature = "backend_sqlite")] + { + let sled_exists = path.join("db").exists(); + let sqlite_exists = path.join("conduit.db").exists(); + if sled_exists { + if sqlite_exists { + // most likely an in-place directory, only warn + log::warn!("Both sled and sqlite databases are detected in database directory"); + log::warn!("Currently running from the sqlite database, but consider removing sled database files to free up space") + } else { + log::error!( + "Sled database detected, conduit now uses sqlite for database operations" + ); + log::error!("This database must be converted to sqlite, go to https://github.com/ShadowJonathan/conduit_toolbox#conduit_sled_to_sqlite"); + return Err(Error::bad_config( + "sled database detected, migrate to sqlite", + )); + } + } + } + + Ok(()) + } + /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result> { + pub async fn load_or_create(config: Config) -> Result>> { + Self::check_sled_or_sqlite_db(&config)?; + let builder = Engine::open(&config)?; if config.max_request_size < 1024 { @@ -128,7 +211,8 @@ impl Database { let (admin_sender, admin_receiver) = mpsc::unbounded(); let (sending_sender, sending_receiver) = mpsc::unbounded(); - let db = Arc::new(Self { + let db = Arc::new(TokioRwLock::from(Self { + _db: builder.clone(), users: users::Users { userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, @@ -231,100 +315,112 @@ impl Database { globals: globals::Globals::load( builder.open_tree("global")?, builder.open_tree("server_signingkeys")?, - config, + config.clone(), )?, - }); + })); - // MIGRATIONS - // TODO: database versions of new dbs should probably not be 0 - if db.globals.database_version()? < 1 { - for (roomserverid, _) in db.rooms.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); - let servername = match parts.next() { - Some(s) => s, - None => { - error!("Migration: Invalid roomserverid in db."); + { + let db = db.read().await; + // MIGRATIONS + // TODO: database versions of new dbs should probably not be 0 + if db.globals.database_version()? < 1 { + for (roomserverid, _) in db.rooms.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xff); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + } + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xff); + serverroomid.extend_from_slice(room_id); + + db.rooms.serverroomids.insert(&serverroomid, &[])?; + } + + db.globals.bump_database_version(1)?; + + println!("Migration: 0 -> 1 finished"); + } + + if db.globals.database_version()? < 2 { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.users.userid_password.iter() { + let password = utils::string_from_bytes(&password); + + let empty_hashed_password = password.map_or(false, |password| { + argon2::verify_encoded(&password, b"").unwrap_or(false) + }); + + if empty_hashed_password { + db.users.userid_password.insert(&userid, b"")?; + } + } + + db.globals.bump_database_version(2)?; + + println!("Migration: 1 -> 2 finished"); + } + + if db.globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.media.mediaid_file.iter() { + if content.len() == 0 { continue; } - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); - serverroomid.extend_from_slice(room_id); - db.rooms.serverroomids.insert(&serverroomid, &[])?; - } - - db.globals.bump_database_version(1)?; - - println!("Migration: 0 -> 1 finished"); - } - - if db.globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.users.userid_password.iter() { - let password = utils::string_from_bytes(&password); - - let empty_hashed_password = password.map_or(false, |password| { - argon2::verify_encoded(&password, b"").unwrap_or(false) - }); - - if empty_hashed_password { - db.users.userid_password.insert(&userid, b"")?; - } - } - - db.globals.bump_database_version(2)?; - - println!("Migration: 1 -> 2 finished"); - } - - if db.globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.media.mediaid_file.iter() { - if content.len() == 0 { - continue; + let path = db.globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.media.mediaid_file.insert(&key, &[])?; } - let path = db.globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.media.mediaid_file.insert(&key, &[])?; + db.globals.bump_database_version(3)?; + + println!("Migration: 2 -> 3 finished"); } - db.globals.bump_database_version(3)?; - - println!("Migration: 2 -> 3 finished"); - } - - if db.globals.database_version()? < 4 { - // Add federated users to db as deactivated - for our_user in db.users.iter() { - let our_user = our_user?; - if db.users.is_deactivated(&our_user)? { - continue; - } - for room in db.rooms.rooms_joined(&our_user) { - for user in db.rooms.room_members(&room?) { - let user = user?; - if user.server_name() != db.globals.server_name() { - println!("Migration: Creating user {}", user); - db.users.create(&user, None)?; + if db.globals.database_version()? < 4 { + // Add federated users to db as deactivated + for our_user in db.users.iter() { + let our_user = our_user?; + if db.users.is_deactivated(&our_user)? { + continue; + } + for room in db.rooms.rooms_joined(&our_user) { + for user in db.rooms.room_members(&room?) { + let user = user?; + if user.server_name() != db.globals.server_name() { + println!("Migration: Creating user {}", user); + db.users.create(&user, None)?; + } } } } + + db.globals.bump_database_version(4)?; + + println!("Migration: 3 -> 4 finished"); } - - db.globals.bump_database_version(4)?; - - println!("Migration: 3 -> 4 finished"); } - // This data is probably outdated - db.rooms.edus.presenceid_presence.clear()?; + let guard = db.read().await; - db.admin.start_handler(Arc::clone(&db), admin_receiver); - db.sending.start_handler(Arc::clone(&db), sending_receiver); + // This data is probably outdated + guard.rooms.edus.presenceid_presence.clear()?; + + guard.admin.start_handler(Arc::clone(&db), admin_receiver); + guard + .sending + .start_handler(Arc::clone(&db), sending_receiver); + + drop(guard); + + #[cfg(feature = "sqlite")] + Self::start_wal_clean_task(&db, &config).await; Ok(db) } @@ -413,13 +509,113 @@ impl Database { .watch_prefix(&userid_bytes), ); + futures.push(Box::pin(self.globals.rotate.watch())); + // Wait until one of them finds something futures.next().await; } pub async fn flush(&self) -> Result<()> { - // noop while we don't use sled 1.0 - //self._db.flush_async().await?; - Ok(()) + let start = std::time::Instant::now(); + + let res = self._db.flush(); + + log::debug!("flush: took {:?}", start.elapsed()); + + res + } + + #[cfg(feature = "sqlite")] + pub fn flush_wal(&self) -> Result<()> { + self._db.flush_wal() + } + + #[cfg(feature = "sqlite")] + pub async fn start_wal_clean_task(lock: &Arc>, config: &Config) { + use tokio::{ + select, + signal::unix::{signal, SignalKind}, + time::{interval, timeout}, + }; + + use std::{ + sync::Weak, + time::{Duration, Instant}, + }; + + let weak: Weak> = Arc::downgrade(&lock); + + let lock_timeout = Duration::from_secs(config.sqlite_wal_clean_second_timeout as u64); + let timer_interval = Duration::from_secs(config.sqlite_wal_clean_second_interval as u64); + let do_timer = config.sqlite_wal_clean_timer; + + tokio::spawn(async move { + let mut i = interval(timer_interval); + let mut s = signal(SignalKind::hangup()).unwrap(); + + loop { + select! { + _ = i.tick(), if do_timer => { + log::info!(target: "wal-trunc", "Timer ticked") + } + _ = s.recv() => { + log::info!(target: "wal-trunc", "Received SIGHUP") + } + }; + + if let Some(arc) = Weak::upgrade(&weak) { + log::info!(target: "wal-trunc", "Rotating sync helpers..."); + // This actually creates a very small race condition between firing this and trying to acquire the subsequent write lock. + // Though it is not a huge deal if the write lock doesn't "catch", as it'll harmlessly time out. + arc.read().await.globals.rotate.fire(); + + log::info!(target: "wal-trunc", "Locking..."); + let guard = { + if let Ok(guard) = timeout(lock_timeout, arc.write()).await { + guard + } else { + log::info!(target: "wal-trunc", "Lock failed in timeout, canceled."); + continue; + } + }; + log::info!(target: "wal-trunc", "Locked, flushing..."); + let start = Instant::now(); + if let Err(e) = guard.flush_wal() { + log::error!(target: "wal-trunc", "Errored: {}", e); + } else { + log::info!(target: "wal-trunc", "Flushed in {:?}", start.elapsed()); + } + } else { + break; + } + } + }); + } +} + +pub struct DatabaseGuard(OwnedRwLockReadGuard); + +impl Deref for DatabaseGuard { + type Target = OwnedRwLockReadGuard; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DatabaseGuard { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome { + let db = try_outcome!(req.guard::>>>().await); + + Ok(DatabaseGuard(Arc::clone(&db).read_owned().await)).or_forward(()) + } +} + +impl Into for OwnedRwLockReadGuard { + fn into(self) -> DatabaseGuard { + DatabaseGuard(self) } } diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index f81c9def..fb11ba0b 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,28 +1,21 @@ use super::Config; -use crate::{utils, Result}; -use log::warn; +use crate::Result; + use std::{future::Future, pin::Pin, sync::Arc}; #[cfg(feature = "rocksdb")] -use std::{collections::BTreeMap, sync::RwLock}; +pub mod rocksdb; #[cfg(feature = "sled")] -pub struct SledEngine(sled::Db); -#[cfg(feature = "sled")] -pub struct SledEngineTree(sled::Tree); +pub mod sled; -#[cfg(feature = "rocksdb")] -pub struct RocksDbEngine(rocksdb::DBWithThreadMode); -#[cfg(feature = "rocksdb")] -pub struct RocksDbEngineTree<'a> { - db: Arc, - name: &'a str, - watchers: RwLock, Vec>>>, -} +#[cfg(feature = "sqlite")] +pub mod sqlite; pub trait DatabaseEngine: Sized { fn open(config: &Config) -> Result>; fn open_tree(self: &Arc, name: &'static str) -> Result>; + fn flush(self: &Arc) -> Result<()>; } pub trait Tree: Send + Sync { @@ -32,20 +25,20 @@ pub trait Tree: Send + Sync { fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a>; fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box, Box<[u8]>)> + 'a>; + ) -> Box, Vec)> + Send + 'a>; fn increment(&self, key: &[u8]) -> Result>; fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a>; + ) -> Box, Vec)> + Send + 'a>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; @@ -57,273 +50,3 @@ pub trait Tree: Send + Sync { Ok(()) } } - -#[cfg(feature = "sled")] -impl DatabaseEngine for SledEngine { - fn open(config: &Config) -> Result> { - Ok(Arc::new(SledEngine( - sled::Config::default() - .path(&config.database_path) - .cache_capacity(config.cache_capacity as u64) - .use_compression(true) - .open()?, - ))) - } - - fn open_tree(self: &Arc, name: &'static str) -> Result> { - Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) - } -} - -#[cfg(feature = "sled")] -impl Tree for SledEngineTree { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.0.get(key)?.map(|v| v.to_vec())) - } - - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - self.0.insert(key, value)?; - Ok(()) - } - - fn remove(&self, key: &[u8]) -> Result<()> { - self.0.remove(key)?; - Ok(()) - } - - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a> { - Box::new( - self.0 - .iter() - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), - ) - } - - fn iter_from( - &self, - from: &[u8], - backwards: bool, - ) -> Box, Box<[u8]>)>> { - let iter = if backwards { - self.0.range(..from) - } else { - self.0.range(from..) - }; - - let iter = iter - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); - - if backwards { - Box::new(iter.rev()) - } else { - Box::new(iter) - } - } - - fn increment(&self, key: &[u8]) -> Result> { - Ok(self - .0 - .update_and_fetch(key, utils::increment) - .map(|o| o.expect("increment always sets a value").to_vec())?) - } - - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a> { - let iter = self - .0 - .scan_prefix(prefix) - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); - - Box::new(iter) - } - - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - let prefix = prefix.to_vec(); - Box::pin(async move { - self.0.watch_prefix(prefix).await; - }) - } -} - -#[cfg(feature = "rocksdb")] -impl DatabaseEngine for RocksDbEngine { - fn open(config: &Config) -> Result> { - let mut db_opts = rocksdb::Options::default(); - db_opts.create_if_missing(true); - db_opts.set_max_open_files(16); - db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); - db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); - db_opts.set_target_file_size_base(256 << 20); - db_opts.set_write_buffer_size(256 << 20); - - let mut block_based_options = rocksdb::BlockBasedOptions::default(); - block_based_options.set_block_size(512 << 10); - db_opts.set_block_based_table_factory(&block_based_options); - - let cfs = rocksdb::DBWithThreadMode::::list_cf( - &db_opts, - &config.database_path, - ) - .unwrap_or_default(); - - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( - &db_opts, - &config.database_path, - cfs.iter() - .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), - )?; - - Ok(Arc::new(RocksDbEngine(db))) - } - - fn open_tree(self: &Arc, name: &'static str) -> Result> { - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - // Create if it doesn't exist - let _ = self.0.create_cf(name, &options); - - Ok(Arc::new(RocksDbEngineTree { - name, - db: Arc::clone(self), - watchers: RwLock::new(BTreeMap::new()), - })) - } -} - -#[cfg(feature = "rocksdb")] -impl RocksDbEngineTree<'_> { - fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { - self.db.0.cf_handle(self.name).unwrap() - } -} - -#[cfg(feature = "rocksdb")] -impl Tree for RocksDbEngineTree<'_> { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.db.0.get_cf(self.cf(), key)?) - } - - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let watchers = self.watchers.read().unwrap(); - let mut triggered = Vec::new(); - - for length in 0..=key.len() { - if watchers.contains_key(&key[..length]) { - triggered.push(&key[..length]); - } - } - - drop(watchers); - - if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); - for prefix in triggered { - if let Some(txs) = watchers.remove(prefix) { - for tx in txs { - let _ = tx.send(()); - } - } - } - } - - Ok(self.db.0.put_cf(self.cf(), key, value)?) - } - - fn remove(&self, key: &[u8]) -> Result<()> { - Ok(self.db.0.delete_cf(self.cf(), key)?) - } - - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a> { - Box::new( - self.db - .0 - .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), - ) - } - - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box, Box<[u8]>)> + 'a> { - Box::new(self.db.0.iterator_cf( - self.cf(), - rocksdb::IteratorMode::From( - from, - if backwards { - rocksdb::Direction::Reverse - } else { - rocksdb::Direction::Forward - }, - ), - )) - } - - fn increment(&self, key: &[u8]) -> Result> { - let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); - dbg!(stats.mem_table_total); - dbg!(stats.mem_table_unflushed); - dbg!(stats.mem_table_readers_total); - dbg!(stats.cache_total); - // TODO: atomic? - let old = self.get(key)?; - let new = utils::increment(old.as_deref()).unwrap(); - self.insert(key, &new)?; - Ok(new) - } - - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a> { - Box::new( - self.db - .0 - .iterator_cf( - self.cf(), - rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), - ) - .take_while(move |(k, _)| k.starts_with(&prefix)), - ) - } - - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - let (tx, rx) = tokio::sync::oneshot::channel(); - - self.watchers - .write() - .unwrap() - .entry(prefix.to_vec()) - .or_default() - .push(tx); - - Box::pin(async move { - // Tx is never destroyed - rx.await.unwrap(); - }) - } -} diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs new file mode 100644 index 00000000..b9961302 --- /dev/null +++ b/src/database/abstraction/rocksdb.rs @@ -0,0 +1,176 @@ +use super::super::Config; +use crate::{utils, Result}; + +use std::{future::Future, pin::Pin, sync::Arc}; + +use super::{DatabaseEngine, Tree}; + +use std::{collections::BTreeMap, sync::RwLock}; + +pub struct Engine(rocksdb::DBWithThreadMode); + +pub struct RocksDbEngineTree<'a> { + db: Arc, + name: &'a str, + watchers: RwLock, Vec>>>, +} + +impl DatabaseEngine for Engine { + fn open(config: &Config) -> Result> { + let mut db_opts = rocksdb::Options::default(); + db_opts.create_if_missing(true); + db_opts.set_max_open_files(16); + db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); + db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); + db_opts.set_target_file_size_base(256 << 20); + db_opts.set_write_buffer_size(256 << 20); + + let mut block_based_options = rocksdb::BlockBasedOptions::default(); + block_based_options.set_block_size(512 << 10); + db_opts.set_block_based_table_factory(&block_based_options); + + let cfs = rocksdb::DBWithThreadMode::::list_cf( + &db_opts, + &config.database_path, + ) + .unwrap_or_default(); + + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( + &db_opts, + &config.database_path, + cfs.iter() + .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), + )?; + + Ok(Arc::new(Engine(db))) + } + + fn open_tree(self: &Arc, name: &'static str) -> Result> { + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + // Create if it doesn't exist + let _ = self.0.create_cf(name, &options); + + Ok(Arc::new(RocksDbEngineTree { + name, + db: Arc::clone(self), + watchers: RwLock::new(BTreeMap::new()), + })) + } +} + +impl RocksDbEngineTree<'_> { + fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { + self.db.0.cf_handle(self.name).unwrap() + } +} + +impl Tree for RocksDbEngineTree<'_> { + fn get(&self, key: &[u8]) -> Result>> { + Ok(self.db.0.get_cf(self.cf(), key)?) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let watchers = self.watchers.read().unwrap(); + let mut triggered = Vec::new(); + + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } + + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write().unwrap(); + for prefix in triggered { + if let Some(txs) = watchers.remove(prefix) { + for tx in txs { + let _ = tx.send(()); + } + } + } + } + + Ok(self.db.0.put_cf(self.cf(), key, value)?) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + Ok(self.db.0.delete_cf(self.cf(), key)?) + } + + fn iter<'a>(&'a self) -> Box, Vec)> + Send + Sync + 'a> { + Box::new( + self.db + .0 + .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), + ) + } + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box, Vec)> + 'a> { + Box::new(self.db.0.iterator_cf( + self.cf(), + rocksdb::IteratorMode::From( + from, + if backwards { + rocksdb::Direction::Reverse + } else { + rocksdb::Direction::Forward + }, + ), + )) + } + + fn increment(&self, key: &[u8]) -> Result> { + let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); + dbg!(stats.mem_table_total); + dbg!(stats.mem_table_unflushed); + dbg!(stats.mem_table_readers_total); + dbg!(stats.cache_total); + // TODO: atomic? + let old = self.get(key)?; + let new = utils::increment(old.as_deref()).unwrap(); + self.insert(key, &new)?; + Ok(new) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box, Vec)> + Send + 'a> { + Box::new( + self.db + .0 + .iterator_cf( + self.cf(), + rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), + ) + .take_while(move |(k, _)| k.starts_with(&prefix)), + ) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.watchers + .write() + .unwrap() + .entry(prefix.to_vec()) + .or_default() + .push(tx); + + Box::pin(async move { + // Tx is never destroyed + rx.await.unwrap(); + }) + } +} diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs new file mode 100644 index 00000000..271be1e9 --- /dev/null +++ b/src/database/abstraction/sled.rs @@ -0,0 +1,119 @@ +use super::super::Config; +use crate::{utils, Result}; +use log::warn; +use std::{future::Future, pin::Pin, sync::Arc}; + +use super::{DatabaseEngine, Tree}; + +pub struct Engine(sled::Db); + +pub struct SledEngineTree(sled::Tree); + +impl DatabaseEngine for Engine { + fn open(config: &Config) -> Result> { + Ok(Arc::new(Engine( + sled::Config::default() + .path(&config.database_path) + .cache_capacity((config.db_cache_capacity_mb * 1024 * 1024) as u64) + .use_compression(true) + .open()?, + ))) + } + + fn open_tree(self: &Arc, name: &'static str) -> Result> { + Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) + } + + fn flush(self: &Arc) -> Result<()> { + Ok(()) // noop + } +} + +impl Tree for SledEngineTree { + fn get(&self, key: &[u8]) -> Result>> { + Ok(self.0.get(key)?.map(|v| v.to_vec())) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + self.0.insert(key, value)?; + Ok(()) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + self.0.remove(key)?; + Ok(()) + } + + fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { + Box::new( + self.0 + .iter() + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), + ) + } + + fn iter_from( + &self, + from: &[u8], + backwards: bool, + ) -> Box, Vec)> + Send> { + let iter = if backwards { + self.0.range(..from) + } else { + self.0.range(from..) + }; + + let iter = iter + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + if backwards { + Box::new(iter.rev()) + } else { + Box::new(iter) + } + } + + fn increment(&self, key: &[u8]) -> Result> { + Ok(self + .0 + .update_and_fetch(key, utils::increment) + .map(|o| o.expect("increment always sets a value").to_vec())?) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box, Vec)> + Send + 'a> { + let iter = self + .0 + .scan_prefix(prefix) + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + Box::new(iter) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let prefix = prefix.to_vec(); + Box::pin(async move { + self.0.watch_prefix(prefix).await; + }) + } +} diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs new file mode 100644 index 00000000..22a55593 --- /dev/null +++ b/src/database/abstraction/sqlite.rs @@ -0,0 +1,444 @@ +use std::{ + collections::BTreeMap, + future::Future, + ops::Deref, + path::{Path, PathBuf}, + pin::Pin, + sync::Arc, + thread, + time::{Duration, Instant}, +}; + +use crate::{database::Config, Result}; + +use super::{DatabaseEngine, Tree}; + +use log::debug; + +use crossbeam::channel::{bounded, Sender as ChannelSender}; +use parking_lot::{Mutex, MutexGuard, RwLock}; +use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; + +use tokio::sync::oneshot::Sender; + +// const SQL_CREATE_TABLE: &str = +// "CREATE TABLE IF NOT EXISTS {} {{ \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL }}"; +// const SQL_SELECT: &str = "SELECT value FROM {} WHERE key = ?"; +// const SQL_INSERT: &str = "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)"; +// const SQL_DELETE: &str = "DELETE FROM {} WHERE key = ?"; +// const SQL_SELECT_ITER: &str = "SELECT key, value FROM {}"; +// const SQL_SELECT_PREFIX: &str = "SELECT key, value FROM {} WHERE key LIKE ?||'%' ORDER BY key ASC"; +// const SQL_SELECT_ITER_FROM_FORWARDS: &str = "SELECT key, value FROM {} WHERE key >= ? ORDER BY ASC"; +// const SQL_SELECT_ITER_FROM_BACKWARDS: &str = +// "SELECT key, value FROM {} WHERE key <= ? ORDER BY DESC"; + +struct Pool { + writer: Mutex, + readers: Vec>, + spill_tracker: Arc<()>, + path: PathBuf, +} + +pub const MILLI: Duration = Duration::from_millis(1); + +enum HoldingConn<'a> { + FromGuard(MutexGuard<'a, Connection>), + FromOwned(Connection, Arc<()>), +} + +impl<'a> Deref for HoldingConn<'a> { + type Target = Connection; + + fn deref(&self) -> &Self::Target { + match self { + HoldingConn::FromGuard(guard) => guard.deref(), + HoldingConn::FromOwned(conn, _) => conn, + } + } +} + +impl Pool { + fn new>(path: P, num_readers: usize, total_cache_size_mb: f64) -> Result { + // calculates cache-size per permanent connection + // 1. convert MB to KiB + // 2. divide by permanent connections + // 3. round down to nearest integer + let cache_size: u32 = ((total_cache_size_mb * 1024.0) / (num_readers + 1) as f64) as u32; + + let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); + + let mut readers = Vec::new(); + + for _ in 0..num_readers { + readers.push(Mutex::new(Self::prepare_conn(&path, Some(cache_size))?)) + } + + Ok(Self { + writer, + readers, + spill_tracker: Arc::new(()), + path: path.as_ref().to_path_buf(), + }) + } + + fn prepare_conn>(path: P, cache_size: Option) -> Result { + let conn = Connection::open(path)?; + + conn.pragma_update(Some(Main), "journal_mode", &"WAL".to_owned())?; + + // conn.pragma_update(Some(Main), "wal_autocheckpoint", &250)?; + + // conn.pragma_update(Some(Main), "wal_checkpoint", &"FULL".to_owned())?; + + conn.pragma_update(Some(Main), "synchronous", &"OFF".to_owned())?; + + if let Some(cache_kib) = cache_size { + conn.pragma_update(Some(Main), "cache_size", &(-Into::::into(cache_kib)))?; + } + + Ok(conn) + } + + fn write_lock(&self) -> MutexGuard<'_, Connection> { + self.writer.lock() + } + + fn read_lock(&self) -> HoldingConn<'_> { + for r in &self.readers { + if let Some(reader) = r.try_lock() { + return HoldingConn::FromGuard(reader); + } + } + + let spill_arc = self.spill_tracker.clone(); + let now_count = Arc::strong_count(&spill_arc) - 1 /* because one is held by the pool */; + + log::warn!("read_lock: all readers locked, creating spillover reader..."); + + if now_count > 1 { + log::warn!("read_lock: now {} spillover readers exist", now_count); + } + + let spilled = Self::prepare_conn(&self.path, None).unwrap(); + + return HoldingConn::FromOwned(spilled, spill_arc); + } +} + +pub struct Engine { + pool: Pool, +} + +impl DatabaseEngine for Engine { + fn open(config: &Config) -> Result> { + let pool = Pool::new( + Path::new(&config.database_path).join("conduit.db"), + config.sqlite_read_pool_size, + config.db_cache_capacity_mb, + )?; + + pool.write_lock() + .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; + + let arc = Arc::new(Engine { pool }); + + Ok(arc) + } + + fn open_tree(self: &Arc, name: &str) -> Result> { + self.pool.write_lock().execute(format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name).as_str(), [])?; + + Ok(Arc::new(SqliteTable { + engine: Arc::clone(self), + name: name.to_owned(), + watchers: RwLock::new(BTreeMap::new()), + })) + } + + fn flush(self: &Arc) -> Result<()> { + self.pool + .write_lock() + .execute_batch( + " + PRAGMA synchronous=FULL; + BEGIN; + DELETE FROM _noop; + INSERT INTO _noop VALUES (1); + COMMIT; + PRAGMA synchronous=OFF; + ", + ) + .map_err(Into::into) + } +} + +impl Engine { + pub fn flush_wal(self: &Arc) -> Result<()> { + self.pool + .write_lock() + .execute_batch( + " + PRAGMA synchronous=FULL; PRAGMA wal_checkpoint=TRUNCATE; + BEGIN; + DELETE FROM _noop; + INSERT INTO _noop VALUES (1); + COMMIT; + PRAGMA wal_checkpoint=PASSIVE; PRAGMA synchronous=OFF; + ", + ) + .map_err(Into::into) + } +} + +pub struct SqliteTable { + engine: Arc, + name: String, + watchers: RwLock, Vec>>>, +} + +type TupleOfBytes = (Vec, Vec); + +impl SqliteTable { + fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { + Ok(guard + .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? + .query_row([key], |row| row.get(0)) + .optional()?) + } + + fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { + guard.execute( + format!( + "INSERT INTO {} (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + self.name + ) + .as_str(), + [key, value], + )?; + Ok(()) + } + + fn _iter_from_thread(&self, f: F) -> Box + Send> + where + F: (for<'a> FnOnce(&'a Connection, ChannelSender)) + Send + 'static, + { + let (s, r) = bounded::(5); + + let engine = self.engine.clone(); + + thread::spawn(move || { + let _ = f(&engine.pool.read_lock(), s); + }); + + Box::new(r.into_iter()) + } +} + +macro_rules! iter_from_thread { + ($self:expr, $sql:expr, $param:expr) => { + $self._iter_from_thread(move |guard, s| { + let _ = guard + .prepare($sql) + .unwrap() + .query_map($param, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()) + .try_for_each(|bob| s.send(bob)); + }) + }; +} + +impl Tree for SqliteTable { + fn get(&self, key: &[u8]) -> Result>> { + let guard = self.engine.pool.read_lock(); + + // let start = Instant::now(); + + let val = self.get_with_guard(&guard, key); + + // debug!("get: took {:?}", start.elapsed()); + // debug!("get key: {:?}", &key) + + val + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let guard = self.engine.pool.write_lock(); + + let start = Instant::now(); + + self.insert_with_guard(&guard, key, value)?; + + let elapsed = start.elapsed(); + if elapsed > MILLI { + debug!("insert: took {:012?} : {}", elapsed, &self.name); + } + + drop(guard); + + let watchers = self.watchers.read(); + let mut triggered = Vec::new(); + + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } + + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write(); + for prefix in triggered { + if let Some(txs) = watchers.remove(prefix) { + for tx in txs { + let _ = tx.send(()); + } + } + } + }; + + Ok(()) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + let guard = self.engine.pool.write_lock(); + + let start = Instant::now(); + + guard.execute( + format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), + [key], + )?; + + let elapsed = start.elapsed(); + + if elapsed > MILLI { + debug!("remove: took {:012?} : {}", elapsed, &self.name); + } + // debug!("remove key: {:?}", &key); + + Ok(()) + } + + fn iter<'a>(&'a self) -> Box + Send + 'a> { + let name = self.name.clone(); + iter_from_thread!( + self, + format!("SELECT key, value FROM {}", name).as_str(), + params![] + ) + } + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box + Send + 'a> { + let name = self.name.clone(); + let from = from.to_vec(); // TODO change interface? + if backwards { + iter_from_thread!( + self, + format!( + "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", + name + ) + .as_str(), + [from] + ) + } else { + iter_from_thread!( + self, + format!( + "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + name + ) + .as_str(), + [from] + ) + } + } + + fn increment(&self, key: &[u8]) -> Result> { + let guard = self.engine.pool.write_lock(); + + let start = Instant::now(); + + let old = self.get_with_guard(&guard, key)?; + + let new = + crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); + + self.insert_with_guard(&guard, key, &new)?; + + let elapsed = start.elapsed(); + + if elapsed > MILLI { + debug!("increment: took {:012?} : {}", elapsed, &self.name); + } + // debug!("increment key: {:?}", &key); + + Ok(new) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box + Send + 'a> { + // let name = self.name.clone(); + // iter_from_thread!( + // self, + // format!( + // "SELECT key, value FROM {} WHERE key BETWEEN ?1 AND ?1 || X'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF' ORDER BY key ASC", + // name + // ) + // .as_str(), + // [prefix] + // ) + Box::new( + self.iter_from(&prefix, false) + .take_while(move |(key, _)| key.starts_with(&prefix)), + ) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.watchers + .write() + .entry(prefix.to_vec()) + .or_default() + .push(tx); + + Box::pin(async move { + // Tx is never destroyed + rx.await.unwrap(); + }) + } + + fn clear(&self) -> Result<()> { + debug!("clear: running"); + self.engine + .pool + .write_lock() + .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; + debug!("clear: ran"); + Ok(()) + } +} + +// TODO +// struct Pool { +// writer: Mutex, +// readers: [Mutex; NUM_READERS], +// } + +// // then, to pick a reader: +// for r in &pool.readers { +// if let Ok(reader) = r.try_lock() { +// // use reader +// } +// } +// // none unlocked, pick the next reader +// pool.readers[pool.counter.fetch_add(1, Relaxed) % NUM_READERS].lock() diff --git a/src/database/account_data.rs b/src/database/account_data.rs index 2ba7bc3d..b1d5b6b5 100644 --- a/src/database/account_data.rs +++ b/src/database/account_data.rs @@ -127,7 +127,7 @@ impl AccountData { room_id: Option<&RoomId>, user_id: &UserId, kind: &EventType, - ) -> Result, Box<[u8]>)>> { + ) -> Result, Vec)>> { let mut prefix = room_id .map(|r| r.to_string()) .unwrap_or_default() diff --git a/src/database/admin.rs b/src/database/admin.rs index 7826cfea..cd5fa847 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -10,6 +10,7 @@ use ruma::{ events::{room::message, EventType}, UserId, }; +use tokio::sync::{RwLock, RwLockReadGuard}; pub enum AdminCommand { RegisterAppservice(serde_yaml::Value), @@ -25,20 +26,23 @@ pub struct Admin { impl Admin { pub fn start_handler( &self, - db: Arc, + db: Arc>, mut receiver: mpsc::UnboundedReceiver, ) { tokio::spawn(async move { // TODO: Use futures when we have long admin commands //let mut futures = FuturesUnordered::new(); - let conduit_user = UserId::try_from(format!("@conduit:{}", db.globals.server_name())) - .expect("@conduit:server_name is valid"); + let guard = db.read().await; - let conduit_room = db + let conduit_user = + UserId::try_from(format!("@conduit:{}", guard.globals.server_name())) + .expect("@conduit:server_name is valid"); + + let conduit_room = guard .rooms .id_from_alias( - &format!("#admins:{}", db.globals.server_name()) + &format!("#admins:{}", guard.globals.server_name()) .try_into() .expect("#admins:server_name is a valid room alias"), ) @@ -48,48 +52,54 @@ impl Admin { warn!("Conduit instance does not have an #admins room. Logging to that room will not work. Restart Conduit after creating a user to fix this."); } - let send_message = |message: message::MessageEventContent| { - if let Some(conduit_room) = &conduit_room { - db.rooms - .build_and_append_pdu( - PduBuilder { - event_type: EventType::RoomMessage, - content: serde_json::to_value(message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &db, - ) - .unwrap(); - } - }; + drop(guard); + + let send_message = + |message: message::MessageEventContent, guard: RwLockReadGuard<'_, Database>| { + if let Some(conduit_room) = &conduit_room { + guard + .rooms + .build_and_append_pdu( + PduBuilder { + event_type: EventType::RoomMessage, + content: serde_json::to_value(message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &guard, + ) + .unwrap(); + } + }; loop { tokio::select! { Some(event) = receiver.next() => { + let guard = db.read().await; + match event { AdminCommand::RegisterAppservice(yaml) => { - db.appservice.register_appservice(yaml).unwrap(); // TODO handle error + guard.appservice.register_appservice(yaml).unwrap(); // TODO handle error } AdminCommand::ListAppservices => { - if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::>()) { + if let Ok(appservices) = guard.appservice.iter_ids().map(|ids| ids.collect::>()) { let count = appservices.len(); let output = format!( "Appservices ({}): {}", count, appservices.into_iter().filter_map(|r| r.ok()).collect::>().join(", ") ); - send_message(message::MessageEventContent::text_plain(output)); + send_message(message::MessageEventContent::text_plain(output), guard); } else { - send_message(message::MessageEventContent::text_plain("Failed to get appservices.")); + send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard); } } AdminCommand::SendMessage(message) => { - send_message(message); + send_message(message, guard) } } } diff --git a/src/database/appservice.rs b/src/database/appservice.rs index 4bf3a218..f39520c7 100644 --- a/src/database/appservice.rs +++ b/src/database/appservice.rs @@ -49,7 +49,7 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + Send + Sync + '_> { + pub fn iter_ids(&self) -> Result> + Send + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) @@ -58,7 +58,7 @@ impl Appservice { pub fn iter_all( &self, - ) -> Result> + '_ + Send + Sync> { + ) -> Result> + '_ + Send> { Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { Ok(( id.clone(), diff --git a/src/database/globals.rs b/src/database/globals.rs index eef478a1..4242cf5c 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -11,11 +11,12 @@ use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ collections::{BTreeMap, HashMap}, fs, + future::Future, path::PathBuf, sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::Semaphore; +use tokio::sync::{broadcast, Semaphore}; use trust_dns_resolver::TokioAsyncResolver; use super::abstraction::Tree; @@ -47,6 +48,7 @@ pub struct Globals { ), // since, rx >, >, + pub rotate: RotationHandler, } struct MatrixServerVerifier { @@ -82,6 +84,31 @@ impl ServerCertVerifier for MatrixServerVerifier { } } +/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. +/// +/// This is utilized to have sync workers return early and release read locks on the database. +pub struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>); + +impl RotationHandler { + pub fn new() -> Self { + let (s, r) = broadcast::channel::<()>(1); + + Self(s, r) + } + + pub fn watch(&self) -> impl Future { + let mut r = self.0.subscribe(); + + async move { + let _ = r.recv().await; + } + } + + pub fn fire(&self) { + let _ = self.0.send(()); + } +} + impl Globals { pub fn load( globals: Arc, @@ -168,6 +195,7 @@ impl Globals { bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), sync_receivers: RwLock::new(BTreeMap::new()), + rotate: RotationHandler::new(), }; fs::create_dir_all(s.get_media_folder())?; diff --git a/src/database/pusher.rs b/src/database/pusher.rs index a27bf2ce..3210cb18 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -73,7 +73,7 @@ impl PushData { pub fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a { + ) -> impl Iterator> + 'a { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index e23b8046..7b64c462 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1078,13 +1078,13 @@ impl Rooms { .scan_prefix(old_shortstatehash.clone()) // Chop the old_shortstatehash out leaving behind the short state key .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::, Box<[u8]>>>() + .collect::, Vec>>() } else { HashMap::new() }; if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap, Box<[u8]>> = old_state; + let mut new_state: HashMap, Vec> = old_state; let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); new_state_key.push(0xff); @@ -1450,7 +1450,7 @@ impl Rooms { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { self.pdus_since(user_id, room_id, 0) } @@ -1462,7 +1462,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1491,7 +1491,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1523,7 +1523,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/sending.rs b/src/database/sending.rs index ecf07618..7c9cf644 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -30,7 +30,10 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, }; -use tokio::{select, sync::Semaphore}; +use tokio::{ + select, + sync::{RwLock, Semaphore}, +}; use super::abstraction::Tree; @@ -90,7 +93,11 @@ enum TransactionStatus { } impl Sending { - pub fn start_handler(&self, db: Arc, mut receiver: mpsc::UnboundedReceiver>) { + pub fn start_handler( + &self, + db: Arc>, + mut receiver: mpsc::UnboundedReceiver>, + ) { tokio::spawn(async move { let mut futures = FuturesUnordered::new(); @@ -98,8 +105,12 @@ impl Sending { // Retry requests we could not finish yet let mut initial_transactions = HashMap::>::new(); + + let guard = db.read().await; + for (key, outgoing_kind, event) in - db.sending + guard + .sending .servercurrentevents .iter() .filter_map(|(key, _)| { @@ -117,17 +128,23 @@ impl Sending { "Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event ); - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); continue; } entry.push(event); } + drop(guard); + for (outgoing_kind, events) in initial_transactions { current_transaction_status .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events, &db)); + futures.push(Self::handle_events( + outgoing_kind.clone(), + events, + Arc::clone(&db), + )); } loop { @@ -135,15 +152,17 @@ impl Sending { Some(response) = futures.next() => { match response { Ok(outgoing_kind) => { + let guard = db.read().await; + let prefix = outgoing_kind.get_prefix(); - for (key, _) in db.sending.servercurrentevents + for (key, _) in guard.sending.servercurrentevents .scan_prefix(prefix.clone()) { - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); } // Find events that have been added since starting the last request - let new_events = db.sending.servernamepduids + let new_events = guard.sending.servernamepduids .scan_prefix(prefix.clone()) .map(|(k, _)| { SendingEventType::Pdu(k[prefix.len()..].to_vec()) @@ -161,17 +180,19 @@ impl Sending { SendingEventType::Pdu(b) | SendingEventType::Edu(b) => { current_key.extend_from_slice(&b); - db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); - db.sending.servernamepduids.remove(¤t_key).unwrap(); + guard.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); + guard.sending.servernamepduids.remove(¤t_key).unwrap(); } } } + drop(guard); + futures.push( Self::handle_events( outgoing_kind.clone(), new_events, - &db, + Arc::clone(&db), ) ); } else { @@ -192,13 +213,15 @@ impl Sending { }, Some(key) = receiver.next() => { if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { + let guard = db.read().await; + if let Ok(Some(events)) = Self::select_events( &outgoing_kind, vec![(event, key)], &mut current_transaction_status, - &db + &guard ) { - futures.push(Self::handle_events(outgoing_kind, events, &db)); + futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db))); } } } @@ -357,7 +380,7 @@ impl Sending { } #[tracing::instrument(skip(self))] - pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> { + pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Vec) -> Result<()> { let mut key = b"$".to_vec(); key.extend_from_slice(&senderkey); key.push(0xff); @@ -403,8 +426,10 @@ impl Sending { async fn handle_events( kind: OutgoingKind, events: Vec, - db: &Database, + db: Arc>, ) -> std::result::Result { + let db = db.read().await; + match &kind { OutgoingKind::Appservice(server) => { let mut pdu_jsons = Vec::new(); @@ -543,7 +568,7 @@ impl Sending { &pusher, rules_for_user, &pdu, - db, + &db, ) .await .map(|_response| kind.clone()) diff --git a/src/error.rs b/src/error.rs index 1017fb16..f62bdee0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -36,6 +36,12 @@ pub enum Error { #[from] source: rocksdb::Error, }, + #[cfg(feature = "sqlite")] + #[error("There was a problem with the connection to the sqlite database: {source}")] + SqliteError { + #[from] + source: rusqlite::Error, + }, #[error("Could not generate an image.")] ImageError { #[from] diff --git a/src/main.rs b/src/main.rs index 99d45607..e0d2e3df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,10 +30,11 @@ use rocket::{ }, routes, Request, }; +use tokio::sync::RwLock; use tracing::span; use tracing_subscriber::{prelude::*, Registry}; -fn setup_rocket(config: Figment, data: Arc) -> rocket::Rocket { +fn setup_rocket(config: Figment, data: Arc>) -> rocket::Rocket { rocket::custom(config) .manage(data) .mount( @@ -193,13 +194,14 @@ async fn main() { ) .merge(Env::prefixed("CONDUIT_").global()); + std::env::set_var("RUST_LOG", "warn"); + let config = raw_config .extract::() .expect("It looks like your config is invalid. Please take a look at the error"); - let db = Database::load_or_create(config.clone()) - .await - .expect("config is valid"); + let mut _span: Option = None; + let mut _enter: Option> = None; if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() @@ -209,18 +211,21 @@ async fn main() { let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); Registry::default().with(telemetry).try_init().unwrap(); - let root = span!(tracing::Level::INFO, "app_start", work_units = 2); - let _enter = root.enter(); - - let rocket = setup_rocket(raw_config, db); - rocket.launch().await.unwrap(); + _span = Some(span!(tracing::Level::INFO, "app_start", work_units = 2)); + _enter = Some(_span.as_ref().unwrap().enter()); } else { - std::env::set_var("RUST_LOG", config.log); + std::env::set_var("RUST_LOG", &config.log); tracing_subscriber::fmt::init(); - - let rocket = setup_rocket(raw_config, db); - rocket.launch().await.unwrap(); } + + config.warn_deprecated(); + + let db = Database::load_or_create(config) + .await + .expect("config is valid"); + + let rocket = setup_rocket(raw_config, db); + rocket.launch().await.unwrap(); } #[catch(404)] diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 8c22f79b..347406da 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,4 +1,4 @@ -use crate::Error; +use crate::{database::DatabaseGuard, Error}; use ruma::{ api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, @@ -9,7 +9,7 @@ use std::ops::Deref; #[cfg(feature = "conduit_bin")] use { - crate::{server_server, Database}, + crate::server_server, log::{debug, warn}, rocket::{ data::{self, ByteUnit, Data, FromData}, @@ -17,13 +17,12 @@ use { outcome::Outcome::*, response::{self, Responder}, tokio::io::AsyncReadExt, - Request, State, + Request, }, ruma::api::{AuthScheme, IncomingRequest}, std::collections::BTreeMap, std::convert::TryFrom, std::io::Cursor, - std::sync::Arc, }; /// This struct converts rocket requests into ruma structs by converting them into http requests @@ -49,7 +48,7 @@ where async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome { let metadata = T::Incoming::METADATA; let db = request - .guard::>>() + .guard::() .await .expect("database was loaded"); diff --git a/src/server_server.rs b/src/server_server.rs index d00e3d67..25cdd99e 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,12 +1,13 @@ use crate::{ client_server::{self, claim_keys_helper, get_keys_helper}, + database::DatabaseGuard, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, }; use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION, HOST}; use log::{debug, error, info, trace, warn}; use regex::Regex; -use rocket::{response::content::Json, State}; +use rocket::response::content::Json; use ruma::{ api::{ client::error::{Error as RumaError, ErrorKind}, @@ -432,7 +433,7 @@ pub async fn request_well_known( #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[tracing::instrument(skip(db))] pub fn get_server_version_route( - db: State<'_, Arc>, + db: DatabaseGuard, ) -> ConduitResult { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); @@ -450,7 +451,7 @@ pub fn get_server_version_route( // Response type for this endpoint is Json because we need to calculate a signature for the response #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_route(db: DatabaseGuard) -> Json { if !db.globals.allow_federation() { // TODO: Use proper types return Json("Federation is disabled.".to_owned()); @@ -497,7 +498,7 @@ pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_deprecated_route(db: DatabaseGuard) -> Json { get_server_keys_route(db) } @@ -507,7 +508,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -551,7 +552,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -595,7 +596,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_transaction_message_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -774,6 +775,8 @@ pub async fn send_transaction_message_route( } } + db.flush().await?; + Ok(send_transaction_message::v1::Response { pdus: resolved_map }.into()) } @@ -1673,7 +1676,7 @@ pub(crate) fn append_incoming_pdu( )] #[tracing::instrument(skip(db, body))] pub fn get_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1698,7 +1701,7 @@ pub fn get_event_route( )] #[tracing::instrument(skip(db, body))] pub fn get_missing_events_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1747,7 +1750,7 @@ pub fn get_missing_events_route( )] #[tracing::instrument(skip(db, body))] pub fn get_event_authorization_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1791,7 +1794,7 @@ pub fn get_event_authorization_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1854,7 +1857,7 @@ pub fn get_room_state_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_ids_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1906,7 +1909,7 @@ pub fn get_room_state_ids_route( )] #[tracing::instrument(skip(db, body))] pub fn create_join_event_template_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2075,7 +2078,7 @@ pub fn create_join_event_template_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_join_event_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2160,6 +2163,8 @@ pub async fn create_join_event_route( db.sending.send_pdu(&server, &pdu_id)?; } + db.flush().await?; + Ok(create_join_event::v2::Response { room_state: RoomState { auth_chain: auth_chain_ids @@ -2183,7 +2188,7 @@ pub async fn create_join_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_invite_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2276,6 +2281,8 @@ pub async fn create_invite_route( )?; } + db.flush().await?; + Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), } @@ -2288,7 +2295,7 @@ pub async fn create_invite_route( )] #[tracing::instrument(skip(db, body))] pub fn get_devices_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2328,7 +2335,7 @@ pub fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_information_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2356,7 +2363,7 @@ pub fn get_room_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_profile_information_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2389,8 +2396,8 @@ pub fn get_profile_information_route( post("/_matrix/federation/v1/user/keys/query", data = "") )] #[tracing::instrument(skip(db, body))] -pub fn get_keys_route( - db: State<'_, Arc>, +pub async fn get_keys_route( + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2404,6 +2411,8 @@ pub fn get_keys_route( &db, )?; + db.flush().await?; + Ok(get_keys::v1::Response { device_keys: result.device_keys, master_keys: result.master_keys, @@ -2418,7 +2427,7 @@ pub fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() {