From ca7d90dcf3b764c133023bc8558338bd14b16835 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Mon, 18 Dec 2023 16:26:33 +0100 Subject: [PATCH] Replace IterNextOutput by autoref-based specialization to allow returning arbitrary values --- pyo3-macros-backend/src/pymethod.rs | 31 ++++++++--- pytests/src/awaitable.rs | 17 +++--- pytests/src/pyclasses.rs | 9 ++-- src/coroutine.rs | 26 +++------ src/impl_/pymethods.rs | 82 +++++++++++++++++++++++++++++ src/lib.rs | 1 + src/pyclass.rs | 36 ++++--------- tests/test_proto_methods.rs | 8 +-- 8 files changed, 141 insertions(+), 69 deletions(-) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 239bd96a..2a730fd2 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -792,9 +792,11 @@ pub const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc" const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc") .arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]); const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); -const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion( - TokenGenerator(|| quote! { _pyo3::class::iter::IterNextOutput::<_, _> }), -); +const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc") + .return_specialized_conversion( + TokenGenerator(|| quote! { IterBaseKind, IterOptionKind, IterResultOptionKind }), + TokenGenerator(|| quote! { iter_tag }), + ); const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc"); const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc"); const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion( @@ -1003,17 +1005,23 @@ fn extract_object( enum ReturnMode { ReturnSelf, Conversion(TokenGenerator), + SpecializedConversion(TokenGenerator, TokenGenerator), } impl ReturnMode { fn return_call_output(&self, call: TokenStream) -> TokenStream { match self { ReturnMode::Conversion(conversion) => quote! { - let _result: _pyo3::PyResult<#conversion> = #call; + let _result: _pyo3::PyResult<#conversion> = _pyo3::callback::convert(py, #call); _pyo3::callback::convert(py, _result) }, + ReturnMode::SpecializedConversion(traits, tag) => quote! { + let _result = #call; + use _pyo3::impl_::pymethods::{#traits}; + (&_result).#tag().convert(py, _result) + }, ReturnMode::ReturnSelf => quote! { - let _result: _pyo3::PyResult<()> = #call; + let _result: _pyo3::PyResult<()> = _pyo3::callback::convert(py, #call); _result?; _pyo3::ffi::Py_XINCREF(_raw_slf); ::std::result::Result::Ok(_raw_slf) @@ -1062,6 +1070,15 @@ impl SlotDef { self } + const fn return_specialized_conversion( + mut self, + traits: TokenGenerator, + tag: TokenGenerator, + ) -> Self { + self.return_mode = Some(ReturnMode::SpecializedConversion(traits, tag)); + self + } + const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { self.extract_error_mode = extract_error_mode; self @@ -1162,11 +1179,11 @@ fn generate_method_body( let self_arg = spec.tp.self_arg(Some(cls), extract_error_mode, holders); let rust_name = spec.name; let args = extract_proto_arguments(spec, arguments, extract_error_mode, holders)?; - let call = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) }; + let call = quote! { #cls::#rust_name(#self_arg #(#args),*) }; Ok(if let Some(return_mode) = return_mode { return_mode.return_call_output(call) } else { - call + quote! { _pyo3::callback::convert(py, #call) } }) } diff --git a/pytests/src/awaitable.rs b/pytests/src/awaitable.rs index 1f798aa4..0cc17333 100644 --- a/pytests/src/awaitable.rs +++ b/pytests/src/awaitable.rs @@ -5,7 +5,8 @@ //! when awaited, see guide examples related to pyo3-asyncio for ways //! to suspend tasks and await results. -use pyo3::{prelude::*, pyclass::IterNextOutput}; +use pyo3::exceptions::PyStopIteration; +use pyo3::prelude::*; #[pyclass] #[derive(Debug)] @@ -30,13 +31,13 @@ impl IterAwaitable { pyself } - fn __next__(&mut self, py: Python<'_>) -> PyResult> { + fn __next__(&mut self, py: Python<'_>) -> PyResult { match self.result.take() { Some(res) => match res { - Ok(v) => Ok(IterNextOutput::Return(v)), + Ok(v) => Err(PyStopIteration::new_err(v)), Err(err) => Err(err), }, - _ => Ok(IterNextOutput::Yield(py.None().into())), + _ => Ok(py.None().into()), } } } @@ -66,15 +67,13 @@ impl FutureAwaitable { pyself } - fn __next__( - mut pyself: PyRefMut<'_, Self>, - ) -> PyResult, PyObject>> { + fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult> { match pyself.result { Some(_) => match pyself.result.take().unwrap() { - Ok(v) => Ok(IterNextOutput::Return(v)), + Ok(v) => Err(PyStopIteration::new_err(v)), Err(err) => Err(err), }, - _ => Ok(IterNextOutput::Yield(pyself)), + _ => Ok(pyself), } } } diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index 46c8523c..326893d1 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -1,5 +1,4 @@ -use pyo3::exceptions::PyValueError; -use pyo3::iter::IterNextOutput; +use pyo3::exceptions::{PyStopIteration, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyType; @@ -28,12 +27,12 @@ impl PyClassIter { Default::default() } - fn __next__(&mut self) -> IterNextOutput { + fn __next__(&mut self) -> PyResult { if self.count < 5 { self.count += 1; - IterNextOutput::Yield(self.count) + Ok(self.count) } else { - IterNextOutput::Return("Ended") + Err(PyStopIteration::new_err("Ended")) } } } diff --git a/src/coroutine.rs b/src/coroutine.rs index 6380b4e0..7dd73cbb 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -14,7 +14,6 @@ use crate::{ coroutine::{cancel::ThrowCallback, waker::AsyncioWaker}, exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration}, panic::PanicException, - pyclass::IterNextOutput, types::{PyIterator, PyString}, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python, }; @@ -68,11 +67,7 @@ impl Coroutine { } } - fn poll( - &mut self, - py: Python<'_>, - throw: Option, - ) -> PyResult> { + fn poll(&mut self, py: Python<'_>, throw: Option) -> PyResult { // raise if the coroutine has already been run to completion let future_rs = match self.future { Some(ref mut fut) => fut, @@ -100,7 +95,7 @@ impl Coroutine { match panic::catch_unwind(panic::AssertUnwindSafe(poll)) { Ok(Poll::Ready(res)) => { self.close(); - return Ok(IterNextOutput::Return(res?)); + return Err(PyStopIteration::new_err(res?)); } Err(err) => { self.close(); @@ -115,19 +110,12 @@ impl Coroutine { if let Some(future) = PyIterator::from_object(future).unwrap().next() { // future has not been leaked into Python for now, and Rust code can only call // `set_result(None)` in `Wake` implementation, so it's safe to unwrap - return Ok(IterNextOutput::Yield(future.unwrap().into())); + return Ok(future.unwrap().into()); } } // if waker has been waken during future polling, this is roughly equivalent to // `await asyncio.sleep(0)`, so just yield `None`. - Ok(IterNextOutput::Yield(py.None().into())) - } -} - -pub(crate) fn iter_result(result: IterNextOutput) -> PyResult { - match result { - IterNextOutput::Yield(ob) => Ok(ob), - IterNextOutput::Return(ob) => Err(PyStopIteration::new_err(ob)), + Ok(py.None().into()) } } @@ -153,11 +141,11 @@ impl Coroutine { } fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult { - iter_result(self.poll(py, None)?) + self.poll(py, None) } fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult { - iter_result(self.poll(py, Some(exc))?) + self.poll(py, Some(exc)) } fn close(&mut self) { @@ -170,7 +158,7 @@ impl Coroutine { self_ } - fn __next__(&mut self, py: Python<'_>) -> PyResult> { + fn __next__(&mut self, py: Python<'_>) -> PyResult { self.poll(py, None) } } diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index ff2857ec..3cc77341 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -1,3 +1,4 @@ +use crate::callback::IntoPyCallbackOutput; use crate::gil::LockGIL; use crate::impl_::panic::PanicTrap; use crate::internal_tricks::extract_c_string; @@ -7,6 +8,7 @@ use std::ffi::CStr; use std::fmt; use std::os::raw::{c_int, c_void}; use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::ptr::null_mut; /// Python 3.8 and up - __ipow__ has modulo argument correctly populated. #[cfg(Py_3_8)] @@ -299,3 +301,83 @@ pub(crate) fn get_name(name: &'static str) -> PyResult> { pub(crate) fn get_doc(doc: &'static str) -> PyResult> { extract_c_string(doc, "function doc cannot contain NUL byte.") } + +// Autoref-based specialization for handling `__next__` returning `Option` + +pub struct IterBaseTag; + +impl IterBaseTag { + #[inline] + pub fn convert(self, py: Python<'_>, value: Value) -> PyResult + where + Value: IntoPyCallbackOutput, + { + value.convert(py) + } +} + +pub trait IterBaseKind { + #[inline] + fn iter_tag(&self) -> IterBaseTag { + IterBaseTag + } +} + +impl IterBaseKind for &Value {} + +pub struct IterOptionTag; + +impl IterOptionTag { + #[inline] + pub fn convert( + self, + py: Python<'_>, + value: Option, + ) -> PyResult<*mut ffi::PyObject> + where + Value: IntoPyCallbackOutput<*mut ffi::PyObject>, + { + match value { + Some(value) => value.convert(py), + None => Ok(null_mut()), + } + } +} + +pub trait IterOptionKind { + #[inline] + fn iter_tag(&self) -> IterOptionTag { + IterOptionTag + } +} + +impl IterOptionKind for Option {} + +pub struct IterResultOptionTag; + +impl IterResultOptionTag { + #[inline] + pub fn convert( + self, + py: Python<'_>, + value: PyResult>, + ) -> PyResult<*mut ffi::PyObject> + where + Value: IntoPyCallbackOutput<*mut ffi::PyObject>, + { + match value { + Ok(Some(value)) => value.convert(py), + Ok(None) => Ok(null_mut()), + Err(err) => Err(err), + } + } +} + +pub trait IterResultOptionKind { + #[inline] + fn iter_tag(&self) -> IterResultOptionTag { + IterResultOptionTag + } +} + +impl IterResultOptionKind for PyResult> {} diff --git a/src/lib.rs b/src/lib.rs index 985ec0aa..f787b23a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -362,6 +362,7 @@ pub mod class { /// For compatibility reasons this has not yet been removed, however will be done so /// once is resolved. pub mod iter { + #[allow(deprecated)] pub use crate::pyclass::{IterNextOutput, PyIterNextOutput}; } diff --git a/src/pyclass.rs b/src/pyclass.rs index 23affb4b..b1c49be2 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -91,6 +91,7 @@ impl CompareOp { /// Usage example: /// /// ```rust +/// # #![allow(deprecated)] /// use pyo3::prelude::*; /// use pyo3::iter::IterNextOutput; /// @@ -122,6 +123,7 @@ impl CompareOp { /// } /// } /// ``` +#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")] pub enum IterNextOutput { /// The value yielded by the iterator. Yield(T), @@ -130,38 +132,22 @@ pub enum IterNextOutput { } /// Alias of `IterNextOutput` with `PyObject` yield & return values. +#[deprecated(since = "0.21.0", note = "Use `Option` or `PyStopIteration` instead.")] +#[allow(deprecated)] pub type PyIterNextOutput = IterNextOutput; -impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput { - fn convert(self, _py: Python<'_>) -> PyResult<*mut ffi::PyObject> { - match self { - IterNextOutput::Yield(o) => Ok(o.into_ptr()), - IterNextOutput::Return(opt) => Err(crate::exceptions::PyStopIteration::new_err((opt,))), - } - } -} - -impl IntoPyCallbackOutput for IterNextOutput +#[allow(deprecated)] +impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput where T: IntoPy, U: IntoPy, { - fn convert(self, py: Python<'_>) -> PyResult { + fn convert(self, py: Python<'_>) -> PyResult<*mut ffi::PyObject> { match self { - IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))), - IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))), - } - } -} - -impl IntoPyCallbackOutput for Option -where - T: IntoPy, -{ - fn convert(self, py: Python<'_>) -> PyResult { - match self { - Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))), - None => Ok(PyIterNextOutput::Return(py.None().into())), + IterNextOutput::Yield(o) => Ok(o.into_py(py).into_ptr()), + IterNextOutput::Return(o) => { + Err(crate::exceptions::PyStopIteration::new_err(o.into_py(py))) + } } } } diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index b3503451..caaae751 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -658,10 +658,10 @@ impl OnceFuture { fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } - fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { - if !slf.polled { - slf.polled = true; - Some(slf.future.clone()) + fn __next__<'py>(&'py mut self, py: Python<'py>) -> Option<&'py PyAny> { + if !self.polled { + self.polled = true; + Some(self.future.as_ref(py)) } else { None }