Replace IterNextOutput by autoref-based specialization to allow returning arbitrary values

This commit is contained in:
Adam Reichold 2023-12-18 16:26:33 +01:00
parent d75d4bdf81
commit ca7d90dcf3
8 changed files with 141 additions and 69 deletions

View File

@ -792,9 +792,11 @@ pub const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc"
const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc") const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc")
.arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]); .arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]);
const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc");
const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion( const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc")
TokenGenerator(|| quote! { _pyo3::class::iter::IterNextOutput::<_, _> }), .return_specialized_conversion(
); TokenGenerator(|| quote! { IterBaseKind, IterOptionKind, IterResultOptionKind }),
TokenGenerator(|| quote! { iter_tag }),
);
const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc"); const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc");
const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc"); const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc");
const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion( const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion(
@ -1003,17 +1005,23 @@ fn extract_object(
enum ReturnMode { enum ReturnMode {
ReturnSelf, ReturnSelf,
Conversion(TokenGenerator), Conversion(TokenGenerator),
SpecializedConversion(TokenGenerator, TokenGenerator),
} }
impl ReturnMode { impl ReturnMode {
fn return_call_output(&self, call: TokenStream) -> TokenStream { fn return_call_output(&self, call: TokenStream) -> TokenStream {
match self { match self {
ReturnMode::Conversion(conversion) => quote! { ReturnMode::Conversion(conversion) => quote! {
let _result: _pyo3::PyResult<#conversion> = #call; let _result: _pyo3::PyResult<#conversion> = _pyo3::callback::convert(py, #call);
_pyo3::callback::convert(py, _result) _pyo3::callback::convert(py, _result)
}, },
ReturnMode::SpecializedConversion(traits, tag) => quote! {
let _result = #call;
use _pyo3::impl_::pymethods::{#traits};
(&_result).#tag().convert(py, _result)
},
ReturnMode::ReturnSelf => quote! { ReturnMode::ReturnSelf => quote! {
let _result: _pyo3::PyResult<()> = #call; let _result: _pyo3::PyResult<()> = _pyo3::callback::convert(py, #call);
_result?; _result?;
_pyo3::ffi::Py_XINCREF(_raw_slf); _pyo3::ffi::Py_XINCREF(_raw_slf);
::std::result::Result::Ok(_raw_slf) ::std::result::Result::Ok(_raw_slf)
@ -1062,6 +1070,15 @@ impl SlotDef {
self self
} }
const fn return_specialized_conversion(
mut self,
traits: TokenGenerator,
tag: TokenGenerator,
) -> Self {
self.return_mode = Some(ReturnMode::SpecializedConversion(traits, tag));
self
}
const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self {
self.extract_error_mode = extract_error_mode; self.extract_error_mode = extract_error_mode;
self self
@ -1162,11 +1179,11 @@ fn generate_method_body(
let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode, holders); let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode, holders);
let rust_name = spec.name; let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?; let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?;
let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) }; let call = quote! { #cls::#rust_name(#self_arg #(#args),*) };
Ok(if let Some(return_mode) = return_mode { Ok(if let Some(return_mode) = return_mode {
return_mode.return_call_output(call) return_mode.return_call_output(call)
} else { } else {
call quote! { _pyo3::callback::convert(py, #call) }
}) })
} }

View File

@ -5,7 +5,8 @@
//! when awaited, see guide examples related to pyo3-asyncio for ways //! when awaited, see guide examples related to pyo3-asyncio for ways
//! to suspend tasks and await results. //! to suspend tasks and await results.
use pyo3::{prelude::*, pyclass::IterNextOutput}; use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
#[pyclass] #[pyclass]
#[derive(Debug)] #[derive(Debug)]
@ -30,13 +31,13 @@ impl IterAwaitable {
pyself pyself
} }
fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> { fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
match self.result.take() { match self.result.take() {
Some(res) => match res { Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)), Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err), Err(err) => Err(err),
}, },
_ => Ok(IterNextOutput::Yield(py.None().into())), _ => Ok(py.None().into()),
} }
} }
} }
@ -66,15 +67,13 @@ impl FutureAwaitable {
pyself pyself
} }
fn __next__( fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
mut pyself: PyRefMut<'_, Self>,
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
match pyself.result { match pyself.result {
Some(_) => match pyself.result.take().unwrap() { Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)), Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err), Err(err) => Err(err),
}, },
_ => Ok(IterNextOutput::Yield(pyself)), _ => Ok(pyself),
} }
} }
} }

