pyfunction: refactor argument extraction

This commit is contained in:
David Hewitt 2021-02-21 18:02:22 +00:00
parent ffd5874c3a
commit 29a525b327
5 changed files with 293 additions and 127 deletions

View File

@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fix FFI definition `_PyEval_RequestCodeExtraIndex` which took an argument of the wrong type. [#1429](https://github.com/PyO3/pyo3/pull/1429)
- Fix FFI definition `PyIndex_Check` missing with the `abi3` feature. [#1436](https://github.com/PyO3/pyo3/pull/1436)
- Fix incorrect `TypeError` raised when keyword-only argument passed along with a positional argument in `*args`. [#1440](https://github.com/PyO3/pyo3/pull/1440)
- Fix inability to use a named lifetime for `&PyTuple` of `*args` in `#[pyfunction]`. [#1440](https://github.com/PyO3/pyo3/pull/1440)
## [0.13.2] - 2021-02-12
### Packaging

View File

@ -143,12 +143,11 @@ def test_time_fold(fold):
assert t.fold == fold
@pytest.mark.xfail
@pytest.mark.parametrize(
"args", [(-1, 0, 0, 0), (0, -1, 0, 0), (0, 0, -1, 0), (0, 0, 0, -1)]
)
def test_invalid_time_fails_xfail(args):
with pytest.raises(ValueError):
def test_invalid_time_fails_overflow(args):
with pytest.raises(OverflowError):
rdt.make_time(*args)

View File

@ -406,8 +406,9 @@ pub fn impl_arg_params(
};
}
let mut params = Vec::new();
let mut num_positional_params = 0usize;
let mut positional_parameter_names = Vec::new();
let mut required_positional_parameters = 0usize;
let mut keyword_only_parameters = Vec::new();
for arg in spec.args.iter() {
if arg.py || spec.is_args(&arg.name) || spec.is_kwargs(&arg.name) {
@ -415,22 +416,24 @@ pub fn impl_arg_params(
}
let name = arg.name.unraw().to_string();
let kwonly = spec.is_kw_only(&arg.name);
let opt = arg.optional.is_some() || spec.default_value(&arg.name).is_some();
let required = !(arg.optional.is_some() || spec.default_value(&arg.name).is_some());
if !kwonly {
num_positional_params += 1;
}
params.push(quote! {
pyo3::derive_utils::ParamDescription {
name: #name,
is_optional: #opt,
kw_only: #kwonly
if kwonly {
keyword_only_parameters.push(quote! {
pyo3::derive_utils::KeywordOnlyParameterDescription {
name: #name,
required: #required,
}
});
} else {
if required {
required_positional_parameters += 1;
}
});
positional_parameter_names.push(name);
}
}
let num_normal_params = params.len();
let num_params = positional_parameter_names.len() + keyword_only_parameters.len();
let mut param_conversion = Vec::new();
let mut option_pos = 0;
@ -451,21 +454,19 @@ pub fn impl_arg_params(
// create array of arguments, and then parse
quote! {{
const PARAMS: &'static [pyo3::derive_utils::ParamDescription] = &[
#(#params),*
];
const DESCRIPTION: pyo3::derive_utils::FunctionDescription = pyo3::derive_utils::FunctionDescription {
fname: _LOCATION,
positional_parameter_names: &[#(#positional_parameter_names),*],
// TODO: https://github.com/PyO3/pyo3/issues/1439 - support specifying these
positional_only_parameters: 0,
required_positional_parameters: #required_positional_parameters,
keyword_only_parameters: &[#(#keyword_only_parameters),*],
accept_varargs: #accept_args,
accept_varkeywords: #accept_kwargs,
};
let mut output = [None; #num_normal_params];
let (_args, _kwargs) = pyo3::derive_utils::parse_fn_args(
_LOCATION,
PARAMS,
_args,
_kwargs,
#num_positional_params,
#accept_args,
#accept_kwargs,
&mut output
)?;
let mut output = [None; #num_params];
let (_args, _kwargs) = DESCRIPTION.extract_arguments(_args, _kwargs, &mut output)?;
#(#param_conversion)*
@ -497,15 +498,31 @@ fn impl_arg_param(
};
if spec.is_args(&name) {
return quote! {
let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())
.map_err(#transform_error)?;
return if arg.optional.is_some() {
quote! {
let #arg_name = _args.map(|args| args.extract())
.transpose()
.map_err(#transform_error)?;
}
} else {
quote! {
let #arg_name = _args.unwrap().extract()
.map_err(#transform_error)?;
}
};
} else if spec.is_kwargs(&name) {
// FIXME: check the below?
// ensure_spanned!(
// arg.optional.is_some(),
// arg.name.span() => "kwargs must be Option<_>"
// );
return quote! {
let #arg_name = _kwargs;
let #arg_name = _kwargs.map(|kwargs| kwargs.extract())
.transpose()
.map_err(#transform_error)?;
};
}
let arg_value = quote!(output[#option_pos]);
*option_pos += 1;

View File

@ -10,111 +10,245 @@ use crate::exceptions::PyTypeError;
use crate::instance::PyNativeType;
use crate::pyclass::PyClass;
use crate::types::{PyAny, PyDict, PyModule, PyString, PyTuple};
use crate::{ffi, GILPool, IntoPy, PyCell, Python};
use crate::{ffi, GILPool, PyCell, Python};
use std::cell::UnsafeCell;
/// Description of a python parameter; used for `parse_args()`.
#[derive(Debug)]
pub struct ParamDescription {
/// The name of the parameter.
pub struct KeywordOnlyParameterDescription {
pub name: &'static str,
/// Whether the parameter is optional.
pub is_optional: bool,
/// Whether the parameter is optional.
pub kw_only: bool,
pub required: bool,
}
/// Parse argument list
///
/// * fname: Name of the current function
/// * params: Declared parameters of the function
/// * args: Positional arguments
/// * kwargs: Keyword arguments
/// * output: Output array that receives the arguments.
/// Must have same length as `params` and must be initialized to `None`.
pub fn parse_fn_args<'p>(
fname: &str,
params: &[ParamDescription],
args: &'p PyTuple,
kwargs: Option<&'p PyDict>,
num_positional_params: usize,
accept_args: bool,
accept_kwargs: bool,
output: &mut [Option<&'p PyAny>],
) -> PyResult<(&'p PyTuple, Option<&'p PyDict>)> {
macro_rules! raise_error {
($s: expr $(,$arg:expr)*) => (return Err(PyTypeError::new_err(format!(
concat!("{} ", $s), fname $(,$arg)*
))))
}
#[derive(Debug)]
pub struct FunctionDescription {
pub fname: &'static str,
pub positional_parameter_names: &'static [&'static str],
pub positional_only_parameters: usize,
pub required_positional_parameters: usize,
pub keyword_only_parameters: &'static [KeywordOnlyParameterDescription],
pub accept_varargs: bool,
pub accept_varkeywords: bool,
}
let nargs = args.len();
let provided_positional_args = std::cmp::min(nargs, num_positional_params);
impl FunctionDescription {
pub fn extract_arguments<'p>(
&self,
args: &'p PyTuple,
kwargs: Option<&'p PyDict>,
output: &mut [Option<&'p PyAny>],
) -> PyResult<(Option<&'p PyTuple>, Option<&'p PyDict>)> {
let num_positional_parameters = self.positional_parameter_names.len();
// Copy kwargs not to modify it
let kwargs = match kwargs {
Some(k) => Some(k.copy()?),
None => None,
};
// Iterate through the parameters and assign values to output:
for (i, (p, out)) in params.iter().zip(output).enumerate() {
*out = match kwargs.and_then(|d| d.get_item(p.name)) {
Some(kwarg) => {
if i < provided_positional_args {
raise_error!("got multiple values for argument '{}'", p.name)
}
kwargs.as_ref().unwrap().del_item(p.name)?;
Some(kwarg)
debug_assert!(self.positional_only_parameters <= num_positional_parameters);
debug_assert!(self.required_positional_parameters <= num_positional_parameters);
debug_assert_eq!(
output.len(),
num_positional_parameters + self.keyword_only_parameters.len()
);
// Handle positional arguments
let (args_provided, varargs) = {
let args_provided = args.len();
if self.accept_varargs {
(
std::cmp::min(num_positional_parameters, args_provided),
Some(args.slice(num_positional_parameters as isize, args_provided as isize)),
)
} else if args_provided > num_positional_parameters {
return Err(self.too_many_positional_arguments(args_provided));
} else {
(args_provided, None)
}
None => {
if p.kw_only {
if !p.is_optional {
raise_error!("missing required keyword-only argument '{}'", p.name)
}
None
} else if i < nargs {
Some(args.get_item(i))
} else {
if !p.is_optional {
raise_error!("missing required positional argument '{}'", p.name)
}
None
}
};
// Copy positional arguments into output
for (out, arg) in output[..args_provided].iter_mut().zip(args) {
*out = Some(arg);
}
// Handle keyword arguments
let varkeywords = match (kwargs, self.accept_varkeywords) {
(Some(kwargs), true) => {
let mut varkeywords = None;
self.extract_keyword_arguments(kwargs, output, |name, value| {
varkeywords
.get_or_insert_with(|| PyDict::new(kwargs.py()))
.set_item(name, value)
})?;
varkeywords
}
(Some(kwargs), false) => {
self.extract_keyword_arguments(kwargs, output, |name, _| {
Err(self.unexpected_keyword_argument(name))
})?;
None
}
(None, _) => None,
};
// Check that there's sufficient positional arguments once keyword arguments are specified
if args_provided < self.required_positional_parameters {
let missing_positional_arguments: Vec<_> = self.positional_parameter_names
[..self.required_positional_parameters]
.iter()
.copied()
.zip(output.iter())
.filter_map(|(param, out)| if out.is_none() { Some(param) } else { None })
.collect();
if !missing_positional_arguments.is_empty() {
return Err(
self.missing_required_arguments("positional", &missing_positional_arguments)
);
}
}
// Check no missing required keyword arguments
let missing_keyword_only_arguments: Vec<_> = self
.keyword_only_parameters
.iter()
.zip(&output[num_positional_parameters..])
.filter_map(|(keyword_desc, out)| {
if keyword_desc.required && out.is_none() {
Some(keyword_desc.name)
} else {
None
}
})
.collect();
if !missing_keyword_only_arguments.is_empty() {
return Err(self.missing_required_arguments("keyword", &missing_keyword_only_arguments));
}
Ok((varargs, varkeywords))
}
let is_kwargs_empty = kwargs.as_ref().map_or(true, |dict| dict.is_empty());
// Raise an error when we get an unknown key
if !accept_kwargs && !is_kwargs_empty {
let (key, _) = kwargs.unwrap().iter().next().unwrap();
raise_error!("got an unexpected keyword argument: {}", key)
#[inline]
fn extract_keyword_arguments<'p>(
&self,
kwargs: &'p PyDict,
output: &mut [Option<&'p PyAny>],
mut unexpected_keyword_handler: impl FnMut(&'p PyAny, &'p PyAny) -> PyResult<()>,
) -> PyResult<()> {
let (args_output, kwargs_output) =
output.split_at_mut(self.positional_parameter_names.len());
let mut positional_only_keyword_arguments = Vec::new();
'kwarg_loop: for (kwarg_name, value) in kwargs {
let utf8_string = match kwarg_name.downcast::<PyString>()?.to_str() {
Ok(utf8_string) => utf8_string,
// This keyword is not a UTF8 string: all PyO3 argument names are guaranteed to be
// UTF8 by construction.
Err(_) => {
unexpected_keyword_handler(kwarg_name, value)?;
continue 'kwarg_loop;
}
};
// Compare the keyword name against each parameter in turn. This is exactly the same method
// which CPython uses to map keyword names. Although it's O(num_parameters), the number of
// parameters is expected to be small so it's not worth constructing a mapping.
for (param, out) in self.keyword_only_parameters.iter().zip(&mut *kwargs_output) {
if utf8_string == param.name {
*out = Some(value);
continue 'kwarg_loop;
}
}
// Repeat for positional parameters
for (i, (&param, out)) in self
.positional_parameter_names
.iter()
.zip(&mut *args_output)
.enumerate()
{
if utf8_string == param {
if i < self.positional_only_parameters {
positional_only_keyword_arguments.push(param);
} else {
match out {
Some(_) => return Err(self.multiple_values_for_argument(param)),
None => {
*out = Some(value);
}
}
}
continue 'kwarg_loop;
}
}
unexpected_keyword_handler(kwarg_name, value)?;
}
if positional_only_keyword_arguments.is_empty() {
Ok(())
} else {
Err(self.positional_only_keyword_arguments(&positional_only_keyword_arguments))
}
}
// Raise an error when we get too many positional args
if !accept_args && num_positional_params < nargs {
raise_error!(
"takes {} positional argument{} but {} {} given",
num_positional_params,
if num_positional_params == 1 { "" } else { "s" },
nargs,
if nargs == 1 { "was" } else { "were" }
)
fn too_many_positional_arguments(&self, args_provided: usize) -> PyErr {
let was = if args_provided == 1 { "was" } else { "were" };
let msg = if self.required_positional_parameters != self.positional_parameter_names.len() {
format!(
"{} takes from {} to {} positional arguments but {} {} given",
self.fname,
self.required_positional_parameters,
self.positional_parameter_names.len(),
args_provided,
was
)
} else {
format!(
"{} takes {} positional arguments but {} {} given",
self.fname,
self.positional_parameter_names.len(),
args_provided,
was
)
};
PyTypeError::new_err(msg)
}
fn multiple_values_for_argument(&self, argument: &str) -> PyErr {
PyTypeError::new_err(format!(
"{} got multiple values for argument '{}'",
self.fname, argument
))
}
fn unexpected_keyword_argument(&self, argument: &PyAny) -> PyErr {
PyTypeError::new_err(format!(
"{} got an unexpected keyword argument '{}'",
self.fname, argument
))
}
fn positional_only_keyword_arguments(&self, parameter_names: &[&str]) -> PyErr {
let mut msg = format!(
"{} got some positional-only arguments passed as keyword arguments: ",
self.fname
);
write_parameter_list(&mut msg, parameter_names);
PyTypeError::new_err(msg)
}
fn missing_required_arguments(&self, argument_type: &str, parameter_names: &[&str]) -> PyErr {
let arguments = if parameter_names.len() == 1 {
"argument"
} else {
"arguments"
};
let mut msg = format!(
"{} missing {} required {} {}: ",
self.fname,
parameter_names.len(),
argument_type,
arguments,
);
write_parameter_list(&mut msg, parameter_names);
PyTypeError::new_err(msg)
}
// Adjust the remaining args
let args = if accept_args {
let py = args.py();
let slice = args
.slice(num_positional_params as isize, nargs as isize)
.into_py(py);
py.checked_cast_as(slice).unwrap()
} else {
args
};
let kwargs = if accept_kwargs && is_kwargs_empty {
None
} else {
kwargs
};
Ok((args, kwargs))
}
/// Add the argument name to the error message of an error which occurred during argument extraction
@ -257,3 +391,19 @@ impl<'a> From<&'a PyModule> for PyFunctionArguments<'a> {
PyFunctionArguments::PyModule(module)
}
}
fn write_parameter_list(msg: &mut String, parameter_names: &[&str]) {
for (i, parameter) in parameter_names.iter().enumerate() {
if i != 0 && parameter_names.len() > 2 {
msg.push(',');
}
if i == parameter_names.len() - 1 {
msg.push_str(" and ")
}
msg.push('\'');
msg.push_str(parameter);
msg.push('\'');
}
}

View File

@ -49,7 +49,6 @@ impl InstanceMethodWithArgs {
}
#[test]
#[allow(dead_code)]
fn instance_method_with_args() {
let gil = Python::acquire_gil();
let py = gil.python();