Merge pull request #3599 from wyfo/coroutine_cancel

feat: add `coroutine::CancelHandle`
This commit is contained in:
David Hewitt 2023-12-04 12:44:36 +00:00 committed by GitHub
commit e8f852bce7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 381 additions and 21 deletions

View file

@ -69,10 +69,27 @@ where
## Cancellation
*To be implemented*
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].
```rust
# #![allow(dead_code)]
use futures::FutureExt;
use pyo3::prelude::*;
use pyo3::coroutine::CancelHandle;
#[pyfunction]
async fn cancellable(#[pyo3(cancel_handle)] mut cancel: CancelHandle) {
futures::select! {
/* _ = ... => println!("done"), */
_ = cancel.cancelled().fuse() => println!("cancelled"),
}
}
```
## The `Coroutine` type
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).
Each `coroutine.send` call is translated to a `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;
*The type does not yet have a public constructor until the design is finalized.*

View file

@ -0,0 +1 @@
Add `coroutine::CancelHandle` to catch coroutine cancellation

View file

@ -11,6 +11,7 @@ use syn::{
pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);

View file

@ -24,6 +24,7 @@ pub struct FnArg<'a> {
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
}
impl<'a> FnArg<'a> {
@ -44,6 +45,8 @@ impl<'a> FnArg<'a> {
other => return Err(handle_argument_error(other)),
};
let is_cancel_handle = arg_attrs.cancel_handle.is_some();
Ok(FnArg {
name: ident,
ty: &cap.ty,
@ -53,6 +56,7 @@ impl<'a> FnArg<'a> {
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle,
})
}
}
@ -455,9 +459,27 @@ impl<'a> FnSpec<'a> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;
let mut cancel_handle_iter = self
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
}
}
let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
quote! { None }
};
let python_name = &self.python_name;
let qualname_prefix = match cls {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
@ -468,9 +490,17 @@ impl<'a> FnSpec<'a> {
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
#throw_callback,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
if cancel_handle.is_some() {
call = quote! {{
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
let __throw_callback = __cancel_handle.throw_callback();
#call
}};
}
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
@ -483,12 +513,21 @@ impl<'a> FnSpec<'a> {
Ok(match self.convention {
CallingConvention::Noargs => {
let call = if !self.signature.arguments.is_empty() {
// Only `py` arg can be here
rust_call(vec![quote!(py)])
} else {
rust_call(vec![])
};
let args = self
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
})
.collect();
let call = rust_call(args);
quote! {
unsafe fn #ident<'py>(

View file

@ -155,6 +155,10 @@ fn impl_arg_param(
return Ok(quote! { py });
}
if arg.is_cancel_handle {
return Ok(quote! { __cancel_handle });
}
let name = arg.name;
let name_str = name.to_string();

View file

@ -23,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
pub from_py_with: Option<FromPyWithAttribute>,
pub cancel_handle: Option<attributes::kw::cancel_handle>,
}
enum PyFunctionArgPyO3Attribute {
FromPyWith(FromPyWithAttribute),
CancelHandle(attributes::kw::cancel_handle),
}
impl Parse for PyFunctionArgPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::from_py_with) {
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
} else {
Err(lookahead.error())
@ -43,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
impl PyFunctionArgPyO3Attributes {
/// Parses #[pyo3(from_python_with = "func")]
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
let mut attributes = PyFunctionArgPyO3Attributes {
from_py_with: None,
cancel_handle: None,
};
take_attributes(attrs, |attr| {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for attr in pyo3_attrs {
@ -55,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
);
attributes.from_py_with = Some(from_py_with);
}
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
ensure_spanned!(
attributes.cancel_handle.is_none(),
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
);
attributes.cancel_handle = Some(cancel_handle);
}
}
ensure_spanned!(
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
);
}
Ok(true)
} else {

View file

@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
// Otherwise try next argument.
continue;
}
if fn_arg.is_cancel_handle {
// If the user incorrectly tried to include cancel: CoroutineCancel in the
// signature, give a useful error as a hint.
ensure_spanned!(
name != fn_arg.name,
name.span() => "`cancel_handle` argument must not be part of the signature"
);
// Otherwise try next argument.
continue;
}
ensure_spanned!(
name == fn_arg.name,
@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
}
// Ensure no non-py arguments remain
if let Some(arg) = args_iter.find(|arg| !arg.py) {
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
bail_spanned!(
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
);
@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
let mut python_signature = PythonSignature::default();
for arg in &arguments {
// Python<'_> arguments don't show in Python signature
if arg.py {
if arg.py || arg.is_cancel_handle {
continue;
}

View file

@ -21,8 +21,12 @@ use crate::{
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};
pub(crate) mod cancel;
mod waker;
use crate::coroutine::cancel::ThrowCallback;
pub use cancel::CancelHandle;
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
@ -46,6 +51,7 @@ impl Coroutine {
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Self
where
@ -61,6 +67,7 @@ impl Coroutine {
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
@ -77,9 +84,13 @@ impl Coroutine {
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
// reraise thrown exception it
if let Some(exc) = throw {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
}
(None, _) => {}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {

78
src/coroutine/cancel.rs Normal file
View file

@ -0,0 +1,78 @@
use crate::{PyAny, PyObject};
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
#[derive(Debug, Default)]
struct Inner {
exception: Option<PyObject>,
waker: Option<Waker>,
}
/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
///
/// Only the last exception thrown can be retrieved.
#[derive(Debug, Default)]
pub struct CancelHandle(Arc<Mutex<Inner>>);
impl CancelHandle {
/// Create a new `CoroutineCancel`.
pub fn new() -> Self {
Default::default()
}
/// Returns whether the associated coroutine has been cancelled.
pub fn is_cancelled(&self) -> bool {
self.0.lock().exception.is_some()
}
/// Poll to retrieve the exception thrown in the associated coroutine.
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
let mut inner = self.0.lock();
if let Some(exc) = inner.exception.take() {
return Poll::Ready(exc);
}
if let Some(ref waker) = inner.waker {
if cx.waker().will_wake(waker) {
return Poll::Pending;
}
}
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
/// Retrieve the exception thrown in the associated coroutine.
pub async fn cancelled(&mut self) -> PyObject {
Cancelled(self).await
}
#[doc(hidden)]
pub fn throw_callback(&self) -> ThrowCallback {
ThrowCallback(self.0.clone())
}
}
// Because `poll_fn` is not available in MSRV
struct Cancelled<'a>(&'a mut CancelHandle);
impl Future for Cancelled<'_> {
type Output = PyObject;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_cancelled(cx)
}
}
#[doc(hidden)]
pub struct ThrowCallback(Arc<Mutex<Inner>>);
impl ThrowCallback {
pub(super) fn throw(&self, exc: &PyAny) {
let mut inner = self.0.lock();
inner.exception = Some(exc.into());
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}

View file

@ -1,10 +1,12 @@
use std::future::Future;
use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Coroutine
where
@ -12,5 +14,5 @@ where
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(Some(name.into()), qualname_prefix, future)
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}

View file

@ -1038,7 +1038,7 @@ impl<T> Py<T> {
/// # Safety
/// `ptr` must point to a Python object of type T.
#[inline]
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
pub(crate) unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
Self(ptr, PhantomData)
}

View file

@ -3,7 +3,8 @@
use std::ops::Deref;
use std::{task::Poll, thread, time::Duration};
use futures::{channel::oneshot, future::poll_fn};
use futures::{channel::oneshot, future::poll_fn, FutureExt};
use pyo3::coroutine::CancelHandle;
use pyo3::types::{IntoPyDict, PyType};
use pyo3::{prelude::*, py_run};
@ -136,3 +137,69 @@ fn cancelled_coroutine() {
assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError");
})
}
#[test]
fn coroutine_cancel_handle() {
#[pyfunction]
async fn cancellable_sleep(
seconds: f64,
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
) -> usize {
futures::select! {
_ = sleep(seconds).fuse() => 42,
_ = cancel.cancelled().fuse() => 0,
}
}
Python::with_gil(|gil| {
let cancellable_sleep = wrap_pyfunction!(cancellable_sleep, gil).unwrap();
let test = r#"
import asyncio;
async def main():
task = asyncio.create_task(cancellable_sleep(1))
await asyncio.sleep(0)
task.cancel()
return await task
assert asyncio.run(main()) == 0
"#;
let globals = gil.import("__main__").unwrap().dict();
globals
.set_item("cancellable_sleep", cancellable_sleep)
.unwrap();
gil.run(
&pyo3::unindent::unindent(&handle_windows(test)),
Some(globals),
None,
)
.unwrap();
})
}
#[test]
fn coroutine_is_cancelled() {
#[pyfunction]
async fn sleep_loop(#[pyo3(cancel_handle)] cancel: CancelHandle) {
while !cancel.is_cancelled() {
sleep(0.001).await;
}
}
Python::with_gil(|gil| {
let sleep_loop = wrap_pyfunction!(sleep_loop, gil).unwrap();
let test = r#"
import asyncio;
async def main():
task = asyncio.create_task(sleep_loop())
await asyncio.sleep(0)
task.cancel()
await task
asyncio.run(main())
"#;
let globals = gil.import("__main__").unwrap().dict();
globals.set_item("sleep_loop", sleep_loop).unwrap();
gil.run(
&pyo3::unindent::unindent(&handle_windows(test)),
Some(globals),
None,
)
.unwrap();
})
}

View file

@ -12,4 +12,32 @@ fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
#[pyfunction]
fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
#[pyfunction]
fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}
#[pyfunction]
async fn from_py_with_value_and_cancel_handle(
#[pyo3(from_py_with = "func", cancel_handle)] _param: String,
) {
}
#[pyfunction]
async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}
#[pyfunction]
async fn cancel_handle_repeated2(
#[pyo3(cancel_handle)] _param: String,
#[pyo3(cancel_handle)] _param2: String,
) {
}
#[pyfunction]
fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}
#[pyfunction]
async fn cancel_handle_wrong_type(#[pyo3(cancel_handle)] _param: String) {}
#[pyfunction]
async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
fn main() {}

View file

@ -1,4 +1,4 @@
error: expected `from_py_with`
error: expected `cancel_handle` or `from_py_with`
--> tests/ui/invalid_argument_attributes.rs:4:29
|
4 | fn invalid_attribute(#[pyo3(get)] param: String) {}
@ -10,7 +10,7 @@ error: expected `=`
7 | fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {}
| ^
error: expected `from_py_with`
error: expected `cancel_handle` or `from_py_with`
--> tests/ui/invalid_argument_attributes.rs:10:31
|
10 | fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
@ -21,3 +21,87 @@ error: expected string literal
|
13 | fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
| ^^^^
error: `from_py_with` may only be specified once per argument
--> tests/ui/invalid_argument_attributes.rs:16:56
|
16 | fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}
| ^^^^^^^^^^^^
error: `from_py_with` and `cancel_handle` cannot be specified together
--> tests/ui/invalid_argument_attributes.rs:20:35
|
20 | #[pyo3(from_py_with = "func", cancel_handle)] _param: String,
| ^^^^^^^^^^^^^
error: `cancel_handle` may only be specified once per argument
--> tests/ui/invalid_argument_attributes.rs:25:55
|
25 | async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}
| ^^^^^^^^^^^^^
error: `cancel_handle` may only be specified once
--> tests/ui/invalid_argument_attributes.rs:30:28
|
30 | #[pyo3(cancel_handle)] _param2: String,
| ^^^^^^^
error: `cancel_handle` attribute can only be used with `async fn`
--> tests/ui/invalid_argument_attributes.rs:35:53
|
35 | fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}
| ^^^^^^
error[E0308]: mismatched types
--> tests/ui/invalid_argument_attributes.rs:37:1
|
37 | #[pyfunction]
| ^^^^^^^^^^^^^
| |
| expected `String`, found `CancelHandle`
| arguments to this function are incorrect
|
note: function defined here
--> tests/ui/invalid_argument_attributes.rs:38:10
|
38 | async fn cancel_handle_wrong_type(#[pyo3(cancel_handle)] _param: String) {}
| ^^^^^^^^^^^^^^^^^^^^^^^^ --------------
= note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0277]: the trait bound `CancelHandle: PyClass` is not satisfied
--> tests/ui/invalid_argument_attributes.rs:41:50
|
41 | async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
| ^^^^ the trait `PyClass` is not implemented for `CancelHandle`
|
= help: the trait `PyClass` is implemented for `Coroutine`
= note: required for `CancelHandle` to implement `FromPyObject<'_>`
= note: required for `CancelHandle` to implement `PyFunctionArgument<'_, '_>`
note: required by a bound in `extract_argument`
--> src/impl_/extract_argument.rs
|
| pub fn extract_argument<'a, 'py, T>(
| ---------------- required by a bound in this function
...
| T: PyFunctionArgument<'a, 'py>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `extract_argument`
error[E0277]: the trait bound `CancelHandle: Clone` is not satisfied
--> tests/ui/invalid_argument_attributes.rs:41:50
|
41 | async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
| ^^^^ the trait `Clone` is not implemented for `CancelHandle`
|
= help: the following other types implement trait `PyFunctionArgument<'a, 'py>`:
&'a Coroutine
&'a mut Coroutine
= note: required for `CancelHandle` to implement `FromPyObject<'_>`
= note: required for `CancelHandle` to implement `PyFunctionArgument<'_, '_>`
note: required by a bound in `extract_argument`
--> src/impl_/extract_argument.rs
|
| pub fn extract_argument<'a, 'py, T>(
| ---------------- required by a bound in this function
...
| T: PyFunctionArgument<'a, 'py>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `extract_argument`