Also apply holder lifetime extension to slot implementations.

This commit is contained in:
Adam Reichold 2023-05-06 13:19:52 +02:00
parent 27019b5523
commit f03ccf204c

View file

@ -913,11 +913,13 @@ impl Ty {
ident: &syn::Ident, ident: &syn::Ident,
arg: &FnArg<'_>, arg: &FnArg<'_>,
extract_error_mode: ExtractErrorMode, extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> TokenStream { ) -> TokenStream {
let name_str = arg.name.unraw().to_string(); let name_str = arg.name.unraw().to_string();
match self { match self {
Ty::Object => extract_object( Ty::Object => extract_object(
extract_error_mode, extract_error_mode,
holders,
&name_str, &name_str,
quote! { quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>(#ident) py.from_borrowed_ptr::<_pyo3::PyAny>(#ident)
@ -925,6 +927,7 @@ impl Ty {
), ),
Ty::MaybeNullObject => extract_object( Ty::MaybeNullObject => extract_object(
extract_error_mode, extract_error_mode,
holders,
&name_str, &name_str,
quote! { quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>( py.from_borrowed_ptr::<_pyo3::PyAny>(
@ -938,6 +941,7 @@ impl Ty {
), ),
Ty::NonNullObject => extract_object( Ty::NonNullObject => extract_object(
extract_error_mode, extract_error_mode,
holders,
&name_str, &name_str,
quote! { quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>(#ident.as_ptr()) py.from_borrowed_ptr::<_pyo3::PyAny>(#ident.as_ptr())
@ -945,6 +949,7 @@ impl Ty {
), ),
Ty::IPowModulo => extract_object( Ty::IPowModulo => extract_object(
extract_error_mode, extract_error_mode,
holders,
&name_str, &name_str,
quote! { quote! {
#ident.to_borrowed_any(py) #ident.to_borrowed_any(py)
@ -972,13 +977,19 @@ impl Ty {
fn extract_object( fn extract_object(
extract_error_mode: ExtractErrorMode, extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
name: &str, name: &str,
source: TokenStream, source: TokenStream,
) -> TokenStream { ) -> TokenStream {
let holder = syn::Ident::new(&format!("holder_{}", holders.len()), Span::call_site());
holders.push(quote! {
#[allow(clippy::let_unit_value)]
let mut #holder = _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT;
});
extract_error_mode.handle_error(quote! { extract_error_mode.handle_error(quote! {
_pyo3::impl_::extract_argument::extract_argument( _pyo3::impl_::extract_argument::extract_argument(
#source, #source,
&mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, &mut #holder,
#name #name
) )
}) })
@ -1088,11 +1099,13 @@ impl SlotDef {
.collect(); .collect();
let wrapper_ident = format_ident!("__pymethod_{}__", method_name); let wrapper_ident = format_ident!("__pymethod_{}__", method_name);
let ret_ty = ret_ty.ffi_type(); let ret_ty = ret_ty.ffi_type();
let mut holders = Vec::new();
let body = generate_method_body( let body = generate_method_body(
cls, cls,
spec, spec,
arguments, arguments,
*extract_error_mode, *extract_error_mode,
&mut holders,
return_mode.as_ref(), return_mode.as_ref(),
)?; )?;
let name = spec.name; let name = spec.name;
@ -1104,6 +1117,7 @@ impl SlotDef {
) -> _pyo3::PyResult<#ret_ty> { ) -> _pyo3::PyResult<#ret_ty> {
let function = #cls::#name; // Shadow the method name to avoid #3017 let function = #cls::#name; // Shadow the method name to avoid #3017
let _slf = _raw_slf; let _slf = _raw_slf;
#( #holders )*
#body #body
} }
}; };
@ -1137,11 +1151,12 @@ fn generate_method_body(
spec: &FnSpec<'_>, spec: &FnSpec<'_>,
arguments: &[Ty], arguments: &[Ty],
extract_error_mode: ExtractErrorMode, extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
return_mode: Option<&ReturnMode>, return_mode: Option<&ReturnMode>,
) -> Result<TokenStream> { ) -> Result<TokenStream> {
let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode); let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode);
let rust_name = spec.name; let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode)?; let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?;
let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) }; let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) };
Ok(if let Some(return_mode) = return_mode { Ok(if let Some(return_mode) = return_mode {
return_mode.return_call_output(call) return_mode.return_call_output(call)
@ -1191,7 +1206,15 @@ impl SlotFragmentDef {
let arg_idents: &Vec<_> = &(0..arguments.len()) let arg_idents: &Vec<_> = &(0..arguments.len())
.map(|i| format_ident!("arg{}", i)) .map(|i| format_ident!("arg{}", i))
.collect(); .collect();
let body = generate_method_body(cls, spec, arguments, *extract_error_mode, None)?; let mut holders = Vec::new();
let body = generate_method_body(
cls,
spec,
arguments,
*extract_error_mode,
&mut holders,
None,
)?;
let ret_ty = ret_ty.ffi_type(); let ret_ty = ret_ty.ffi_type();
Ok(quote! { Ok(quote! {
impl _pyo3::impl_::pyclass::#fragment_trait<#cls> for _pyo3::impl_::pyclass::PyClassImplCollector<#cls> { impl _pyo3::impl_::pyclass::#fragment_trait<#cls> for _pyo3::impl_::pyclass::PyClassImplCollector<#cls> {
@ -1210,6 +1233,7 @@ impl SlotFragmentDef {
#(#arg_idents: #arg_types),* #(#arg_idents: #arg_types),*
) -> _pyo3::PyResult<#ret_ty> { ) -> _pyo3::PyResult<#ret_ty> {
let _slf = _raw_slf; let _slf = _raw_slf;
#( #holders )*
#body #body
} }
} }
@ -1298,6 +1322,7 @@ fn extract_proto_arguments(
spec: &FnSpec<'_>, spec: &FnSpec<'_>,
proto_args: &[Ty], proto_args: &[Ty],
extract_error_mode: ExtractErrorMode, extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> Result<Vec<TokenStream>> { ) -> Result<Vec<TokenStream>> {
let mut args = Vec::with_capacity(spec.signature.arguments.len()); let mut args = Vec::with_capacity(spec.signature.arguments.len());
let mut non_python_args = 0; let mut non_python_args = 0;
@ -1309,7 +1334,7 @@ fn extract_proto_arguments(
let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site());
let conversions = proto_args.get(non_python_args) let conversions = proto_args.get(non_python_args)
.ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))? .ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))?
.extract(&ident, arg, extract_error_mode); .extract(&ident, arg, extract_error_mode, holders);
non_python_args += 1; non_python_args += 1;
args.push(conversions); args.push(conversions);
} }