Merge pull request #3609 from wyfo/async_receiver

feat: allow async methods to accept `&self`/`&mut self`
This commit is contained in:
David Hewitt 2023-12-07 07:38:25 +00:00 committed by GitHub
commit 07726aefc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 167 additions and 22 deletions

View File

@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '
As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.
It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*
However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`
## Implicit GIL holding

View File

@ -0,0 +1 @@
Allow async methods to accept `&self`/`&mut self`

View File

@ -1,18 +1,19 @@
use std::fmt::Display;
use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::{Ident, Result};
use quote::{quote, quote_spanned, ToTokens};
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
use crate::{
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::impl_arg_params,
pyfunction::{
FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute,
},
quotes,
utils::{self, PythonDoc},
};
#[derive(Clone, Debug)]
pub struct FnArg<'a> {
@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> {
}
let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
None => quote!(None),
};
call = quote! {{
let future = #call;
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&__guard, #(#args),*).await }
}},
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}},
_ => quote! { function(#self_arg #(#args),*) },
};
let mut call = quote! {{
let future = #future;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> {
#call
}};
}
}
call
} else {
quote! { function(#self_arg #(#args),*) }
};
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};

View File

@ -1,7 +1,15 @@
use std::future::Future;
use std::{
future::Future,
mem,
ops::{Deref, DerefMut},
};
use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
use crate::{
coroutine::{cancel::ThrowCallback, Coroutine},
pyclass::boolean_struct::False,
types::PyString,
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, Python,
};
pub fn new_coroutine<F, T, E>(
name: &PyString,
@ -16,3 +24,63 @@ where
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
// SAFETY: Py<T> can be casted as *const PyCell<T>
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
}
pub struct RefGuard<T: PyClass>(Py<T>);
impl<T: PyClass> RefGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow(obj.py())?);
Ok(RefGuard(owned))
}
}
impl<T: PyClass> Deref for RefGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}
impl<T: PyClass> Drop for RefGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
}
}
pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);
impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow_mut(obj.py())?);
Ok(RefMutGuard(owned))
}
}
impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}
impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &mut *get_ptr(&self.0) }
}
}
impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
}
}

View File

@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
#[allow(clippy::useless_conversion)]
offset.try_into().expect("offset should fit in Py_ssize_t")
}
#[cfg(feature = "macros")]
pub(crate) fn release_ref(&self) {
self.borrow_checker().release_borrow();
}
#[cfg(feature = "macros")]
pub(crate) fn release_mut(&self) {
self.borrow_checker().release_borrow_mut();
}
}
impl<T: PyClassImpl> PyCell<T> {

View File

@ -234,3 +234,56 @@ fn coroutine_panic() {
py_run!(gil, panic, &handle_windows(test));
})
}
#[test]
fn test_async_method_receiver() {
#[pyclass]
struct Counter(usize);
#[pymethods]
impl Counter {
#[new]
fn new() -> Self {
Self(0)
}
async fn get(&self) -> usize {
self.0
}
async fn incr(&mut self) -> usize {
self.0 += 1;
self.0
}
}
Python::with_gil(|gil| {
let test = r#"
import asyncio
obj = Counter()
coro1 = obj.get()
coro2 = obj.get()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro1) == 0
coro2.close()
coro3 = obj.incr()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
try:
obj.get() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro3) == 1
"#;
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
py_run!(gil, *locals, test);
})
}