Replace IterNextOutput by autoref-based specialization to allow returning arbitrary values

This commit is contained in:
Adam Reichold 2023-12-18 16:26:33 +01:00
parent d75d4bdf81
commit ca7d90dcf3
8 changed files with 141 additions and 69 deletions

View File

@ -792,9 +792,11 @@ pub const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc"
const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc")
.arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]);
const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc");
const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion(
TokenGenerator(|| quote! { _pyo3::class::iter::IterNextOutput::<_, _> }),
);
const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc")
.return_specialized_conversion(
TokenGenerator(|| quote! { IterBaseKind, IterOptionKind, IterResultOptionKind }),
TokenGenerator(|| quote! { iter_tag }),
);
const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc");
const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc");
const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion(
@ -1003,17 +1005,23 @@ fn extract_object(
enum ReturnMode {
ReturnSelf,
Conversion(TokenGenerator),
SpecializedConversion(TokenGenerator, TokenGenerator),
}
impl ReturnMode {
fn return_call_output(&self, call: TokenStream) -> TokenStream {
match self {
ReturnMode::Conversion(conversion) => quote! {
let _result: _pyo3::PyResult<#conversion> = #call;
let _result: _pyo3::PyResult<#conversion> = _pyo3::callback::convert(py, #call);
_pyo3::callback::convert(py, _result)
},
ReturnMode::SpecializedConversion(traits, tag) => quote! {
let _result = #call;
use _pyo3::impl_::pymethods::{#traits};
(&_result).#tag().convert(py, _result)
},
ReturnMode::ReturnSelf => quote! {
let _result: _pyo3::PyResult<()> = #call;
let _result: _pyo3::PyResult<()> = _pyo3::callback::convert(py, #call);
_result?;
_pyo3::ffi::Py_XINCREF(_raw_slf);
::std::result::Result::Ok(_raw_slf)
@ -1062,6 +1070,15 @@ impl SlotDef {
self
}
const fn return_specialized_conversion(
mut self,
traits: TokenGenerator,
tag: TokenGenerator,
) -> Self {
self.return_mode = Some(ReturnMode::SpecializedConversion(traits, tag));
self
}
const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self {
self.extract_error_mode = extract_error_mode;
self
@ -1162,11 +1179,11 @@ fn generate_method_body(
let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode, holders);
let rust_name = spec.name;
let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?;
let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) };
let call = quote! { #cls::#rust_name(#self_arg #(#args),*) };
Ok(if let Some(return_mode) = return_mode {
return_mode.return_call_output(call)
} else {
call
quote! { _pyo3::callback::convert(py, #call) }
})
}

View File

@ -5,7 +5,8 @@
//! when awaited, see guide examples related to pyo3-asyncio for ways
//! to suspend tasks and await results.
use pyo3::{prelude::*, pyclass::IterNextOutput};
use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
#[pyclass]
#[derive(Debug)]
@ -30,13 +31,13 @@ impl IterAwaitable {
pyself
}
fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
match self.result.take() {
Some(res) => match res {
Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(py.None().into())),
_ => Ok(py.None().into()),
}
}
}
@ -66,15 +67,13 @@ impl FutureAwaitable {
pyself
}
fn __next__(
mut pyself: PyRefMut<'_, Self>,
) -> PyResult<IterNextOutput<PyRefMut<'_, Self>, PyObject>> {
fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
match pyself.result {
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Ok(IterNextOutput::Return(v)),
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(IterNextOutput::Yield(pyself)),
_ => Ok(pyself),
}
}
}

View File

@ -1,5 +1,4 @@
use pyo3::exceptions::PyValueError;
use pyo3::iter::IterNextOutput;
use pyo3::exceptions::{PyStopIteration, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyType;
@ -28,12 +27,12 @@ impl PyClassIter {
Default::default()
}
fn __next__(&mut self) -> IterNextOutput<usize, &'static str> {
fn __next__(&mut self) -> PyResult<usize> {
if self.count < 5 {
self.count += 1;
IterNextOutput::Yield(self.count)
Ok(self.count)
} else {
IterNextOutput::Return("Ended")
Err(PyStopIteration::new_err("Ended"))
}
}
}

View File

@ -14,7 +14,6 @@ use crate::{
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};
@ -68,11 +67,7 @@ impl Coroutine {
}
}
fn poll(
&mut self,
py: Python<'_>,
throw: Option<PyObject>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
// raise if the coroutine has already been run to completion
let future_rs = match self.future {
Some(ref mut fut) => fut,
@ -100,7 +95,7 @@ impl Coroutine {
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Ok(IterNextOutput::Return(res?));
return Err(PyStopIteration::new_err(res?));
}
Err(err) => {
self.close();
@ -115,19 +110,12 @@ impl Coroutine {
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
// future has not been leaked into Python for now, and Rust code can only call
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
return Ok(IterNextOutput::Yield(future.unwrap().into()));
return Ok(future.unwrap().into());
}
}
// if waker has been waken during future polling, this is roughly equivalent to
// `await asyncio.sleep(0)`, so just yield `None`.
Ok(IterNextOutput::Yield(py.None().into()))
}
}
pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResult<PyObject> {
match result {
IterNextOutput::Yield(ob) => Ok(ob),
IterNextOutput::Return(ob) => Err(PyStopIteration::new_err(ob)),
Ok(py.None().into())
}
}
@ -153,11 +141,11 @@ impl Coroutine {
}
fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?)
self.poll(py, None)
}
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(exc))?)
self.poll(py, Some(exc))
}
fn close(&mut self) {
@ -170,7 +158,7 @@ impl Coroutine {
self_
}
fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
self.poll(py, None)
}
}