View File

@ -1,5 +1,4 @@
use pyo3::exceptions::PyValueError; use pyo3::exceptions::{PyStopIteration, PyValueError};
use pyo3::iter::IterNextOutput;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyType; use pyo3::types::PyType;
@ -28,12 +27,12 @@ impl PyClassIter {
Default::default() Default::default()
} }
fn __next__(&mut self) -> IterNextOutput<usize, &'static str> { fn __next__(&mut self) -> PyResult<usize> {
if self.count < 5 { if self.count < 5 {
self.count += 1; self.count += 1;
IterNextOutput::Yield(self.count) Ok(self.count)
} else { } else {
IterNextOutput::Return("Ended") Err(PyStopIteration::new_err("Ended"))
} }
} }
} }

View File

@ -14,7 +14,6 @@ use crate::{
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker}, coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration}, exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException, panic::PanicException,
pyclass::IterNextOutput,
types::{PyIterator, PyString}, types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
}; };
@ -68,11 +67,7 @@ impl Coroutine {
} }
} }
fn poll( fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
&mut self,
py: Python<'_>,
throw: Option<PyObject>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
// raise if the coroutine has already been run to completion // raise if the coroutine has already been run to completion
let future_rs = match self.future { let future_rs = match self.future {
Some(ref mut fut) => fut, Some(ref mut fut) => fut,
@ -100,7 +95,7 @@ impl Coroutine {
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) { match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => { Ok(Poll::Ready(res)) => {
self.close(); self.close();
return Ok(IterNextOutput::Return(res?)); return Err(PyStopIteration::new_err(res?));
} }
Err(err) => { Err(err) => {
self.close(); self.close();
@ -115,19 +110,12 @@ impl Coroutine {
if let Some(future) = PyIterator::from_object(future).unwrap().next() { if let Some(future) = PyIterator::from_object(future).unwrap().next() {
// future has not been leaked into Python for now, and Rust code can only call // future has not been leaked into Python for now, and Rust code can only call
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
return Ok(IterNextOutput::Yield(future.unwrap().into())); return Ok(future.unwrap().into());
} }
} }
// if waker has been waken during future polling, this is roughly equivalent to // if waker has been waken during future polling, this is roughly equivalent to
// `await asyncio.sleep(0)`, so just yield `None`. // `await asyncio.sleep(0)`, so just yield `None`.
Ok(IterNextOutput::Yield(py.None().into())) Ok(py.None().into())
}
}
pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResult<PyObject> {
match result {
IterNextOutput::Yield(ob) => Ok(ob),
IterNextOutput::Return(ob) => Err(PyStopIteration::new_err(ob)),
} }
} }
@ -153,11 +141,11 @@ impl Coroutine {
} }
fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> { fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?) self.poll(py, None)
} }
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> { fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(exc))?) self.poll(py, Some(exc))
} }
fn close(&mut self) { fn close(&mut self) {
@ -170,7 +158,7 @@ impl Coroutine {
self_ self_
} }
fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> { fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
self.poll(py, None) self.poll(py, None)
} }
} }

View File

@ -1,3 +1,4 @@
use crate::callback::IntoPyCallbackOutput;
use crate::gil::LockGIL; use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap; use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string; use crate::internal_tricks::extract_c_string;
@ -7,6 +8,7 @@ use std::ffi::CStr;
use std::fmt; use std::fmt;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe}; use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;
/// Python 3.8 and up - __ipow__ has modulo argument correctly populated. /// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)] #[cfg(Py_3_8)]
@ -299,3 +301,83 @@ pub(crate) fn get_name(name: &'static str) -> PyResult<Cow<'static, CStr>> {
pub(crate) fn get_doc(doc: &'static str) -> PyResult<Cow<'static, CStr>> { pub(crate) fn get_doc(doc: &'static str) -> PyResult<Cow<'static, CStr>> {
extract_c_string(doc, "function doc cannot contain NUL byte.") extract_c_string(doc, "function doc cannot contain NUL byte.")
} }
// Autoref-based specialization for handling `__next__` returning `Option`
pub struct IterBaseTag;
impl IterBaseTag {
#[inline]
pub fn convert<Value, Target>(self, py: Python<'_>, value: Value) -> PyResult<Target>
where
Value: IntoPyCallbackOutput<Target>,
{
value.convert(py)
}
}
pub trait IterBaseKind {
#[inline]
fn iter_tag(&self) -> IterBaseTag {
IterBaseTag
}
}
impl<Value> IterBaseKind for &Value {}
pub struct IterOptionTag;
impl IterOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: Option<Value>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Some(value) => value.convert(py),
None => Ok(null_mut()),
}
}
}
pub trait IterOptionKind {
#[inline]
fn iter_tag(&self) -> IterOptionTag {
IterOptionTag
}
}
impl<Value> IterOptionKind for Option<Value> {}
pub struct IterResultOptionTag;
impl IterResultOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: PyResult<Option<Value>>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Ok(Some(value)) => value.convert(py),
Ok(None) => Ok(null_mut()),
Err(err) => Err(err),
}
}
}
pub trait IterResultOptionKind {
#[inline]
fn iter_tag(&self) -> IterResultOptionTag {
IterResultOptionTag
}
}
impl<Value> IterResultOptionKind for PyResult<Option<Value>> {}

