From e4cc98607e6bfce235b496982afb3155441c38ec Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 2 Apr 2024 18:43:51 +0100 Subject: [PATCH] fix compile error for multiple async method arguments (#4035) --- newsfragments/4035.fixed.md | 1 + pyo3-macros-backend/src/method.rs | 25 ++++-------- tests/test_coroutine.rs | 66 +++++++++++++++---------------- 3 files changed, 41 insertions(+), 51 deletions(-) create mode 100644 newsfragments/4035.fixed.md diff --git a/newsfragments/4035.fixed.md b/newsfragments/4035.fixed.md new file mode 100644 index 00000000..5425c5cb --- /dev/null +++ b/newsfragments/4035.fixed.md @@ -0,0 +1 @@ +Fix compile error for `async fn` in `#[pymethods]` with a `&self` receiver and more than one additional argument. diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 155c5540..6af0ec97 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -522,33 +522,22 @@ impl<'a> FnSpec<'a> { Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)), None => quote!(None), }; - let evaluate_args = || -> (Vec, TokenStream) { - let mut arg_names = Vec::with_capacity(args.len()); - let mut evaluate_arg = quote! {}; - for arg in &args { - let arg_name = format_ident!("arg_{}", arg_names.len()); - arg_names.push(arg_name.clone()); - evaluate_arg.extend(quote! { - let #arg_name = #arg - }); - } - (arg_names, evaluate_arg) - }; + let arg_names = (0..args.len()) + .map(|i| format_ident!("arg_{}", i)) + .collect::>(); let future = match self.tp { FnType::Fn(SelfType::Receiver { mutable: false, .. }) => { - let (arg_name, evaluate_arg) = evaluate_args(); quote! {{ - #evaluate_arg; + #(let #arg_names = #args;)* let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?; - async move { function(&__guard, #(#arg_name),*).await } + async move { function(&__guard, #(#arg_names),*).await } }} } FnType::Fn(SelfType::Receiver { mutable: true, .. }) => { - let (arg_name, evaluate_arg) = evaluate_args(); quote! {{ - #evaluate_arg; + #(let #arg_names = #args;)* let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?; - async move { function(&mut __guard, #(#arg_name),*).await } + async move { function(&mut __guard, #(#arg_names),*).await } }} } _ => { diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 0e698deb..4abba9f3 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -245,39 +245,6 @@ fn coroutine_panic() { }) } -#[test] -fn test_async_method_receiver_with_other_args() { - #[pyclass] - struct Value(i32); - #[pymethods] - impl Value { - #[new] - fn new() -> Self { - Self(0) - } - async fn get_value_plus_with(&self, v: i32) -> i32 { - self.0 + v - } - async fn set_value(&mut self, new_value: i32) -> i32 { - self.0 = new_value; - self.0 - } - } - - Python::with_gil(|gil| { - let test = r#" - import asyncio - - v = Value() - assert asyncio.run(v.get_value_plus_with(3)) == 3 - assert asyncio.run(v.set_value(10)) == 10 - assert asyncio.run(v.get_value_plus_with(1)) == 11 - "#; - let locals = [("Value", gil.get_type_bound::())].into_py_dict_bound(gil); - py_run!(gil, *locals, test); - }); -} - #[test] fn test_async_method_receiver() { #[pyclass] @@ -341,3 +308,36 @@ fn test_async_method_receiver() { assert!(IS_DROPPED.load(Ordering::SeqCst)); } + +#[test] +fn test_async_method_receiver_with_other_args() { + #[pyclass] + struct Value(i32); + #[pymethods] + impl Value { + #[new] + fn new() -> Self { + Self(0) + } + async fn get_value_plus_with(&self, v1: i32, v2: i32) -> i32 { + self.0 + v1 + v2 + } + async fn set_value(&mut self, new_value: i32) -> i32 { + self.0 = new_value; + self.0 + } + } + + Python::with_gil(|gil| { + let test = r#" + import asyncio + + v = Value() + assert asyncio.run(v.get_value_plus_with(3, 0)) == 3 + assert asyncio.run(v.set_value(10)) == 10 + assert asyncio.run(v.get_value_plus_with(1, 1)) == 12 + "#; + let locals = [("Value", gil.get_type_bound::())].into_py_dict_bound(gil); + py_run!(gil, *locals, test); + }); +}