View File

@ -1,3 +1,4 @@
use crate::callback::IntoPyCallbackOutput;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
@ -7,6 +8,7 @@ use std::ffi::CStr;
use std::fmt;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;
/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
@ -299,3 +301,83 @@ pub(crate) fn get_name(name: &'static str) -> PyResult<Cow<'static, CStr>> {
pub(crate) fn get_doc(doc: &'static str) -> PyResult<Cow<'static, CStr>> {
extract_c_string(doc, "function doc cannot contain NUL byte.")
}
// Autoref-based specialization for handling `__next__` returning `Option`
pub struct IterBaseTag;
impl IterBaseTag {
#[inline]
pub fn convert<Value, Target>(self, py: Python<'_>, value: Value) -> PyResult<Target>
where
Value: IntoPyCallbackOutput<Target>,
{
value.convert(py)
}
}
pub trait IterBaseKind {
#[inline]
fn iter_tag(&self) -> IterBaseTag {
IterBaseTag
}
}
impl<Value> IterBaseKind for &Value {}
pub struct IterOptionTag;
impl IterOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: Option<Value>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Some(value) => value.convert(py),
None => Ok(null_mut()),
}
}
}
pub trait IterOptionKind {
#[inline]
fn iter_tag(&self) -> IterOptionTag {
IterOptionTag
}
}
impl<Value> IterOptionKind for Option<Value> {}
pub struct IterResultOptionTag;
impl IterResultOptionTag {
#[inline]
pub fn convert<Value>(
self,
py: Python<'_>,
value: PyResult<Option<Value>>,
) -> PyResult<*mut ffi::PyObject>
where
Value: IntoPyCallbackOutput<*mut ffi::PyObject>,
{
match value {
Ok(Some(value)) => value.convert(py),
Ok(None) => Ok(null_mut()),
Err(err) => Err(err),
}
}
}
pub trait IterResultOptionKind {
#[inline]
fn iter_tag(&self) -> IterResultOptionTag {
IterResultOptionTag
}
}
impl<Value> IterResultOptionKind for PyResult<Option<Value>> {}

View File

@ -362,6 +362,7 @@ pub mod class {
/// For compatibility reasons this has not yet been removed, however will be done so
/// once <https://github.com/rust-lang/rust/issues/30827> is resolved.
pub mod iter {
#[allow(deprecated)]
pub use crate::pyclass::{IterNextOutput, PyIterNextOutput};
}

View File

@ -91,6 +91,7 @@ impl CompareOp {
/// Usage example:
///
/// ```rust
/// # #![allow(deprecated)]
/// use pyo3::prelude::*;
/// use pyo3::iter::IterNextOutput;
///
@ -122,6 +123,7 @@ impl CompareOp {
/// }
/// }
/// ```
#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")]
pub enum IterNextOutput<T, U> {
/// The value yielded by the iterator.
Yield(T),
@ -130,39 +132,23 @@ pub enum IterNextOutput<T, U> {
}
/// Alias of `IterNextOutput` with `PyObject` yield & return values.
#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")]
#[allow(deprecated)]
pub type PyIterNextOutput = IterNextOutput<PyObject, PyObject>;
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput {
fn convert(self, _py: Python<'_>) -> PyResult<*mut ffi::PyObject> {
match self {
IterNextOutput::Yield(o) => Ok(o.into_ptr()),
IterNextOutput::Return(opt) => Err(crate::exceptions::PyStopIteration::new_err((opt,))),
}
}
}
impl<T, U> IntoPyCallbackOutput<PyIterNextOutput> for IterNextOutput<T, U>
#[allow(deprecated)]
impl<T, U> IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python<'_>) -> PyResult<PyIterNextOutput> {
fn convert(self, py: Python<'_>) -> PyResult<*mut ffi::PyObject> {
match self {
IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))),
IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))),
IterNextOutput::Yield(o) => Ok(o.into_py(py).into_ptr()),
IterNextOutput::Return(o) => {
Err(crate::exceptions::PyStopIteration::new_err(o.into_py(py)))
}
}
}
impl<T> IntoPyCallbackOutput<PyIterNextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python<'_>) -> PyResult<PyIterNextOutput> {
match self {
Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))),
None => Ok(PyIterNextOutput::Return(py.None().into())),
}
}
}

View File

@ -658,10 +658,10 @@ impl OnceFuture {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<PyObject> {
if !slf.polled {
slf.polled = true;
Some(slf.future.clone())
fn __next__<'py>(&'py mut self, py: Python<'py>) -> Option<&'py PyAny> {
if !self.polled {
self.polled = true;
Some(self.future.as_ref(py))
} else {
None
}