Hot-Reloading Refactor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-05-09 15:59:08 -07:00 committed by June 🍓🦴
parent ae1a4fd283
commit 6c1434c165
212 changed files with 5679 additions and 4206 deletions

View File

@ -1,115 +1,95 @@
[package]
# TODO: when can we rename to conduwuit?
name = "conduit"
#cargo-features = ["profile-rustflags"]
[workspace]
resolver = "2"
members = ["src/*"]
default-members = ["src/*"]
[workspace.package]
description = "a very cool fork of Conduit, a Matrix homeserver written in Rust"
license = "Apache-2.0"
authors = [
"strawberry <strawberry@puppygock.gay>",
"timokoesters <timo@koesters.xyz>",
]
version = "0.3.4"
edition = "2021"
# See also `rust-toolchain.toml`
rust-version = "1.77.0"
homepage = "https://conduwuit.puppyirl.gay/"
repository = "https://github.com/girlbossceo/conduwuit"
readme = "README.md"
version = "0.3.4"
edition = "2021"
# See also `rust-toolchain.toml`
rust-version = "1.77.0"
[dependencies]
# 1.1.17 seems broken on nix from a permission error?
libz-sys = "=1.1.16"
[workspace.dependencies.libz-sys]
version = "=1.1.16"
console-subscriber = { version = "0.2", optional = true }
[workspace.dependencies.sanitize-filename]
version = "0.5.0"
infer = { version = "0.15", default-features = false }
[workspace.dependencies.infer]
version = "0.15"
default-features = false
# for hot lib reload
hot-lib-reloader = { version = "^0.7", optional = true }
[workspace.dependencies.jsonwebtoken]
version = "9.3.0"
# Used for secure identifiers
rand = "0.8.5"
# Used for conduit::Error type
thiserror = "1.0.60"
# Used to encode server public key
base64 = "0.22.1"
# Used when hashing the state
ring = "0.17.8"
# Used to find matching events for appservices
regex = "1.10.4"
# Used to load forbidden room/user regex from config
serde_regex = "1.1.0"
# Used to make working with iterators easier, was already a transitive depdendency
itertools = "0.12.1"
# jwt jsonwebtokens
jsonwebtoken = "9.3.0"
# Used for ruma wrapper
serde_html_form = "0.2.6"
[workspace.dependencies.base64]
version = "0.22.1"
# used for TURN server authentication
hmac = "0.12.1"
sha-1 = "0.10.1"
[workspace.dependencies.hmac]
version = "0.12.1"
[workspace.dependencies.sha-1]
version = "0.10.1"
# used for checking if an IP is in specific subnets / CIDR ranges easier
ipaddress = "0.1.3"
[workspace.dependencies.ipaddress]
version = "0.1.3"
# to get the client IP address of requests
#axum-client-ip = "0.4.2"
[workspace.dependencies.rand]
version = "0.8.5"
# to parse user-friendly time durations in admin commands
cyborgtime = "2.1.1"
# all the web/HTTP dependencies
# Used for the http request / response body type for Ruma endpoints used with reqwest
bytes = "1.6.0"
http = "1.1.0"
http-body-util = "0.1.1"
[workspace.dependencies.bytes]
version = "1.6.0"
# used to replace the channels of the tokio runtime
loole = "0.3.0"
[workspace.dependencies.http-body-util]
version = "0.1.1"
# Validating urls in config, was already a transitive dependency
url = { version = "2.5.0", features = ["serde"] }
[workspace.dependencies.http]
version = "1.1.0"
async-trait = "0.1.80"
[workspace.dependencies.regex]
version = "1.10.4"
lru-cache = "0.1.2"
sanitize-filename = "0.5.0"
# standard date and time tools
[dependencies.chrono]
version = "0.4.38"
features = ["alloc"]
default-features = false
# Web framework
[dependencies.axum]
[workspace.dependencies.axum]
version = "0.7.5"
default-features = false
features = ["form", "http1", "http2", "json", "matched-path"]
features = [
"form",
"http1",
"http2",
"json",
"matched-path",
"tokio",
]
[dependencies.axum-extra]
[workspace.dependencies.axum-extra]
version = "0.9.3"
default-features = false
features = ["typed-header"]
[dependencies.axum-server]
[workspace.dependencies.axum-server]
version = "0.6.0"
features = ["tls-rustls"]
[dependencies.tower]
[workspace.dependencies.tower]
version = "0.4.13"
features = ["util"]
[dependencies.tower-http]
[workspace.dependencies.tower-http]
version = "0.5.2"
features = [
"add-extension",
@ -121,153 +101,170 @@ features = [
"catch-panic",
]
[dependencies.hyper]
version = "1.3.1"
features = ["server", "http1", "http2"]
[dependencies.hyper-util]
version = "0.1.3"
[dependencies.reqwest]
[workspace.dependencies.reqwest]
version = "0.12.4"
default-features = false
features = ["rustls-tls-native-roots", "socks", "hickory-dns"]
features = [
"rustls-tls-native-roots",
"socks",
"hickory-dns",
]
# all the serde stuff
# Used for pdu definition
[dependencies.serde]
[workspace.dependencies.serde]
version = "1.0.201"
features = ["rc"]
# Used for appservice registration files
[dependencies.serde_yaml]
version = "0.9.34"
# Used for ruma wrapper
[dependencies.serde_json]
[workspace.dependencies.serde_json]
version = "1.0.117"
features = ["raw_value"]
# Used for appservice registration files
[workspace.dependencies.serde_yaml]
version = "0.9.34"
# Used to load forbidden room/user regex from config
[workspace.dependencies.serde_regex]
version = "1.1.0"
# Used for ruma wrapper
[workspace.dependencies.serde_html_form]
version = "0.2.6"
# Used for password hashing
[dependencies.argon2]
[workspace.dependencies.argon2]
version = "0.5.3"
features = ["alloc", "rand"]
default-features = false
# Used to generate thumbnails for images
[dependencies.image]
[workspace.dependencies.image]
version = "0.25.1"
default-features = false
features = ["jpeg", "png", "gif", "webp"]
features = [
"jpeg",
"png",
"gif",
"webp",
]
# logging
[dependencies.log]
[workspace.dependencies.log]
version = "0.4.21"
default-features = false
[dependencies.tracing]
[workspace.dependencies.tracing]
version = "0.1.40"
default-features = false
[dependencies.tracing-subscriber]
[workspace.dependencies.tracing-subscriber]
version = "0.3.18"
features = ["env-filter"]
# optional SHA256 media keys feature
[dependencies.sha2]
version = "0.10.8"
optional = true
# optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring
[dependencies.opentelemetry]
version = "0.21.0"
optional = true
[dependencies.tracing-flame]
version = "0.2.0"
optional = true
[dependencies.tracing-opentelemetry]
version = "0.22.0"
optional = true
[dependencies.opentelemetry_sdk]
version = "0.21.2"
optional = true
features = ["rt-tokio"]
[dependencies.opentelemetry-jaeger]
version = "0.20.0"
optional = true
features = ["rt-tokio"]
# optional sentry metrics for crash/panic reporting
[dependencies.sentry]
version = "0.32.3"
optional = true
default-features = false
features = [
"backtrace",
"contexts",
"debug-images",
"panic",
"rustls",
"tower",
"tower-http",
"tracing",
"reqwest",
"log",
]
[dependencies.sentry-tracing]
version = "0.32.3"
optional = true
[dependencies.sentry-tower]
version = "0.32.3"
optional = true
# optional jemalloc usage
[dependencies.tikv-jemalloc-sys]
version = "0.5.4"
optional = true
default-features = false
features = ["stats", "unprefixed_malloc_on_supported_platforms"]
[dependencies.tikv-jemallocator]
version = "0.5.4"
optional = true
default-features = false
features = ["stats", "unprefixed_malloc_on_supported_platforms"]
[dependencies.tikv-jemalloc-ctl]
version = "0.5.4"
optional = true
default-features = false
features = ["use_std"]
# for URL previews
[dependencies.webpage]
[workspace.dependencies.webpage]
version = "2.0.1"
default-features = false
# to support multiple variations of setting a config option
[dependencies.either]
version = "1.11.0"
features = ["serde"]
# to listen on both HTTP and HTTPS if listening on TLS dierctly from conduwuit for complement or sytest
[dependencies.axum-server-dual-protocol]
version = "0.6"
optional = true
# used for conduit's CLI and admin room command parsing
[dependencies.clap]
[workspace.dependencies.clap]
version = "4.5.4"
default-features = false
features = ["std", "derive", "help", "usage", "error-context", "string"]
features = [
"std",
"derive",
"help",
"usage",
"error-context",
"string",
]
[dependencies.futures-util]
[workspace.dependencies.futures-util]
version = "0.3.30"
default-features = false
[workspace.dependencies.tokio]
version = "1.37.0"
features = [
"fs",
"net",
"macros",
"sync",
"signal",
"time",
"rt-multi-thread",
"io-util",
]
[workspace.dependencies.libloading]
version = "0.8.3"
# Validating urls in config, was already a transitive dependency
[workspace.dependencies.url]
version = "2.5.0"
features = ["serde"]
# standard date and time tools
[workspace.dependencies.chrono]
version = "0.4.38"
features = ["alloc"]
default-features = false
[workspace.dependencies.hyper]
version = "1.3.1"
features = [
"server",
"http1",
"http2",
]
[workspace.dependencies.hyper-util]
version = "0.1.3"
# to support multiple variations of setting a config option
[workspace.dependencies.either]
version = "1.11.0"
features = ["serde"]
# Used for reading the configuration from conduwuit.toml & environment variables
[dependencies.figment]
[workspace.dependencies.figment]
version = "0.10.18"
features = ["env", "toml"]
[workspace.dependencies.hickory-resolver]
version = "0.24.1"
default-features = false
# Used for conduit::Error type
[workspace.dependencies.thiserror]
version = "1.0.60"
# Used when hashing the state
[workspace.dependencies.ring]
version = "0.17.8"
# Used to make working with iterators easier, was already a transitive depdendency
[workspace.dependencies.itertools]
version = "0.12.1"
# to parse user-friendly time durations in admin commands
#TODO: overlaps chrono?
[workspace.dependencies.cyborgtime]
version = "2.1.1"
# used to replace the channels of the tokio runtime
[workspace.dependencies.loole]
version = "0.3.0"
[workspace.dependencies.async-trait]
version = "0.1.80"
[workspace.dependencies.lru-cache]
version = "0.1.2"
[workspace.dependencies.num_cpus]
version = "1.16.0"
# Used for matrix spec type definitions and helpers
[dependencies.ruma]
git = "https://github.com/girlbossceo/ruma"
[workspace.dependencies.ruma]
git = "https://github.com/girlbossceo/ruwuma"
branch = "conduwuit-changes"
features = [
"compat",
@ -292,60 +289,128 @@ features = [
"unstable-extensible-events",
]
[dependencies.ruma-identifiers-validation]
git = "https://github.com/girlbossceo/ruma"
[workspace.dependencies.ruma-identifiers-validation]
git = "https://github.com/girlbossceo/ruwuma"
branch = "conduwuit-changes"
[dependencies.hickory-resolver]
version = "0.24.1"
[workspace.dependencies.rust-rocksdb]
path = "deps/rust-rocksdb"
package = "rust-rocksdb-uwu"
features = [
"multi-threaded-cf",
"mt_static",
"snappy",
"lz4",
"zstd",
"zlib",
"bzip2",
]
[workspace.dependencies.zstd]
version = "0.13.1"
# to listen on both HTTP and HTTPS if listening on TLS dierctly from conduwuit for complement or sytest
[workspace.dependencies.axum-server-dual-protocol]
version = "0.6"
# optional SHA256 media keys feature
[workspace.dependencies.sha2]
version = "0.10.8"
# optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring
[workspace.dependencies.opentelemetry]
version = "0.21.0"
[workspace.dependencies.tracing-flame]
version = "0.2.0"
[workspace.dependencies.tracing-opentelemetry]
version = "0.22.0"
[workspace.dependencies.opentelemetry_sdk]
version = "0.21.2"
features = ["rt-tokio"]
[workspace.dependencies.opentelemetry-jaeger]
version = "0.20.0"
features = ["rt-tokio"]
# optional sentry metrics for crash/panic reporting
[workspace.dependencies.sentry]
version = "0.32.3"
default-features = false
features = [
"backtrace",
"contexts",
"debug-images",
"panic",
"rustls",
"tower",
"tower-http",
"tracing",
"reqwest",
"log",
]
[dependencies.rust-rocksdb]
git = "https://github.com/zaidoon1/rust-rocksdb"
branch = "master"
optional = true
default-features = true
features = ["multi-threaded-cf", "zstd"]
[workspace.dependencies.sentry-tracing]
version = "0.32.3"
[workspace.dependencies.sentry-tower]
version = "0.32.3"
[dependencies.rusqlite]
# optional jemalloc usage
[workspace.dependencies.tikv-jemalloc-sys]
version = "0.5.4"
default-features = false
features = ["unprefixed_malloc_on_supported_platforms"]
[workspace.dependencies.tikv-jemallocator]
version = "0.5.4"
default-features = false
features = ["unprefixed_malloc_on_supported_platforms"]
[workspace.dependencies.tikv-jemalloc-ctl]
version = "0.5.4"
default-features = false
features = ["use_std"]
[workspace.dependencies.rusqlite]
git = "https://github.com/rusqlite/rusqlite"
#branch = "master"
rev = "e00b626e2b1c67347d789fb7f600281705c89381"
optional = true
features = ["bundled"]
# used only by rusqlite
[dependencies.parking_lot]
[workspace.dependencies.parking_lot]
version = "0.12.2"
optional = true
# used only by rusqlite
[dependencies.thread_local]
[workspace.dependencies.thread_local]
version = "1.1.8"
optional = true
# used only by rusqlite and rust-rocksdb
[dependencies.num_cpus]
version = "1.16.0"
[workspace.dependencies.tokio-metrics]
version = "0.3.1"
default-features = false
[dependencies.tokio]
version = "1.37.0"
features = ["fs", "macros", "sync", "signal"]
[workspace.dependencies.console-subscriber]
version = "0.2"
# *nix-specific dependencies
[target.'cfg(unix)'.dependencies]
nix = { version = "0.28.0", features = ["resource"] }
sd-notify = { version = "0.4.1", optional = true } # systemd is only available/relevant on *nix platforms
[workspace.dependencies.nix]
version = "0.28.0"
features = ["resource"]
[workspace.dependencies.sd-notify]
version = "0.4.1"
[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies]
hardened_malloc-rs = { version = "0.1.2", optional = true, features = [
[workspace.dependencies.hardened_malloc-rs]
version = "0.1.2"
default-features = false
features = [
"static",
"gcc",
"light",
], default-features = false }
#hardened_malloc-rs = { optional = true, features = ["static","clang","light"], path = "../hardened_malloc-rs", default-features = false }
]
#
# Patches
#
# backport of [https://github.com/tokio-rs/tracing/pull/2956] to the 0.1.x branch of tracing.
# we can switch back to upstream if #2956 is merged and backported in the upstream repo.
@ -359,166 +424,213 @@ branch = "tracing-subscriber/env-filter-clone-0.1.x-backport"
git = "https://github.com/girlbossceo/tracing"
branch = "tracing-subscriber/env-filter-clone-0.1.x-backport"
[features]
default = [
"backend_rocksdb",
"systemd",
"element_hacks",
"sentry_telemetry",
"gzip_compression",
"brotli_compression",
"zstd_compression",
"release_max_log_level",
"io_uring",
]
backend_sqlite = ["sqlite"]
backend_rocksdb = ["rocksdb"]
rocksdb = ["dep:rust-rocksdb"]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
jemalloc_prof = ["tikv-jemalloc-sys/profiling"]
sqlite = ["dep:rusqlite", "dep:parking_lot", "dep:thread_local"]
systemd = ["dep:sd-notify"]
sentry_telemetry = ["dep:sentry", "dep:sentry-tracing", "dep:sentry-tower"]
gzip_compression = ["tower-http/compression-gzip", "reqwest/gzip"]
zstd_compression = ["tower-http/compression-zstd"]
brotli_compression = ["tower-http/compression-br", "reqwest/brotli"]
sha256_media = ["dep:sha2"]
io_uring = ["rust-rocksdb/io-uring"]
axum_dual_protocol = ["dep:axum-server-dual-protocol"]
perf_measurements = [
"dep:opentelemetry",
"dep:tracing-flame",
"dep:tracing-opentelemetry",
"dep:opentelemetry_sdk",
"dep:opentelemetry-jaeger",
]
# enable the tokio_console server
# incompatible with release_max_log_level
tokio_console = ["dep:console-subscriber", "tokio/tracing"]
hot_reload = ["dep:hot-lib-reloader"]
hardened_malloc = ["dep:hardened_malloc-rs"]
# increases performance, reduces build times, and reduces binary size by not compiling or
# genreating code for log level filters that users will generally not use (debug and trace) only in release builds
#
# the expense is obviously losing those log level filters for usage at runtime. debug builds will still have all log levels
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
# developer feature useful only in debug builds.
dev_release_log_level = []
# client/server interopability hacks
# Our crates
#
## element has various non-spec compliant behaviour
element_hacks = []
[workspace.dependencies.conduit-router]
package = "conduit_router"
path = "src/router"
default-features = false
[package.metadata.deb]
name = "conduwuit"
maintainer = "strawberry <strawberry@puppygock.gay>"
copyright = "2024, strawberry <strawberry@puppygock.gay>"
license-file = ["LICENSE", "3"]
depends = "$auto, ca-certificates"
extended-description = """\
a cool hard fork of Conduit, a Matrix homeserver written in Rust"""
section = "net"
priority = "optional"
assets = [
[
"debian/README.md",
"usr/share/doc/conduwuit/README.Debian",
"644",
],
[
"README.md",
"usr/share/doc/conduwuit/",
"644",
],
[
"target/release/conduwuit",
"usr/sbin/conduwuit",
"755",
],
[
"conduwuit-example.toml",
"etc/conduwuit/conduwuit.toml",
"640",
],
]
conf-files = ["/etc/conduwuit/conduwuit.toml"]
maintainer-scripts = "debian/"
systemd-units = { unit-name = "conduwuit", start = false }
[workspace.dependencies.conduit-admin]
package = "conduit_admin"
path = "src/admin"
default-features = false
[workspace.dependencies.conduit-api]
package = "conduit_api"
path = "src/api"
default-features = false
[profile.dev]
#debug = 0
lto = 'off'
codegen-units = 512
incremental = true
overflow-checks = true
#panic = "abort"
[workspace.dependencies.conduit-service]
package = "conduit_service"
path = "src/service"
default-features = false
# seems to speed up continuous debug compilations
[profile.dev.build-override]
opt-level = 3
[profile.dev.package."*"] # external dependencies
opt-level = 1
[profile.dev.package."tokio"]
opt-level = 3
[workspace.dependencies.conduit-database]
package = "conduit_database"
path = "src/database"
default-features = false
[workspace.dependencies.conduit-core]
package = "conduit_core"
path = "src/core"
default-features = false
###############################################################################
#
# Release profiles
#
# default release profile
[profile.release]
lto = 'thin'
incremental = false
opt-level = 3
strip = "symbols"
control-flow-guard = true # Windows only
debug = 0
lto = "thin"
# release profile with debug symbols
[profile.release-debuginfo]
inherits = "release"
strip = "none"
debug = "full"
strip = "none"
# high performance release profile which uses fat LTO across all crates, 1 codegen unit, max opt-level, and optimises across all crates
[profile.release-high-perf]
inherits = "release"
lto = 'fat'
lto = "fat"
codegen-units = 1
panic = "abort"
# For releases also try to max optimizations for dependencies:
[profile.release-high-perf.build-override]
debug = 0
opt-level = 3
# do not use without profile-rustflags enabled
[profile.release-max-perf]
inherits = "release"
strip = "symbols"
lto = "fat"
#rustflags = [
# '-Ctarget-cpu=native',
# '-Ztune-cpu=native',
# '-Ctarget-feature=+crt-static',
# '-Crelocation-model=static',
# '-Ztls-model=local-exec',
# '-Zinline-in-all-cgus=true',
# '-Zinline-mir=true',
# '-Zmir-opt-level=3',
# '-Clink-arg=-fuse-ld=mold',
# '-Clink-arg=-Wl,--threads',
# '-Clink-arg=-Wl,--gc-sections',
# '-Ztime-passes',
# '-Ztime-llvm-passes',
#]
[profile.release-max-perf.build-override]
inherits = "release-max-perf"
opt-level = 0
#rustflags = [
# '-Ctarget-feature=-crt-static',
#]
[profile.bench]
inherits = "release"
#rustflags = [
# "-Cremark=all",
# '-Ztime-passes',
# '-Ztime-llvm-passes',
#]
###############################################################################
#
# Developer profile
#
# To enable hot-reloading:
# 1. Uncomment all of the rustflags here.
# 2. Uncomment crate-type=dylib in src/*/Cargo.toml and deps/rust-rocksdb/Cargo.toml
# 2. Build with the 'mods' feature.
#
# opt-level, mir-opt-level, validate-mir are not known to interfere with reloading
# and can be raised if build times are tolerable.
[profile.dev]
debug = 1
opt-level = 0
panic = "unwind"
debug-assertions = true
incremental = true
codegen-units = 64
rpath = true
#rustflags = [
# '-Ztime-passes',
# '-Zmir-opt-level=0',
# '-Zvalidate-mir=false',
# '-Ztls-model=global-dynamic',
# '-Cprefer-dynamic=true',
# '-Zstaticlib-prefer-dynamic=true',
# '-Zstaticlib-allow-rdylib-deps=true',
# '-Zpacked-bundled-libs=false',
# '-Zplt=true',
# '-Crpath=true',
# '-Clink-arg=-Wl,--as-needed',
# '-Clink-arg=-Wl,--allow-shlib-undefined',
# '-Clink-arg=-Wl,-z,keep-text-section-prefix',
# '-Clink-arg=-Wl,-z,lazy',
#]
[profile.dev.package.conduit_core]
inherits = "dev"
incremental = false
#rustflags = [
# '-Ztime-passes',
# '-Zmir-opt-level=0',
# '-Ztls-model=initial-exec',
# '-Cprefer-dynamic=true',
# '-Zstaticlib-prefer-dynamic=true',
# '-Zstaticlib-allow-rdylib-deps=true',
# '-Zpacked-bundled-libs=false',
# '-Zplt=true',
# '-Clink-arg=-Wl,--as-needed',
# '-Clink-arg=-Wl,--allow-shlib-undefined',
# '-Clink-arg=-Wl,-z,lazy',
# '-Clink-arg=-Wl,-z,unique',
# '-Clink-arg=-Wl,-z,nodlopen',
# '-Clink-arg=-Wl,-z,nodelete',
#]
[profile.dev.package.conduit]
inherits = "dev"
incremental = false
#rustflags = [
# '-Ztime-passes',
# '-Zmir-opt-level=0',
# '-Zvalidate-mir=false',
# '-Ztls-model=global-dynamic',
# '-Cprefer-dynamic=true',
# '-Zexport-executable-symbols=true',
# '-Zplt=true',
# '-Crpath=true',
# '-Clink-arg=-Wl,--as-needed',
# '-Clink-arg=-Wl,--allow-shlib-undefined',
# '-Clink-arg=-Wl,--export-dynamic',
# '-Clink-arg=-Wl,-z,lazy',
#]
[profile.dev.package.rust-rocksdb-uwu]
inherits = "dev"
debug = 'limited'
incremental = false
codegen-units = 1
opt-level = 'z'
#rustflags = [
# '-Ztls-model=initial-exec',
# '-Cprefer-dynamic=true',
# '-Zstaticlib-prefer-dynamic=true',
# '-Zstaticlib-allow-rdylib-deps=true',
# '-Zpacked-bundled-libs=true',
# '-Zplt=true',
# '-Clink-arg=-Wl,--no-as-needed',
# '-Clink-arg=-Wl,--allow-shlib-undefined',
# '-Clink-arg=-Wl,-z,lazy',
# '-Clink-arg=-Wl,-z,nodlopen',
# '-Clink-arg=-Wl,-z,nodelete',
#]
[profile.release-high-perf.package."*"]
debug = 0
opt-level = 3
[profile.dev.package.'*']
inherits = "dev"
debug = 'limited'
incremental = false
codegen-units = 1
opt-level = 'z'
#rustflags = [
# '-Ztls-model=global-dynamic',
# '-Cprefer-dynamic=true',
# '-Zstaticlib-prefer-dynamic=true',
# '-Zstaticlib-allow-rdylib-deps=true',
# '-Zpacked-bundled-libs=true',
# '-Zplt=true',
# '-Clink-arg=-Wl,--as-needed',
# '-Clink-arg=-Wl,-z,lazy',
# '-Clink-arg=-Wl,-z,nodelete',
#]
[lints]
workspace = true
[profile.test]
incremental = false
[workspace.lints.rust]
missing_abi = "warn"
@ -543,7 +655,6 @@ unused_braces = "allow"
# some sadness
missing_docs = "allow"
[workspace.lints.clippy]
# pedantic = "warn"
@ -615,7 +726,6 @@ unnecessary_box_returns = "warn"
map_unwrap_or = "warn"
implicit_clone = "warn"
match_wildcard_for_single_variants = "warn"
unnecessary_wraps = "warn"
match_same_arms = "warn"
ignored_unit_patterns = "warn"
redundant_else = "warn"
@ -650,6 +760,7 @@ unwrap_used = "allow"
expect_used = "allow"
if_then_some_else_none = "allow"
let_underscore_must_use = "allow"
let_underscore_future = "allow"
map_err_ignore = "allow"
missing_docs_in_private_items = "allow"
multiple_inherent_impl = "allow"
@ -657,3 +768,4 @@ error_impl_error = "allow"
string_add = "allow"
string_slice = "allow"
ref_patterns = "allow"
unnecessary_wraps = "allow"

32
deps/rust-rocksdb/Cargo.toml vendored Normal file
View File

@ -0,0 +1,32 @@
[package]
name = "rust-rocksdb-uwu"
version = "0.0.1"
edition = "2021"
[features]
default = ["snappy", "lz4", "zstd", "zlib", "bzip2"]
jemalloc = ["rust-rocksdb/jemalloc"]
io-uring = ["rust-rocksdb/io-uring"]
valgrind = ["rust-rocksdb/valgrind"]
snappy = ["rust-rocksdb/snappy"]
lz4 = ["rust-rocksdb/lz4"]
zstd = ["rust-rocksdb/zstd"]
zlib = ["rust-rocksdb/zlib"]
bzip2 = ["rust-rocksdb/bzip2"]
rtti = ["rust-rocksdb/rtti"]
mt_static = ["rust-rocksdb/mt_static"]
multi-threaded-cf = ["rust-rocksdb/multi-threaded-cf"]
serde1 = ["rust-rocksdb/serde1"]
malloc-usable-size = ["rust-rocksdb/malloc-usable-size"]
[dependencies.rust-rocksdb]
git = "https://github.com/zaidoon1/rust-rocksdb"
branch = "master"
default-features = false
[lib]
path = "lib.rs"
crate-type = [
"rlib",
# "dylib"
]

59
deps/rust-rocksdb/lib.rs vendored Normal file
View File

@ -0,0 +1,59 @@
pub use rust_rocksdb::*;
#[link(name = "rocksdb", kind = "static")]
extern "C" {
pub fn rocksdb_list_column_families();
pub fn rocksdb_logger_create_stderr_logger();
pub fn rocksdb_options_set_info_log();
pub fn rocksdb_get_options_from_string();
pub fn rocksdb_writebatch_create();
pub fn rocksdb_writebatch_put_cf();
pub fn rocksdb_writebatch_delete_cf();
pub fn rocksdb_iter_value();
pub fn rocksdb_iter_seek_to_last();
pub fn rocksdb_iter_seek_for_prev();
pub fn rocksdb_iter_seek_to_first();
pub fn rocksdb_iter_next();
pub fn rocksdb_iter_prev();
pub fn rocksdb_iter_seek();
pub fn rocksdb_iter_valid();
pub fn rocksdb_iter_get_error();
pub fn rocksdb_iter_key();
pub fn rocksdb_iter_destroy();
pub fn rocksdb_livefiles();
pub fn rocksdb_livefiles_count();
pub fn rocksdb_livefiles_destroy();
pub fn rocksdb_livefiles_column_family_name();
pub fn rocksdb_livefiles_name();
pub fn rocksdb_livefiles_size();
pub fn rocksdb_livefiles_level();
pub fn rocksdb_livefiles_smallestkey();
pub fn rocksdb_livefiles_largestkey();
pub fn rocksdb_livefiles_entries();
pub fn rocksdb_livefiles_deletions();
pub fn rocksdb_put_cf();
pub fn rocksdb_delete_cf();
pub fn rocksdb_get_pinned_cf();
pub fn rocksdb_create_column_family();
pub fn rocksdb_get_latest_sequence_number();
pub fn rocksdb_batched_multi_get_cf();
pub fn rocksdb_cancel_all_background_work();
pub fn rocksdb_repair_db();
pub fn rocksdb_list_column_families_destroy();
pub fn rocksdb_flush();
pub fn rocksdb_flush_wal();
pub fn rocksdb_open_column_families();
pub fn rocksdb_open_for_read_only_column_families();
pub fn rocksdb_open_as_secondary_column_families();
pub fn rocksdb_open_column_families_with_ttl();
pub fn rocksdb_open();
pub fn rocksdb_open_for_read_only();
pub fn rocksdb_open_with_ttl();
pub fn rocksdb_open_as_secondary();
pub fn rocksdb_write();
pub fn rocksdb_create_iterator_cf();
pub fn rocksdb_backup_engine_create_new_backup_flush();
pub fn rocksdb_backup_engine_options_create();
pub fn rocksdb_write_buffer_manager_destroy();
pub fn rocksdb_options_set_ttl();
}

View File

@ -55,7 +55,7 @@ commonAttrs = {
include = [
"Cargo.lock"
"Cargo.toml"
"hot_lib"
"deps"
"src"
];
};

63
src/admin/Cargo.toml Normal file
View File

@ -0,0 +1,63 @@
[package]
name = "conduit_admin"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"zstd_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"rust-rocksdb/jemalloc",
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
[dependencies]
clap.workspace = true
conduit-api.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
conduit-service.workspace = true
futures-util.workspace = true
log.workspace = true
loole.workspace = true
regex.workspace = true
ruma.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_yaml.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
[lints]
workspace = true

View File

@ -1,6 +1,6 @@
use ruma::{api::appservice::Registration, events::room::message::RoomMessageEventContent};
use crate::{service::admin::escape_html, services, Result};
use crate::{escape_html, services, Result};
pub(crate) async fn register(body: Vec<&str>) -> Result<RoomMessageEventContent> {
if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" {

View File

@ -1,8 +1,8 @@
use clap::Subcommand;
use conduit::Result;
use ruma::events::room::message::RoomMessageEventContent;
use self::appservice_command::{list, register, show, unregister};
use crate::Result;
pub(crate) mod appservice_command;

View File

@ -1,18 +1,15 @@
use std::{collections::BTreeMap, sync::Arc, time::Instant};
use conduit::{utils::HtmlEscape, Error, Result};
use ruma::{
api::client::error::ErrorKind, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId,
RoomId, RoomVersionId, ServerName,
};
use service::{rooms::event_handler::parse_incoming_pdu, sending::send::resolve_actual_dest, services, PduEvent};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use tracing_subscriber::EnvFilter;
use crate::{
api::server_server::parse_incoming_pdu, service::sending::send::resolve_actual_dest, services, utils::HtmlEscape,
Error, PduEvent, Result,
};
pub(crate) async fn get_auth_chain(_body: Vec<&str>, event_id: Box<EventId>) -> Result<RoomMessageEventContent> {
let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? {
@ -332,7 +329,7 @@ pub(crate) async fn change_log_level(
};
match services()
.globals
.server
.tracing_reload_handle
.reload(&old_filter_layer)
{
@ -361,7 +358,7 @@ pub(crate) async fn change_log_level(
};
match services()
.globals
.server
.tracing_reload_handle
.reload(&new_filter_layer)
{
@ -447,15 +444,16 @@ pub(crate) async fn resolve_true_destination(
));
}
let (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, no_cache, true).await?;
let (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, no_cache).await?;
Ok(RoomMessageEventContent::text_plain(format!(
"Actual destination: {actual_dest:?} | Hostname URI: {hostname_uri}"
)))
}
#[must_use]
pub(crate) fn memory_stats() -> RoomMessageEventContent {
let html_body = crate::alloc::memory_stats();
let html_body = conduit::alloc::memory_stats();
if html_body.is_empty() {
return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc.");

View File

@ -2,12 +2,7 @@ use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId};
use crate::{
service::admin::{escape_html, get_room_info},
services,
utils::HtmlEscape,
Result,
};
use crate::{escape_html, get_room_info, services, utils::HtmlEscape, Result};
pub(crate) async fn disable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> {
services().rooms.metadata.disable_room(&room_id, true)?;

305
src/admin/handler.rs Normal file
View File

@ -0,0 +1,305 @@
use std::sync::Arc;
use clap::Parser;
use regex::Regex;
use ruma::{
events::{
relation::InReplyTo,
room::message::{Relation::Reply, RoomMessageEventContent},
TimelineEventType,
},
OwnedRoomId, OwnedUserId, ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::MutexGuard;
use tracing::error;
extern crate conduit_service as service;
use conduit::{Error, Result};
pub(crate) use service::admin::{AdminRoomEvent, Service};
use service::{admin::HandlerResult, pdu::PduBuilder};
use self::{fsck::FsckCommand, tester::TesterCommands};
use crate::{
appservice, appservice::AppserviceCommand, debug, debug::DebugCommand, escape_html, federation,
federation::FederationCommand, fsck, media, media::MediaCommand, query, query::QueryCommand, room,
room::RoomCommand, server, server::ServerCommand, services, tester, user, user::UserCommand,
};
pub(crate) const PAGE_SIZE: usize = 100;
#[cfg_attr(test, derive(Debug))]
#[derive(Parser)]
#[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))]
pub(crate) enum AdminCommand {
#[command(subcommand)]
/// - Commands for managing appservices
Appservices(AppserviceCommand),
#[command(subcommand)]
/// - Commands for managing local users
Users(UserCommand),
#[command(subcommand)]
/// - Commands for managing rooms
Rooms(RoomCommand),
#[command(subcommand)]
/// - Commands for managing federation
Federation(FederationCommand),
#[command(subcommand)]
/// - Commands for managing the server
Server(ServerCommand),
#[command(subcommand)]
/// - Commands for managing media
Media(MediaCommand),
#[command(subcommand)]
/// - Commands for debugging things
Debug(DebugCommand),
#[command(subcommand)]
/// - Query all the database getters and iterators
Query(QueryCommand),
#[command(subcommand)]
/// - Query all the database getters and iterators
Fsck(FsckCommand),
#[command(subcommand)]
Tester(TesterCommands),
}
#[must_use]
pub fn handle(event: AdminRoomEvent, room: OwnedRoomId, user: OwnedUserId) -> HandlerResult {
Box::pin(handle_event(event, room, user))
}
async fn handle_event(event: AdminRoomEvent, admin_room: OwnedRoomId, server_user: OwnedUserId) -> Result<()> {
let (mut message_content, reply) = match event {
AdminRoomEvent::SendMessage(content) => (content, None),
AdminRoomEvent::ProcessMessage(room_message, reply_id) => {
(process_admin_message(room_message).await, Some(reply_id))
},
};
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(admin_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
if let Some(reply) = reply {
message_content.relates_to = Some(Reply {
in_reply_to: InReplyTo {
event_id: reply.into(),
},
});
}
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&message_content).expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
};
if let Err(e) = services()
.rooms
.timeline
.build_and_append_pdu(response_pdu, &server_user, &admin_room, &state_lock)
.await
{
handle_response_error(&e, &admin_room, &server_user, &state_lock).await?;
}
Ok(())
}
async fn handle_response_error(
e: &Error, admin_room: &OwnedRoomId, server_user: &UserId, state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
error!("Failed to build and append admin room response PDU: \"{e}\"");
let error_room_message = RoomMessageEventContent::text_plain(format!(
"Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \
successfully, but we could not return the output."
));
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&error_room_message).expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
};
services()
.rooms
.timeline
.build_and_append_pdu(response_pdu, server_user, admin_room, state_lock)
.await?;
Ok(())
}
// Parse and process a message from the admin room
async fn process_admin_message(room_message: String) -> RoomMessageEventContent {
let mut lines = room_message.lines().filter(|l| !l.trim().is_empty());
let command_line = lines.next().expect("each string has at least one line");
let body = lines.collect::<Vec<_>>();
let admin_command = match parse_admin_command(command_line) {
Ok(command) => command,
Err(error) => {
let server_name = services().globals.server_name();
let message = error.replace("server.name", server_name.as_str());
let html_message = usage_to_html(&message, server_name);
return RoomMessageEventContent::text_html(message, html_message);
},
};
match process_admin_command(admin_command, body).await {
Ok(reply_message) => reply_message,
Err(error) => {
let markdown_message = format!("Encountered an error while handling the command:\n```\n{error}\n```",);
let html_message = format!("Encountered an error while handling the command:\n<pre>\n{error}\n</pre>",);
RoomMessageEventContent::text_html(markdown_message, html_message)
},
}
}
// Parse chat messages from the admin room into an AdminCommand object
fn parse_admin_command(command_line: &str) -> Result<AdminCommand, String> {
// Note: argv[0] is `@conduit:servername:`, which is treated as the main command
let mut argv = command_line.split_whitespace().collect::<Vec<_>>();
// Replace `help command` with `command --help`
// Clap has a help subcommand, but it omits the long help description.
if argv.len() > 1 && argv[1] == "help" {
argv.remove(1);
argv.push("--help");
}
// Backwards compatibility with `register_appservice`-style commands
let command_with_dashes_argv1;
if argv.len() > 1 && argv[1].contains('_') {
command_with_dashes_argv1 = argv[1].replace('_', "-");
argv[1] = &command_with_dashes_argv1;
}
// Backwards compatibility with `register_appservice`-style commands
let command_with_dashes_argv2;
if argv.len() > 2 && argv[2].contains('_') {
command_with_dashes_argv2 = argv[2].replace('_', "-");
argv[2] = &command_with_dashes_argv2;
}
// if the user is using the `query` command (argv[1]), replace the database
// function/table calls with underscores to match the codebase
let command_with_dashes_argv3;
if argv.len() > 3 && argv[1].eq("query") {
command_with_dashes_argv3 = argv[3].replace('_', "-");
argv[3] = &command_with_dashes_argv3;
}
AdminCommand::try_parse_from(argv).map_err(|error| error.to_string())
}
async fn process_admin_command(command: AdminCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
let reply_message_content = match command {
AdminCommand::Appservices(command) => appservice::process(command, body).await?,
AdminCommand::Media(command) => media::process(command, body).await?,
AdminCommand::Users(command) => user::process(command, body).await?,
AdminCommand::Rooms(command) => room::process(command, body).await?,
AdminCommand::Federation(command) => federation::process(command, body).await?,
AdminCommand::Server(command) => server::process(command, body).await?,
AdminCommand::Debug(command) => debug::process(command, body).await?,
AdminCommand::Query(command) => query::process(command, body).await?,
AdminCommand::Fsck(command) => fsck::process(command, body).await?,
AdminCommand::Tester(command) => tester::process(command, body).await?,
};
Ok(reply_message_content)
}
// Utility to turn clap's `--help` text to HTML.
fn usage_to_html(text: &str, server_name: &ServerName) -> String {
// Replace `@conduit:servername:-subcmdname` with `@conduit:servername:
// subcmdname`
let text = text.replace(&format!("@conduit:{server_name}:-"), &format!("@conduit:{server_name}: "));
// For the conduit admin room, subcommands become main commands
let text = text.replace("SUBCOMMAND", "COMMAND");
let text = text.replace("subcommand", "command");
// Escape option names (e.g. `<element-id>`) since they look like HTML tags
let text = escape_html(&text);
// Italicize the first line (command name and version text)
let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<em>$1</em>\n");
// Unmerge wrapped lines
let text = text.replace("\n ", " ");
// Wrap option names in backticks. The lines look like:
// -V, --version Prints version information
// And are converted to:
// <code>-V, --version</code>: Prints version information
// (?m) enables multi-line mode for ^ and $
let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$").expect("Regex compilation should not fail");
let text = re.replace_all(&text, "<code>$1</code>: $4");
// Look for a `[commandbody]` tag. If it exists, use all lines below it that
// start with a `#` in the USAGE section.
let mut text_lines = text.lines().collect::<Vec<&str>>();
let mut command_body = String::new();
if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") {
text_lines.remove(line_index);
while text_lines
.get(line_index)
.is_some_and(|line| line.starts_with('#'))
{
command_body += if text_lines[line_index].starts_with("# ") {
&text_lines[line_index][2..]
} else {
&text_lines[line_index][1..]
};
command_body += "[nobr]\n";
text_lines.remove(line_index);
}
}
let text = text_lines.join("\n");
// Improve the usage section
let text = if command_body.is_empty() {
// Wrap the usage line in code tags
let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$").expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<code>$1</code>").to_string()
} else {
// Wrap the usage line in a code block, and add a yaml block example
// This makes the usage of e.g. `register-appservice` more accurate
let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n").expect("Regex compilation should not fail");
re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>")
.replace("[commandbodyblock]", &command_body)
};
// Add HTML line-breaks
text.replace("\n\n\n", "\n\n")
.replace('\n', "<br>\n")
.replace("[nobr]<br>", "")
}

View File

@ -1,7 +1,7 @@
use ruma::{events::room::message::RoomMessageEventContent, EventId};
use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri};
use tracing::{debug, info};
use crate::{service::admin::MxcUri, services, Result};
use crate::{services, Result};
pub(crate) async fn delete(
_body: Vec<&str>, mxc: Option<Box<MxcUri>>, event_id: Option<Box<EventId>>,

View File

@ -1,8 +1,8 @@
use clap::Subcommand;
use ruma::{events::room::message::RoomMessageEventContent, EventId};
use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri};
use self::media_commands::{delete, delete_list, delete_past_remote_media};
use crate::{service::admin::MxcUri, Result};
use crate::Result;
pub(crate) mod media_commands;

55
src/admin/mod.rs Normal file
View File

@ -0,0 +1,55 @@
pub(crate) mod appservice;
pub(crate) mod debug;
pub(crate) mod federation;
pub(crate) mod fsck;
pub(crate) mod handler;
pub(crate) mod media;
pub(crate) mod query;
pub(crate) mod room;
pub(crate) mod server;
pub(crate) mod tester;
pub(crate) mod user;
pub(crate) mod utils;
extern crate conduit_api as api;
extern crate conduit_core as conduit;
extern crate conduit_service as service;
pub(crate) use conduit::{mod_ctor, mod_dtor, Result};
pub use handler::handle;
pub(crate) use service::{services, user_is_local};
pub(crate) use crate::{
handler::Service,
utils::{escape_html, get_room_info},
};
mod_ctor! {}
mod_dtor! {}
#[cfg(test)]
mod test {
use clap::Parser;
use crate::handler::AdminCommand;
#[test]
fn get_help_short() { get_help_inner("-h"); }
#[test]
fn get_help_long() { get_help_inner("--help"); }
#[test]
fn get_help_subcommand() { get_help_inner("help"); }
fn get_help_inner(input: &str) {
let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.unwrap_err()
.to_string();
// Search for a handful of keywords that suggest the help printed properly
assert!(error.contains("Usage:"));
assert!(error.contains("Commands:"));
assert!(error.contains("Options:"));
}
}

View File

@ -3,7 +3,7 @@ use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId};
use super::RoomAliasCommand;
use crate::{service::admin::escape_html, services, Result};
use crate::{escape_html, services, Result};
pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {

View File

@ -2,10 +2,7 @@ use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use crate::{
service::admin::{escape_html, get_room_info, PAGE_SIZE},
services, Result,
};
use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result};
pub(crate) async fn list(_body: Vec<&str>, page: Option<usize>) -> Result<RoomMessageEventContent> {
// TODO: i know there's a way to do this with clap, but i can't seem to find it

View File

@ -3,10 +3,7 @@ use std::fmt::Write as _;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId};
use super::RoomDirectoryCommand;
use crate::{
service::admin::{escape_html, get_room_info, PAGE_SIZE},
services, Result,
};
use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result};
pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {

View File

@ -1,18 +1,16 @@
use std::fmt::Write as _;
use std::fmt::Write;
use api::client_server::{get_alias_helper, leave_room};
use ruma::{
events::room::message::RoomMessageEventContent, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId,
};
use tracing::{debug, error, info, warn};
use super::RoomModerationCommand;
use crate::{
api::client_server::{get_alias_helper, leave_room},
service::admin::{escape_html, Service},
services,
utils::user_id::user_is_local,
Result,
use super::{
super::{escape_html, Service},
RoomModerationCommand,
};
use crate::{services, user_is_local, Result};
pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> {
match command {

View File

@ -4,7 +4,7 @@ use crate::{services, Result};
pub(crate) async fn uptime(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let seconds = services()
.globals
.server
.started
.elapsed()
.expect("standard duration")
@ -28,7 +28,7 @@ pub(crate) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventCont
pub(crate) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
let response0 = services().memory_usage().await;
let response1 = services().globals.db.memory_usage();
let response2 = crate::alloc::memory_usage();
let response2 = conduit::alloc::memory_usage();
Ok(RoomMessageEventContent::text_plain(format!(
"Services:\n{response0}\n\nDatabase:\n{response1}\n{}",
@ -69,7 +69,10 @@ pub(crate) async fn backup_database(_body: Vec<&str>) -> Result<RoomMessageEvent
));
}
let mut result = tokio::task::spawn_blocking(move || match services().globals.db.backup() {
let mut result = services()
.server
.runtime()
.spawn_blocking(move || match services().globals.db.backup() {
Ok(()) => String::new(),
Err(e) => (*e).to_string(),
})

View File

@ -9,6 +9,6 @@ pub(crate) enum TesterCommands {
}
pub(crate) async fn process(command: TesterCommands, _body: Vec<&str>) -> Result<RoomMessageEventContent> {
Ok(match command {
TesterCommands::Tester => RoomMessageEventContent::notice_plain(String::from("complete")),
TesterCommands::Tester => RoomMessageEventContent::notice_plain(String::from("completed")),
})
}

View File

@ -1,15 +1,13 @@
use std::{fmt::Write as _, sync::Arc};
use api::client_server::{join_room_by_id_helper, leave_all_rooms};
use conduit::utils;
use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, UserId};
use tracing::{error, info, warn};
use crate::{
api::client_server::{join_room_by_id_helper, leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
service::admin::{escape_html, get_room_info},
services,
utils::{self, user_id::user_is_local},
Result,
};
use crate::{escape_html, get_room_info, services, user_is_local, Result};
const AUTO_GEN_PASSWORD_LENGTH: usize = 25;
pub(crate) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> {
match services().users.list_local_users() {
@ -111,7 +109,7 @@ pub(crate) async fn create(
)
.await
{
Ok(_) => {
Ok(_response) => {
info!("Automatically joined room {room} for user {user_id}");
},
Err(e) => {

30
src/admin/utils.rs Normal file
View File

@ -0,0 +1,30 @@
pub(crate) use conduit::utils::HtmlEscape;
use ruma::OwnedRoomId;
use crate::services;
pub(crate) fn escape_html(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
pub(crate) fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) {
(
id.clone(),
services()
.rooms
.state_cache
.room_joined_count(id)
.ok()
.flatten()
.unwrap_or(0),
services()
.rooms
.state_accessor
.get_name(id)
.ok()
.flatten()
.unwrap_or_else(|| id.to_string()),
)
}

View File

@ -1,7 +0,0 @@
//! Default allocator with no special features
/// Always returns the empty string
pub(crate) fn memory_stats() -> String { Default::default() }
/// Always returns the empty string
pub(crate) fn memory_usage() -> String { Default::default() }

View File

@ -1,8 +0,0 @@
#[global_allocator]
static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc;
pub(crate) fn memory_usage() -> String {
String::default() //TODO: get usage
}
pub(crate) fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() }

66
src/api/Cargo.toml Normal file
View File

@ -0,0 +1,66 @@
[package]
name = "conduit_api"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"element_hacks",
"gzip_compression",
"brotli_compression",
"release_max_log_level",
]
element_hacks = []
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
gzip_compression = [
"reqwest/gzip",
]
brotli_compression = [
"reqwest/brotli",
]
[dependencies]
argon2.workspace = true
axum-extra.workspace = true
axum.workspace = true
base64.workspace = true
bytes.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
conduit-service.workspace = true
futures-util.workspace = true
hmac.workspace = true
http.workspace = true
hyper.workspace = true
image.workspace = true
ipaddress.workspace = true
jsonwebtoken.workspace = true
log.workspace = true
rand.workspace = true
reqwest.workspace = true
ruma.workspace = true
serde_html_form.workspace = true
serde_json.workspace = true
serde.workspace = true
sha-1.workspace = true
thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
webpage.workspace = true
[lints]
workspace = true

View File

@ -19,9 +19,10 @@ use tracing::{error, info, warn};
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{
api::client_server::{self, join_room_by_id_helper},
service, services,
utils::{self, user_id::user_is_local},
client_server::{self, join_room_by_id_helper},
service::user_is_local,
services,
utils::{self},
Error, Result, Ruma,
};

View File

@ -13,8 +13,9 @@ use ruma::{
use tracing::debug;
use crate::{
debug_info, debug_warn, service::appservice::RegistrationInfo, services, utils::server_name::server_is_ours, Error,
Result, Ruma,
debug_info, debug_warn,
service::{appservice::RegistrationInfo, server_is_ours},
services, Error, Result, Ruma,
};
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
@ -65,7 +66,6 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) ->
/// - TODO: Update canonical alias event
pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
alias_checks(&body.room_alias, &body.appservice_info).await?;
if services()
.rooms
.alias
@ -99,7 +99,7 @@ pub(crate) async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Resul
get_alias_helper(body.body.room_alias, None).await
}
pub(crate) async fn get_alias_helper(
pub async fn get_alias_helper(
room_alias: OwnedRoomAliasId, servers: Option<Vec<OwnedServerName>>,
) -> Result<get_alias::v3::Response> {
debug!("get_alias_helper servers: {servers:?}");

View File

@ -24,7 +24,7 @@ use ruma::{
};
use tracing::{error, info, warn};
use crate::{services, utils::server_name::server_is_ours, Error, Result, Ruma};
use crate::{service::server_is_ours, services, Error, Result, Ruma};
/// # `POST /_matrix/client/v3/publicRooms`
///

View File

@ -22,8 +22,9 @@ use tracing::debug;
use super::SESSION_ID_LENGTH;
use crate::{
service::user_is_local,
services,
utils::{self, user_id::user_is_local},
utils::{self},
Error, Result, Ruma,
};

View File

@ -15,14 +15,16 @@ use webpage::HTML;
use crate::{
debug_warn,
service::media::{FileMeta, UrlPreviewData},
service::{
media::{FileMeta, UrlPreviewData},
server_is_ours,
},
services,
utils::{
self,
content_disposition::{
content_disposition_type, make_content_disposition, make_content_type, sanitise_filename,
},
server_name::server_is_ours,
},
Error, Result, Ruma, RumaResponse,
};

View File

@ -35,9 +35,12 @@ use tracing::{debug, error, info, trace, warn};
use super::get_alias_helper;
use crate::{
service::pdu::{gen_event_id_canonical_json, PduBuilder},
service::{
pdu::{gen_event_id_canonical_json, PduBuilder},
server_is_ours, user_is_local,
},
services,
utils::{self, server_name::server_is_ours, user_id::user_is_local},
utils::{self},
Error, PduEvent, Result, Ruma,
};
@ -607,7 +610,7 @@ pub(crate) async fn joined_members_route(
})
}
pub(crate) async fn join_room_by_id_helper(
pub async fn join_room_by_id_helper(
sender_user: Option<&UserId>, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName],
_third_party_signed: Option<&ThirdPartySigned>,
) -> Result<join_room_by_id::v3::Response> {
@ -1525,7 +1528,7 @@ pub(crate) async fn invite_helper(
// Make a user leave all their joined rooms, forgets all rooms, and ignores
// errors
pub(crate) async fn leave_all_rooms(user_id: &UserId) {
pub async fn leave_all_rooms(user_id: &UserId) {
let all_rooms = services()
.rooms
.state_cache
@ -1550,7 +1553,7 @@ pub(crate) async fn leave_all_rooms(user_id: &UserId) {
}
}
pub(crate) async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> {
pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> {
// Ask a remote server if we don't have this room
if !services()
.rooms

View File

@ -3,6 +3,7 @@ use std::{
sync::Arc,
};
use conduit::PduCount;
use ruma::{
api::client::{
error::ErrorKind,
@ -14,10 +15,7 @@ use ruma::{
};
use serde_json::{from_str, Value};
use crate::{
service::{pdu::PduBuilder, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result, Ruma,
};
use crate::{service::pdu::PduBuilder, services, utils, Error, PduEvent, Result, Ruma};
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
///

View File

@ -1,40 +1,41 @@
mod account;
mod alias;
mod backup;
mod capabilities;
mod config;
mod context;
mod device;
mod directory;
mod filter;
mod keys;
mod media;
mod membership;
mod message;
mod presence;
mod profile;
mod push;
mod read_marker;
mod redact;
mod relations;
mod report;
mod room;
mod search;
mod session;
mod space;
mod state;
mod sync;
mod tag;
mod thirdparty;
mod threads;
mod to_device;
mod typing;
mod unstable;
mod unversioned;
mod user_directory;
mod voip;
pub(crate) mod account;
pub(crate) mod alias;
pub(crate) mod backup;
pub(crate) mod capabilities;
pub(crate) mod config;
pub(crate) mod context;
pub(crate) mod device;
pub(crate) mod directory;
pub(crate) mod filter;
pub(crate) mod keys;
pub(crate) mod media;
pub(crate) mod membership;
pub(crate) mod message;
pub(crate) mod presence;
pub(crate) mod profile;
pub(crate) mod push;
pub(crate) mod read_marker;
pub(crate) mod redact;
pub(crate) mod relations;
pub(crate) mod report;
pub(crate) mod room;
pub(crate) mod search;
pub(crate) mod session;
pub(crate) mod space;
pub(crate) mod state;
pub(crate) mod sync;
pub(crate) mod tag;
pub(crate) mod thirdparty;
pub(crate) mod threads;
pub(crate) mod to_device;
pub(crate) mod typing;
pub(crate) mod unstable;
pub(crate) mod unversioned;
pub(crate) mod user_directory;
pub(crate) mod voip;
pub(crate) use account::*;
pub use alias::get_alias_helper;
pub(crate) use alias::*;
pub(crate) use backup::*;
pub(crate) use capabilities::*;
@ -46,6 +47,7 @@ pub(crate) use filter::*;
pub(crate) use keys::*;
pub(crate) use media::*;
pub(crate) use membership::*;
pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room};
pub(crate) use message::*;
pub(crate) use presence::*;
pub(crate) use profile::*;
@ -77,7 +79,4 @@ const DEVICE_ID_LENGTH: usize = 10;
const TOKEN_LENGTH: usize = 32;
/// generated user session ID length
pub(crate) const SESSION_ID_LENGTH: usize = 32;
/// auto-generated password length
pub(crate) const AUTO_GEN_PASSWORD_LENGTH: usize = 25;
const SESSION_ID_LENGTH: usize = service::uiaa::SESSION_ID_LENGTH;

View File

@ -13,7 +13,10 @@ use ruma::{
};
use serde_json::value::to_raw_value;
use crate::{service::pdu::PduBuilder, services, utils::user_id::user_is_local, Error, Result, Ruma};
use crate::{
service::{pdu::PduBuilder, user_is_local},
services, Error, Result, Ruma,
};
/// # `PUT /_matrix/client/r0/profile/{userId}/displayname`
///

View File

@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use conduit::PduCount;
use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
events::{
@ -9,7 +10,7 @@ use ruma::{
MilliSecondsSinceUnixEpoch,
};
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
use crate::{services, Error, Result, Ruma};
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
///

View File

@ -1,5 +1,6 @@
use std::{cmp::max, collections::BTreeMap, sync::Arc};
use conduit::{debug_info, debug_warn};
use ruma::{
api::client::{
error::ErrorKind,
@ -28,8 +29,7 @@ use serde_json::{json, value::to_raw_value};
use tracing::{error, info, warn};
use crate::{
api::client_server::invite_helper,
debug_info, debug_warn,
client_server::invite_helper,
service::{appservice::RegistrationInfo, pdu::PduBuilder},
services, Error, Result, Ruma,
};

View File

@ -19,10 +19,8 @@ use ruma::{
use tracing::{error, log::warn};
use crate::{
service::{self, pdu::PduBuilder},
services,
utils::server_name::server_is_ours,
Error, Result, Ruma, RumaResponse,
service::{pdu::PduBuilder, server_is_ours},
services, Error, Result, Ruma, RumaResponse,
};
/// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}`

View File

@ -5,6 +5,7 @@ use std::{
time::Duration,
};
use conduit::PduCount;
use ruma::{
api::client::{
filter::{FilterDefinition, LazyLoadOptions},
@ -29,10 +30,7 @@ use ruma::{
};
use tracing::{error, Instrument as _, Span};
use crate::{
service::{pdu::EventHash, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result, Ruma, RumaResponse,
};
use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse};
/// # `GET /_matrix/client/r0/sync`
///

View File

@ -8,7 +8,7 @@ use ruma::{
to_device::DeviceIdOrAllDevices,
};
use crate::{services, utils::user_id::user_is_local, Error, Result, Ruma};
use crate::{services, user_is_local, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
///

View File

@ -1,3 +1,15 @@
pub(crate) mod client_server;
pub mod client_server;
pub mod router;
pub(crate) mod ruma_wrapper;
pub(crate) mod server_server;
pub mod server_server;
extern crate conduit_core as conduit;
extern crate conduit_service as service;
pub use client_server::membership::{join_room_by_id_helper, leave_all_rooms};
pub(crate) use conduit::{debug_error, debug_info, debug_warn, error::RumaResponse, utils, Error, Result};
pub(crate) use ruma_wrapper::Ruma;
pub(crate) use service::{pdu::PduEvent, services, user_is_local};
conduit::mod_ctor! {}
conduit::mod_dtor! {}

View File

@ -6,16 +6,15 @@ use axum::{
routing::{any, get, on, post, MethodFilter},
Router,
};
use conduit::{Error, Result, Server};
use http::{Method, Uri};
use ruma::api::{client::error::ErrorKind, IncomingRequest};
use crate::{
api::{client_server, server_server},
Config, Error, Result, Ruma, RumaResponse,
};
use crate::{client_server, server_server, Ruma, RumaResponse};
pub(crate) fn routes(config: &Config) -> Router {
let router = Router::new()
pub fn build(router: Router, server: &Server) -> Router {
let config = &server.config;
let router = router
.ruma_route(client_server::get_supported_versions_route)
.ruma_route(client_server::get_register_available_route)
.ruma_route(client_server::register_route)
@ -187,9 +186,7 @@ pub(crate) fn routes(config: &Config) -> Router {
.route("/_conduwuit/server_version", get(client_server::conduwuit_server_version))
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
.route("/client/server.json", get(client_server::syncv3_client_server_json))
.route("/", get(it_works))
.fallback(not_found);
.route("/client/server.json", get(client_server::syncv3_client_server_json));
if config.allow_federation {
router
@ -230,16 +227,10 @@ pub(crate) fn routes(config: &Config) -> Router {
}
}
async fn not_found(_uri: Uri) -> impl IntoResponse {
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
}
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented")
}
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") }
trait RouterExt {
@ -259,7 +250,7 @@ impl RouterExt for Router {
}
}
pub(crate) trait RumaHandler<T> {
trait RumaHandler<T> {
// Can't transform to a handler without boxing or relying on the nightly-only
// impl-trait-in-traits feature. Moving a small amount of extra logic into the
// trait allows bypassing both.

View File

@ -3,30 +3,26 @@ use std::{collections::BTreeMap, str};
use axum::{
async_trait,
extract::{FromRequest, Path},
response::{IntoResponse, Response},
RequestExt, RequestPartsExt,
};
use axum_extra::{
headers::{
authorization::{Bearer, Credentials},
Authorization,
},
headers::{authorization::Bearer, Authorization},
typed_header::TypedHeaderRejectionReason,
TypedHeader,
};
use bytes::{BufMut, BytesMut};
use http::{uri::PathAndQuery, StatusCode};
use http_body_util::Full;
use conduit::debug_warn;
use http::uri::PathAndQuery;
use hyper::Request;
use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
api::{client::error::ErrorKind, AuthScheme, IncomingRequest},
CanonicalJsonValue, OwnedDeviceId, OwnedUserId, UserId,
};
use serde::Deserialize;
use tracing::{debug, error, trace, warn};
use super::{Ruma, RumaResponse};
use crate::{debug_warn, service::appservice::RegistrationInfo, services, Error, Result};
use super::{xmatrix::XMatrix, Ruma};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -332,68 +328,3 @@ where
})
}
}
struct XMatrix {
origin: OwnedServerName,
destination: Option<String>,
key: String, // KeyName?
sig: String,
}
impl Credentials for XMatrix {
const SCHEME: &'static str = "X-Matrix";
fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
);
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
.ok()?
.trim_start();
let mut origin = None;
let mut destination = None;
let mut key = None;
let mut sig = None;
for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not in the spec,
// let's simply accept either form for every field.
let value = value
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))
.unwrap_or(value);
// FIXME: Catch multiple fields of the same name
match name {
"origin" => origin = Some(value.try_into().ok()?),
"key" => key = Some(value.to_owned()),
"sig" => sig = Some(value.to_owned()),
"destination" => destination = Some(value.to_owned()),
_ => debug!("Unexpected field `{name}` in X-Matrix Authorization header"),
}
}
Some(Self {
origin: origin?,
key: key?,
sig: sig?,
destination,
})
}
fn encode(&self) -> http::HeaderValue { todo!() }
}
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
fn into_response(self) -> Response {
match self.0.try_into_http_response::<BytesMut>() {
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
}

View File

@ -1,10 +1,11 @@
pub(crate) mod axum;
mod xmatrix;
use std::ops::Deref;
use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
use crate::{service::appservice::RegistrationInfo, Error};
mod axum;
use crate::service::appservice::RegistrationInfo;
/// Extractor for Ruma request structs
pub(crate) struct Ruma<T> {
@ -21,14 +22,3 @@ impl<T> Deref for Ruma<T> {
fn deref(&self) -> &Self::Target { &self.body }
}
#[derive(Clone)]
pub(crate) struct RumaResponse<T>(pub(crate) T);
impl<T> From<T> for RumaResponse<T> {
fn from(t: T) -> Self { Self(t) }
}
impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { t.to_response() }
}

View File

@ -0,0 +1,61 @@
use std::str;
use axum_extra::headers::authorization::Credentials;
use ruma::OwnedServerName;
use tracing::debug;
pub(crate) struct XMatrix {
pub(crate) origin: OwnedServerName,
pub(crate) destination: Option<String>,
pub(crate) key: String, // KeyName?
pub(crate) sig: String,
}
impl Credentials for XMatrix {
const SCHEME: &'static str = "X-Matrix";
fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
);
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
.ok()?
.trim_start();
let mut origin = None;
let mut destination = None;
let mut key = None;
let mut sig = None;
for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not in the spec,
// let's simply accept either form for every field.
let value = value
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))
.unwrap_or(value);
// FIXME: Catch multiple fields of the same name
match name {
"origin" => origin = Some(value.try_into().ok()?),
"key" => key = Some(value.to_owned()),
"sig" => sig = Some(value.to_owned()),
"destination" => destination = Some(value.to_owned()),
_ => debug!("Unexpected field `{name}` in X-Matrix Authorization header"),
}
}
Some(Self {
origin: origin?,
key: key?,
sig: sig?,
destination,
})
}
fn encode(&self) -> http::HeaderValue { todo!() }
}

View File

@ -44,19 +44,23 @@ use ruma::{
},
serde::{Base64, JsonObject, Raw},
to_device::DeviceIdOrAllDevices,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, RoomVersionId, ServerName,
uint, user_id, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId, RoomId, RoomVersionId, ServerName,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tracing::{debug, error, trace, warn};
use crate::{
api::client_server::{self, claim_keys_helper, get_keys_helper},
client_server::{self, claim_keys_helper, get_keys_helper},
debug_error,
service::pdu::{gen_event_id_canonical_json, PduBuilder},
service::{
pdu::{gen_event_id_canonical_json, PduBuilder},
rooms::event_handler::parse_incoming_pdu,
server_is_ours, user_is_local,
},
services,
utils::{self, server_name::server_is_ours, user_id::user_is_local},
utils::{self},
Error, PduEvent, Result, Ruma,
};
@ -196,32 +200,6 @@ pub(crate) async fn get_public_rooms_route(
})
}
pub(crate) fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else {
return Err(Error::Err(format!("Server is not in room {room_id}")));
};
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Could not convert event to canonical json.",
));
};
Ok((event_id, value, room_id))
}
/// # `PUT /_matrix/federation/v1/send/{txnId}`
///
/// Push EDUs and PDUs to this server.

123
src/bin/Cargo.toml Normal file
View File

@ -0,0 +1,123 @@
[package]
# TODO: when can we rename to conduwuit?
name = "conduit"
default-run = "conduit"
description.workspace = true
license.workspace = true
authors.workspace = true
homepage.workspace = true
repository.workspace = true
readme.workspace = true
version.workspace = true
edition.workspace = true
rust-version.workspace = true
[package.metadata.deb]
name = "conduwuit"
maintainer = "strawberry <strawberry@puppygock.gay>"
copyright = "2024, strawberry <strawberry@puppygock.gay>"
license-file = ["LICENSE", "3"]
depends = "$auto, ca-certificates"
extended-description = """\
a cool hard fork of Conduit, a Matrix homeserver written in Rust"""
section = "net"
priority = "optional"
conf-files = ["/etc/conduwuit/conduwuit.toml"]
maintainer-scripts = "debian/"
systemd-units = { unit-name = "conduwuit", start = false }
assets = [
["debian/README.md", "usr/share/doc/conduwuit/README.Debian", "644"],
["README.md", "usr/share/doc/conduwuit/", "644"],
["target/release/conduwuit", "usr/sbin/conduwuit", "755"],
["conduwuit-example.toml", "etc/conduwuit/conduwuit.toml", "640"],
]
[features]
default = [
"sentry_telemetry",
"release_max_log_level",
]
# increases performance, reduces build times, and reduces binary size by not compiling or
# genreating code for log level filters that users will generally not use (debug and trace)
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sentry_telemetry = [
"dep:sentry",
"dep:sentry-tracing",
"dep:sentry-tower",
]
# enable the tokio_console server ncompatible with release_max_log_level
tokio_console = [
"dep:console-subscriber",
"tokio/tracing",
]
perf_measurements = [
"dep:opentelemetry",
"dep:tracing-flame",
"dep:tracing-opentelemetry",
"dep:opentelemetry_sdk",
"dep:opentelemetry-jaeger",
]
jemalloc = [
"dep:tikv-jemallocator",
]
panic_trap = []
mods = []
[dependencies]
conduit-router.workspace = true
conduit-admin.workspace = true
conduit-api.workspace = true
conduit-service.workspace = true
conduit-database.workspace = true
conduit-core.workspace = true
tokio.workspace = true
log.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
clap.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
opentelemetry.optional = true
tracing-flame.workspace = true
tracing-flame.optional = true
tracing-opentelemetry.workspace = true
tracing-opentelemetry.optional = true
opentelemetry_sdk.workspace = true
opentelemetry_sdk.optional = true
opentelemetry-jaeger.workspace = true
opentelemetry-jaeger.optional = true
sentry.workspace = true
sentry.optional = true
sentry-tracing.workspace = true
sentry-tracing.optional = true
sentry-tower.workspace = true
sentry-tower.optional = true
tikv-jemallocator.workspace = true
tikv-jemallocator.optional = true
tokio-metrics.workspace = true
tokio-metrics.optional = true
console-subscriber.workspace = true
console-subscriber.optional = true
[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies]
hardened_malloc-rs.workspace = true
hardened_malloc-rs.optional = true
[lints]
workspace = true
[[bin]]
name = "conduit"
path = "main.rs"

96
src/bin/main.rs Normal file
View File

@ -0,0 +1,96 @@
mod mods;
mod server;
extern crate conduit_core as conduit;
use std::{cmp, sync::Arc, time::Duration};
use conduit::{debug_info, error, utils::clap, Error, Result};
use server::Server;
use tokio::runtime;
const WORKER_NAME: &str = "conduwuit:worker";
const WORKER_MIN: usize = 2;
const WORKER_KEEPALIVE_MS: u64 = 2500;
fn main() -> Result<(), Error> {
let args = clap::parse();
let runtime = runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.thread_name(WORKER_NAME)
.worker_threads(cmp::max(WORKER_MIN, num_cpus::get()))
.thread_keep_alive(Duration::from_millis(WORKER_KEEPALIVE_MS))
.build()
.expect("built runtime");
let handle = runtime.handle();
let server: Arc<Server> = Server::build(args, Some(handle))?;
runtime.block_on(async { async_main(server.clone()).await })?;
// explicit drop here to trace thread and tls dtors
drop(runtime);
debug_info!("Exit");
Ok(())
}
/// Operate the server normally in release-mode static builds. This will start,
/// run and stop the server within the asynchronous runtime.
#[cfg(not(feature = "mods"))]
async fn async_main(server: Arc<Server>) -> Result<(), Error> {
extern crate conduit_router as router;
use tracing::error;
if let Err(error) = router::start(&server.server).await {
error!("Critical error starting server: {error}");
return Err(error);
}
if let Err(error) = router::run(&server.server).await {
error!("Critical error running server: {error}");
return Err(error);
}
if let Err(error) = router::stop(&server.server).await {
error!("Critical error stopping server: {error}");
return Err(error);
}
debug_info!("Exit runtime");
Ok(())
}
/// Operate the server in developer-mode dynamic builds. This will start, run,
/// and hot-reload portions of the server as-needed before returning for an
/// actual shutdown. This is not available in release-mode or static builds.
#[cfg(feature = "mods")]
async fn async_main(server: Arc<Server>) -> Result<(), Error> {
let mut starts = true;
let mut reloads = true;
while reloads {
if let Err(error) = mods::open(&server).await {
error!("Loading router: {error}");
return Err(error);
}
let result = mods::run(&server, starts).await;
if let Ok(result) = result {
(starts, reloads) = result;
}
let force = !reloads || result.is_err();
if let Err(error) = mods::close(&server, force).await {
error!("Unloading router: {error}");
return Err(error);
}
if let Err(error) = result {
error!("{error}");
return Err(error);
}
}
debug_info!("Exit runtime");
Ok(())
}

129
src/bin/mods.rs Normal file
View File

@ -0,0 +1,129 @@
#![cfg(feature = "mods")]
#[cfg(not(any(clippy, debug_assertions, doctest, test)))]
compile_error!("Feature 'mods' is only available in developer builds");
use std::{
future::Future,
pin::Pin,
sync::{atomic::Ordering, Arc},
};
use conduit::{mods, Error, Result};
use tracing::{debug, error};
use crate::Server;
type RunFuncResult = Pin<Box<dyn Future<Output = Result<(), Error>>>>;
type RunFuncProto = fn(&Arc<conduit::Server>) -> RunFuncResult;
const RESTART_THRESH: &str = "conduit_service";
const MODULE_NAMES: &[&str] = &[
//"conduit_core",
"conduit_database",
"conduit_service",
"conduit_api",
"conduit_admin",
"conduit_router",
];
#[cfg(feature = "panic_trap")]
conduit::mod_init! {{
conduit::debug::set_panic_trap();
}}
pub(crate) async fn run(server: &Arc<Server>, starts: bool) -> Result<(bool, bool), Error> {
let main_lock = server.mods.read().await;
let main_mod = (*main_lock).last().expect("main module loaded");
if starts {
let start = main_mod.get::<RunFuncProto>("start")?;
if let Err(error) = start(&server.server).await {
error!("Starting server: {error}");
return Err(error);
}
}
let run = main_mod.get::<RunFuncProto>("run")?;
if let Err(error) = run(&server.server).await {
error!("Running server: {error}");
return Err(error);
}
let reloads = server.server.reload.swap(false, Ordering::AcqRel);
let stops = !reloads || stale(server).await? <= restart_thresh();
let starts = reloads && stops;
if stops {
let stop = main_mod.get::<RunFuncProto>("stop")?;
if let Err(error) = stop(&server.server).await {
error!("Stopping server: {error}");
return Err(error);
}
}
Ok((starts, reloads))
}
pub(crate) async fn open(server: &Arc<Server>) -> Result<usize, Error> {
let mut mods_lock = server.mods.write().await;
let mods: &mut Vec<mods::Module> = &mut mods_lock;
debug!(
available = %available(),
loaded = %mods.len(),
"Loading modules",
);
for (i, name) in MODULE_NAMES.iter().enumerate() {
if mods.get(i).is_none() {
mods.push(mods::Module::from_name(name)?);
}
}
Ok(mods.len())
}
pub(crate) async fn close(server: &Arc<Server>, force: bool) -> Result<usize, Error> {
let stale = stale_count(server).await;
let mut mods_lock = server.mods.write().await;
let mods: &mut Vec<mods::Module> = &mut mods_lock;
debug!(
available = %available(),
loaded = %mods.len(),
stale = %stale,
force,
"Unloading modules",
);
while mods.last().is_some() {
let module = &mods.last().expect("module");
if force || module.deleted()? {
mods.pop();
} else {
break;
}
}
Ok(mods.len())
}
async fn stale_count(server: &Arc<Server>) -> usize {
let watermark = stale(server).await.unwrap_or(available());
available() - watermark
}
async fn stale(server: &Arc<Server>) -> Result<usize, Error> {
let mods_lock = server.mods.read().await;
let mods: &Vec<mods::Module> = &mods_lock;
for (i, module) in mods.iter().enumerate() {
if module.deleted()? {
return Ok(i);
}
}
Ok(mods.len())
}
fn restart_thresh() -> usize {
MODULE_NAMES
.iter()
.position(|&name| name.ends_with(RESTART_THRESH))
.unwrap_or(MODULE_NAMES.len())
}
const fn available() -> usize { MODULE_NAMES.len() }

186
src/bin/server.rs Normal file
View File

@ -0,0 +1,186 @@
use std::sync::Arc;
use conduit::{
conduwuit_version,
config::Config,
info,
log::{LogLevelReloadHandles, ReloadHandle},
utils::{clap, maximize_fd_limit},
Error, Result,
};
use tokio::runtime;
use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry};
/// Server runtime state; complete
pub(crate) struct Server {
/// Server runtime state; public portion
pub(crate) server: Arc<conduit::Server>,
_tracing_flame_guard: TracingFlameGuard,
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: Option<sentry::ClientInitGuard>,
// Module instances; TODO: move to mods::loaded mgmt vector
#[cfg(feature = "mods")]
pub(crate) mods: tokio::sync::RwLock<Vec<conduit::mods::Module>>,
}
impl Server {
pub(crate) fn build(args: clap::Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Server>, Error> {
let config = Config::new(args.config)?;
#[cfg(feature = "sentry_telemetry")]
let sentry_guard = init_sentry(&config);
let (tracing_reload_handle, tracing_flame_guard) = init_tracing(&config);
config.check()?;
#[cfg(unix)]
maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit");
info!(
server_name = %config.server_name,
database_path = ?config.database_path,
log_levels = %config.log,
"{}",
conduwuit_version(),
);
Ok(Arc::new(Server {
server: Arc::new(conduit::Server::new(config, runtime.cloned(), tracing_reload_handle)),
_tracing_flame_guard: tracing_flame_guard,
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: sentry_guard,
#[cfg(feature = "mods")]
mods: tokio::sync::RwLock::new(Vec::new()),
}))
}
}
#[cfg(feature = "sentry_telemetry")]
fn init_sentry(config: &Config) -> Option<sentry::ClientInitGuard> {
if !config.sentry {
return None;
}
let sentry_endpoint = config
.sentry_endpoint
.as_ref()
.expect("init_sentry should only be called if sentry is enabled and this is not None")
.as_str();
let server_name = if config.sentry_send_server_name {
Some(config.server_name.to_string().into())
} else {
None
};
Some(sentry::init((
sentry_endpoint,
sentry::ClientOptions {
release: sentry::release_name!(),
traces_sample_rate: config.sentry_traces_sample_rate,
server_name,
..Default::default()
},
)))
}
#[cfg(feature = "perf_measurements")]
type TracingFlameGuard = Option<tracing_flame::FlushGuard<std::io::BufWriter<std::fs::File>>>;
#[cfg(not(feature = "perf_measurements"))]
type TracingFlameGuard = ();
// clippy thinks the filter_layer clones are redundant if the next usage is
// behind a disabled feature.
#[allow(clippy::redundant_clone)]
fn init_tracing(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard) {
let registry = Registry::default();
let fmt_layer = tracing_subscriber::fmt::Layer::new();
let filter_layer = match EnvFilter::try_new(&config.log) {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
EnvFilter::try_new("warn").unwrap()
},
};
let mut reload_handles = Vec::<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>::new();
let subscriber = registry;
#[cfg(feature = "tokio_console")]
let subscriber = {
let console_layer = console_subscriber::spawn();
subscriber.with(console_layer)
};
let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone());
reload_handles.push(Box::new(fmt_reload_handle));
let subscriber = subscriber.with(fmt_layer.with_filter(fmt_reload_filter));
#[cfg(feature = "sentry_telemetry")]
let subscriber = {
let sentry_layer = sentry_tracing::layer();
let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone());
reload_handles.push(Box::new(sentry_reload_handle));
subscriber.with(sentry_layer.with_filter(sentry_reload_filter))
};
#[cfg(feature = "perf_measurements")]
let (subscriber, flame_guard) = {
let (flame_layer, flame_guard) = if config.tracing_flame {
let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) {
Ok(flame_filter) => flame_filter,
Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"),
};
let (flame_layer, flame_guard) =
match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) {
Ok(ok) => ok,
Err(e) => {
panic!("failed to initialize tracing-flame: {e}");
},
};
let flame_layer = flame_layer
.with_empty_samples(false)
.with_filter(flame_filter);
(Some(flame_layer), Some(flame_guard))
} else {
(None, None)
};
let jaeger_layer = if config.allow_jaeger {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_auto_split_batch(true)
.with_service_name("conduwuit")
.install_batch(opentelemetry_sdk::runtime::Tokio)
.unwrap();
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer);
reload_handles.push(Box::new(jaeger_reload_handle));
Some(telemetry.with_filter(jaeger_reload_filter))
} else {
None
};
let subscriber = subscriber.with(flame_layer).with(jaeger_layer);
(subscriber, flame_guard)
};
#[cfg(not(feature = "perf_measurements"))]
#[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))]
let flame_guard = ();
tracing::subscriber::set_global_default(subscriber).unwrap();
#[cfg(all(feature = "tokio_console", feature = "release_max_log_level"))]
tracing::error!(
"'tokio_console' feature and 'release_max_log_level' feature are incompatible, because console-subscriber \
needs access to trace-level events. 'release_max_log_level' must be disabled to use tokio-console."
);
(LogLevelReloadHandles::new(reload_handles), flame_guard)
}

133
src/core/Cargo.toml Normal file
View File

@ -0,0 +1,133 @@
[package]
name = "conduit_core"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"gzip_compression",
"zstd_compression",
"brotli_compression",
"sentry_telemetry",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sqlite = [
"dep:rusqlite",
"dep:parking_lot",
"dep:thread_local",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
jemalloc_prof = [
"tikv-jemalloc-sys/profiling",
]
hardened_malloc = [
"dep:hardened_malloc-rs"
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
gzip_compression = [
"reqwest/gzip",
]
brotli_compression = [
"reqwest/brotli",
]
perf_measurements = []
sentry_telemetry = []
mods = [
"dep:libloading"
]
panic_trap = []
[dependencies]
async-trait.workspace = true
axum-server.workspace = true
axum.workspace = true
base64.workspace = true
bytes.workspace = true
clap.workspace = true
cyborgtime.workspace = true
either.workspace = true
figment.workspace = true
futures-util.workspace = true
http-body-util.workspace = true
http.workspace = true
image.workspace = true
infer.workspace = true
ipaddress.workspace = true
itertools.workspace = true
libloading.workspace = true
libloading.optional = true
log.workspace = true
lru-cache.workspace = true
parking_lot.optional = true
parking_lot.workspace = true
rand.workspace = true
regex.workspace = true
reqwest.workspace = true
ring.workspace = true
ruma.workspace = true
rusqlite.optional = true
rusqlite.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
sanitize-filename.workspace = true
serde_json.workspace = true
serde_regex.workspace = true
serde.workspace = true
serde_yaml.workspace = true
sha-1.workspace = true
thiserror.workspace = true
thread_local.optional = true
thread_local.workspace = true
tikv-jemallocator.optional = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl.optional = true
tikv-jemalloc-ctl.workspace = true
tikv-jemalloc-sys.optional = true
tikv-jemalloc-sys.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
zstd.optional = true
zstd.workspace = true
[target.'cfg(unix)'.dependencies]
nix.workspace = true
[target.'cfg(all(not(target_env = "msvc"), target_os = "linux"))'.dependencies]
hardened_malloc-rs.workspace = true
hardened_malloc-rs.optional = true
[lints]
workspace = true

View File

@ -0,0 +1,9 @@
//! Default allocator with no special features
/// Always returns the empty string
#[must_use]
pub fn memory_stats() -> String { Default::default() }
/// Always returns the empty string
#[must_use]
pub fn memory_usage() -> String { Default::default() }

View File

@ -0,0 +1,10 @@
#[global_allocator]
static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc;
#[must_use]
pub fn memory_usage() -> String {
String::default() //TODO: get usage
}
#[must_use]
pub fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() }

View File

@ -7,7 +7,8 @@ use tikv_jemallocator as jemalloc;
#[global_allocator]
static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc;
pub(crate) fn memory_usage() -> String {
#[must_use]
pub fn memory_usage() -> String {
use mallctl::stats;
let allocated = stats::allocated::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
let active = stats::active::read().unwrap_or_default() as f64 / 1024.0 / 1024.0;
@ -21,7 +22,8 @@ pub(crate) fn memory_usage() -> String {
)
}
pub(crate) fn memory_stats() -> String {
#[must_use]
pub fn memory_stats() -> String {
const MAX_LENGTH: usize = 65536 - 4096;
let opts_s = "d";

View File

@ -2,24 +2,24 @@
// jemalloc
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))]
mod je;
pub mod je;
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc", not(feature = "hardened_malloc")))]
pub(crate) use je::{memory_stats, memory_usage};
pub use je::{memory_stats, memory_usage};
// hardened_malloc
#[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))]
mod hardened;
pub mod hardened;
#[cfg(all(not(target_env = "msvc"), feature = "hardened_malloc", target_os = "linux", not(feature = "jemalloc")))]
pub(crate) use hardened::{memory_stats, memory_usage};
pub use hardened::{memory_stats, memory_usage};
// default, enabled when none or multiple of the above are enabled
#[cfg(any(
not(any(feature = "jemalloc", feature = "hardened_malloc")),
all(feature = "jemalloc", feature = "hardened_malloc"),
))]
mod default;
pub mod default;
#[cfg(any(
not(any(feature = "jemalloc", feature = "hardened_malloc")),
all(feature = "jemalloc", feature = "hardened_malloc"),
))]
pub(crate) use default::{memory_stats, memory_usage};
pub use default::{memory_stats, memory_usage};

