[derive_utils] Copy kwargs not to modify it

This commit is contained in:
kngwyu 2019-09-01 23:59:24 +09:00
parent 96b71bfb76
commit c7e377a472
2 changed files with 23 additions and 16 deletions

View file

@ -475,12 +475,12 @@ pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream {
// Workaround to use the question mark operator without rewriting everything
let _result = (|| {
pyo3::derive_utils::parse_fn_args(
let (_args, _kwargs) = pyo3::derive_utils::parse_fn_args(
_py,
Some(_LOCATION),
PARAMS,
&mut _args,
&mut _kwargs,
_args,
_kwargs,
#accept_args,
#accept_kwargs,
&mut output

View file

@ -36,12 +36,12 @@ pub fn parse_fn_args<'p>(
py: Python<'p>,
fname: Option<&str>,
params: &[ParamDescription],
args: &mut &'p PyTuple,
kwargs: &mut Option<&'p PyDict>,
args: &'p PyTuple,
kwargs: Option<&'p PyDict>,
accept_args: bool,
accept_kwargs: bool,
output: &mut [Option<&'p PyAny>],
) -> PyResult<()> {
) -> PyResult<(&'p PyTuple, Option<&'p PyDict>)> {
let nargs = args.len();
let mut used_args = 0;
macro_rules! raise_error {
@ -49,6 +49,11 @@ pub fn parse_fn_args<'p>(
concat!("{} ", $s), fname.unwrap_or("function") $(,$arg)*
))))
}
// Copy kwargs not to modify it
let kwargs = match kwargs {
Some(k) => Some(k.copy()?),
None => None,
};
// Iterate through the parameters and assign values to output:
for (i, (p, out)) in params.iter().zip(output).enumerate() {
*out = match kwargs.and_then(|d| d.get_item(p.name)) {
@ -93,16 +98,18 @@ pub fn parse_fn_args<'p>(
)
}
// Adjust the remaining args
if accept_args {
let slice = args
.slice(used_args as isize, nargs as isize)
.into_object(py);
*args = py.checked_cast_as(slice).unwrap();
}
if accept_kwargs && is_kwargs_empty {
*kwargs = None;
}
Ok(())
let args = if accept_args {
let slice = args.slice(used_args as isize, nargs as isize).into_py(py);
py.checked_cast_as(slice).unwrap()
} else {
args
};
let kwargs = if accept_kwargs && is_kwargs_empty {
None
} else {
kwargs
};
Ok((args, kwargs))
}
/// Builds a module (or null) from a user given initializer. Used for `#[pymodule]`.