async method should allow args not only receiver (#4015)
* async method should allow args not only receiver * add changelog md
This commit is contained in:
parent
4d033c4497
commit
74d9d23ba0
|
@ -0,0 +1 @@
|
|||
Fix the bug that an async `#[pymethod]` with receiver can't have any other args.
|
|
@ -1,7 +1,7 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::{quote, quote_spanned, ToTokens};
|
||||
use quote::{format_ident, quote, quote_spanned, ToTokens};
|
||||
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
|
||||
|
||||
use crate::utils::Ctx;
|
||||
|
@ -518,17 +518,33 @@ impl<'a> FnSpec<'a> {
|
|||
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
|
||||
None => quote!(None),
|
||||
};
|
||||
let evaluate_args = || -> (Vec<Ident>, TokenStream) {
|
||||
let mut arg_names = Vec::with_capacity(args.len());
|
||||
let mut evaluate_arg = quote! {};
|
||||
for arg in &args {
|
||||
let arg_name = format_ident!("arg_{}", arg_names.len());
|
||||
arg_names.push(arg_name.clone());
|
||||
evaluate_arg.extend(quote! {
|
||||
let #arg_name = #arg
|
||||
});
|
||||
}
|
||||
(arg_names, evaluate_arg)
|
||||
};
|
||||
let future = match self.tp {
|
||||
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
|
||||
let (arg_name, evaluate_arg) = evaluate_args();
|
||||
quote! {{
|
||||
#evaluate_arg;
|
||||
let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
|
||||
async move { function(&__guard, #(#args),*).await }
|
||||
async move { function(&__guard, #(#arg_name),*).await }
|
||||
}}
|
||||
}
|
||||
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => {
|
||||
let (arg_name, evaluate_arg) = evaluate_args();
|
||||
quote! {{
|
||||
#evaluate_arg;
|
||||
let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
|
||||
async move { function(&mut __guard, #(#args),*).await }
|
||||
async move { function(&mut __guard, #(#arg_name),*).await }
|
||||
}}
|
||||
}
|
||||
_ => {
|
||||
|
|
|
@ -245,6 +245,39 @@ fn coroutine_panic() {
|
|||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_async_method_receiver_with_other_args() {
|
||||
#[pyclass]
|
||||
struct Value(i32);
|
||||
#[pymethods]
|
||||
impl Value {
|
||||
#[new]
|
||||
fn new() -> Self {
|
||||
Self(0)
|
||||
}
|
||||
async fn get_value_plus_with(&self, v: i32) -> i32 {
|
||||
self.0 + v
|
||||
}
|
||||
async fn set_value(&mut self, new_value: i32) -> i32 {
|
||||
self.0 = new_value;
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
Python::with_gil(|gil| {
|
||||
let test = r#"
|
||||
import asyncio
|
||||
|
||||
v = Value()
|
||||
assert asyncio.run(v.get_value_plus_with(3)) == 3
|
||||
assert asyncio.run(v.set_value(10)) == 10
|
||||
assert asyncio.run(v.get_value_plus_with(1)) == 11
|
||||
"#;
|
||||
let locals = [("Value", gil.get_type_bound::<Value>())].into_py_dict_bound(gil);
|
||||
py_run!(gil, *locals, test);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_async_method_receiver() {
|
||||
#[pyclass]
|
||||
|
|
Loading…
Reference in New Issue