diff --git a/newsfragments/4015.fixed.md b/newsfragments/4015.fixed.md new file mode 100644 index 00000000..a8f4f636 --- /dev/null +++ b/newsfragments/4015.fixed.md @@ -0,0 +1 @@ +Fix the bug that an async `#[pymethod]` with receiver can't have any other args. \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 31dfb075..f4fdb193 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use proc_macro2::{Span, TokenStream}; -use quote::{quote, quote_spanned, ToTokens}; +use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::{ext::IdentExt, spanned::Spanned, Ident, Result}; use crate::utils::Ctx; @@ -518,17 +518,33 @@ impl<'a> FnSpec<'a> { Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)), None => quote!(None), }; + let evaluate_args = || -> (Vec, TokenStream) { + let mut arg_names = Vec::with_capacity(args.len()); + let mut evaluate_arg = quote! {}; + for arg in &args { + let arg_name = format_ident!("arg_{}", arg_names.len()); + arg_names.push(arg_name.clone()); + evaluate_arg.extend(quote! { + let #arg_name = #arg + }); + } + (arg_names, evaluate_arg) + }; let future = match self.tp { FnType::Fn(SelfType::Receiver { mutable: false, .. }) => { + let (arg_name, evaluate_arg) = evaluate_args(); quote! {{ + #evaluate_arg; let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?; - async move { function(&__guard, #(#args),*).await } + async move { function(&__guard, #(#arg_name),*).await } }} } FnType::Fn(SelfType::Receiver { mutable: true, .. }) => { + let (arg_name, evaluate_arg) = evaluate_args(); quote! {{ + #evaluate_arg; let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?; - async move { function(&mut __guard, #(#args),*).await } + async move { function(&mut __guard, #(#arg_name),*).await } }} } _ => { diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 23f6a672..0e698deb 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -245,6 +245,39 @@ fn coroutine_panic() { }) } +#[test] +fn test_async_method_receiver_with_other_args() { + #[pyclass] + struct Value(i32); + #[pymethods] + impl Value { + #[new] + fn new() -> Self { + Self(0) + } + async fn get_value_plus_with(&self, v: i32) -> i32 { + self.0 + v + } + async fn set_value(&mut self, new_value: i32) -> i32 { + self.0 = new_value; + self.0 + } + } + + Python::with_gil(|gil| { + let test = r#" + import asyncio + + v = Value() + assert asyncio.run(v.get_value_plus_with(3)) == 3 + assert asyncio.run(v.set_value(10)) == 10 + assert asyncio.run(v.get_value_plus_with(1)) == 11 + "#; + let locals = [("Value", gil.get_type_bound::())].into_py_dict_bound(gil); + py_run!(gil, *locals, test); + }); +} + #[test] fn test_async_method_receiver() { #[pyclass]