Also apply holder lifetime extension to slot implementations.

This commit is contained in:
Adam Reichold 2023-05-06 13:19:52 +02:00
parent 27019b5523
commit f03ccf204c

View file

@ -913,11 +913,13 @@ impl Ty {
ident: &syn::Ident,
arg: &FnArg<'_>,
extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> TokenStream {
let name_str = arg.name.unraw().to_string();
match self {
Ty::Object => extract_object(
extract_error_mode,
holders,
&name_str,
quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>(#ident)
@ -925,6 +927,7 @@ impl Ty {
),
Ty::MaybeNullObject => extract_object(
extract_error_mode,
holders,
&name_str,
quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>(
@ -938,6 +941,7 @@ impl Ty {
),
Ty::NonNullObject => extract_object(
extract_error_mode,
holders,
&name_str,
quote! {
py.from_borrowed_ptr::<_pyo3::PyAny>(#ident.as_ptr())
@ -945,6 +949,7 @@ impl Ty {
),
Ty::IPowModulo => extract_object(
extract_error_mode,
holders,
&name_str,
quote! {
#ident.to_borrowed_any(py)
@ -972,13 +977,19 @@ impl Ty {
fn extract_object(
extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
name: &str,
source: TokenStream,
) -> TokenStream {
let holder = syn::Ident::new(&format!("holder_{}", holders.len()), Span::call_site());
holders.push(quote! {
#[allow(clippy::let_unit_value)]
let mut #holder = _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT;
});
extract_error_mode.handle_error(quote! {
_pyo3::impl_::extract_argument::extract_argument(
#source,
&mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT },
&mut #holder,
#name
)
})
@ -1088,11 +1099,13 @@ impl SlotDef {
.collect();
let wrapper_ident = format_ident!("__pymethod_{}__", method_name);
let ret_ty = ret_ty.ffi_type();
let mut holders = Vec::new();
let body = generate_method_body(
cls,
spec,
arguments,
*extract_error_mode,
&mut holders,
return_mode.as_ref(),
)?;
let name = spec.name;
@ -1104,6 +1117,7 @@ impl SlotDef {
) -> _pyo3::PyResult<#ret_ty> {
let function = #cls::#name; // Shadow the method name to avoid #3017
let _slf = _raw_slf;
#( #holders )*
#body
}
};
@ -1137,11 +1151,12 @@ fn generate_method_body(
spec: &FnSpec<'_>,
arguments: &[Ty],
extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
return_mode: Option<&ReturnMode>,
) -> Result<TokenStream> {
let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode);
let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode)?;
let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?;
let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) };
Ok(if let Some(return_mode) = return_mode {
return_mode.return_call_output(call)
@ -1191,7 +1206,15 @@ impl SlotFragmentDef {
let arg_idents: &Vec<_> = &(0..arguments.len())
.map(|i| format_ident!("arg{}", i))
.collect();
let body = generate_method_body(cls, spec, arguments, *extract_error_mode, None)?;
let mut holders = Vec::new();
let body = generate_method_body(
cls,
spec,
arguments,
*extract_error_mode,
&mut holders,
None,
)?;
let ret_ty = ret_ty.ffi_type();
Ok(quote! {
impl _pyo3::impl_::pyclass::#fragment_trait<#cls> for _pyo3::impl_::pyclass::PyClassImplCollector<#cls> {
@ -1210,6 +1233,7 @@ impl SlotFragmentDef {
#(#arg_idents: #arg_types),*
) -> _pyo3::PyResult<#ret_ty> {
let _slf = _raw_slf;
#( #holders )*
#body
}
}
@ -1298,6 +1322,7 @@ fn extract_proto_arguments(
spec: &FnSpec<'_>,
proto_args: &[Ty],
extract_error_mode: ExtractErrorMode,
holders: &mut Vec<TokenStream>,
) -> Result<Vec<TokenStream>> {
let mut args = Vec::with_capacity(spec.signature.arguments.len());
let mut non_python_args = 0;
@ -1309,7 +1334,7 @@ fn extract_proto_arguments(
let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site());
let conversions = proto_args.get(non_python_args)
.ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))?
.extract(&ident, arg, extract_error_mode);
.extract(&ident, arg, extract_error_mode, holders);
non_python_args += 1;
args.push(conversions);
}