From 86694f2d1d55605af2058b5347c71ebf977c5daf Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 12 Nov 2024 08:01:23 +0000 Subject: [PATCH] move non-generic code out of generic; reduce codegen Signed-off-by: Jason Volk --- src/api/router/args.rs | 77 +++++++++++++++--------------- src/service/sending/send.rs | 93 ++++++++++++++++++++----------------- 2 files changed, 90 insertions(+), 80 deletions(-) diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 4c0aff4c..0b693956 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -66,6 +66,15 @@ where } } +impl Deref for Args +where + T: IncomingRequest + Send + Sync + 'static, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { &self.body } +} + #[async_trait] impl FromRequest for Args where @@ -78,7 +87,7 @@ where let mut json_body = serde_json::from_slice::(&request.body).ok(); let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?; Ok(Self { - body: make_body::(services, &mut request, &mut json_body, &auth)?, + body: make_body::(services, &mut request, json_body.as_mut(), &auth)?, origin: auth.origin, sender_user: auth.sender_user, sender_device: auth.sender_device, @@ -88,20 +97,11 @@ where } } -impl Deref for Args -where - T: IncomingRequest + Send + Sync + 'static, -{ - type Target = T; - - fn deref(&self) -> &Self::Target { &self.body } -} - fn make_body( - services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, ) -> Result where - T: IncomingRequest + Send + Sync + 'static, + T: IncomingRequest, { let body = take_body(services, request, json_body, auth); let http_request = into_http_request(request, body); @@ -125,36 +125,37 @@ fn into_http_request(request: &Request, body: Bytes) -> hyper::Request { http_request } +#[allow(clippy::needless_pass_by_value)] fn take_body( - services: &Services, request: &mut Request, json_body: &mut Option, auth: &Auth, + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, ) -> Bytes { - if let Some(CanonicalJsonValue::Object(json_body)) = json_body { - let user_id = auth.sender_user.clone().unwrap_or_else(|| { - let server_name = services.globals.server_name(); - UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") + let Some(CanonicalJsonValue::Object(json_body)) = json_body else { + return mem::take(&mut request.body); + }; + + let user_id = auth.sender_user.clone().unwrap_or_else(|| { + let server_name = services.globals.server_name(); + UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(CanonicalJsonValue::as_object) + .and_then(|auth| auth.get("session")) + .and_then(CanonicalJsonValue::as_str) + .and_then(|session| { + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) }); - let uiaa_request = json_body - .get("auth") - .and_then(CanonicalJsonValue::as_object) - .and_then(|auth| auth.get("session")) - .and_then(CanonicalJsonValue::as_str) - .and_then(|session| { - services - .uiaa - .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) - }); - - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); } - - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); - buf.into_inner().freeze() - } else { - mem::take(&mut request.body) } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); + buf.into_inner().freeze() } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 939d6e73..5bf48aaa 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,5 +1,6 @@ -use std::{fmt::Debug, mem}; +use std::mem; +use bytes::Bytes; use conduit::{ debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error, Result, @@ -23,10 +24,10 @@ use crate::{ }; impl super::Service { - #[tracing::instrument(skip(self, client, req), name = "send")] - pub async fn send(&self, client: &Client, dest: &ServerName, req: T) -> Result + #[tracing::instrument(skip(self, client, request), name = "send")] + pub async fn send(&self, client: &Client, dest: &ServerName, request: T) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { if !self.server.config.allow_federation { return Err!(Config("allow_federation", "Federation is disabled.")); @@ -42,7 +43,8 @@ impl super::Service { } let actual = self.services.resolver.get_actual_dest(dest).await?; - let request = self.prepare::(dest, &actual, req).await?; + let request = into_http_request::(&actual, request)?; + let request = self.prepare(dest, request)?; self.execute::(dest, &actual, request, client).await } @@ -50,7 +52,7 @@ impl super::Service { &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client, ) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { let url = request.url().clone(); let method = request.method().clone(); @@ -58,25 +60,14 @@ impl super::Service { debug!(?method, ?url, "Sending request"); match client.execute(request).await { Ok(response) => handle_response::(&self.services.resolver, dest, actual, &method, &url, response).await, - Err(error) => handle_error::(dest, actual, &method, &url, error), + Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")), } } - async fn prepare(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result - where - T: OutgoingRequest + Debug + Send, - { - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; - const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); + fn prepare(&self, dest: &ServerName, mut request: http::Request>) -> Result { + self.sign_request(&mut request, dest); - trace!("Preparing request"); - let mut http_request = req - .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) - .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; - - self.sign_request(&mut http_request, dest); - - let request = Request::try_from(http_request)?; + let request = Request::try_from(request)?; self.validate_url(request.url())?; Ok(request) @@ -96,11 +87,31 @@ impl super::Service { async fn handle_response( resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, - mut response: Response, + response: Response, ) -> Result where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { + let response = into_http_response(dest, actual, method, url, response).await?; + let result = T::IncomingResponse::try_from_http_response(response); + + if result.is_ok() && !actual.cached { + resolver.set_cached_destination( + dest.to_owned(), + CachedDest { + dest: actual.dest.clone(), + host: actual.host.clone(), + expire: CachedDest::default_expire(), + }, + ); + } + + result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) +} + +async fn into_http_response( + dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, +) -> Result> { let status = response.status(); trace!( ?status, ?method, @@ -113,6 +124,7 @@ where let mut http_response_builder = http::Response::builder() .status(status) .version(response.version()); + mem::swap( response.headers_mut(), http_response_builder @@ -137,27 +149,10 @@ where return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response))); } - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && !actual.cached { - resolver.set_cached_destination( - dest.to_owned(), - CachedDest { - dest: actual.dest.clone(), - host: actual.host.clone(), - expire: CachedDest::default_expire(), - }, - ); - } - - response.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) + Ok(http_response) } -fn handle_error( - _dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error, -) -> Result -where - T: OutgoingRequest + Debug + Send, -{ +fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result { if e.is_timeout() || e.is_connect() { e = e.without_url(); debug_warn!("{e:?}"); @@ -246,3 +241,17 @@ fn sign_request(&self, http_request: &mut http::Request>, dest: &ServerN debug_assert!(authorization.is_none(), "Authorization header already present"); } + +fn into_http_request(actual: &ActualDest, request: T) -> Result>> +where + T: OutgoingRequest + Send, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); + + let http_request = request + .try_into_http_request::>(actual.string().as_str(), SATIR, &VERSIONS) + .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; + + Ok(http_request) +}