feat: add `coroutine::CancelHandle`
This commit is contained in:
parent
81ad2e8bab
commit
8a674c2bd3
|
@ -69,10 +69,27 @@ where
|
||||||
|
|
||||||
## Cancellation
|
## Cancellation
|
||||||
|
|
||||||
*To be implemented*
|
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# #![allow(dead_code)]
|
||||||
|
use futures::FutureExt;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::coroutine::CancelHandle;
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn cancellable(#[pyo3(cancel_handle)] mut cancel: CancelHandle) {
|
||||||
|
futures::select! {
|
||||||
|
/* _ = ... => println!("done"), */
|
||||||
|
_ = cancel.cancelled().fuse() => println!("cancelled"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## The `Coroutine` type
|
## The `Coroutine` type
|
||||||
|
|
||||||
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
|
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).
|
||||||
|
|
||||||
|
Each `coroutine.send` call is translated to a `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;
|
||||||
|
|
||||||
*The type does not yet have a public constructor until the design is finalized.*
|
*The type does not yet have a public constructor until the design is finalized.*
|
|
@ -0,0 +1 @@
|
||||||
|
Add `coroutine::CancelHandle` to catch coroutine cancellation
|
|
@ -11,6 +11,7 @@ use syn::{
|
||||||
pub mod kw {
|
pub mod kw {
|
||||||
syn::custom_keyword!(annotation);
|
syn::custom_keyword!(annotation);
|
||||||
syn::custom_keyword!(attribute);
|
syn::custom_keyword!(attribute);
|
||||||
|
syn::custom_keyword!(cancel_handle);
|
||||||
syn::custom_keyword!(dict);
|
syn::custom_keyword!(dict);
|
||||||
syn::custom_keyword!(extends);
|
syn::custom_keyword!(extends);
|
||||||
syn::custom_keyword!(freelist);
|
syn::custom_keyword!(freelist);
|
||||||
|
|
|
@ -24,6 +24,7 @@ pub struct FnArg<'a> {
|
||||||
pub attrs: PyFunctionArgPyO3Attributes,
|
pub attrs: PyFunctionArgPyO3Attributes,
|
||||||
pub is_varargs: bool,
|
pub is_varargs: bool,
|
||||||
pub is_kwargs: bool,
|
pub is_kwargs: bool,
|
||||||
|
pub is_cancel_handle: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> FnArg<'a> {
|
impl<'a> FnArg<'a> {
|
||||||
|
@ -44,6 +45,8 @@ impl<'a> FnArg<'a> {
|
||||||
other => return Err(handle_argument_error(other)),
|
other => return Err(handle_argument_error(other)),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let is_cancel_handle = arg_attrs.cancel_handle.is_some();
|
||||||
|
|
||||||
Ok(FnArg {
|
Ok(FnArg {
|
||||||
name: ident,
|
name: ident,
|
||||||
ty: &cap.ty,
|
ty: &cap.ty,
|
||||||
|
@ -53,6 +56,7 @@ impl<'a> FnArg<'a> {
|
||||||
attrs: arg_attrs,
|
attrs: arg_attrs,
|
||||||
is_varargs: false,
|
is_varargs: false,
|
||||||
is_kwargs: false,
|
is_kwargs: false,
|
||||||
|
is_cancel_handle,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -455,9 +459,27 @@ impl<'a> FnSpec<'a> {
|
||||||
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
|
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
|
||||||
let func_name = &self.name;
|
let func_name = &self.name;
|
||||||
|
|
||||||
|
let mut cancel_handle_iter = self
|
||||||
|
.signature
|
||||||
|
.arguments
|
||||||
|
.iter()
|
||||||
|
.filter(|arg| arg.is_cancel_handle);
|
||||||
|
let cancel_handle = cancel_handle_iter.next();
|
||||||
|
if let Some(arg) = cancel_handle {
|
||||||
|
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
|
||||||
|
if let Some(arg2) = cancel_handle_iter.next() {
|
||||||
|
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let rust_call = |args: Vec<TokenStream>| {
|
let rust_call = |args: Vec<TokenStream>| {
|
||||||
let mut call = quote! { function(#self_arg #(#args),*) };
|
let mut call = quote! { function(#self_arg #(#args),*) };
|
||||||
if self.asyncness.is_some() {
|
if self.asyncness.is_some() {
|
||||||
|
let throw_callback = if cancel_handle.is_some() {
|
||||||
|
quote! { Some(__throw_callback) }
|
||||||
|
} else {
|
||||||
|
quote! { None }
|
||||||
|
};
|
||||||
let python_name = &self.python_name;
|
let python_name = &self.python_name;
|
||||||
let qualname_prefix = match cls {
|
let qualname_prefix = match cls {
|
||||||
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
|
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
|
||||||
|
@ -468,9 +490,17 @@ impl<'a> FnSpec<'a> {
|
||||||
_pyo3::impl_::coroutine::new_coroutine(
|
_pyo3::impl_::coroutine::new_coroutine(
|
||||||
_pyo3::intern!(py, stringify!(#python_name)),
|
_pyo3::intern!(py, stringify!(#python_name)),
|
||||||
#qualname_prefix,
|
#qualname_prefix,
|
||||||
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
|
#throw_callback,
|
||||||
|
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
|
||||||
)
|
)
|
||||||
}};
|
}};
|
||||||
|
if cancel_handle.is_some() {
|
||||||
|
call = quote! {{
|
||||||
|
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
|
||||||
|
let __throw_callback = __cancel_handle.throw_callback();
|
||||||
|
#call
|
||||||
|
}};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
quotes::map_result_into_ptr(quotes::ok_wrap(call))
|
quotes::map_result_into_ptr(quotes::ok_wrap(call))
|
||||||
};
|
};
|
||||||
|
@ -483,12 +513,21 @@ impl<'a> FnSpec<'a> {
|
||||||
|
|
||||||
Ok(match self.convention {
|
Ok(match self.convention {
|
||||||
CallingConvention::Noargs => {
|
CallingConvention::Noargs => {
|
||||||
let call = if !self.signature.arguments.is_empty() {
|
let args = self
|
||||||
// Only `py` arg can be here
|
.signature
|
||||||
rust_call(vec![quote!(py)])
|
.arguments
|
||||||
} else {
|
.iter()
|
||||||
rust_call(vec![])
|
.map(|arg| {
|
||||||
};
|
if arg.py {
|
||||||
|
quote!(py)
|
||||||
|
} else if arg.is_cancel_handle {
|
||||||
|
quote!(__cancel_handle)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let call = rust_call(args);
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
unsafe fn #ident<'py>(
|
unsafe fn #ident<'py>(
|
||||||
|
|
|
@ -155,6 +155,10 @@ fn impl_arg_param(
|
||||||
return Ok(quote! { py });
|
return Ok(quote! { py });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if arg.is_cancel_handle {
|
||||||
|
return Ok(quote! { __cancel_handle });
|
||||||
|
}
|
||||||
|
|
||||||
let name = arg.name;
|
let name = arg.name;
|
||||||
let name_str = name.to_string();
|
let name_str = name.to_string();
|
||||||
|
|
||||||
|
|
|
@ -23,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct PyFunctionArgPyO3Attributes {
|
pub struct PyFunctionArgPyO3Attributes {
|
||||||
pub from_py_with: Option<FromPyWithAttribute>,
|
pub from_py_with: Option<FromPyWithAttribute>,
|
||||||
|
pub cancel_handle: Option<attributes::kw::cancel_handle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum PyFunctionArgPyO3Attribute {
|
enum PyFunctionArgPyO3Attribute {
|
||||||
FromPyWith(FromPyWithAttribute),
|
FromPyWith(FromPyWithAttribute),
|
||||||
|
CancelHandle(attributes::kw::cancel_handle),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parse for PyFunctionArgPyO3Attribute {
|
impl Parse for PyFunctionArgPyO3Attribute {
|
||||||
fn parse(input: ParseStream<'_>) -> Result<Self> {
|
fn parse(input: ParseStream<'_>) -> Result<Self> {
|
||||||
let lookahead = input.lookahead1();
|
let lookahead = input.lookahead1();
|
||||||
if lookahead.peek(attributes::kw::from_py_with) {
|
if lookahead.peek(attributes::kw::cancel_handle) {
|
||||||
|
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
|
||||||
|
} else if lookahead.peek(attributes::kw::from_py_with) {
|
||||||
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
|
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
|
||||||
} else {
|
} else {
|
||||||
Err(lookahead.error())
|
Err(lookahead.error())
|
||||||
|
@ -43,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
|
||||||
impl PyFunctionArgPyO3Attributes {
|
impl PyFunctionArgPyO3Attributes {
|
||||||
/// Parses #[pyo3(from_python_with = "func")]
|
/// Parses #[pyo3(from_python_with = "func")]
|
||||||
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
|
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
|
||||||
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
|
let mut attributes = PyFunctionArgPyO3Attributes {
|
||||||
|
from_py_with: None,
|
||||||
|
cancel_handle: None,
|
||||||
|
};
|
||||||
take_attributes(attrs, |attr| {
|
take_attributes(attrs, |attr| {
|
||||||
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
|
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
|
||||||
for attr in pyo3_attrs {
|
for attr in pyo3_attrs {
|
||||||
|
@ -55,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
|
||||||
);
|
);
|
||||||
attributes.from_py_with = Some(from_py_with);
|
attributes.from_py_with = Some(from_py_with);
|
||||||
}
|
}
|
||||||
|
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
|
||||||
|
ensure_spanned!(
|
||||||
|
attributes.cancel_handle.is_none(),
|
||||||
|
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
|
||||||
|
);
|
||||||
|
attributes.cancel_handle = Some(cancel_handle);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
ensure_spanned!(
|
||||||
|
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
|
||||||
|
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
|
||||||
// Otherwise try next argument.
|
// Otherwise try next argument.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if fn_arg.is_cancel_handle {
|
||||||
|
// If the user incorrectly tried to include cancel: CoroutineCancel in the
|
||||||
|
// signature, give a useful error as a hint.
|
||||||
|
ensure_spanned!(
|
||||||
|
name != fn_arg.name,
|
||||||
|
name.span() => "`cancel_handle` argument must not be part of the signature"
|
||||||
|
);
|
||||||
|
// Otherwise try next argument.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
ensure_spanned!(
|
ensure_spanned!(
|
||||||
name == fn_arg.name,
|
name == fn_arg.name,
|
||||||
|
@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure no non-py arguments remain
|
// Ensure no non-py arguments remain
|
||||||
if let Some(arg) = args_iter.find(|arg| !arg.py) {
|
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
|
||||||
bail_spanned!(
|
bail_spanned!(
|
||||||
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
|
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
|
||||||
);
|
);
|
||||||
|
@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
|
||||||
let mut python_signature = PythonSignature::default();
|
let mut python_signature = PythonSignature::default();
|
||||||
for arg in &arguments {
|
for arg in &arguments {
|
||||||
// Python<'_> arguments don't show in Python signature
|
// Python<'_> arguments don't show in Python signature
|
||||||
if arg.py {
|
if arg.py || arg.is_cancel_handle {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,12 @@ use crate::{
|
||||||
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
|
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) mod cancel;
|
||||||
mod waker;
|
mod waker;
|
||||||
|
|
||||||
|
use crate::coroutine::cancel::ThrowCallback;
|
||||||
|
pub use cancel::CancelHandle;
|
||||||
|
|
||||||
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
|
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
|
||||||
|
|
||||||
type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
|
type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
|
||||||
|
@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
|
||||||
pub struct Coroutine {
|
pub struct Coroutine {
|
||||||
name: Option<Py<PyString>>,
|
name: Option<Py<PyString>>,
|
||||||
qualname_prefix: Option<&'static str>,
|
qualname_prefix: Option<&'static str>,
|
||||||
|
throw_callback: Option<ThrowCallback>,
|
||||||
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
|
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
|
||||||
waker: Option<Arc<AsyncioWaker>>,
|
waker: Option<Arc<AsyncioWaker>>,
|
||||||
}
|
}
|
||||||
|
@ -46,6 +51,7 @@ impl Coroutine {
|
||||||
pub(crate) fn new<F, T, E>(
|
pub(crate) fn new<F, T, E>(
|
||||||
name: Option<Py<PyString>>,
|
name: Option<Py<PyString>>,
|
||||||
qualname_prefix: Option<&'static str>,
|
qualname_prefix: Option<&'static str>,
|
||||||
|
throw_callback: Option<ThrowCallback>,
|
||||||
future: F,
|
future: F,
|
||||||
) -> Self
|
) -> Self
|
||||||
where
|
where
|
||||||
|
@ -61,6 +67,7 @@ impl Coroutine {
|
||||||
Self {
|
Self {
|
||||||
name,
|
name,
|
||||||
qualname_prefix,
|
qualname_prefix,
|
||||||
|
throw_callback,
|
||||||
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
|
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
|
||||||
waker: None,
|
waker: None,
|
||||||
}
|
}
|
||||||
|
@ -77,9 +84,13 @@ impl Coroutine {
|
||||||
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
|
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
|
||||||
};
|
};
|
||||||
// reraise thrown exception it
|
// reraise thrown exception it
|
||||||
if let Some(exc) = throw {
|
match (throw, &self.throw_callback) {
|
||||||
self.close();
|
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
|
||||||
return Err(PyErr::from_value(exc.as_ref(py)));
|
(Some(exc), None) => {
|
||||||
|
self.close();
|
||||||
|
return Err(PyErr::from_value(exc.as_ref(py)));
|
||||||
|
}
|
||||||
|
(None, _) => {}
|
||||||
}
|
}
|
||||||
// create a new waker, or try to reset it in place
|
// create a new waker, or try to reset it in place
|
||||||
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
|
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
use crate::{PyAny, PyObject};
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll, Waker};
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct Inner {
|
||||||
|
exception: Option<PyObject>,
|
||||||
|
waker: Option<Waker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
|
||||||
|
///
|
||||||
|
/// Only the last exception thrown can be retrieved.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct CancelHandle(Arc<Mutex<Inner>>);
|
||||||
|
|
||||||
|
impl CancelHandle {
|
||||||
|
/// Create a new `CoroutineCancel`.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether the associated coroutine has been cancelled.
|
||||||
|
pub fn is_cancelled(&self) -> bool {
|
||||||
|
self.0.lock().exception.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Poll to retrieve the exception thrown in the associated coroutine.
|
||||||
|
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
|
||||||
|
let mut inner = self.0.lock();
|
||||||
|
if let Some(exc) = inner.exception.take() {
|
||||||
|
return Poll::Ready(exc);
|
||||||
|
}
|
||||||
|
if let Some(ref waker) = inner.waker {
|
||||||
|
if cx.waker().will_wake(waker) {
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inner.waker = Some(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve the exception thrown in the associated coroutine.
|
||||||
|
pub async fn cancelled(&mut self) -> PyObject {
|
||||||
|
Cancelled(self).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn throw_callback(&self) -> ThrowCallback {
|
||||||
|
ThrowCallback(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Because `poll_fn` is not available in MSRV
|
||||||
|
struct Cancelled<'a>(&'a mut CancelHandle);
|
||||||
|
|
||||||
|
impl Future for Cancelled<'_> {
|
||||||
|
type Output = PyObject;
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
self.0.poll_cancelled(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub struct ThrowCallback(Arc<Mutex<Inner>>);
|
||||||
|
|
||||||
|
impl ThrowCallback {
|
||||||
|
pub(super) fn throw(&self, exc: &PyAny) {
|
||||||
|
let mut inner = self.0.lock();
|
||||||
|
inner.exception = Some(exc.into());
|
||||||
|
if let Some(waker) = inner.waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,12 @@
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
|
||||||
|
use crate::coroutine::cancel::ThrowCallback;
|
||||||
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
|
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
|
||||||
|
|
||||||
pub fn new_coroutine<F, T, E>(
|
pub fn new_coroutine<F, T, E>(
|
||||||
name: &PyString,
|
name: &PyString,
|
||||||
qualname_prefix: Option<&'static str>,
|
qualname_prefix: Option<&'static str>,
|
||||||
|
throw_callback: Option<ThrowCallback>,
|
||||||
future: F,
|
future: F,
|
||||||
) -> Coroutine
|
) -> Coroutine
|
||||||
where
|
where
|
||||||
|
@ -12,5 +14,5 @@ where
|
||||||
T: IntoPy<PyObject>,
|
T: IntoPy<PyObject>,
|
||||||
E: Into<PyErr>,
|
E: Into<PyErr>,
|
||||||
{
|
{
|
||||||
Coroutine::new(Some(name.into()), qualname_prefix, future)
|
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1038,7 +1038,7 @@ impl<T> Py<T> {
|
||||||
/// # Safety
|
/// # Safety
|
||||||
/// `ptr` must point to a Python object of type T.
|
/// `ptr` must point to a Python object of type T.
|
||||||
#[inline]
|
#[inline]
|
||||||
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
|
pub(crate) unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
|
||||||
Self(ptr, PhantomData)
|
Self(ptr, PhantomData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::{task::Poll, thread, time::Duration};
|
use std::{task::Poll, thread, time::Duration};
|
||||||
|
|
||||||
use futures::{channel::oneshot, future::poll_fn};
|
use futures::{channel::oneshot, future::poll_fn, FutureExt};
|
||||||
|
use pyo3::coroutine::CancelHandle;
|
||||||
use pyo3::types::{IntoPyDict, PyType};
|
use pyo3::types::{IntoPyDict, PyType};
|
||||||
use pyo3::{prelude::*, py_run};
|
use pyo3::{prelude::*, py_run};
|
||||||
|
|
||||||
|
@ -136,3 +137,69 @@ fn cancelled_coroutine() {
|
||||||
assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError");
|
assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError");
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coroutine_cancel_handle() {
|
||||||
|
#[pyfunction]
|
||||||
|
async fn cancellable_sleep(
|
||||||
|
seconds: f64,
|
||||||
|
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
|
||||||
|
) -> usize {
|
||||||
|
futures::select! {
|
||||||
|
_ = sleep(seconds).fuse() => 42,
|
||||||
|
_ = cancel.cancelled().fuse() => 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Python::with_gil(|gil| {
|
||||||
|
let cancellable_sleep = wrap_pyfunction!(cancellable_sleep, gil).unwrap();
|
||||||
|
let test = r#"
|
||||||
|
import asyncio;
|
||||||
|
async def main():
|
||||||
|
task = asyncio.create_task(cancellable_sleep(1))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
return await task
|
||||||
|
assert asyncio.run(main()) == 0
|
||||||
|
"#;
|
||||||
|
let globals = gil.import("__main__").unwrap().dict();
|
||||||
|
globals
|
||||||
|
.set_item("cancellable_sleep", cancellable_sleep)
|
||||||
|
.unwrap();
|
||||||
|
gil.run(
|
||||||
|
&pyo3::unindent::unindent(&handle_windows(test)),
|
||||||
|
Some(globals),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn coroutine_is_cancelled() {
|
||||||
|
#[pyfunction]
|
||||||
|
async fn sleep_loop(#[pyo3(cancel_handle)] cancel: CancelHandle) {
|
||||||
|
while !cancel.is_cancelled() {
|
||||||
|
sleep(0.001).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Python::with_gil(|gil| {
|
||||||
|
let sleep_loop = wrap_pyfunction!(sleep_loop, gil).unwrap();
|
||||||
|
let test = r#"
|
||||||
|
import asyncio;
|
||||||
|
async def main():
|
||||||
|
task = asyncio.create_task(sleep_loop())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
await task
|
||||||
|
asyncio.run(main())
|
||||||
|
"#;
|
||||||
|
let globals = gil.import("__main__").unwrap().dict();
|
||||||
|
globals.set_item("sleep_loop", sleep_loop).unwrap();
|
||||||
|
gil.run(
|
||||||
|
&pyo3::unindent::unindent(&handle_windows(test)),
|
||||||
|
Some(globals),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -12,4 +12,32 @@ fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
|
fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn from_py_with_value_and_cancel_handle(
|
||||||
|
#[pyo3(from_py_with = "func", cancel_handle)] _param: String,
|
||||||
|
) {
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn cancel_handle_repeated2(
|
||||||
|
#[pyo3(cancel_handle)] _param: String,
|
||||||
|
#[pyo3(cancel_handle)] _param2: String,
|
||||||
|
) {
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn cancel_handle_wrong_type(#[pyo3(cancel_handle)] _param: String) {}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
|
||||||
|
|
||||||
fn main() {}
|
fn main() {}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
error: expected `from_py_with`
|
error: expected `cancel_handle` or `from_py_with`
|
||||||
--> tests/ui/invalid_argument_attributes.rs:4:29
|
--> tests/ui/invalid_argument_attributes.rs:4:29
|
||||||
|
|
|
|
||||||
4 | fn invalid_attribute(#[pyo3(get)] param: String) {}
|
4 | fn invalid_attribute(#[pyo3(get)] param: String) {}
|
||||||
|
@ -10,7 +10,7 @@ error: expected `=`
|
||||||
7 | fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {}
|
7 | fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {}
|
||||||
| ^
|
| ^
|
||||||
|
|
||||||
error: expected `from_py_with`
|
error: expected `cancel_handle` or `from_py_with`
|
||||||
--> tests/ui/invalid_argument_attributes.rs:10:31
|
--> tests/ui/invalid_argument_attributes.rs:10:31
|
||||||
|
|
|
|
||||||
10 | fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
|
10 | fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
|
||||||
|
@ -21,3 +21,87 @@ error: expected string literal
|
||||||
|
|
|
|
||||||
13 | fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
|
13 | fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
|
||||||
| ^^^^
|
| ^^^^
|
||||||
|
|
||||||
|
error: `from_py_with` may only be specified once per argument
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:16:56
|
||||||
|
|
|
||||||
|
16 | fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}
|
||||||
|
| ^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: `from_py_with` and `cancel_handle` cannot be specified together
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:20:35
|
||||||
|
|
|
||||||
|
20 | #[pyo3(from_py_with = "func", cancel_handle)] _param: String,
|
||||||
|
| ^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: `cancel_handle` may only be specified once per argument
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:25:55
|
||||||
|
|
|
||||||
|
25 | async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}
|
||||||
|
| ^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
error: `cancel_handle` may only be specified once
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:30:28
|
||||||
|
|
|
||||||
|
30 | #[pyo3(cancel_handle)] _param2: String,
|
||||||
|
| ^^^^^^^
|
||||||
|
|
||||||
|
error: `cancel_handle` attribute can only be used with `async fn`
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:35:53
|
||||||
|
|
|
||||||
|
35 | fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}
|
||||||
|
| ^^^^^^
|
||||||
|
|
||||||
|
error[E0308]: mismatched types
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:37:1
|
||||||
|
|
|
||||||
|
37 | #[pyfunction]
|
||||||
|
| ^^^^^^^^^^^^^
|
||||||
|
| |
|
||||||
|
| expected `String`, found `CancelHandle`
|
||||||
|
| arguments to this function are incorrect
|
||||||
|
|
|
||||||
|
note: function defined here
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:38:10
|
||||||
|
|
|
||||||
|
38 | async fn cancel_handle_wrong_type(#[pyo3(cancel_handle)] _param: String) {}
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^ --------------
|
||||||
|
= note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
|
||||||
|
error[E0277]: the trait bound `CancelHandle: PyClass` is not satisfied
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:41:50
|
||||||
|
|
|
||||||
|
41 | async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
|
||||||
|
| ^^^^ the trait `PyClass` is not implemented for `CancelHandle`
|
||||||
|
|
|
||||||
|
= help: the trait `PyClass` is implemented for `Coroutine`
|
||||||
|
= note: required for `CancelHandle` to implement `FromPyObject<'_>`
|
||||||
|
= note: required for `CancelHandle` to implement `PyFunctionArgument<'_, '_>`
|
||||||
|
note: required by a bound in `extract_argument`
|
||||||
|
--> src/impl_/extract_argument.rs
|
||||||
|
|
|
||||||
|
| pub fn extract_argument<'a, 'py, T>(
|
||||||
|
| ---------------- required by a bound in this function
|
||||||
|
...
|
||||||
|
| T: PyFunctionArgument<'a, 'py>,
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `extract_argument`
|
||||||
|
|
||||||
|
error[E0277]: the trait bound `CancelHandle: Clone` is not satisfied
|
||||||
|
--> tests/ui/invalid_argument_attributes.rs:41:50
|
||||||
|
|
|
||||||
|
41 | async fn missing_cancel_handle_attribute(_param: pyo3::coroutine::CancelHandle) {}
|
||||||
|
| ^^^^ the trait `Clone` is not implemented for `CancelHandle`
|
||||||
|
|
|
||||||
|
= help: the following other types implement trait `PyFunctionArgument<'a, 'py>`:
|
||||||
|
&'a Coroutine
|
||||||
|
&'a mut Coroutine
|
||||||
|
= note: required for `CancelHandle` to implement `FromPyObject<'_>`
|
||||||
|
= note: required for `CancelHandle` to implement `PyFunctionArgument<'_, '_>`
|
||||||
|
note: required by a bound in `extract_argument`
|
||||||
|
--> src/impl_/extract_argument.rs
|
||||||
|
|
|
||||||
|
| pub fn extract_argument<'a, 'py, T>(
|
||||||
|
| ---------------- required by a bound in this function
|
||||||
|
...
|
||||||
|
| T: PyFunctionArgument<'a, 'py>,
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `extract_argument`
|
||||||
|
|
Loading…
Reference in New Issue