From 8f87b8636df73ef7d1690bd848389382280aa335 Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Mon, 1 Apr 2024 14:10:18 +0200 Subject: [PATCH] refactor `#[setter]` argument extraction (#4002) --- pyo3-macros-backend/src/method.rs | 4 ++ pyo3-macros-backend/src/params.rs | 39 ++++++------ pyo3-macros-backend/src/pymethod.rs | 95 +++++++++++++++++------------ src/impl_/extract_argument.rs | 4 +- tests/ui/static_ref.stderr | 23 +++---- 5 files changed, 96 insertions(+), 69 deletions(-) diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index f4fdb193..155c5540 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -63,6 +63,10 @@ impl<'a> FnArg<'a> { } } } + + pub fn is_regular(&self) -> bool { + !self.py && !self.is_cancel_handle && !self.is_kwargs && !self.is_varargs + } } fn handle_argument_error(pat: &syn::Pat) -> syn::Error { diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index cab28698..fa50d260 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -73,7 +73,7 @@ pub fn is_forwarded_args(signature: &FunctionSignature<'_>) -> bool { ) } -fn check_arg_for_gil_refs( +pub(crate) fn check_arg_for_gil_refs( tokens: TokenStream, gil_refs_checker: syn::Ident, ctx: &Ctx, @@ -120,7 +120,11 @@ pub fn impl_arg_params( .iter() .enumerate() .map(|(i, arg)| { - impl_arg_param(arg, i, &mut 0, &args_array, holders, ctx).map(|tokens| { + let from_py_with = + syn::Ident::new(&format!("from_py_with_{}", i), Span::call_site()); + let arg_value = quote!(#args_array[0].as_deref()); + + impl_arg_param(arg, from_py_with, arg_value, holders, ctx).map(|tokens| { check_arg_for_gil_refs( tokens, holders.push_gil_refs_checker(arg.ty.span()), @@ -161,14 +165,20 @@ pub fn impl_arg_params( let num_params = positional_parameter_names.len() + keyword_only_parameters.len(); - let mut option_pos = 0; + let mut option_pos = 0usize; let param_conversion = spec .signature .arguments .iter() .enumerate() .map(|(i, arg)| { - impl_arg_param(arg, i, &mut option_pos, &args_array, holders, ctx).map(|tokens| { + let from_py_with = syn::Ident::new(&format!("from_py_with_{}", i), Span::call_site()); + let arg_value = quote!(#args_array[#option_pos].as_deref()); + if arg.is_regular() { + option_pos += 1; + } + + impl_arg_param(arg, from_py_with, arg_value, holders, ctx).map(|tokens| { check_arg_for_gil_refs(tokens, holders.push_gil_refs_checker(arg.ty.span()), ctx) }) }) @@ -234,11 +244,10 @@ pub fn impl_arg_params( /// Re option_pos: The option slice doesn't contain the py: Python argument, so the argument /// index and the index in option diverge when using py: Python -fn impl_arg_param( +pub(crate) fn impl_arg_param( arg: &FnArg<'_>, - pos: usize, - option_pos: &mut usize, - args_array: &syn::Ident, + from_py_with: syn::Ident, + arg_value: TokenStream, // expected type: Option<&'a Bound<'py, PyAny>> holders: &mut Holders, ctx: &Ctx, ) -> Result { @@ -291,9 +300,6 @@ fn impl_arg_param( }); } - let arg_value = quote_arg_span!(#args_array[#option_pos]); - *option_pos += 1; - let mut default = arg.default.as_ref().map(|expr| quote!(#expr)); // Option arguments have special treatment: the default should be specified _without_ the @@ -312,11 +318,10 @@ fn impl_arg_param( .map(|attr| &attr.value) .is_some() { - let from_py_with = syn::Ident::new(&format!("from_py_with_{}", pos), Span::call_site()); if let Some(default) = default { quote_arg_span! { #pyo3_path::impl_::extract_argument::from_py_with_with_default( - #arg_value.as_deref(), + #arg_value, #name_str, #from_py_with as fn(_) -> _, #[allow(clippy::redundant_closure)] @@ -328,7 +333,7 @@ fn impl_arg_param( } else { quote_arg_span! { #pyo3_path::impl_::extract_argument::from_py_with( - &#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value), + #pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value), #name_str, #from_py_with as fn(_) -> _, )? @@ -338,7 +343,7 @@ fn impl_arg_param( let holder = holders.push_holder(arg.name.span()); quote_arg_span! { #pyo3_path::impl_::extract_argument::extract_optional_argument( - #arg_value.as_deref(), + #arg_value, &mut #holder, #name_str, #[allow(clippy::redundant_closure)] @@ -351,7 +356,7 @@ fn impl_arg_param( let holder = holders.push_holder(arg.name.span()); quote_arg_span! { #pyo3_path::impl_::extract_argument::extract_argument_with_default( - #arg_value.as_deref(), + #arg_value, &mut #holder, #name_str, #[allow(clippy::redundant_closure)] @@ -364,7 +369,7 @@ fn impl_arg_param( let holder = holders.push_holder(arg.name.span()); quote_arg_span! { #pyo3_path::impl_::extract_argument::extract_argument( - &#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value), + #pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value), &mut #holder, #name_str )? diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 22802a01..ee7d3d7a 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use crate::attributes::{NameAttribute, RenamingRule}; use crate::method::{CallingConvention, ExtractErrorMode}; -use crate::params::Holders; +use crate::params::{check_arg_for_gil_refs, impl_arg_param, Holders}; use crate::utils::Ctx; use crate::utils::PythonDoc; use crate::{ @@ -586,48 +586,63 @@ pub fn impl_py_setter_def( } }; - // TODO: rework this to make use of `impl_::params::impl_arg_param` which - // handles all these cases already. - let extract = if let PropertyType::Function { spec, .. } = &property_type { - Some(spec) - } else { - None - } - .and_then(|spec| { - let (_, args) = split_off_python_arg(&spec.signature.arguments); - let value_arg = &args[0]; - let from_py_with = &value_arg.attrs.from_py_with.as_ref()?.value; - let name = value_arg.name.to_string(); + let extract = match &property_type { + PropertyType::Function { spec, .. } => { + let (_, args) = split_off_python_arg(&spec.signature.arguments); + let value_arg = &args[0]; + let (from_py_with, ident) = if let Some(from_py_with) = + &value_arg.attrs.from_py_with.as_ref().map(|f| &f.value) + { + let ident = syn::Ident::new("from_py_with", from_py_with.span()); + ( + quote_spanned! { from_py_with.span() => + let e = #pyo3_path::impl_::deprecations::GilRefs::new(); + let #ident = #pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &e); + e.from_py_with_arg(); + }, + ident, + ) + } else { + (quote!(), syn::Ident::new("dummy", Span::call_site())) + }; - Some(quote_spanned! { from_py_with.span() => - let e = #pyo3_path::impl_::deprecations::GilRefs::new(); - let from_py_with = #pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &e); - e.from_py_with_arg(); - let _val = #pyo3_path::impl_::extract_argument::from_py_with( - &_value.into(), - #name, - from_py_with as fn(_) -> _, - )?; - }) - }) - .unwrap_or_else(|| { - let (span, name) = match &property_type { - PropertyType::Descriptor { field, .. } => (field.ty.span(), field.ident.as_ref().map(|i|i.to_string()).unwrap_or_default()), - PropertyType::Function { spec, .. } => { - let (_, args) = split_off_python_arg(&spec.signature.arguments); - (args[0].ty.span(), args[0].name.to_string()) + let extract = impl_arg_param( + &args[0], + ident, + quote!(::std::option::Option::Some(_value.into())), + &mut holders, + ctx, + ) + .map(|tokens| { + check_arg_for_gil_refs( + tokens, + holders.push_gil_refs_checker(value_arg.ty.span()), + ctx, + ) + })?; + quote! { + #from_py_with + let _val = #extract; } - }; - - let holder = holders.push_holder(span); - let gil_refs_checker = holders.push_gil_refs_checker(span); - quote! { - let _val = #pyo3_path::impl_::deprecations::inspect_type( - #pyo3_path::impl_::extract_argument::extract_argument(_value.into(), &mut #holder, #name)?, - &#gil_refs_checker - ); } - }); + PropertyType::Descriptor { field, .. } => { + let span = field.ty.span(); + let name = field + .ident + .as_ref() + .map(|i| i.to_string()) + .unwrap_or_default(); + + let holder = holders.push_holder(span); + let gil_refs_checker = holders.push_gil_refs_checker(span); + quote! { + let _val = #pyo3_path::impl_::deprecations::inspect_type( + #pyo3_path::impl_::extract_argument::extract_argument(_value.into(), &mut #holder, #name)?, + &#gil_refs_checker + ); + } + } + }; let mut cfg_attrs = TokenStream::new(); if let PropertyType::Descriptor { field, .. } = &property_type { diff --git a/src/impl_/extract_argument.rs b/src/impl_/extract_argument.rs index 4dcef02c..485b8645 100644 --- a/src/impl_/extract_argument.rs +++ b/src/impl_/extract_argument.rs @@ -223,7 +223,9 @@ pub fn argument_extraction_error(py: Python<'_>, arg_name: &str, error: PyErr) - /// `argument` must not be `None` #[doc(hidden)] #[inline] -pub unsafe fn unwrap_required_argument(argument: Option>) -> PyArg<'_> { +pub unsafe fn unwrap_required_argument<'a, 'py>( + argument: Option<&'a Bound<'py, PyAny>>, +) -> &'a Bound<'py, PyAny> { match argument { Some(value) => value, #[cfg(debug_assertions)] diff --git a/tests/ui/static_ref.stderr b/tests/ui/static_ref.stderr index 50b054f6..6004c403 100644 --- a/tests/ui/static_ref.stderr +++ b/tests/ui/static_ref.stderr @@ -9,6 +9,18 @@ error: lifetime may not live long enough | = note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info) +error[E0597]: `output[_]` does not live long enough + --> tests/ui/static_ref.rs:4:1 + | +4 | #[pyfunction] + | ^^^^^^^^^^^^- + | | | + | | `output[_]` dropped here while still borrowed + | borrowed value does not live long enough + | argument requires that `output[_]` is borrowed for `'static` + | + = note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info) + error[E0597]: `holder_0` does not live long enough --> tests/ui/static_ref.rs:5:15 | @@ -21,17 +33,6 @@ error[E0597]: `holder_0` does not live long enough 5 | fn static_ref(list: &'static Bound<'_, PyList>) -> usize { | ^^^^^^^ borrowed value does not live long enough -error[E0716]: temporary value dropped while borrowed - --> tests/ui/static_ref.rs:5:21 - | -4 | #[pyfunction] - | ------------- - | | | - | | temporary value is freed at the end of this statement - | argument requires that borrow lasts for `'static` -5 | fn static_ref(list: &'static Bound<'_, PyList>) -> usize { - | ^ creates a temporary value which is freed while still in use - error: lifetime may not live long enough --> tests/ui/static_ref.rs:9:1 |