cleanup/split/dedup sending/send callstack

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-04-16 20:54:16 -07:00 committed by June
parent 9361acadcb
commit 68aa368450
1 changed files with 275 additions and 255 deletions

View File

@ -43,9 +43,16 @@ pub enum FedDest {
Named(String, String),
}
struct ActualDestination {
destination: FedDest,
host: String,
string: String,
cached: bool,
}
#[tracing::instrument(skip_all, name = "send")]
pub(crate) async fn send_request<T>(
client: &reqwest::Client, destination: &ServerName, request: T,
client: &reqwest::Client, destination: &ServerName, req: T,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
@ -54,286 +61,150 @@ where
return Err(Error::bad_config("Federation is disabled."));
}
if destination == services().globals.server_name() {
return Err(Error::bad_config("Won't send federation request to ourselves"));
}
if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) {
debug!(
"Destination {} is an IP literal, checking against IP range denylist.",
destination
);
let ip = IPAddress::parse(destination.host()).map_err(|e| {
warn!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
trace!("Preparing to send request");
validate_destination(destination)?;
let actual = get_actual_destination(destination).await;
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5])
.map_err(|e| {
warn!("Failed to find destination {}: {}", actual.string, e);
Error::BadServerResponse("Invalid destination")
})?;
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
sign_request::<T>(destination, &mut http_request);
let request = reqwest::Request::try_from(http_request)?;
let method = request.method().clone();
let url = request.url().clone();
validate_url(&url)?;
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
debug!(
method = ?method,
url = ?url,
"Sending request",
);
match client.execute(request).await {
Ok(response) => handle_response::<T>(destination, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(destination, &actual, &method, &url, e),
}
}
debug!("List of pushed CIDR ranges: {:?}", cidr_ranges);
async fn handle_response<T>(
destination: &ServerName, actual: ActualDestination, method: &reqwest::Method, url: &reqwest::Url,
mut response: reqwest::Response,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
trace!("Received response from {} for {} with {}", actual.string, url, response.url());
validate_response(&response)?;
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
let status = response.status();
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
debug!("IP literal {} is allowed.", destination);
trace!("Waiting for response body");
let body = response.bytes().await.unwrap_or_else(|e| {
debug!("server error {}", e);
Vec::new().into()
}); // TODO: handle timeout
let http_response = http_response_builder
.body(body)
.expect("reqwest body is valid http body");
debug!("Got {status:?} for {method} {url}");
if !status.is_success() {
return Err(Error::FederationError(
destination.to_owned(),
RumaError::from_http_response(http_response),
));
}
trace!("Preparing to send request to {destination}");
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached {
services()
.globals
.actual_destinations()
.write()
.await
.insert(OwnedServerName::from(destination), (actual.destination, actual.host));
}
let mut write_destination_to_cache = false;
match response {
Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")),
Ok(response) => Ok(response),
}
}
fn handle_error<T>(
_destination: &ServerName, actual: &ActualDestination, method: &reqwest::Method, url: &reqwest::Url,
e: reqwest::Error,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
// we do not need to log that servers in a room are dead, this is normal in
// public rooms and just spams the logs.
if e.is_timeout() {
debug!("Timed out sending request to {}: {}", actual.string, e,);
} else if e.is_connect() {
debug!("Failed to connect to {}: {}", actual.string, e);
} else if e.is_redirect() {
debug!(
method = ?method,
url = ?url,
final_url = ?e.url(),
"Redirect loop sending request to {}: {}",
actual.string,
e,
);
} else {
debug!("Could not send request to {}: {}", actual.string, e);
}
Err(e.into())
}
#[tracing::instrument(skip_all, name = "resolve")]
async fn get_actual_destination(server_name: &ServerName) -> ActualDestination {
let cached;
let cached_result = services()
.globals
.actual_destinations()
.read()
.await
.get(destination)
.get(server_name)
.cloned();
let (actual_destination, host) = if let Some(result) = cached_result {
let (destination, host) = if let Some(result) = cached_result {
cached = true;
result
} else {
write_destination_to_cache = true;
let result = resolve_actual_destination(destination).await;
(result.0, result.1.into_uri_string())
cached = false;
resolve_actual_destination(server_name).await
};
let actual_destination_str = actual_destination.clone().into_https_string();
let mut http_request = request
.try_into_http_request::<Vec<u8>>(
&actual_destination_str,
SendAccessToken::IfRequired(""),
&[MatrixVersion::V1_5],
)
.map_err(|e| {
warn!("Failed to find destination {}: {}", actual_destination_str, e);
Error::BadServerResponse("Invalid destination")
})?;
let mut request_map = serde_json::Map::new();
if !http_request.body().is_empty() {
request_map.insert(
"content".to_owned(),
serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"),
);
};
request_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
request_map.insert(
"uri".to_owned(),
http_request
.uri()
.path_and_query()
.expect("all requests have a path")
.to_string()
.into(),
);
request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
request_map.insert("destination".to_owned(), destination.as_str().into());
let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut request_json,
)
.expect("our request json is what ruma expects");
let request_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap();
let signatures = request_json["signatures"]
.as_object()
.unwrap()
.values()
.map(|v| {
v.as_object()
.unwrap()
.iter()
.map(|(k, v)| (k, v.as_str().unwrap()))
});
for signature_server in signatures {
for s in signature_server {
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"X-Matrix origin={},key=\"{}\",sig=\"{}\"",
services().globals.server_name(),
s.0,
s.1
))
.unwrap(),
);
}
let string = destination.clone().into_https_string();
ActualDestination {
destination,
host,
string,
cached,
}
let reqwest_request = reqwest::Request::try_from(http_request)?;
let method = reqwest_request.method().clone();
let url = reqwest_request.url().clone();
if let Some(url_host) = url.host_str() {
trace!("Checking request URL for IP");
if let Ok(ip) = IPAddress::parse(url_host) {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
}
}
debug!("Sending request {} {}", method, url);
let response = client.execute(reqwest_request).await;
trace!("Received resonse {} {}", method, url);
match response {
Ok(mut response) => {
// reqwest::Response -> http::Response conversion
trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
}
}
let status = response.status();
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
trace!("Waiting for response body");
let body = response.bytes().await.unwrap_or_else(|e| {
debug!("server error {}", e);
Vec::new().into()
}); // TODO: handle timeout
if !status.is_success() {
debug!(
"Got {status:?} for {method} {url}: {}",
String::from_utf8_lossy(&body)
.lines()
.collect::<Vec<_>>()
.join(" ")
);
}
let http_response = http_response_builder
.body(body)
.expect("reqwest body is valid http body");
if status.is_success() {
debug!("Got {status:?} for {method} {url}");
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && write_destination_to_cache {
services()
.globals
.actual_destinations()
.write()
.await
.insert(OwnedServerName::from(destination), (actual_destination, host));
}
response.map_err(|e| {
debug!("Invalid 200 response for {} {}", url, e);
Error::BadServerResponse("Server returned bad 200 response.")
})
} else {
Err(Error::FederationError(
destination.to_owned(),
RumaError::from_http_response(http_response),
))
}
},
Err(e) => {
// we do not need to log that servers in a room are dead, this is normal in
// public rooms and just spams the logs.
if e.is_timeout() {
debug!(
"Timed out sending request to {} at {}: {}",
destination, actual_destination_str, e
);
} else if e.is_connect() {
debug!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e);
} else if e.is_redirect() {
debug!(
"Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}",
destination,
actual_destination_str,
e,
e.url()
);
} else {
debug!("Could not send request to {} at {}: {}", destination, actual_destination_str, e);
}
Err(e.into())
},
}
}
fn get_ip_with_port(destination_str: &str) -> Option<FedDest> {
if let Ok(destination) = destination_str.parse::<SocketAddr>() {
Some(FedDest::Literal(destination))
} else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() {
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
} else {
None
}
}
fn add_port_to_hostname(destination_str: &str) -> FedDest {
let (host, port) = match destination_str.find(':') {
None => (destination_str, ":8448"),
Some(pos) => destination_str.split_at(pos),
};
FedDest::Named(host.to_owned(), port.to_owned())
}
/// Returns: `actual_destination`, host header
/// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names>
/// Numbers in comments below refer to bullet points in linked section of
/// specification
#[tracing::instrument(skip_all, name = "resolve")]
async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) {
async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, String) {
trace!("Finding actual destination for {destination}");
let destination_str = destination.as_str().to_owned();
let mut hostname = destination_str.clone();
@ -429,7 +300,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, Fe
};
debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}");
(actual_destination, hostname)
(actual_destination, hostname.into_uri_string())
}
async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) {
@ -441,7 +312,6 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
{
Ok(override_ip) => {
trace!("Caching result of {:?} overriding {:?}", hostname, overname);
services()
.globals
.resolver
@ -533,6 +403,156 @@ async fn request_well_known(destination: &str) -> Option<String> {
Some(body.get("m.server")?.as_str()?.to_owned())
}
fn sign_request<T>(destination: &ServerName, http_request: &mut http::Request<Vec<u8>>)
where
T: OutgoingRequest + Debug,
{
let mut req_map = serde_json::Map::new();
if !http_request.body().is_empty() {
req_map.insert(
"content".to_owned(),
serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"),
);
};
req_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
req_map.insert(
"uri".to_owned(),
http_request
.uri()
.path_and_query()
.expect("all requests have a path")
.to_string()
.into(),
);
req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
req_map.insert("destination".to_owned(), destination.as_str().into());
let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut req_json,
)
.expect("our request json is what ruma expects");
let req_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap();
let signatures = req_json["signatures"]
.as_object()
.expect("signatures object")
.values()
.map(|v| {
v.as_object()
.expect("server signatures object")
.iter()
.map(|(k, v)| (k, v.as_str().expect("server signature string")))
});
for signature_server in signatures {
for s in signature_server {
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from_str(&format!(
"X-Matrix origin={},key=\"{}\",sig=\"{}\"",
services().globals.server_name(),
s.0,
s.1
))
.expect("formatted X-Matrix header"),
);
}
}
}
fn validate_response(response: &reqwest::Response) -> Result<()> {
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
trace!("Checking response destination's IP");
validate_ip(&ip)?;
}
}
Ok(())
}
fn validate_url(url: &reqwest::Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
validate_ip(&ip)?;
}
}
Ok(())
}
fn validate_destination(destination: &ServerName) -> Result<()> {
if destination == services().globals.server_name() {
return Err(Error::bad_config("Won't send federation request to ourselves"));
}
if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) {
validate_destination_ip_literal(destination)?;
}
trace!("Destination ServerName is valid");
Ok(())
}
fn validate_destination_ip_literal(destination: &ServerName) -> Result<()> {
debug_assert!(
destination.is_ip_literal() || !IPAddress::is_valid(destination.host()),
"Destination is not an IP literal."
);
debug!("Destination is an IP literal, checking against IP range denylist.",);
let ip = IPAddress::parse(destination.host()).map_err(|e| {
warn!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
validate_ip(&ip)?;
Ok(())
}
fn validate_ip(ip: &IPAddress) -> Result<()> {
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
trace!("List of pushed CIDR ranges: {:?}", cidr_ranges);
for cidr in cidr_ranges {
if cidr.includes(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
Ok(())
}
fn get_ip_with_port(destination_str: &str) -> Option<FedDest> {
if let Ok(destination) = destination_str.parse::<SocketAddr>() {
Some(FedDest::Literal(destination))
} else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() {
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
} else {
None
}
}
fn add_port_to_hostname(destination_str: &str) -> FedDest {
let (host, port) = match destination_str.find(':') {
None => (destination_str, ":8448"),
Some(pos) => destination_str.split_at(pos),
};
FedDest::Named(host.to_owned(), port.to_owned())
}
impl FedDest {
fn into_https_string(self) -> String {
match self {