View File

@ -362,6 +362,7 @@ pub mod class {
/// For compatibility reasons this has not yet been removed, however will be done so /// For compatibility reasons this has not yet been removed, however will be done so
/// once <https://github.com/rust-lang/rust/issues/30827> is resolved. /// once <https://github.com/rust-lang/rust/issues/30827> is resolved.
pub mod iter { pub mod iter {
#[allow(deprecated)]
pub use crate::pyclass::{IterNextOutput, PyIterNextOutput}; pub use crate::pyclass::{IterNextOutput, PyIterNextOutput};
} }

View File

@ -91,6 +91,7 @@ impl CompareOp {
/// Usage example: /// Usage example:
/// ///
/// ```rust /// ```rust
/// # #![allow(deprecated)]
/// use pyo3::prelude::*; /// use pyo3::prelude::*;
/// use pyo3::iter::IterNextOutput; /// use pyo3::iter::IterNextOutput;
/// ///
@ -122,6 +123,7 @@ impl CompareOp {
/// } /// }
/// } /// }
/// ``` /// ```
#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")]
pub enum IterNextOutput<T, U> { pub enum IterNextOutput<T, U> {
/// The value yielded by the iterator. /// The value yielded by the iterator.
Yield(T), Yield(T),
@ -130,38 +132,22 @@ pub enum IterNextOutput<T, U> {
} }
/// Alias of `IterNextOutput` with `PyObject` yield & return values. /// Alias of `IterNextOutput` with `PyObject` yield & return values.
#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")]
#[allow(deprecated)]
pub type PyIterNextOutput = IterNextOutput<PyObject, PyObject>; pub type PyIterNextOutput = IterNextOutput<PyObject, PyObject>;
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput { #[allow(deprecated)]
fn convert(self, _py: Python<'_>) -> PyResult<*mut ffi::PyObject> { impl<T, U> IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput<T, U>
match self {
IterNextOutput::Yield(o) => Ok(o.into_ptr()),
IterNextOutput::Return(opt) => Err(crate::exceptions::PyStopIteration::new_err((opt,))),
}
}
}
impl<T, U> IntoPyCallbackOutput<PyIterNextOutput> for IterNextOutput<T, U>
where where
T: IntoPy<PyObject>, T: IntoPy<PyObject>,
U: IntoPy<PyObject>, U: IntoPy<PyObject>,
{ {
fn convert(self, py: Python<'_>) -> PyResult<PyIterNextOutput> { fn convert(self, py: Python<'_>) -> PyResult<*mut ffi::PyObject> {
match self { match self {
IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))), IterNextOutput::Yield(o) => Ok(o.into_py(py).into_ptr()),
IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))), IterNextOutput::Return(o) => {
} Err(crate::exceptions::PyStopIteration::new_err(o.into_py(py)))
} }
}
impl<T> IntoPyCallbackOutput<PyIterNextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python<'_>) -> PyResult<PyIterNextOutput> {
match self {
Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))),
None => Ok(PyIterNextOutput::Return(py.None().into())),
} }
} }
} }

View File

@ -658,10 +658,10 @@ impl OnceFuture {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf slf
} }
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<PyObject> { fn __next__<'py>(&'py mut self, py: Python<'py>) -> Option<&'py PyAny> {
if !slf.polled { if !self.polled {
slf.polled = true; self.polled = true;
Some(slf.future.clone()) Some(self.future.as_ref(py))
} else { } else {
None None
} }