From 90479ddae4453e56736c476677f61a072bd62956 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 23 Dec 2021 23:33:51 +0000 Subject: [PATCH] opt: make argument extraction code smaller --- CHANGELOG.md | 5 +- pyo3-macros-backend/src/params.rs | 12 +-- pytests/pyo3-benchmarks/tox.ini | 3 + src/derive_utils.rs | 141 ++++++++++++++++-------------- src/impl_.rs | 1 + src/impl_/extract_argument.rs | 30 +++++++ 6 files changed, 120 insertions(+), 72 deletions(-) create mode 100644 src/impl_/extract_argument.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index a4929599..3e58036e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,8 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 accompanies your error type in your crate's documentation. - Improve performance and error messages for `#[derive(FromPyObject)]` for enums. [#2068](https://github.com/PyO3/pyo3/pull/2068) - Reduce generated LLVM code size (to improve compile times) for: - - internal `handle_panic` helper [#2073](https://github.com/PyO3/pyo3/pull/2073) - - `#[pyclass]` type object creation [#2075](https://github.com/PyO3/pyo3/pull/2075) + - internal `handle_panic` helper [#2074](https://github.com/PyO3/pyo3/pull/2074) + - `#[pyfunction]` and `#[pymethods]` argument extraction [#2075](https://github.com/PyO3/pyo3/pull/2075) + - `#[pyclass]` type object creation [#2076](https://github.com/PyO3/pyo3/pull/2076) ### Removed diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index c15b3e2e..5e73a7bf 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -213,8 +213,9 @@ fn impl_arg_param( let ty = arg.ty; let name = arg.name; + let name_str = name.to_string(); let transform_error = quote! { - |e| _pyo3::derive_utils::argument_extraction_error(#py, stringify!(#name), e) + |e| _pyo3::impl_::extract_argument::argument_extraction_error(#py, #name_str, e) }; if is_args(&spec.attrs, name) { @@ -223,7 +224,7 @@ fn impl_arg_param( arg.name.span() => "args cannot be optional" ); return Ok(quote_arg_span! { - let #arg_name = _args.unwrap().extract().map_err(#transform_error)?; + let #arg_name = _pyo3::impl_::extract_argument::extract_argument(_args.unwrap(), #name_str)?; }); } else if is_kwargs(&spec.attrs, name) { ensure_spanned!( @@ -231,9 +232,8 @@ fn impl_arg_param( arg.name.span() => "kwargs must be Option<_>" ); return Ok(quote_arg_span! { - let #arg_name = _kwargs.map(|kwargs| kwargs.extract()) - .transpose() - .map_err(#transform_error)?; + let #arg_name = _kwargs.map(|kwargs| _pyo3::impl_::extract_argument::extract_argument(kwargs, #name_str)) + .transpose()?; }); } @@ -243,7 +243,7 @@ fn impl_arg_param( let extract = if let Some(FromPyWithAttribute(expr_path)) = &arg.attrs.from_py_with { quote_arg_span! { #expr_path(_obj).map_err(#transform_error) } } else { - quote_arg_span! { _obj.extract().map_err(#transform_error) } + quote_arg_span! { _pyo3::impl_::extract_argument::extract_argument(_obj, #name_str) } }; let arg_value_or_default = match (spec.default_value(name), arg.optional.is_some()) { diff --git a/pytests/pyo3-benchmarks/tox.ini b/pytests/pyo3-benchmarks/tox.ini index 1713e7eb..635f66aa 100644 --- a/pytests/pyo3-benchmarks/tox.ini +++ b/pytests/pyo3-benchmarks/tox.ini @@ -3,3 +3,6 @@ usedevelop = True description = Run the unit tests under {basepython} deps = -rrequirements-dev.txt commands = pytest --benchmark-sort=name {posargs} +# Use recreate so that tox always rebuilds, otherwise changes to Rust are not +# picked up. +recreate = True diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 0fc18ae2..ab4e6a0c 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -102,9 +102,12 @@ impl FunctionDescription { varkeywords } (Some(kwargs), false) => { - self.extract_keyword_arguments(kwargs, output, |name, _| { - Err(self.unexpected_keyword_argument(name)) - })?; + self.extract_keyword_arguments( + kwargs, + output, + #[cold] + |name, _| Err(self.unexpected_keyword_argument(name)), + )?; None } (None, _) => None, @@ -112,58 +115,39 @@ impl FunctionDescription { // Check that there's sufficient positional arguments once keyword arguments are specified if args_provided < self.required_positional_parameters { - let missing_positional_arguments: Vec<_> = self - .positional_parameter_names - .iter() - .take(self.required_positional_parameters) - .zip(output.iter()) - .filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None }) - .collect(); - if !missing_positional_arguments.is_empty() { - return Err( - self.missing_required_arguments("positional", &missing_positional_arguments) - ); + for out in &output[..self.required_positional_parameters] { + if out.is_none() { + return Err(self.missing_required_positional_arguments(output)); + } } } // Check no missing required keyword arguments - let missing_keyword_only_arguments: Vec<_> = self - .keyword_only_parameters - .iter() - .zip(&output[num_positional_parameters..]) - .filter_map(|(keyword_desc, out)| { - if keyword_desc.required && out.is_none() { - Some(keyword_desc.name) - } else { - None - } - }) - .collect(); - - if !missing_keyword_only_arguments.is_empty() { - return Err(self.missing_required_arguments("keyword", &missing_keyword_only_arguments)); + let keyword_output = &output[num_positional_parameters..]; + for (param, out) in self.keyword_only_parameters.iter().zip(keyword_output) { + if param.required && out.is_none() { + return Err(self.missing_required_keyword_arguments(keyword_output)); + } } Ok((varargs, varkeywords)) } - #[inline] fn extract_keyword_arguments<'p>( &self, kwargs: impl Iterator, output: &mut [Option<&'p PyAny>], mut unexpected_keyword_handler: impl FnMut(&'p PyAny, &'p PyAny) -> PyResult<()>, ) -> PyResult<()> { - let (args_output, kwargs_output) = - output.split_at_mut(self.positional_parameter_names.len()); + let positional_args_count = self.positional_parameter_names.len(); let mut positional_only_keyword_arguments = Vec::new(); - for (kwarg_name, value) in kwargs { - let utf8_string = match kwarg_name.downcast::()?.to_str() { - Ok(utf8_string) => utf8_string, + 'for_each_kwarg: for (kwarg_name_py, value) in kwargs { + let kwarg_name = match kwarg_name_py.downcast::()?.to_str() { + Ok(kwarg_name) => kwarg_name, // This keyword is not a UTF8 string: all PyO3 argument names are guaranteed to be // UTF8 by construction. Err(_) => { - unexpected_keyword_handler(kwarg_name, value)?; + unexpected_keyword_handler(kwarg_name_py, value)?; continue; } }; @@ -171,31 +155,24 @@ impl FunctionDescription { // Compare the keyword name against each parameter in turn. This is exactly the same method // which CPython uses to map keyword names. Although it's O(num_parameters), the number of // parameters is expected to be small so it's not worth constructing a mapping. - if let Some(i) = self - .keyword_only_parameters - .iter() - .position(|param| utf8_string == param.name) - { - kwargs_output[i] = Some(value); - continue; + for (i, param) in self.keyword_only_parameters.iter().enumerate() { + if param.name == kwarg_name { + output[positional_args_count + i] = Some(value); + continue 'for_each_kwarg; + } } // Repeat for positional parameters - if let Some((i, param)) = self - .positional_parameter_names - .iter() - .enumerate() - .find(|&(_, param)| utf8_string == *param) - { + if let Some(i) = self.find_keyword_parameter_in_positionals(kwarg_name) { if i < self.positional_only_parameters { - positional_only_keyword_arguments.push(*param); - } else if args_output[i].replace(value).is_some() { - return Err(self.multiple_values_for_argument(param)); + positional_only_keyword_arguments.push(kwarg_name); + } else if output[i].replace(value).is_some() { + return Err(self.multiple_values_for_argument(kwarg_name)); } continue; } - unexpected_keyword_handler(kwarg_name, value)?; + unexpected_keyword_handler(kwarg_name_py, value)?; } if positional_only_keyword_arguments.is_empty() { @@ -205,6 +182,16 @@ impl FunctionDescription { } } + fn find_keyword_parameter_in_positionals(&self, kwarg_name: &str) -> Option { + for (i, param_name) in self.positional_parameter_names.iter().enumerate() { + if *param_name == kwarg_name { + return Some(i); + } + } + None + } + + #[cold] fn too_many_positional_arguments(&self, args_provided: usize) -> PyErr { let was = if args_provided == 1 { "was" } else { "were" }; let msg = if self.required_positional_parameters != self.positional_parameter_names.len() { @@ -228,6 +215,7 @@ impl FunctionDescription { PyTypeError::new_err(msg) } + #[cold] fn multiple_values_for_argument(&self, argument: &str) -> PyErr { PyTypeError::new_err(format!( "{} got multiple values for argument '{}'", @@ -236,6 +224,7 @@ impl FunctionDescription { )) } + #[cold] fn unexpected_keyword_argument(&self, argument: &PyAny) -> PyErr { PyTypeError::new_err(format!( "{} got an unexpected keyword argument '{}'", @@ -244,6 +233,7 @@ impl FunctionDescription { )) } + #[cold] fn positional_only_keyword_arguments(&self, parameter_names: &[&str]) -> PyErr { let mut msg = format!( "{} got some positional-only arguments passed as keyword arguments: ", @@ -253,6 +243,7 @@ impl FunctionDescription { PyTypeError::new_err(msg) } + #[cold] fn missing_required_arguments(&self, argument_type: &str, parameter_names: &[&str]) -> PyErr { let arguments = if parameter_names.len() == 1 { "argument" @@ -269,18 +260,40 @@ impl FunctionDescription { push_parameter_list(&mut msg, parameter_names); PyTypeError::new_err(msg) } -} -/// Add the argument name to the error message of an error which occurred during argument extraction -pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr { - if error.is_instance_of::(py) { - let reason = error - .value(py) - .str() - .unwrap_or_else(|_| PyString::new(py, "")); - PyTypeError::new_err(format!("argument '{}': {}", arg_name, reason)) - } else { - error + #[cold] + fn missing_required_keyword_arguments(&self, keyword_outputs: &[Option<&PyAny>]) -> PyErr { + debug_assert_eq!(self.keyword_only_parameters.len(), keyword_outputs.len()); + + let missing_keyword_only_arguments: Vec<_> = self + .keyword_only_parameters + .iter() + .zip(keyword_outputs) + .filter_map(|(keyword_desc, out)| { + if keyword_desc.required && out.is_none() { + Some(keyword_desc.name) + } else { + None + } + }) + .collect(); + + debug_assert!(!missing_keyword_only_arguments.is_empty()); + self.missing_required_arguments("keyword", &missing_keyword_only_arguments) + } + + #[cold] + fn missing_required_positional_arguments(&self, output: &[Option<&PyAny>]) -> PyErr { + let missing_positional_arguments: Vec<_> = self + .positional_parameter_names + .iter() + .take(self.required_positional_parameters) + .zip(output) + .filter_map(|(param, out)| if out.is_none() { Some(*param) } else { None }) + .collect(); + + debug_assert!(!missing_positional_arguments.is_empty()); + self.missing_required_arguments("positional", &missing_positional_arguments) } } diff --git a/src/impl_.rs b/src/impl_.rs index 586ce1a5..0f61b91e 100644 --- a/src/impl_.rs +++ b/src/impl_.rs @@ -5,6 +5,7 @@ //! breaking semver guarantees. pub mod deprecations; +pub mod extract_argument; pub mod freelist; #[doc(hidden)] pub mod frompyobject; diff --git a/src/impl_/extract_argument.rs b/src/impl_/extract_argument.rs new file mode 100644 index 00000000..610e76cd --- /dev/null +++ b/src/impl_/extract_argument.rs @@ -0,0 +1,30 @@ +use crate::{ + exceptions::PyTypeError, type_object::PyTypeObject, FromPyObject, PyAny, PyErr, PyResult, + Python, +}; + +#[doc(hidden)] +#[inline] +pub fn extract_argument<'py, T>(obj: &'py PyAny, arg_name: &str) -> PyResult +where + T: FromPyObject<'py>, +{ + match obj.extract() { + Ok(e) => Ok(e), + Err(e) => Err(argument_extraction_error(obj.py(), arg_name, e)), + } +} + +/// Adds the argument name to the error message of an error which occurred during argument extraction. +/// +/// Only modifies TypeError. (Cannot guarantee all exceptions have constructors from +/// single string.) +#[doc(hidden)] +#[cold] +pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr { + if error.get_type(py) == PyTypeError::type_object(py) { + PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py))) + } else { + error + } +}