View File

@ -3,9 +3,9 @@ use std::path::Path; // not unix specific, just only for UNIX sockets stuff and
use tracing::{debug, error, info, warn};
use crate::{utils::error::Error, Config};
use crate::{error::Error, Config};
pub(crate) fn check(config: &Config) -> Result<(), Error> {
pub fn check(config: &Config) -> Result<(), Error> {
config.warn_deprecated();
config.warn_unknown_key();

View File

@ -22,11 +22,12 @@ use serde::{de::IgnoredAny, Deserialize};
use tracing::{debug, error, warn};
use url::Url;
use self::{check::check, proxy::ProxyConfig};
use crate::utils::error::Error;
pub use self::check::check;
use self::proxy::ProxyConfig;
use crate::error::Error;
pub(crate) mod check;
mod proxy;
pub mod check;
pub mod proxy;
#[derive(Deserialize, Clone, Debug)]
#[serde(transparent)]
@ -38,310 +39,310 @@ struct ListeningPort {
/// all the config options for conduwuit
#[derive(Clone, Debug, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub(crate) struct Config {
pub struct Config {
/// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6)
#[serde(default = "default_address")]
pub(crate) address: IpAddr,
pub address: IpAddr,
/// default TCP port(s) conduwuit will listen on
#[serde(default = "default_port")]
port: ListeningPort,
pub(crate) tls: Option<TlsConfig>,
pub(crate) unix_socket_path: Option<PathBuf>,
pub tls: Option<TlsConfig>,
pub unix_socket_path: Option<PathBuf>,
#[serde(default = "default_unix_socket_perms")]
pub(crate) unix_socket_perms: u32,
pub(crate) server_name: OwnedServerName,
pub unix_socket_perms: u32,
pub server_name: OwnedServerName,
#[serde(default = "default_database_backend")]
pub(crate) database_backend: String,
pub(crate) database_path: PathBuf,
pub(crate) database_backup_path: Option<PathBuf>,
pub database_backend: String,
pub database_path: PathBuf,
pub database_backup_path: Option<PathBuf>,
#[serde(default = "default_database_backups_to_keep")]
pub(crate) database_backups_to_keep: i16,
pub database_backups_to_keep: i16,
#[serde(default = "default_db_cache_capacity_mb")]
pub(crate) db_cache_capacity_mb: f64,
pub db_cache_capacity_mb: f64,
#[serde(default = "default_new_user_displayname_suffix")]
pub(crate) new_user_displayname_suffix: String,
pub new_user_displayname_suffix: String,
#[serde(default)]
pub(crate) allow_check_for_updates: bool,
pub allow_check_for_updates: bool,
#[serde(default = "default_pdu_cache_capacity")]
pub(crate) pdu_cache_capacity: u32,
pub pdu_cache_capacity: u32,
#[serde(default = "default_conduit_cache_capacity_modifier")]
pub(crate) conduit_cache_capacity_modifier: f64,
pub conduit_cache_capacity_modifier: f64,
#[serde(default = "default_auth_chain_cache_capacity")]
pub(crate) auth_chain_cache_capacity: u32,
pub auth_chain_cache_capacity: u32,
#[serde(default = "default_shorteventid_cache_capacity")]
pub(crate) shorteventid_cache_capacity: u32,
pub shorteventid_cache_capacity: u32,
#[serde(default = "default_eventidshort_cache_capacity")]
pub(crate) eventidshort_cache_capacity: u32,
pub eventidshort_cache_capacity: u32,
#[serde(default = "default_shortstatekey_cache_capacity")]
pub(crate) shortstatekey_cache_capacity: u32,
pub shortstatekey_cache_capacity: u32,
#[serde(default = "default_statekeyshort_cache_capacity")]
pub(crate) statekeyshort_cache_capacity: u32,
pub statekeyshort_cache_capacity: u32,
#[serde(default = "default_server_visibility_cache_capacity")]
pub(crate) server_visibility_cache_capacity: u32,
pub server_visibility_cache_capacity: u32,
#[serde(default = "default_user_visibility_cache_capacity")]
pub(crate) user_visibility_cache_capacity: u32,
pub user_visibility_cache_capacity: u32,
#[serde(default = "default_stateinfo_cache_capacity")]
pub(crate) stateinfo_cache_capacity: u32,
pub stateinfo_cache_capacity: u32,
#[serde(default = "default_roomid_spacehierarchy_cache_capacity")]
pub(crate) roomid_spacehierarchy_cache_capacity: u32,
pub roomid_spacehierarchy_cache_capacity: u32,
#[serde(default = "default_cleanup_second_interval")]
pub(crate) cleanup_second_interval: u32,
pub cleanup_second_interval: u32,
#[serde(default = "default_dns_cache_entries")]
pub(crate) dns_cache_entries: u32,
pub dns_cache_entries: u32,
#[serde(default = "default_dns_min_ttl")]
pub(crate) dns_min_ttl: u64,
pub dns_min_ttl: u64,
#[serde(default = "default_dns_min_ttl_nxdomain")]
pub(crate) dns_min_ttl_nxdomain: u64,
pub dns_min_ttl_nxdomain: u64,
#[serde(default = "default_dns_attempts")]
pub(crate) dns_attempts: u16,
pub dns_attempts: u16,
#[serde(default = "default_dns_timeout")]
pub(crate) dns_timeout: u64,
pub dns_timeout: u64,
#[serde(default = "true_fn")]
pub(crate) dns_tcp_fallback: bool,
pub dns_tcp_fallback: bool,
#[serde(default = "true_fn")]
pub(crate) query_all_nameservers: bool,
pub query_all_nameservers: bool,
#[serde(default)]
pub(crate) query_over_tcp_only: bool,
pub query_over_tcp_only: bool,
#[serde(default = "default_ip_lookup_strategy")]
pub(crate) ip_lookup_strategy: u8,
pub ip_lookup_strategy: u8,
#[serde(default = "default_max_request_size")]
pub(crate) max_request_size: u32,
pub max_request_size: u32,
#[serde(default = "default_max_fetch_prev_events")]
pub(crate) max_fetch_prev_events: u16,
pub max_fetch_prev_events: u16,
#[serde(default = "default_request_conn_timeout")]
pub(crate) request_conn_timeout: u64,
pub request_conn_timeout: u64,
#[serde(default = "default_request_timeout")]
pub(crate) request_timeout: u64,
pub request_timeout: u64,
#[serde(default = "default_request_total_timeout")]
pub(crate) request_total_timeout: u64,
pub request_total_timeout: u64,
#[serde(default = "default_request_idle_timeout")]
pub(crate) request_idle_timeout: u64,
pub request_idle_timeout: u64,
#[serde(default = "default_request_idle_per_host")]
pub(crate) request_idle_per_host: u16,
pub request_idle_per_host: u16,
#[serde(default = "default_well_known_conn_timeout")]
pub(crate) well_known_conn_timeout: u64,
pub well_known_conn_timeout: u64,
#[serde(default = "default_well_known_timeout")]
pub(crate) well_known_timeout: u64,
pub well_known_timeout: u64,
#[serde(default = "default_federation_timeout")]
pub(crate) federation_timeout: u64,
pub federation_timeout: u64,
#[serde(default = "default_federation_idle_timeout")]
pub(crate) federation_idle_timeout: u64,
pub federation_idle_timeout: u64,
#[serde(default = "default_federation_idle_per_host")]
pub(crate) federation_idle_per_host: u16,
pub federation_idle_per_host: u16,
#[serde(default = "default_sender_timeout")]
pub(crate) sender_timeout: u64,
pub sender_timeout: u64,
#[serde(default = "default_sender_idle_timeout")]
pub(crate) sender_idle_timeout: u64,
pub sender_idle_timeout: u64,
#[serde(default = "default_sender_retry_backoff_limit")]
pub(crate) sender_retry_backoff_limit: u64,
pub sender_retry_backoff_limit: u64,
#[serde(default = "default_appservice_timeout")]
pub(crate) appservice_timeout: u64,
pub appservice_timeout: u64,
#[serde(default = "default_appservice_idle_timeout")]
pub(crate) appservice_idle_timeout: u64,
pub appservice_idle_timeout: u64,
#[serde(default = "default_pusher_idle_timeout")]
pub(crate) pusher_idle_timeout: u64,
pub pusher_idle_timeout: u64,
#[serde(default)]
pub(crate) allow_registration: bool,
pub allow_registration: bool,
#[serde(default)]
pub(crate) yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool,
pub(crate) registration_token: Option<String>,
pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool,
pub registration_token: Option<String>,
#[serde(default = "true_fn")]
pub(crate) allow_encryption: bool,
pub allow_encryption: bool,
#[serde(default = "true_fn")]
pub(crate) allow_federation: bool,
pub allow_federation: bool,
#[serde(default)]
pub(crate) allow_public_room_directory_over_federation: bool,
pub allow_public_room_directory_over_federation: bool,
#[serde(default)]
pub(crate) allow_public_room_directory_without_auth: bool,
pub allow_public_room_directory_without_auth: bool,
#[serde(default)]
pub(crate) lockdown_public_room_directory: bool,
pub lockdown_public_room_directory: bool,
#[serde(default)]
pub(crate) allow_device_name_federation: bool,
pub allow_device_name_federation: bool,
#[serde(default = "true_fn")]
pub(crate) allow_profile_lookup_federation_requests: bool,
pub allow_profile_lookup_federation_requests: bool,
#[serde(default = "true_fn")]
pub(crate) allow_room_creation: bool,
pub allow_room_creation: bool,
#[serde(default = "true_fn")]
pub(crate) allow_unstable_room_versions: bool,
pub allow_unstable_room_versions: bool,
#[serde(default = "default_default_room_version")]
pub(crate) default_room_version: RoomVersionId,
pub default_room_version: RoomVersionId,
#[serde(default)]
pub(crate) well_known: WellKnownConfig,
pub well_known: WellKnownConfig,
#[serde(default)]
#[cfg(feature = "perf_measurements")]
pub(crate) allow_jaeger: bool,
pub allow_jaeger: bool,
#[serde(default)]
#[cfg(feature = "perf_measurements")]
pub(crate) tracing_flame: bool,
pub tracing_flame: bool,
#[serde(default = "default_tracing_flame_filter")]
#[cfg(feature = "perf_measurements")]
pub(crate) tracing_flame_filter: String,
pub tracing_flame_filter: String,
#[serde(default = "default_tracing_flame_output_path")]
#[cfg(feature = "perf_measurements")]
pub(crate) tracing_flame_output_path: String,
pub tracing_flame_output_path: String,
#[serde(default)]
pub(crate) proxy: ProxyConfig,
pub(crate) jwt_secret: Option<String>,
pub proxy: ProxyConfig,
pub jwt_secret: Option<String>,
#[serde(default = "default_trusted_servers")]
pub(crate) trusted_servers: Vec<OwnedServerName>,
pub trusted_servers: Vec<OwnedServerName>,
#[serde(default = "true_fn")]
pub(crate) query_trusted_key_servers_first: bool,
pub query_trusted_key_servers_first: bool,
#[serde(default = "default_log")]
pub(crate) log: String,
pub log: String,
#[serde(default)]
pub(crate) turn_username: String,
pub turn_username: String,
#[serde(default)]
pub(crate) turn_password: String,
pub turn_password: String,
#[serde(default = "Vec::new")]
pub(crate) turn_uris: Vec<String>,
pub turn_uris: Vec<String>,
#[serde(default)]
pub(crate) turn_secret: String,
pub turn_secret: String,
#[serde(default = "default_turn_ttl")]
pub(crate) turn_ttl: u64,
pub turn_ttl: u64,
#[serde(default = "Vec::new")]
pub(crate) auto_join_rooms: Vec<OwnedRoomId>,
pub auto_join_rooms: Vec<OwnedRoomId>,
#[serde(default)]
pub(crate) auto_deactivate_banned_room_attempts: bool,
pub auto_deactivate_banned_room_attempts: bool,
#[serde(default = "default_rocksdb_log_level")]
pub(crate) rocksdb_log_level: String,
pub rocksdb_log_level: String,
#[serde(default)]
pub(crate) rocksdb_log_stderr: bool,
pub rocksdb_log_stderr: bool,
#[serde(default = "default_rocksdb_max_log_file_size")]
pub(crate) rocksdb_max_log_file_size: usize,
pub rocksdb_max_log_file_size: usize,
#[serde(default = "default_rocksdb_log_time_to_roll")]
pub(crate) rocksdb_log_time_to_roll: usize,
pub rocksdb_log_time_to_roll: usize,
#[serde(default)]
pub(crate) rocksdb_optimize_for_spinning_disks: bool,
pub rocksdb_optimize_for_spinning_disks: bool,
#[serde(default = "true_fn")]
pub(crate) rocksdb_direct_io: bool,
pub rocksdb_direct_io: bool,
#[serde(default = "default_rocksdb_parallelism_threads")]
pub(crate) rocksdb_parallelism_threads: usize,
pub rocksdb_parallelism_threads: usize,
#[serde(default = "default_rocksdb_max_log_files")]
pub(crate) rocksdb_max_log_files: usize,
pub rocksdb_max_log_files: usize,
#[serde(default = "default_rocksdb_compression_algo")]
pub(crate) rocksdb_compression_algo: String,
pub rocksdb_compression_algo: String,
#[serde(default = "default_rocksdb_compression_level")]
pub(crate) rocksdb_compression_level: i32,
pub rocksdb_compression_level: i32,
#[serde(default = "default_rocksdb_bottommost_compression_level")]
pub(crate) rocksdb_bottommost_compression_level: i32,
pub rocksdb_bottommost_compression_level: i32,
#[serde(default)]
pub(crate) rocksdb_bottommost_compression: bool,
pub rocksdb_bottommost_compression: bool,
#[serde(default = "default_rocksdb_recovery_mode")]
pub(crate) rocksdb_recovery_mode: u8,
pub rocksdb_recovery_mode: u8,
#[serde(default)]
pub(crate) rocksdb_repair: bool,
pub rocksdb_repair: bool,
#[serde(default)]
pub(crate) rocksdb_read_only: bool,
pub rocksdb_read_only: bool,
#[serde(default)]
pub(crate) rocksdb_periodic_cleanup: bool,
pub rocksdb_periodic_cleanup: bool,
#[serde(default)]
pub(crate) rocksdb_compaction_prio_idle: bool,
pub rocksdb_compaction_prio_idle: bool,
#[serde(default = "true_fn")]
pub(crate) rocksdb_compaction_ioprio_idle: bool,
pub rocksdb_compaction_ioprio_idle: bool,
pub(crate) emergency_password: Option<String>,
pub emergency_password: Option<String>,
#[serde(default = "default_notification_push_path")]
pub(crate) notification_push_path: String,
pub notification_push_path: String,
#[serde(default = "true_fn")]
pub(crate) allow_local_presence: bool,
pub allow_local_presence: bool,
#[serde(default = "true_fn")]
pub(crate) allow_incoming_presence: bool,
pub allow_incoming_presence: bool,
#[serde(default = "true_fn")]
pub(crate) allow_outgoing_presence: bool,
pub allow_outgoing_presence: bool,
#[serde(default = "default_presence_idle_timeout_s")]
pub(crate) presence_idle_timeout_s: u64,
pub presence_idle_timeout_s: u64,
#[serde(default = "default_presence_offline_timeout_s")]
pub(crate) presence_offline_timeout_s: u64,
pub presence_offline_timeout_s: u64,
#[serde(default = "true_fn")]
pub(crate) presence_timeout_remote_users: bool,
pub presence_timeout_remote_users: bool,
#[serde(default = "true_fn")]
pub(crate) allow_incoming_read_receipts: bool,
pub allow_incoming_read_receipts: bool,
#[serde(default = "true_fn")]
pub(crate) allow_outgoing_read_receipts: bool,
pub allow_outgoing_read_receipts: bool,
#[serde(default = "true_fn")]
pub(crate) allow_outgoing_typing: bool,
pub allow_outgoing_typing: bool,
#[serde(default = "true_fn")]
pub(crate) allow_incoming_typing: bool,
pub allow_incoming_typing: bool,
#[serde(default = "default_typing_federation_timeout_s")]
pub(crate) typing_federation_timeout_s: u64,
pub typing_federation_timeout_s: u64,
#[serde(default = "default_typing_client_timeout_min_s")]
pub(crate) typing_client_timeout_min_s: u64,
pub typing_client_timeout_min_s: u64,
#[serde(default = "default_typing_client_timeout_max_s")]
pub(crate) typing_client_timeout_max_s: u64,
pub typing_client_timeout_max_s: u64,
#[serde(default)]
pub(crate) zstd_compression: bool,
pub zstd_compression: bool,
#[serde(default)]
pub(crate) gzip_compression: bool,
pub gzip_compression: bool,
#[serde(default)]
pub(crate) brotli_compression: bool,
pub brotli_compression: bool,
#[serde(default)]
pub(crate) allow_guest_registration: bool,
pub allow_guest_registration: bool,
#[serde(default)]
pub(crate) log_guest_registrations: bool,
pub log_guest_registrations: bool,
#[serde(default)]
pub(crate) allow_guests_auto_join_rooms: bool,
pub allow_guests_auto_join_rooms: bool,
#[serde(default = "Vec::new")]
pub(crate) prevent_media_downloads_from: Vec<OwnedServerName>,
pub prevent_media_downloads_from: Vec<OwnedServerName>,
#[serde(default = "Vec::new")]
pub(crate) forbidden_remote_server_names: Vec<OwnedServerName>,
pub forbidden_remote_server_names: Vec<OwnedServerName>,
#[serde(default = "Vec::new")]
pub(crate) forbidden_remote_room_directory_server_names: Vec<OwnedServerName>,
pub forbidden_remote_room_directory_server_names: Vec<OwnedServerName>,
#[serde(default = "default_ip_range_denylist")]
pub(crate) ip_range_denylist: Vec<String>,
pub ip_range_denylist: Vec<String>,
#[serde(default = "Vec::new")]
pub(crate) url_preview_domain_contains_allowlist: Vec<String>,
pub url_preview_domain_contains_allowlist: Vec<String>,
#[serde(default = "Vec::new")]
pub(crate) url_preview_domain_explicit_allowlist: Vec<String>,
pub url_preview_domain_explicit_allowlist: Vec<String>,
#[serde(default = "Vec::new")]
pub(crate) url_preview_domain_explicit_denylist: Vec<String>,
pub url_preview_domain_explicit_denylist: Vec<String>,
#[serde(default = "Vec::new")]
pub(crate) url_preview_url_contains_allowlist: Vec<String>,
pub url_preview_url_contains_allowlist: Vec<String>,
#[serde(default = "default_url_preview_max_spider_size")]
pub(crate) url_preview_max_spider_size: usize,
pub url_preview_max_spider_size: usize,
#[serde(default)]
pub(crate) url_preview_check_root_domain: bool,
pub url_preview_check_root_domain: bool,
#[serde(default = "RegexSet::empty")]
#[serde(with = "serde_regex")]
pub(crate) forbidden_alias_names: RegexSet,
pub forbidden_alias_names: RegexSet,
#[serde(default = "RegexSet::empty")]
#[serde(with = "serde_regex")]
pub(crate) forbidden_usernames: RegexSet,
pub forbidden_usernames: RegexSet,
#[serde(default = "true_fn")]
pub(crate) startup_netburst: bool,
pub startup_netburst: bool,
#[serde(default = "default_startup_netburst_keep")]
pub(crate) startup_netburst_keep: i64,
pub startup_netburst_keep: i64,
#[serde(default)]
pub(crate) block_non_admin_invites: bool,
pub block_non_admin_invites: bool,
#[serde(default)]
pub(crate) sentry: bool,
pub sentry: bool,
#[serde(default = "default_sentry_endpoint")]
pub(crate) sentry_endpoint: Option<Url>,
pub sentry_endpoint: Option<Url>,
#[serde(default)]
pub(crate) sentry_send_server_name: bool,
pub sentry_send_server_name: bool,
#[serde(default = "default_sentry_traces_sample_rate")]
pub(crate) sentry_traces_sample_rate: f32,
pub sentry_traces_sample_rate: f32,
#[serde(flatten)]
#[allow(clippy::zero_sized_map_values)] // this is a catchall, the map shouldn't be zero at runtime
@ -349,24 +350,24 @@ pub(crate) struct Config {
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct TlsConfig {
pub(crate) certs: String,
pub(crate) key: String,
pub struct TlsConfig {
pub certs: String,
pub key: String,
#[serde(default)]
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
/// Only works / does something if the `axum_dual_protocol` feature flag was
/// built
pub(crate) dual_protocol: bool,
pub dual_protocol: bool,
}
#[derive(Clone, Debug, Deserialize, Default)]
pub(crate) struct WellKnownConfig {
pub(crate) client: Option<Url>,
pub(crate) server: Option<OwnedServerName>,
pub(crate) support_page: Option<Url>,
pub(crate) support_role: Option<ContactRole>,
pub(crate) support_email: Option<String>,
pub(crate) support_mxid: Option<OwnedUserId>,
pub struct WellKnownConfig {
pub client: Option<Url>,
pub server: Option<OwnedServerName>,
pub support_page: Option<Url>,
pub support_role: Option<ContactRole>,
pub support_email: Option<String>,
pub support_mxid: Option<OwnedUserId>,
}
const DEPRECATED_KEYS: &[&str] = &[
@ -382,7 +383,7 @@ const DEPRECATED_KEYS: &[&str] = &[
impl Config {
/// Initialize config
pub(crate) fn new(path: Option<PathBuf>) -> Result<Self, Error> {
pub fn new(path: Option<PathBuf>) -> Result<Self, Error> {
let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") {
Figment::new()
.merge(Toml::file(config_file_env).nested())
@ -469,7 +470,7 @@ impl Config {
}
#[must_use]
pub(crate) fn get_bind_addrs(&self) -> Vec<SocketAddr> {
pub fn get_bind_addrs(&self) -> Vec<SocketAddr> {
match &self.port.ports {
Left(port) => {
// Left is only 1 value, so make a vec with 1 value only
@ -489,7 +490,7 @@ impl Config {
}
}
pub(crate) fn check(&self) -> Result<(), Error> { check(self) }
pub fn check(&self) -> Result<(), Error> { check(self) }
}
impl fmt::Display for Config {
@ -1027,7 +1028,8 @@ fn default_rocksdb_compression_level() -> i32 { 32767 }
fn default_rocksdb_bottommost_compression_level() -> i32 { 32767 }
// I know, it's a great name
pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
#[must_use]
pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
fn default_ip_range_denylist() -> Vec<String> {
vec![

View File

@ -30,7 +30,7 @@ use crate::Result;
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
#[derive(Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum ProxyConfig {
pub enum ProxyConfig {
#[default]
None,
Global {
@ -40,7 +40,7 @@ pub(crate) enum ProxyConfig {
ByDomain(Vec<PartialProxyConfig>),
}
impl ProxyConfig {
pub(crate) fn to_proxy(&self) -> Result<Option<Proxy>> {
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
Ok(match self.clone() {
ProxyConfig::None => None,
ProxyConfig::Global {
@ -55,7 +55,7 @@ impl ProxyConfig {
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct PartialProxyConfig {
pub struct PartialProxyConfig {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
url: Url,
#[serde(default)]
@ -64,7 +64,8 @@ pub(crate) struct PartialProxyConfig {
exclude: Vec<WildCardedDomain>,
}
impl PartialProxyConfig {
pub(crate) fn for_url(&self, url: &Url) -> Option<&Url> {
#[must_use]
pub fn for_url(&self, url: &Url) -> Option<&Url> {
let domain = url.domain()?;
let mut included_because = None; // most specific reason it was included
let mut excluded_because = None; // most specific reason it was excluded

View File

@ -1,3 +1,7 @@
#![allow(dead_code)] // this is a developer's toolbox
use std::{panic, panic::PanicInfo};
/// Log event at given level in debug-mode (when debug-assertions are enabled).
/// In release-mode it becomes DEBUG level, and possibly subject to elision.
///
@ -43,3 +47,32 @@ macro_rules! debug_info {
$crate::debug_event!(tracing::Level::INFO, $($x)+ );
}
}
pub fn set_panic_trap() {
let next = panic::take_hook();
panic::set_hook(Box::new(move |info| {
panic_handler(info, &next);
}));
}
#[inline(always)]
fn panic_handler(info: &PanicInfo<'_>, next: &dyn Fn(&PanicInfo<'_>)) {
trap();
next(info);
}
#[inline(always)]
#[allow(unexpected_cfgs)]
pub fn trap() {
#[cfg(core_intrinsics)]
//SAFETY: embeds llvm intrinsic for hardware breakpoint
unsafe {
std::intrinsics::breakpoint();
}
#[cfg(all(not(core_intrinsics), target_arch = "x86_64"))]
//SAFETY: embeds instruction for hardware breakpoint
unsafe {
std::arch::asm!("int3");
}
}

View File

@ -1,11 +1,17 @@
use std::{convert::Infallible, fmt};
use axum::response::{IntoResponse, Response};
use bytes::BytesMut;
use http::StatusCode;
use http_body_util::Full;
use ruma::{
api::client::{
api::{
client::{
error::{Error as RumaError, ErrorBody, ErrorKind},
uiaa::{UiaaInfo, UiaaResponse},
},
OutgoingResponse,
},
OwnedServerName,
};
use thiserror::Error;
@ -15,12 +21,10 @@ use ErrorKind::{
TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, WrongRoomKeysVersion,
};
use crate::RumaResponse;
pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Error)]
pub(crate) enum Error {
pub enum Error {
#[cfg(feature = "sqlite")]
#[error("There was a problem with the connection to the sqlite database: {source}")]
Sqlite {
@ -83,17 +87,70 @@ pub(crate) enum Error {
}
impl Error {
pub(crate) fn bad_database(message: &'static str) -> Self {
pub fn bad_database(message: &'static str) -> Self {
error!("BadDatabase: {}", message);
Self::BadDatabase(message)
}
pub(crate) fn bad_config(message: &str) -> Self {
pub fn bad_config(message: &str) -> Self {
error!("BadConfig: {}", message);
Self::BadConfig(message.to_owned())
}
pub(crate) fn to_response(&self) -> RumaResponse<UiaaResponse> {
/// Returns the Matrix error code / error kind
pub fn error_code(&self) -> ErrorKind {
if let Self::Federation(_, error) = self {
return error.error_kind().unwrap_or_else(|| &Unknown).clone();
}
match self {
Self::BadRequest(kind, _) => kind.clone(),
_ => Unknown,
}
}
/// Sanitizes public-facing errors that can leak sensitive information.
pub fn sanitized_error(&self) -> String {
let db_error = String::from("Database or I/O error occurred.");
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite {
..
} => db_error,
#[cfg(feature = "rocksdb")]
Self::RocksDb {
..
} => db_error,
Self::Io {
..
} => db_error,
_ => self.to_string(),
}
}
}
impl From<Infallible> for Error {
fn from(i: Infallible) -> Self { match i {} }
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) }
}
#[derive(Clone)]
pub struct RumaResponse<T>(pub T);
impl<T> From<T> for RumaResponse<T> {
fn from(t: T) -> Self { Self(t) }
}
impl From<Error> for RumaResponse<UiaaResponse> {
fn from(t: Error) -> Self { t.to_response() }
}
impl Error {
pub fn to_response(&self) -> RumaResponse<UiaaResponse> {
if let Self::Uiaa(uiaainfo) = self {
return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone()));
}
@ -147,48 +204,17 @@ impl Error {
status_code,
}))
}
}
/// Returns the Matrix error code / error kind
pub(crate) fn error_code(&self) -> ErrorKind {
if let Self::Federation(_, error) = self {
return error.error_kind().unwrap_or_else(|| &Unknown).clone();
}
impl ::axum::response::IntoResponse for Error {
fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() }
}
match self {
Self::BadRequest(kind, _) => kind.clone(),
_ => Unknown,
}
}
/// Sanitizes public-facing errors that can leak sensitive information.
pub(crate) fn sanitized_error(&self) -> String {
let db_error = String::from("Database or I/O error occurred.");
match self {
#[cfg(feature = "sqlite")]
Self::Sqlite {
..
} => db_error,
#[cfg(feature = "rocksdb")]
Self::RocksDb {
..
} => db_error,
Self::Io {
..
} => db_error,
_ => self.to_string(),
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
fn into_response(self) -> Response {
match self.0.try_into_http_response::<BytesMut>() {
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
}
impl From<Infallible> for Error {
fn from(i: Infallible) -> Self { match i {} }
}
impl axum::response::IntoResponse for Error {
fn into_response(self) -> axum::response::Response { self.to_response().into_response() }
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) }
}

79
src/core/log.rs Normal file
View File

@ -0,0 +1,79 @@
use std::sync::Arc;
use tracing_subscriber::{reload, EnvFilter};
/// We need to store a reload::Handle value, but can't name it's type explicitly
/// because the S type parameter depends on the subscriber's previous layers. In
/// our case, this includes unnameable 'impl Trait' types.
///
/// This is fixed[1] in the unreleased tracing-subscriber from the master
/// branch, which removes the S parameter. Unfortunately can't use it without
/// pulling in a version of tracing that's incompatible with the rest of our
/// deps.
///
/// To work around this, we define an trait without the S paramter that forwards
/// to the reload::Handle::reload method, and then store the handle as a trait
/// object.
///
/// [1]: <https://github.com/tokio-rs/tracing/pull/1035/commits/8a87ea52425098d3ef8f56d92358c2f6c144a28f>
pub trait ReloadHandle<L> {
fn reload(&self, new_value: L) -> Result<(), reload::Error>;
}
impl<L, S> ReloadHandle<L> for reload::Handle<L, S> {
fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) }
}
struct LogLevelReloadHandlesInner {
handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>,
}
/// Wrapper to allow reloading the filter on several several
/// [`tracing_subscriber::reload::Handle`]s at once, with the same value.
#[derive(Clone)]
pub struct LogLevelReloadHandles {
inner: Arc<LogLevelReloadHandlesInner>,
}
impl LogLevelReloadHandles {
#[must_use]
pub fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> LogLevelReloadHandles {
LogLevelReloadHandles {
inner: Arc::new(LogLevelReloadHandlesInner {
handles,
}),
}
}
pub fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> {
for handle in &self.inner.handles {
handle.reload(new_value.clone())?;
}
Ok(())
}
}
#[macro_export]
macro_rules! error {
( $($x:tt)+ ) => { tracing::error!( $($x)+ ); }
}
#[macro_export]
macro_rules! warn {
( $($x:tt)+ ) => { tracing::warn!( $($x)+ ); }
}
#[macro_export]
macro_rules! info {
( $($x:tt)+ ) => { tracing::info!( $($x)+ ); }
}
#[macro_export]
macro_rules! debug {
( $($x:tt)+ ) => { tracing::debug!( $($x)+ ); }
}
#[macro_export]
macro_rules! trace {
( $($x:tt)+ ) => { tracing::trace!( $($x)+ ); }
}

27
src/core/mod.rs Normal file
View File

@ -0,0 +1,27 @@
pub mod alloc;
pub mod config;
pub mod debug;
pub mod error;
pub mod log;
pub mod mods;
pub mod pducount;
pub mod server;
pub mod utils;
pub use config::Config;
pub use error::{Error, Result, RumaResponse};
pub use pducount::PduCount;
pub use server::Server;
pub use utils::conduwuit_version;
#[cfg(not(feature = "mods"))]
mod mods {
#[macro_export]
macro_rules! mod_ctor {
() => {};
}
#[macro_export]
macro_rules! mod_dtor {
() => {};
}
}

28
src/core/mods/canary.rs Normal file
View File

@ -0,0 +1,28 @@
use std::sync::atomic::{AtomicI32, Ordering};
const ORDERING: Ordering = Ordering::Relaxed;
static STATIC_DTORS: AtomicI32 = AtomicI32::new(0);
/// Called by Module::unload() to indicate module is about to be unloaded and
/// static destruction is intended. This will allow verifying it actually took
/// place.
pub(crate) fn prepare() {
let count = STATIC_DTORS.fetch_sub(1, ORDERING);
debug_assert!(count <= 0, "STATIC_DTORS should not be greater than zero.");
}
/// Called by static destructor of a module. This call should only be found
/// inside a mod_fini! macro. Do not call from anywhere else.
#[inline(always)]
pub fn report() { let _count = STATIC_DTORS.fetch_add(1, ORDERING); }
/// Called by Module::unload() (see check()) with action in case a check()
/// failed. This can allow a stuck module to be noted while allowing for other
/// independent modules to be diagnosed.
pub(crate) fn check_and_reset() -> bool { STATIC_DTORS.swap(0, ORDERING) == 0 }
/// Called by Module::unload() after unload to verify static destruction took
/// place. A call to prepare() must be made prior to Module::unload() and making
/// this call.
#[allow(dead_code)]
pub(crate) fn check() -> bool { STATIC_DTORS.load(ORDERING) == 0 }

44
src/core/mods/macros.rs Normal file
View File

@ -0,0 +1,44 @@
#[macro_export]
macro_rules! mod_ctor {
( $($body:block)? ) => {
$crate::mod_init! {{
$crate::debug_info!("Module loaded");
$($body)?
}}
}
}
#[macro_export]
macro_rules! mod_dtor {
( $($body:block)? ) => {
$crate::mod_fini! {{
$crate::debug_info!("Module unloading");
$($body)?
$crate::mods::canary::report();
}}
}
}
#[macro_export]
macro_rules! mod_init {
($body:block) => {
#[used]
#[cfg_attr(target_family = "unix", link_section = ".init_array")]
static MOD_INIT: extern "C" fn() = { _mod_init };
#[cfg_attr(target_family = "unix", link_section = ".text.startup")]
extern "C" fn _mod_init() -> () $body
};
}
#[macro_export]
macro_rules! mod_fini {
($body:block) => {
#[used]
#[cfg_attr(target_family = "unix", link_section = ".fini_array")]
static MOD_FINI: extern "C" fn() = { _mod_fini };
#[cfg_attr(target_family = "unix", link_section = ".text.startup")]
extern "C" fn _mod_fini() -> () $body
};
}

11
src/core/mods/mod.rs Normal file
View File

@ -0,0 +1,11 @@
#![cfg(feature = "mods")]
pub(crate) use libloading::os::unix::{Library, Symbol};
pub mod canary;
pub mod macros;
pub mod module;
pub mod new;
pub mod path;
pub use module::Module;

74
src/core/mods/module.rs Normal file
View File

@ -0,0 +1,74 @@
use std::{
ffi::{CString, OsString},
time::SystemTime,
};
use super::{canary, new, path, Library, Symbol};
use crate::{error, Result};
pub struct Module {
handle: Option<Library>,
loaded: SystemTime,
path: OsString,
}
impl Module {
pub fn from_name(name: &str) -> Result<Self> { Self::from_path(path::from_name(name)?) }
pub fn from_path(path: OsString) -> Result<Self> {
Ok(Self {
handle: Some(new::from_path(&path)?),
loaded: SystemTime::now(),
path,
})
}
pub fn unload(&mut self) {
canary::prepare();
self.close();
if !canary::check_and_reset() {
let name = self.name().expect("Module is named");
error!("Module {name:?} is stuck and failed to unload.");
}
}
pub(crate) fn close(&mut self) {
if let Some(handle) = self.handle.take() {
handle.close().expect("Module handle closed");
}
}
pub fn get<Prototype>(&self, name: &str) -> Result<Symbol<Prototype>> {
let cname = CString::new(name.to_owned()).expect("terminated string from provided name");
let handle = self
.handle
.as_ref()
.expect("backing library loaded by this instance");
// SAFETY: Calls dlsym(3) on unix platforms. This might not have to be unsafe
// if wrapped in libloading with_dlerror().
let sym = unsafe { handle.get::<Prototype>(cname.as_bytes()) };
let sym = sym.expect("symbol found; binding successful");
Ok(sym)
}
pub fn deleted(&self) -> Result<bool> {
let mtime = path::mtime(self.path())?;
let res = mtime > self.loaded;
Ok(res)
}
pub fn name(&self) -> Result<String> { path::to_name(self.path()) }
#[must_use]
pub fn path(&self) -> &OsString { &self.path }
}
impl Drop for Module {
fn drop(&mut self) {
if self.handle.is_some() {
self.unload();
}
}
}

23
src/core/mods/new.rs Normal file
View File

@ -0,0 +1,23 @@
use std::ffi::OsStr;
use super::{path, Library};
use crate::{Error, Result};
const OPEN_FLAGS: i32 = libloading::os::unix::RTLD_LAZY | libloading::os::unix::RTLD_GLOBAL;
pub fn from_name(name: &str) -> Result<Library> {
let path = path::from_name(name)?;
from_path(&path)
}
pub fn from_path(path: &OsStr) -> Result<Library> {
//SAFETY: Calls dlopen(3) on unix platforms. This might not have to be unsafe
// if wrapped in with_dlerror.
let lib = unsafe { Library::open(Some(path), OPEN_FLAGS) };
if let Err(e) = lib {
let name = path::to_name(path)?;
return Err(Error::Err(format!("Loading module {name:?} failed: {e}")));
}
Ok(lib.expect("module loaded"))
}

40
src/core/mods/path.rs Normal file
View File

@ -0,0 +1,40 @@
use std::{
env::current_exe,
ffi::{OsStr, OsString},
path::{Path, PathBuf},
time::SystemTime,
};
use libloading::library_filename;
use crate::Result;
pub fn from_name(name: &str) -> Result<OsString> {
let root = PathBuf::new();
let exe_path = current_exe()?;
let exe_dir = exe_path.parent().unwrap_or(&root);
let mut mod_path = exe_dir.to_path_buf();
let mod_file = library_filename(name);
mod_path.push(mod_file);
Ok(mod_path.into_os_string())
}
pub fn to_name(path: &OsStr) -> Result<String> {
let path = Path::new(path);
let name = path
.file_stem()
.expect("path file stem")
.to_str()
.expect("name string");
let name = name.strip_prefix("lib").unwrap_or(name).to_owned();
Ok(name)
}
pub fn mtime(path: &OsStr) -> Result<SystemTime> {
let meta = std::fs::metadata(path)?;
let mtime = meta.modified()?;
Ok(mtime)
}

51
src/core/pducount.rs Normal file
View File

@ -0,0 +1,51 @@
use std::cmp::Ordering;
use ruma::api::client::error::ErrorKind;
use crate::{Error, Result};
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub enum PduCount {
Backfilled(u64),
Normal(u64),
}
impl PduCount {
#[must_use]
pub fn min() -> Self { Self::Backfilled(u64::MAX) }
#[must_use]
pub fn max() -> Self { Self::Normal(u64::MAX) }
pub fn try_from_string(token: &str) -> Result<Self> {
if let Some(stripped_token) = token.strip_prefix('-') {
stripped_token.parse().map(PduCount::Backfilled)
} else {
token.parse().map(PduCount::Normal)
}
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token."))
}
#[must_use]
pub fn stringify(&self) -> String {
match self {
PduCount::Backfilled(x) => format!("-{x}"),
PduCount::Normal(x) => x.to_string(),
}
}
}
impl PartialOrd for PduCount {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}
impl Ord for PduCount {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o),
(PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s),
(PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater,
(PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less,
}
}
}

72
src/core/server.rs Normal file
View File

@ -0,0 +1,72 @@
use std::{
sync::{
atomic::{AtomicBool, AtomicU32},
Mutex,
},
time::SystemTime,
};
use tokio::runtime;
use crate::{config::Config, log::LogLevelReloadHandles};
/// Server runtime state; public portion
pub struct Server {
/// Server-wide configuration instance
pub config: Config,
/// Timestamp server was started; used for uptime.
pub started: SystemTime,
/// Reload/shutdown signal channel. Called from the signal handler or admin
/// command to initiate shutdown.
pub shutdown: Mutex<Option<axum_server::Handle>>,
/// Reload/shutdown desired indicator; when false, shutdown is desired. This
/// is an observable used on shutdown and modifying is not recommended.
pub reload: AtomicBool,
/// Reload/shutdown pending indicator; server is shutting down. This is an
/// observable used on shutdown and should not be modified.
pub interrupt: AtomicBool,
/// Handle to the runtime
pub runtime: Option<runtime::Handle>,
/// Log level reload handles.
pub tracing_reload_handle: LogLevelReloadHandles,
/// TODO: move stats
pub requests_spawn_active: AtomicU32,
pub requests_spawn_finished: AtomicU32,
pub requests_handle_active: AtomicU32,
pub requests_handle_finished: AtomicU32,
pub requests_panic: AtomicU32,
}
impl Server {
#[must_use]
pub fn new(config: Config, runtime: Option<runtime::Handle>, tracing_reload_handle: LogLevelReloadHandles) -> Self {
Self {
config,
started: SystemTime::now(),
shutdown: Mutex::new(None),
reload: AtomicBool::new(false),
interrupt: AtomicBool::new(false),
runtime,
tracing_reload_handle,
requests_spawn_active: AtomicU32::new(0),
requests_spawn_finished: AtomicU32::new(0),
requests_handle_active: AtomicU32::new(0),
requests_handle_finished: AtomicU32::new(0),
requests_panic: AtomicU32::new(0),
}
}
#[inline]
pub fn runtime(&self) -> &runtime::Handle {
self.runtime
.as_ref()
.expect("runtime handle available in Server")
}
}

View File

@ -2,19 +2,19 @@
use std::path::PathBuf;
use clap::Parser;
pub use clap::Parser;
use super::conduwuit_version;
/// Commandline arguments
#[derive(Parser, Debug)]
#[clap(version = conduwuit_version(), about, long_about = None)]
pub(crate) struct Args {
pub struct Args {
#[arg(short, long)]
/// Optional argument to the path of a conduwuit config TOML file
pub(crate) config: Option<PathBuf>,
pub config: Option<PathBuf>,
}
/// Parse commandline arguments into structured data
#[must_use]
pub(crate) fn parse() -> Args { Args::parse() }
pub fn parse() -> Args { Args::parse() }

View File

@ -17,8 +17,9 @@ const IMAGE_SVG_XML: &str = "image/svg+xml";
///
/// TODO: add a "strict" function for comparing the Content-Type with what we
/// detected: `file_type.mime_type() != content_type`
#[must_use]
#[tracing::instrument(skip(buf))]
pub(crate) fn content_disposition_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
pub fn content_disposition_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
let Some(file_type) = infer::get(buf) else {
return ATTACHMENT;
};
@ -41,8 +42,9 @@ pub(crate) fn content_disposition_type(buf: &[u8], content_type: &Option<String>
///
/// SVG is special-cased due to the MIME type being classified as `text/xml` but
/// browsers need `image/svg+xml`
#[must_use]
#[tracing::instrument(skip(buf))]
pub(crate) fn make_content_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
pub fn make_content_type(buf: &[u8], content_type: &Option<String>) -> &'static str {
let Some(file_type) = infer::get(buf) else {
debug_info!("Failed to infer the file's contents");
return APPLICATION_OCTET_STREAM;
@ -62,7 +64,7 @@ pub(crate) fn make_content_type(buf: &[u8], content_type: &Option<String>) -> &'
/// sanitises the file name for the Content-Disposition using
/// `sanitize_filename` crate
#[tracing::instrument]
pub(crate) fn sanitise_filename(filename: String) -> String {
pub fn sanitise_filename(filename: String) -> String {
let options = sanitize_filename::Options {
truncate: false,
..Default::default()
@ -79,7 +81,7 @@ pub(crate) fn sanitise_filename(filename: String) -> String {
///
/// else: `Content-Disposition: attachment/inline`
#[tracing::instrument(skip(file))]
pub(crate) fn make_content_disposition(
pub fn make_content_disposition(
file: &[u8], content_type: &Option<String>, content_disposition: Option<String>,
) -> String {
let filename = content_disposition.map_or_else(String::new, |content_disposition| {

22
src/core/utils/defer.rs Normal file
View File

@ -0,0 +1,22 @@
#[macro_export]
macro_rules! defer {
($body:block) => {
struct _Defer_<F>
where
F: FnMut(),
{
closure: F,
}
impl<F> Drop for _Defer_<F>
where
F: FnMut(),
{
fn drop(&mut self) { (self.closure)(); }
}
let _defer_ = _Defer_ {
closure: || $body,
};
};
}

View File

@ -1,10 +1,3 @@
pub(crate) mod clap;
pub(crate) mod content_disposition;
pub(crate) mod debug;
pub(crate) mod error;
pub(crate) mod server_name;
pub(crate) mod user_id;
use std::{
cmp,
cmp::Ordering,
@ -13,24 +6,29 @@ use std::{
time::{SystemTime, UNIX_EPOCH},
};
use argon2::{password_hash::SaltString, PasswordHasher};
use rand::prelude::*;
use ring::digest;
use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId};
use tracing::debug;
use crate::{services, Error, Result};
use crate::{Error, Result};
pub(crate) fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) }
pub mod clap;
pub mod content_disposition;
pub mod defer;
pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) }
#[must_use]
#[allow(clippy::as_conversions)]
pub(crate) fn millis_since_unix_epoch() -> u64 {
pub fn millis_since_unix_epoch() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time is valid")
.as_millis() as u64
}
pub(crate) fn increment(old: Option<&[u8]>) -> Vec<u8> {
pub fn increment(old: Option<&[u8]>) -> Vec<u8> {
let number = match old.map(TryInto::try_into) {
Some(Ok(bytes)) => {
let number = u64::from_be_bytes(bytes);
@ -42,7 +40,8 @@ pub(crate) fn increment(old: Option<&[u8]>) -> Vec<u8> {
number.to_be_bytes().to_vec()
}
pub(crate) fn generate_keypair() -> Vec<u8> {
#[must_use]
pub fn generate_keypair() -> Vec<u8> {
let mut value = random_string(8).as_bytes().to_vec();
value.push(0xFF);
value.extend_from_slice(
@ -52,25 +51,25 @@ pub(crate) fn generate_keypair() -> Vec<u8> {
}
/// Parses the bytes into an u64.
pub(crate) fn u64_from_bytes(bytes: &[u8]) -> Result<u64, std::array::TryFromSliceError> {
pub fn u64_from_bytes(bytes: &[u8]) -> Result<u64, std::array::TryFromSliceError> {
let array: [u8; 8] = bytes.try_into()?;
Ok(u64::from_be_bytes(array))
}
/// Parses the bytes into a string.
pub(crate) fn string_from_bytes(bytes: &[u8]) -> Result<String, std::string::FromUtf8Error> {
pub fn string_from_bytes(bytes: &[u8]) -> Result<String, std::string::FromUtf8Error> {
String::from_utf8(bytes.to_vec())
}
/// Parses a `OwnedUserId` from bytes.
pub(crate) fn user_id_from_bytes(bytes: &[u8]) -> Result<OwnedUserId> {
pub fn user_id_from_bytes(bytes: &[u8]) -> Result<OwnedUserId> {
OwnedUserId::try_from(
string_from_bytes(bytes).map_err(|_| Error::bad_database("Failed to parse string from bytes"))?,
)
.map_err(|_| Error::bad_database("Failed to parse user id from bytes"))
}
pub(crate) fn random_string(length: usize) -> String {
pub fn random_string(length: usize) -> String {
thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(length)
@ -78,25 +77,16 @@ pub(crate) fn random_string(length: usize) -> String {
.collect()
}
/// Calculate a new hash for the given password
pub(crate) fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(thread_rng());
services()
.globals
.argon
.hash_password(password.as_bytes(), &salt)
.map(|it| it.to_string())
}
#[tracing::instrument(skip(keys))]
pub(crate) fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
// We only hash the pdu's event ids, not the whole pdu
let bytes = keys.join(&0xFF);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().to_owned()
}
pub(crate) fn common_elements(
#[allow(clippy::impl_trait_in_params)]
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering,
) -> Option<impl Iterator<Item = Vec<u8>>> {
let first_iterator = iterators.next()?;
@ -123,7 +113,7 @@ pub(crate) fn common_elements(
/// `CanonicalJsonObject`.
///
/// `value` must serialize to an `serde_json::Value::Object`.
pub(crate) fn to_canonical_object<T: serde::Serialize>(value: T) -> Result<CanonicalJsonObject, CanonicalJsonError> {
pub fn to_canonical_object<T: serde::Serialize>(value: T) -> Result<CanonicalJsonObject, CanonicalJsonError> {
use serde::ser::Error;
match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? {
@ -132,7 +122,7 @@ pub(crate) fn to_canonical_object<T: serde::Serialize>(value: T) -> Result<Canon
}
}
pub(crate) fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr<Err = E>, E: fmt::Display>(
pub fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr<Err = E>, E: fmt::Display>(
deserializer: D,
) -> Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
@ -158,7 +148,7 @@ pub(crate) fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: From
/// Wrapper struct which will emit the HTML-escaped version of the contained
/// string when passed to a format string.
pub(crate) struct HtmlEscape<'a>(pub(crate) &'a str);
pub struct HtmlEscape<'a>(pub &'a str);
impl fmt::Display for HtmlEscape<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@ -196,7 +186,8 @@ impl fmt::Display for HtmlEscape<'_> {
/// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string
/// to include it in parenthesis after the SemVer version. A common value are
/// git commit hashes.
pub(crate) fn conduwuit_version() -> String {
#[must_use]
pub fn conduwuit_version() -> String {
match option_env!("CONDUWUIT_VERSION_EXTRA") {
Some(extra) => {
if extra.is_empty() {
@ -222,7 +213,7 @@ pub(crate) fn conduwuit_version() -> String {
/// Any further elements are replaced by an ellipsis.
///
/// See also [`debug_slice_truncated()`],
pub(crate) struct TruncatedDebugSlice<'a, T> {
pub struct TruncatedDebugSlice<'a, T> {
inner: &'a [T],
max_len: usize,
}
@ -243,11 +234,12 @@ impl<T: fmt::Debug> fmt::Debug for TruncatedDebugSlice<'_, T> {
/// See [`TruncatedDebugSlice`]. Useful for `#[instrument]`:
///
/// ```
/// #[tracing::instrument(fields(
/// foos = debug_slice_truncated(foos, N)
/// ))]
/// use conduit_core::utils::debug_slice_truncated;
///
/// #[tracing::instrument(fields(foos = debug_slice_truncated(foos, 42)))]
/// fn bar(foos: &[&str]);
/// ```
pub(crate) fn debug_slice_truncated<T: fmt::Debug>(
pub fn debug_slice_truncated<T: fmt::Debug>(
slice: &[T], max_len: usize,
) -> tracing::field::DebugValue<TruncatedDebugSlice<'_, T>> {
tracing::field::debug(TruncatedDebugSlice {
@ -255,3 +247,24 @@ pub(crate) fn debug_slice_truncated<T: fmt::Debug>(
max_len,
})
}
/// This is needed for opening lots of file descriptors, which tends to
/// happen more often when using RocksDB and making lots of federation
/// connections at startup. The soft limit is usually 1024, and the hard
/// limit is usually 512000; I've personally seen it hit >2000.
///
/// * <https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6>
/// * <https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741>
#[cfg(unix)]
pub fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE};
let (soft_limit, hard_limit) = getrlimit(NOFILE)?;
if soft_limit < hard_limit {
setrlimit(NOFILE, hard_limit, hard_limit)?;
assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit");
debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",);
}
Ok(())
}

81
src/database/Cargo.toml Normal file
View File

@ -0,0 +1,81 @@
[package]
name = "conduit_database"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"rocksdb",
"io_uring",
"jemalloc",
"zstd_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sqlite = [
"dep:rusqlite",
"dep:parking_lot",
"dep:thread_local",
]
rocksdb = [
"dep:rust-rocksdb",
]
jemalloc = [
"dep:tikv-jemalloc-sys",
"dep:tikv-jemalloc-ctl",
"dep:tikv-jemallocator",
"rust-rocksdb/jemalloc",
]
jemalloc_prof = [
"tikv-jemalloc-sys/profiling",
]
io_uring = [
"rust-rocksdb/io-uring",
]
zstd_compression = [
"rust-rocksdb/zstd",
]
[dependencies]
chrono.workspace = true
conduit-core.workspace = true
futures-util.workspace = true
log.workspace = true
lru-cache.workspace = true
num_cpus.workspace = true
parking_lot.optional = true
parking_lot.workspace = true
ruma.workspace = true
rusqlite.optional = true
rusqlite.workspace = true
rust-rocksdb.optional = true
rust-rocksdb.workspace = true
thread_local.optional = true
thread_local.workspace = true
tikv-jemallocator.optional = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl.optional = true
tikv-jemalloc-ctl.workspace = true
tikv-jemalloc-sys.optional = true
tikv-jemalloc-sys.workspace = true
tokio.workspace = true
tracing.workspace = true
zstd.optional = true
zstd.workspace = true
[lints]
workspace = true

View File

@ -2,14 +2,14 @@ use std::sync::Arc;
use super::KeyValueDatabaseEngine;
pub(crate) struct Cork {
pub struct Cork {
db: Arc<dyn KeyValueDatabaseEngine>,
flush: bool,
sync: bool,
}
impl Cork {
pub(crate) fn new(db: &Arc<dyn KeyValueDatabaseEngine>, flush: bool, sync: bool) -> Self {
pub fn new(db: &Arc<dyn KeyValueDatabaseEngine>, flush: bool, sync: bool) -> Self {
db.cork().unwrap();
Cork {
db: db.clone(),

320
src/database/kvdatabase.rs Normal file
View File

@ -0,0 +1,320 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
path::Path,
sync::{Arc, Mutex, RwLock},
};
use conduit::{Config, Error, PduCount, Result, Server};
use lru_cache::LruCache;
use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId};
use tracing::debug;
use crate::{KeyValueDatabaseEngine, KvTree};
pub struct KeyValueDatabase {
pub db: Arc<dyn KeyValueDatabaseEngine>,
//pub globals: globals::Globals,
pub global: Arc<dyn KvTree>,
pub server_signingkeys: Arc<dyn KvTree>,
pub roomid_inviteviaservers: Arc<dyn KvTree>,
//pub users: users::Users,
pub userid_password: Arc<dyn KvTree>,
pub userid_displayname: Arc<dyn KvTree>,
pub userid_avatarurl: Arc<dyn KvTree>,
pub userid_blurhash: Arc<dyn KvTree>,
pub userdeviceid_token: Arc<dyn KvTree>,
pub userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub token_userdeviceid: Arc<dyn KvTree>,
pub onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub userid_masterkeyid: Arc<dyn KvTree>,
pub userid_selfsigningkeyid: Arc<dyn KvTree>,
pub userid_usersigningkeyid: Arc<dyn KvTree>,
pub userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub userid_presenceid: Arc<dyn KvTree>, // UserId => Count
pub presenceid_presence: Arc<dyn KvTree>, // Count + UserId => Presence
//pub uiaa: uiaa::Uiaa,
pub userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub userdevicesessionid_uiaarequest: RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
//pub edus: RoomEdus,
pub readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
//pub rooms: rooms::Rooms,
pub pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub eventid_pduid: Arc<dyn KvTree>,
pub roomid_pduleaves: Arc<dyn KvTree>,
pub alias_roomid: Arc<dyn KvTree>,
pub aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub publicroomids: Arc<dyn KvTree>,
pub threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub userroomid_joined: Arc<dyn KvTree>,
pub roomuserid_joined: Arc<dyn KvTree>,
pub roomid_joinedcount: Arc<dyn KvTree>,
pub roomid_invitedcount: Arc<dyn KvTree>,
pub roomuseroncejoinedids: Arc<dyn KvTree>,
pub userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub userroomid_leftstate: Arc<dyn KvTree>,
pub roomuserid_leftcount: Arc<dyn KvTree>,
pub disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub bannedroomids: Arc<dyn KvTree>, // Rooms where local users are not allowed to join
pub lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
pub roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64
/// Remember the current state hash of a room.
pub roomid_shortstatehash: Arc<dyn KvTree>,
pub roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub shorteventid_shortstatehash: Arc<dyn KvTree>,
pub statekey_shortstatekey: Arc<dyn KvTree>, /* StateKey = EventType + StateKey, ShortStateKey =
* Count */
pub shortstatekey_statekey: Arc<dyn KvTree>,
pub roomid_shortroomid: Arc<dyn KvTree>,
pub shorteventid_eventid: Arc<dyn KvTree>,
pub eventid_shorteventid: Arc<dyn KvTree>,
pub statehash_shortstatehash: Arc<dyn KvTree>,
pub shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
* (shortstatekey+shorteventid++) + 0_u64 +
* (shortstatekey+shorteventid--) */
pub shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub eventid_outlierpdu: Arc<dyn KvTree>,
pub softfailedeventids: Arc<dyn KvTree>,
/// ShortEventId + ShortEventId -> ().
pub tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub referencedevents: Arc<dyn KvTree>,
//pub account_data: account_data::AccountData,
pub roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub media: media::Media,
pub mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
pub url_previews: Arc<dyn KvTree>,
pub mediaid_user: Arc<dyn KvTree>,
//pub key_backups: key_backups::KeyBackups,
pub backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub transaction_ids: transaction_ids::TransactionIds,
pub userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
* (/send) */
//pub sending: sending::Sending,
pub servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
* PduId / Id (for edus), Data = EDU content */
pub servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
* / Id (for edus), Data = EDU content */
//pub appservice: appservice::Appservice,
pub id_appserviceregistrations: Arc<dyn KvTree>,
//pub pusher: pusher::PushData,
pub senderkey_pusher: Arc<dyn KvTree>,
pub auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
impl KeyValueDatabase {
/// Load an existing database or create a new one.
#[allow(clippy::too_many_lines)]
pub async fn load_or_create(server: &Arc<Server>) -> Result<KeyValueDatabase> {
let config = &server.config;
check_db_setup(config)?;
let builder = build(config)?;
Ok(Self {
db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_presenceid: builder.open_tree("userid_presenceid")?,
presenceid_presence: builder.open_tree("presenceid_presence")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
disabledroomids: builder.open_tree("disabledroomids")?,
bannedroomids: builder.open_tree("bannedroomids")?,
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
url_previews: builder.open_tree("url_previews")?,
mediaid_user: builder.open_tree("mediaid_user")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?,
auth_chain_cache: Mutex::new(LruCache::new(
(f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
})
}
}
fn build(config: &Config) -> Result<Arc<dyn KeyValueDatabaseEngine>> {
match &*config.database_backend {
"sqlite" => {
debug!("Got sqlite database backend");
#[cfg(not(feature = "sqlite"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "sqlite")]
Ok(Arc::new(Arc::<crate::sqlite::Engine>::open(config)?))
},
"rocksdb" => {
debug!("Got rocksdb database backend");
#[cfg(not(feature = "rocksdb"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "rocksdb")]
Ok(Arc::new(Arc::<crate::rocksdb::Engine>::open(config)?))
},
_ => Err(Error::bad_config(
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
)),
}
}
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database_path);
let sqlite_exists = path.join("conduit.db").exists();
let rocksdb_exists = path.join("IDENTITY").exists();
if sqlite_exists && rocksdb_exists {
return Err(Error::bad_config("Multiple databases at database_path detected."));
}
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
));
}
Ok(())
}

View File

@ -3,7 +3,7 @@ use std::{error::Error, sync::Arc};
use super::{Config, KvTree};
use crate::Result;
pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
pub trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self>
where
Self: Sized;

View File

@ -2,7 +2,7 @@ use std::{future::Future, pin::Pin};
use crate::Result;
pub(crate) trait KvTree: Send + Sync {
pub trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
#[allow(dead_code)]

View File

@ -1,573 +1,23 @@
mod cork;
mod key_value;
pub mod cork;
mod kvdatabase;
mod kvengine;
mod kvtree;
mod migrations;
#[cfg(feature = "rocksdb")]
mod rocksdb;
pub(crate) mod rocksdb;
#[cfg(feature = "sqlite")]
mod sqlite;
pub(crate) mod sqlite;
#[cfg(any(feature = "sqlite", feature = "rocksdb"))]
pub(crate) mod watchers;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fs::{self},
path::Path,
sync::{Arc, Mutex, RwLock},
time::Duration,
};
pub(crate) use cork::Cork;
pub(crate) use kvengine::KeyValueDatabaseEngine;
pub(crate) use kvtree::KvTree;
use lru_cache::LruCache;
use ruma::{
events::{
push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, GlobalAccountDataEvent,
GlobalAccountDataEventType,
},
push::Ruleset,
CanonicalJsonValue, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId,
};
use serde::Deserialize;
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tokio::time::{interval, Instant};
use tracing::{debug, error, warn};
use crate::{
database::migrations::migrations, service::rooms::timeline::PduCount, services, Config, Error,
LogLevelReloadHandles, Result, Services, SERVICES,
};
pub(crate) struct KeyValueDatabase {
db: Arc<dyn KeyValueDatabaseEngine>,
//pub(crate) globals: globals::Globals,
pub(crate) global: Arc<dyn KvTree>,
pub(crate) server_signingkeys: Arc<dyn KvTree>,
pub(crate) roomid_inviteviaservers: Arc<dyn KvTree>,
//pub(crate) users: users::Users,
pub(crate) userid_password: Arc<dyn KvTree>,
pub(crate) userid_displayname: Arc<dyn KvTree>,
pub(crate) userid_avatarurl: Arc<dyn KvTree>,
pub(crate) userid_blurhash: Arc<dyn KvTree>,
pub(crate) userdeviceid_token: Arc<dyn KvTree>,
pub(crate) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(crate) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(crate) token_userdeviceid: Arc<dyn KvTree>,
pub(crate) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(crate) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub(crate) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub(crate) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub(crate) userid_masterkeyid: Arc<dyn KvTree>,
pub(crate) userid_selfsigningkeyid: Arc<dyn KvTree>,
pub(crate) userid_usersigningkeyid: Arc<dyn KvTree>,
pub(crate) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub(crate) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub(crate) userid_presenceid: Arc<dyn KvTree>, // UserId => Count
pub(crate) presenceid_presence: Arc<dyn KvTree>, // Count + UserId => Presence
//pub(crate) uiaa: uiaa::Uiaa,
pub(crate) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(crate) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>,
//pub(crate) edus: RoomEdus,
pub(crate) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub(crate) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub(crate) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
//pub(crate) rooms: rooms::Rooms,
pub(crate) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub(crate) eventid_pduid: Arc<dyn KvTree>,
pub(crate) roomid_pduleaves: Arc<dyn KvTree>,
pub(crate) alias_roomid: Arc<dyn KvTree>,
pub(crate) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub(crate) publicroomids: Arc<dyn KvTree>,
pub(crate) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
pub(crate) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub(crate) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub(crate) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub(crate) userroomid_joined: Arc<dyn KvTree>,
pub(crate) roomuserid_joined: Arc<dyn KvTree>,
pub(crate) roomid_joinedcount: Arc<dyn KvTree>,
pub(crate) roomid_invitedcount: Arc<dyn KvTree>,
pub(crate) roomuseroncejoinedids: Arc<dyn KvTree>,
pub(crate) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub(crate) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub(crate) userroomid_leftstate: Arc<dyn KvTree>,
pub(crate) roomuserid_leftcount: Arc<dyn KvTree>,
pub(crate) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub(crate) bannedroomids: Arc<dyn KvTree>, // Rooms where local users are not allowed to join
pub(crate) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub(crate) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub(crate) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
pub(crate) roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64
/// Remember the current state hash of a room.
pub(crate) roomid_shortstatehash: Arc<dyn KvTree>,
pub(crate) roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub(crate) shorteventid_shortstatehash: Arc<dyn KvTree>,
pub(crate) statekey_shortstatekey: Arc<dyn KvTree>, /* StateKey = EventType + StateKey, ShortStateKey =
* Count */
pub(crate) shortstatekey_statekey: Arc<dyn KvTree>,
pub(crate) roomid_shortroomid: Arc<dyn KvTree>,
pub(crate) shorteventid_eventid: Arc<dyn KvTree>,
pub(crate) eventid_shorteventid: Arc<dyn KvTree>,
pub(crate) statehash_shortstatehash: Arc<dyn KvTree>,
pub(crate) shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
* (shortstatekey+shorteventid++) + 0_u64 +
* (shortstatekey+shorteventid--) */
pub(crate) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event
/// /federation/send/txn.
pub(crate) eventid_outlierpdu: Arc<dyn KvTree>,
pub(crate) softfailedeventids: Arc<dyn KvTree>,
/// ShortEventId + ShortEventId -> ().
pub(crate) tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(crate) referencedevents: Arc<dyn KvTree>,
//pub(crate) account_data: account_data::AccountData,
pub(crate) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub(crate) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub(crate) media: media::Media,
pub(crate) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
pub(crate) url_previews: Arc<dyn KvTree>,
pub(crate) mediaid_user: Arc<dyn KvTree>,
//pub(crate) key_backups: key_backups::KeyBackups,
pub(crate) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(crate) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(crate) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub(crate) transaction_ids: transaction_ids::TransactionIds,
pub(crate) userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
* (/send) */
//pub(crate) sending: sending::Sending,
pub(crate) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub(crate) servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
* PduId / Id (for edus), Data = EDU content */
pub(crate) servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
* / Id (for edus), Data = EDU content */
//pub(crate) appservice: appservice::Appservice,
pub(crate) id_appserviceregistrations: Arc<dyn KvTree>,
//pub(crate) pusher: pusher::PushData,
pub(crate) senderkey_pusher: Arc<dyn KvTree>,
pub(crate) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub(crate) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(crate) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
pub(crate) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponseEntry {
id: u64,
date: String,
message: String,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
}
impl KeyValueDatabase {
/// Load an existing database or create a new one.
#[allow(clippy::too_many_lines)]
pub(crate) async fn load_or_create(config: Config, tracing_reload_handler: LogLevelReloadHandles) -> Result<()> {
Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
debug!("Database path does not exist, assuming this is a new setup and creating it");
fs::create_dir_all(&config.database_path).map_err(|e| {
error!("Failed to create database path: {e}");
Error::bad_config(
"Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \
create the database folder yourself or allow conduwuit the permissions to create directories and \
files.",
)
})?;
}
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
"sqlite" => {
debug!("Got sqlite database backend");
#[cfg(not(feature = "sqlite"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "sqlite")]
Arc::new(Arc::<sqlite::Engine>::open(&config)?)
},
"rocksdb" => {
debug!("Got rocksdb database backend");
#[cfg(not(feature = "rocksdb"))]
return Err(Error::bad_config("Database backend not found."));
#[cfg(feature = "rocksdb")]
Arc::new(Arc::<rocksdb::Engine>::open(&config)?)
},
_ => {
return Err(Error::bad_config(
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
));
},
};
let db_raw = Box::new(Self {
db: builder.clone(),
userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userid_blurhash: builder.open_tree("userid_blurhash")?,
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_presenceid: builder.open_tree("userid_presenceid")?,
presenceid_presence: builder.open_tree("presenceid_presence")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
threadid_userids: builder.open_tree("threadid_userids")?,
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
disabledroomids: builder.open_tree("disabledroomids")?,
bannedroomids: builder.open_tree("bannedroomids")?,
lazyloadedids: builder.open_tree("lazyloadedids")?,
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
mediaid_file: builder.open_tree("mediaid_file")?,
url_previews: builder.open_tree("url_previews")?,
mediaid_user: builder.open_tree("mediaid_user")?,
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?,
roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?,
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
auth_chain_cache: Mutex::new(LruCache::new(
(f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
});
let db = Box::leak(db_raw);
let services_raw = Box::new(Services::build(db, &config, tracing_reload_handler)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
migrations(db, &config).await?;
services().admin.start_handler();
// Set emergency access for the conduit user
match set_emergency_access() {
Ok(pwd_set) => {
if pwd_set {
warn!(
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!"
);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(
"The Conduit account emergency password is set! Please unset it as soon as you finish \
admin account recovery!",
))
.await;
}
},
Err(e) => {
error!("Could not set the configured emergency password for the conduit user: {}", e);
},
};
services().sending.start_handler();
if config.allow_local_presence {
services().presence.start_handler();
}
Self::start_cleanup_task().await;
if services().globals.allow_check_for_updates() {
Self::start_check_for_updates_task().await;
}
Ok(())
}
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database_path);
let sqlite_exists = path.join("conduit.db").exists();
let rocksdb_exists = path.join("IDENTITY").exists();
if sqlite_exists && rocksdb_exists {
return Err(Error::bad_config("Multiple databases at database_path detected."));
}
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in config.",
));
}
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in config.",
));
}
Ok(())
}
#[tracing::instrument]
async fn start_check_for_updates_task() {
let timer_interval = Duration::from_secs(7200); // 2 hours
tokio::spawn(async move {
let mut i = interval(timer_interval);
loop {
tokio::select! {
_ = i.tick() => {
debug!(target: "start_check_for_updates_task", "Timer ticked");
},
}
_ = Self::try_handle_updates().await;
}
});
}
async fn try_handle_updates() -> Result<()> {
let response = services()
.globals
.client
.default
.get("https://pupbrain.dev/check-for-updates/stable")
.send()
.await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
error!("Bad check for updates response: {e}");
Error::BadServerResponse("Bad version check response")
})?;
let mut last_update_id = services().globals.last_check_for_updates_id()?;
for update in response.updates {
last_update_id = last_update_id.max(update.id);
if update.id > services().globals.last_check_for_updates_id()? {
error!("{}", update.message);
services()
.admin
.send_message(RoomMessageEventContent::text_plain(format!(
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
update.date, update.message
)))
.await;
}
}
services()
.globals
.update_check_for_updates_id(last_update_id)?;
Ok(())
}
#[tracing::instrument]
async fn start_cleanup_task() {
let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
tokio::spawn(async move {
let mut i = interval(timer_interval);
#[cfg(unix)]
let mut hangup = signal(SignalKind::hangup()).expect("Failed to register SIGHUP signal receiver");
#[cfg(unix)]
let mut ctrl_c = signal(SignalKind::interrupt()).expect("Failed to register SIGINT signal receiver");
#[cfg(unix)]
let mut terminate = signal(SignalKind::terminate()).expect("Failed to register SIGTERM signal receiver");
loop {
#[cfg(unix)]
tokio::select! {
_ = i.tick() => {
debug!(target: "database-cleanup", "Timer ticked");
}
_ = hangup.recv() => {
debug!(target: "database-cleanup","Received SIGHUP");
}
_ = ctrl_c.recv() => {
debug!(target: "database-cleanup", "Received Ctrl+C");
}
_ = terminate.recv() => {
debug!(target: "database-cleanup","Received SIGTERM");
}
}
#[cfg(not(unix))]
{
i.tick().await;
debug!(target: "database-cleanup", "Timer ticked")
}
Self::perform_cleanup();
}
});
}
fn perform_cleanup() {
if !services().globals.config.rocksdb_periodic_cleanup {
return;
}
let start = Instant::now();
if let Err(e) = services().globals.cleanup() {
error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e);
} else {
debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed());
}
}
#[allow(dead_code)]
fn flush(&self) -> Result<()> {
let start = std::time::Instant::now();
let res = self.db.flush();
debug!("flush: took {:?}", start.elapsed());
res
}
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId");
services()
.users
.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
&conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
res
}
extern crate conduit_core as conduit;
pub(crate) use conduit::{Config, Result};
pub use cork::Cork;
pub use kvdatabase::KeyValueDatabase;
pub use kvengine::KeyValueDatabaseEngine;
pub use kvtree::KvTree;
conduit::mod_ctor! {}
conduit::mod_dtor! {}

View File

@ -1,9 +1,8 @@
use std::{future::Future, pin::Pin, sync::Arc};
use rust_rocksdb::WriteBatchWithTransaction;
use conduit::{utils, Result};
use super::{watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree};
use crate::{utils, Result};
use super::{rust_rocksdb::WriteBatchWithTransaction, watchers::Watchers, Engine, KeyValueDatabaseEngine, KvTree};
pub(crate) struct RocksDbEngineTree<'a> {
pub(crate) db: Arc<Engine>,

View File

@ -1,3 +1,8 @@
// no_link to prevent double-inclusion of librocksdb.a here and with
// libconduit_core.so
#[no_link]
extern crate rust_rocksdb;
use std::{
collections::HashMap,
sync::{atomic::AtomicU32, Arc},
@ -6,12 +11,12 @@ use std::{
use chrono::{DateTime, Utc};
use rust_rocksdb::{
backup::{BackupEngine, BackupEngineOptions},
perf::get_memory_usage_stats,
Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode as Db, Env, MultiThreaded, Options,
};
use tracing::{debug, error, info, warn};
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::Result;
use crate::{watchers::Watchers, Config, KeyValueDatabaseEngine, KvTree, Result};
pub(crate) mod kvtree;
pub(crate) mod opts;
@ -22,13 +27,13 @@ use opts::{cf_options, db_options};
use super::watchers;
pub(crate) struct Engine {
rocks: Db<MultiThreaded>,
config: Config,
row_cache: Cache,
col_cache: HashMap<String, Cache>,
old_cfs: Vec<String>,
opts: Options,
env: Env,
config: Config,
old_cfs: Vec<String>,
rocks: Db<MultiThreaded>,
corks: AtomicU32,
}
@ -79,13 +84,13 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
load_time.elapsed()
);
Ok(Arc::new(Engine {
rocks: db,
config: config.clone(),
row_cache,
col_cache,
old_cfs: cfs,
opts: db_opts,
env: db_env,
config: config.clone(),
old_cfs: cfs,
rocks: db,
corks: AtomicU32::new(0),
}))
}
@ -135,7 +140,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn memory_usage(&self) -> Result<String> {
let mut res = String::new();
let stats = rust_rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?;
let stats = get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.row_cache]))?;
_ = std::fmt::write(
&mut res,
format_args!(
@ -258,3 +263,20 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
#[allow(dead_code)]
fn clear_caches(&self) {}
}
impl Drop for Engine {
fn drop(&mut self) {
debug!("Waiting for background tasks to finish...");
const BLOCKING: bool = true;
self.rocks.cancel_all_background_work(BLOCKING);
debug!("Shutting down background threads");
self.env.set_high_priority_background_threads(0);
self.env.set_low_priority_background_threads(0);
self.env.set_bottom_priority_background_threads(0);
self.env.set_background_threads(0);
debug!("Joining background threads...");
self.env.join_all_threads();
}
}

View File

@ -1,14 +1,14 @@
#![allow(dead_code)]
use std::collections::HashMap;
use rust_rocksdb::{
use super::{
rust_rocksdb::{
BlockBasedOptions, Cache, DBCompactionStyle, DBCompressionType, DBRecoveryMode, Env, LogLevel, Options,
UniversalCompactOptions, UniversalCompactionStopStyle,
},
Config,
};
use super::Config;
/// Create database-wide options suitable for opening the database. This also
/// sets our default column options in case of opening a column with the same
/// resulting value. Note that we require special per-column options on some

View File

@ -6,13 +6,13 @@ use std::{
sync::Arc,
};
use conduit::{Config, Result};
use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
use thread_local::ThreadLocal;
use tracing::debug;
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result};
thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = const { RefCell::new(None) };
@ -224,7 +224,7 @@ impl KvTree for SqliteTable {
guard.execute("BEGIN", [])?;
for key in iter {
let old = self.get_with_guard(&guard, &key)?;
let new = crate::utils::increment(old.as_deref());
let new = conduit::utils::increment(old.as_deref());
self.insert_with_guard(&guard, &key, &new)?;
}
guard.execute("COMMIT", [])?;
@ -307,7 +307,7 @@ impl KvTree for SqliteTable {
let old = self.get_with_guard(&guard, key)?;
let new = crate::utils::increment(old.as_deref());
let new = conduit::utils::increment(old.as_deref());
self.insert_with_guard(&guard, key, &new)?;

View File

@ -1,503 +0,0 @@
#[cfg(unix)]
use std::fs::Permissions; // not unix specific, just only for UNIX sockets stuff and *nix container checks
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for UNIX sockets stuff and *nix
* container checks */
// Not async due to services() being used in many closures, and async closures
// are not stable as of writing This is the case for every other occurence of
// sync Mutex/RwLock, except for database related ones
use std::sync::{Arc, RwLock};
use std::{io, net::SocketAddr, time::Duration};
use api::ruma_wrapper::{Ruma, RumaResponse};
use axum::Router;
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
#[cfg(feature = "axum_dual_protocol")]
use axum_server_dual_protocol::ServerExt;
use config::Config;
use database::KeyValueDatabase;
use service::{pdu::PduEvent, Services};
use tokio::{
signal,
sync::oneshot::{self, Sender},
task::JoinSet,
};
use tracing::{debug, error, info, warn};
use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry};
use utils::{
clap,
error::{Error, Result},
};
pub(crate) mod alloc;
mod api;
mod config;
mod database;
mod router;
mod service;
mod utils;
pub(crate) static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
#[must_use]
pub(crate) fn services() -> &'static Services<'static> {
SERVICES
.read()
.unwrap()
.expect("SERVICES should be initialized when this is called")
}
pub(crate) struct Server {
config: Config,
runtime: tokio::runtime::Runtime,
tracing_reload_handle: LogLevelReloadHandles,
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: Option<sentry::ClientInitGuard>,
_tracing_flame_guard: TracingFlameGuard,
}
fn main() -> Result<(), Error> {
let args = clap::parse();
let conduwuit: Server = init(args)?;
conduwuit
.runtime
.block_on(async { async_main(&conduwuit).await })
}
async fn async_main(server: &Server) -> Result<(), Error> {
if let Err(error) = start(server).await {
error!("Critical error starting server: {error}");
return Err(Error::Err(format!("{error}")));
}
if let Err(error) = run(server).await {
error!("Critical error running server: {error}");
return Err(Error::Err(format!("{error}")));
}
if let Err(error) = stop(server).await {
error!("Critical error stopping server: {error}");
return Err(Error::Err(format!("{error}")));
}
Ok(())
}
async fn run(server: &Server) -> io::Result<()> {
let app = router::build(server).await?;
let (tx, rx) = oneshot::channel::<()>();
let handle = ServerHandle::new();
tokio::spawn(shutdown(handle.clone(), tx));
#[cfg(unix)]
if server.config.unix_socket_path.is_some() {
return run_unix_socket_server(server, app, rx).await;
}
let addrs = server.config.get_bind_addrs();
if server.config.tls.is_some() {
return run_tls_server(server, app, handle, addrs).await;
}
let mut join_set = JoinSet::new();
for addr in &addrs {
join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone()));
}
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
info!("Listening on {:?}", addrs);
join_set.join_next().await;
Ok(())
}
async fn run_tls_server(
server: &Server, app: axum::routing::IntoMakeService<Router>, handle: ServerHandle, addrs: Vec<SocketAddr>,
) -> io::Result<()> {
let tls = server.config.tls.as_ref().unwrap();
debug!(
"Using direct TLS. Certificate path {} and certificate private key path {}",
&tls.certs, &tls.key
);
info!(
"Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS."
);
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
if cfg!(feature = "axum_dual_protocol") {
info!(
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only \
take affect if `dual_protocol` is enabled in `[global.tls]`"
);
}
let mut join_set = JoinSet::new();
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
#[cfg(feature = "axum_dual_protocol")]
for addr in &addrs {
join_set.spawn(
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
.set_upgrade(false)
.handle(handle.clone())
.serve(app.clone()),
);
}
} else {
for addr in &addrs {
join_set.spawn(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
);
}
}
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
warn!(
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)",
addrs, &tls.certs
);
} else {
info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs);
}
join_set.join_next().await;
Ok(())
}
#[cfg(unix)]
#[allow(unused_variables)]
async fn run_unix_socket_server(
server: &Server, app: axum::routing::IntoMakeService<Router>, rx: oneshot::Receiver<()>,
) -> io::Result<()> {
let path = server.config.unix_socket_path.as_ref().unwrap();
if path.exists() {
warn!(
"UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.",
path.display()
);
tokio::fs::remove_file(&path).await?;
}
tokio::fs::create_dir_all(path.parent().unwrap()).await?;
let socket_perms = server.config.unix_socket_perms.to_string();
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
tokio::fs::set_permissions(&path, Permissions::from_mode(octal_perms))
.await
.unwrap();
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
let bind = tokio::net::UnixListener::bind(path)?;
info!("Listening at {:?}", path);
Ok(())
}
async fn shutdown(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
let sig: &str;
#[cfg(unix)]
tokio::select! {
() = ctrl_c => { sig = "Ctrl+C"; },
() = terminate => { sig = "SIGTERM"; },
}
#[cfg(not(unix))]
tokio::select! {
_ = ctrl_c => { sig = "Ctrl+C"; },
}
warn!("Received {}, shutting down...", sig);
handle.graceful_shutdown(Some(Duration::from_secs(180)));
services().globals.shutdown();
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
#[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
tx.send(()).expect(
"failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \
system may not be in an okay/ideal state.)",
);
Ok(())
}
async fn stop(_server: &Server) -> io::Result<()> {
info!("Shutdown complete.");
Ok(())
}
/// Async initializations
async fn start(server: &Server) -> Result<(), Error> {
KeyValueDatabase::load_or_create(server.config.clone(), server.tracing_reload_handle.clone()).await?;
Ok(())
}
/// Non-async initializations
fn init(args: clap::Args) -> Result<Server, Error> {
let config = Config::new(args.config)?;
#[cfg(feature = "sentry_telemetry")]
let sentry_guard = if config.sentry {
Some(init_sentry(&config))
} else {
None
};
let (tracing_reload_handle, tracing_flame_guard) = init_tracing(&config);
config.check()?;
info!(
server_name = ?config.server_name,
database_path = ?config.database_path,
log_levels = ?config.log,
"{}",
utils::conduwuit_version(),
);
#[cfg(unix)]
maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit");
Ok(Server {
config,
runtime: tokio::runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.thread_name("conduwuit:worker")
.worker_threads(std::cmp::max(2, num_cpus::get()))
.build()
.unwrap(),
tracing_reload_handle,
#[cfg(feature = "sentry_telemetry")]
_sentry_guard: sentry_guard,
_tracing_flame_guard: tracing_flame_guard,
})
}
#[cfg(feature = "sentry_telemetry")]
fn init_sentry(config: &Config) -> sentry::ClientInitGuard {
sentry::init((
config
.sentry_endpoint
.as_ref()
.expect("init_sentry should only be called if sentry is enabled and this is not None")
.as_str(),
sentry::ClientOptions {
release: sentry::release_name!(),
traces_sample_rate: config.sentry_traces_sample_rate,
server_name: if config.sentry_send_server_name {
Some(config.server_name.to_string().into())
} else {
None
},
..Default::default()
},
))
}
/// We need to store a reload::Handle value, but can't name it's type explicitly
/// because the S type parameter depends on the subscriber's previous layers. In
/// our case, this includes unnameable 'impl Trait' types.
///
/// This is fixed[1] in the unreleased tracing-subscriber from the master
/// branch, which removes the S parameter. Unfortunately can't use it without
/// pulling in a version of tracing that's incompatible with the rest of our
/// deps.
///
/// To work around this, we define an trait without the S paramter that forwards
/// to the reload::Handle::reload method, and then store the handle as a trait
/// object.
///
/// [1]: <https://github.com/tokio-rs/tracing/pull/1035/commits/8a87ea52425098d3ef8f56d92358c2f6c144a28f>
trait ReloadHandle<L> {
fn reload(&self, new_value: L) -> Result<(), reload::Error>;
}
impl<L, S> ReloadHandle<L> for reload::Handle<L, S> {
fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) }
}
struct LogLevelReloadHandlesInner {
handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>,
}
/// Wrapper to allow reloading the filter on several several
/// [`tracing_subscriber::reload::Handle`]s at once, with the same value.
#[derive(Clone)]
struct LogLevelReloadHandles {
inner: Arc<LogLevelReloadHandlesInner>,
}
impl LogLevelReloadHandles {
fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> LogLevelReloadHandles {
LogLevelReloadHandles {
inner: Arc::new(LogLevelReloadHandlesInner {
handles,
}),
}
}
fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> {
for handle in &self.inner.handles {
handle.reload(new_value.clone())?;
}
Ok(())
}
}
#[cfg(feature = "perf_measurements")]
type TracingFlameGuard = Option<tracing_flame::FlushGuard<io::BufWriter<std::fs::File>>>;
#[cfg(not(feature = "perf_measurements"))]
type TracingFlameGuard = ();
// clippy thinks the filter_layer clones are redundant if the next usage is
// behind a disabled feature.
#[allow(clippy::redundant_clone)]
fn init_tracing(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard) {
let registry = Registry::default();
let fmt_layer = tracing_subscriber::fmt::Layer::new();
let filter_layer = match EnvFilter::try_new(&config.log) {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
EnvFilter::try_new("warn").unwrap()
},
};
let mut reload_handles = Vec::<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>::new();
let subscriber = registry;
#[cfg(feature = "tokio_console")]
let subscriber = {
let console_layer = console_subscriber::spawn();
subscriber.with(console_layer)
};
let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone());
reload_handles.push(Box::new(fmt_reload_handle));
let subscriber = subscriber.with(fmt_layer.with_filter(fmt_reload_filter));
#[cfg(feature = "sentry_telemetry")]
let subscriber = {
let sentry_layer = sentry_tracing::layer();
let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone());
reload_handles.push(Box::new(sentry_reload_handle));
subscriber.with(sentry_layer.with_filter(sentry_reload_filter))
};
#[cfg(feature = "perf_measurements")]
let (subscriber, flame_guard) = {
let (flame_layer, flame_guard) = if config.tracing_flame {
let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) {
Ok(flame_filter) => flame_filter,
Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"),
};
let (flame_layer, flame_guard) =
match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) {
Ok(ok) => ok,
Err(e) => {
panic!("failed to initialize tracing-flame: {e}");
},
};
let flame_layer = flame_layer
.with_empty_samples(false)
.with_filter(flame_filter);
(Some(flame_layer), Some(flame_guard))
} else {
(None, None)
};
let jaeger_layer = if config.allow_jaeger {
opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new());
let tracer = opentelemetry_jaeger::new_agent_pipeline()
.with_auto_split_batch(true)
.with_service_name("conduwuit")
.install_batch(opentelemetry_sdk::runtime::Tokio)
.unwrap();
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer);
reload_handles.push(Box::new(jaeger_reload_handle));
Some(telemetry.with_filter(jaeger_reload_filter))
} else {
None
};
let subscriber = subscriber.with(flame_layer).with(jaeger_layer);
(subscriber, flame_guard)
};
#[cfg(not(feature = "perf_measurements"))]
#[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))]
let flame_guard = ();
tracing::subscriber::set_global_default(subscriber).unwrap();
#[cfg(all(feature = "tokio_console", feature = "release_max_log_level"))]
error!(
"'tokio_console' feature and 'release_max_log_level' feature are incompatible, because console-subscriber \
needs access to trace-level events. 'release_max_log_level' must be disabled to use tokio-console."
);
(LogLevelReloadHandles::new(reload_handles), flame_guard)
}
/// This is needed for opening lots of file descriptors, which tends to
/// happen more often when using RocksDB and making lots of federation
/// connections at startup. The soft limit is usually 1024, and the hard
/// limit is usually 512000; I've personally seen it hit >2000.
///
/// * <https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6>
/// * <https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741>
#[cfg(unix)]
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
use nix::sys::resource::{getrlimit, setrlimit, Resource::RLIMIT_NOFILE as NOFILE};
let (soft_limit, hard_limit) = getrlimit(NOFILE)?;
if soft_limit < hard_limit {
setrlimit(NOFILE, hard_limit, hard_limit)?;
assert_eq!((hard_limit, hard_limit), getrlimit(NOFILE)?, "getrlimit != setrlimit");
debug!(to = hard_limit, from = soft_limit, "Raised RLIMIT_NOFILE",);
}
Ok(())
}

85
src/router/Cargo.toml Normal file
View File

@ -0,0 +1,85 @@
[package]
name = "conduit_router"
version.workspace = true
edition.workspace = true
[lib]
path = "mod.rs"
crate-type = [
"rlib",
# "dylib",
]
[features]
default = [
"systemd",
"sentry_telemetry",
"gzip_compression",
"zstd_compression",
"brotli_compression",
"release_max_log_level",
]
dev_release_log_level = []
release_max_log_level = [
"tracing/max_level_trace",
"tracing/release_max_level_info",
"log/max_level_trace",
"log/release_max_level_info",
]
sentry_telemetry = [
"dep:sentry",
"dep:sentry-tracing",
"dep:sentry-tower",
]
zstd_compression = [
"tower-http/compression-zstd",
]
gzip_compression = [
"tower-http/compression-gzip",
]
brotli_compression = [
"tower-http/compression-br",
]
systemd = [
"dep:sd-notify",
]
axum_dual_protocol = [
"dep:axum-server-dual-protocol"
]
[dependencies]
axum-server-dual-protocol.optional = true
axum-server-dual-protocol.workspace = true
axum-server.workspace = true
axum.workspace = true
conduit-admin.workspace = true
conduit-api.workspace = true
conduit-core.workspace = true
conduit-database.workspace = true
conduit-service.workspace = true
log.workspace = true
tokio.workspace = true
tower.workspace = true
tracing.workspace = true
bytes.workspace = true
clap.workspace = true
http-body-util.workspace = true
http.workspace = true
regex.workspace = true
ruma.workspace = true
sentry.optional = true
sentry-tower.optional = true
sentry-tower.workspace = true
sentry-tracing.optional = true
sentry-tracing.workspace = true
sentry.workspace = true
serde_json.workspace = true
tower-http.workspace = true
[target.'cfg(unix)'.dependencies]
sd-notify.workspace = true
sd-notify.optional = true
[lints]
workspace = true

190
src/router/layers.rs Normal file
View File

@ -0,0 +1,190 @@
use std::{any::Any, io, sync::Arc, time::Duration};
use axum::{
extract::{DefaultBodyLimit, MatchedPath},
Router,
};
use conduit::Server;
use http::{
header::{self, HeaderName},
HeaderValue, Method, StatusCode,
};
use tower::ServiceBuilder;
use tower_http::{
catch_panic::CatchPanicLayer,
cors::{self, CorsLayer},
set_header::SetResponseHeaderLayer,
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
ServiceBuilderExt as _,
};
use tracing::Level;
use crate::{request, router};
pub(crate) fn build(server: &Arc<Server>) -> io::Result<axum::routing::IntoMakeService<Router>> {
let layers = ServiceBuilder::new();
#[cfg(feature = "sentry_telemetry")]
let layers = layers.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
let layers = layers.layer(compression_layer(server));
let layers = layers
.sensitive_headers([header::AUTHORIZATION])
.sensitive_request_headers([HeaderName::from_static("x-forwarded-for")].into())
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::spawn))
.layer(
TraceLayer::new_for_http()
.make_span_with(tracing_span::<_>)
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
.on_request(DefaultOnRequest::new().level(Level::TRACE))
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
)
.layer(axum::middleware::from_fn_with_state(Arc::clone(server), request::handle))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("origin-agent-cluster"), // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin-Agent-Cluster
HeaderValue::from_static("?1"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_XSS_PROTECTION,
HeaderValue::from_static("0"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_FRAME_OPTIONS,
HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("interest-cohort=(),browsing-topics=()"),
))
.layer(SetResponseHeaderLayer::if_not_present(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static(
"sandbox; default-src 'none'; font-src 'none'; script-src 'none'; plugin-types application/pdf; \
style-src 'unsafe-inline'; object-src 'self'; frame-ancesors 'none';",
),
))
.layer(cors_layer(server))
.layer(body_limit_layer(server))
.layer(CatchPanicLayer::custom(catch_panic));
let routes = router::build(server);
let layers = routes.layer(layers);
Ok(layers.into_make_service())
}
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
let mut compression_layer = tower_http::compression::CompressionLayer::new();
#[cfg(feature = "zstd_compression")]
{
if server.config.zstd_compression {
compression_layer = compression_layer.zstd(true);
} else {
compression_layer = compression_layer.no_zstd();
};
};
#[cfg(feature = "gzip_compression")]
{
if server.config.gzip_compression {
compression_layer = compression_layer.gzip(true);
} else {
compression_layer = compression_layer.no_gzip();
};
};
#[cfg(feature = "brotli_compression")]
{
if server.config.brotli_compression {
compression_layer = compression_layer.br(true);
} else {
compression_layer = compression_layer.no_br();
};
};
compression_layer
}
fn cors_layer(_server: &Server) -> CorsLayer {
const METHODS: [Method; 7] = [
Method::GET,
Method::HEAD,
Method::PATCH,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
];
let headers: [HeaderName; 5] = [
header::ORIGIN,
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
header::CONTENT_TYPE,
header::ACCEPT,
header::AUTHORIZATION,
];
CorsLayer::new()
.allow_origin(cors::Any)
.allow_methods(METHODS)
.allow_headers(headers)
.max_age(Duration::from_secs(86400))
}
fn body_limit_layer(server: &Server) -> DefaultBodyLimit {
DefaultBodyLimit::max(
server
.config
.max_request_size
.try_into()
.expect("failed to convert max request size"),
)
}
#[allow(clippy::needless_pass_by_value)]
#[tracing::instrument(skip_all)]
fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
conduit_service::services()
.server
.requests_panic
.fetch_add(1, std::sync::atomic::Ordering::Release);
let details = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown internal server error occurred.".to_owned()
};
let body = serde_json::json!({
"errcode": "M_UNKNOWN",
"error": "M_UNKNOWN: Internal server error occurred",
"details": details,
})
.to_string();
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "application/json")
.body(http_body_util::Full::from(body))
.expect("Failed to create response for our panic catcher?")
}
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
path.as_str()
} else {
request.uri().path()
};
tracing::info_span!("router:", %path)
}

Some files were not shown because too many files have changed in this diff Show More