diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c4cc464..d7c15ab3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `PyByteArray::data`, `PyByteArray::as_bytes`, and `PyByteArray::as_bytes_mut`. [#967](https://github.com/PyO3/pyo3/pull/967) - Add `GILOnceCell` to use in situations where `lazy_static` or `once_cell` can deadlock. [#975](https://github.com/PyO3/pyo3/pull/975) - Add `Py::borrow`, `Py::borrow_mut`, `Py::try_borrow`, and `Py::try_borrow_mut` for accessing `#[pyclass]` values. [#976](https://github.com/PyO3/pyo3/pull/976) +- Add `IterNextOutput` and `IterANextOutput` for returning from `__next__` / `__anext__`. [#997](https://github.com/PyO3/pyo3/pull/997) ### Changed - Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934) diff --git a/examples/rustapi_module/setup.py b/examples/rustapi_module/setup.py index c90755f5..f1fe9002 100644 --- a/examples/rustapi_module/setup.py +++ b/examples/rustapi_module/setup.py @@ -99,6 +99,7 @@ setup( make_rust_extension("rustapi_module.othermod"), make_rust_extension("rustapi_module.subclassing"), make_rust_extension("rustapi_module.test_dict"), + make_rust_extension("rustapi_module.pyclass_iter"), ], install_requires=install_requires, tests_require=tests_require, diff --git a/examples/rustapi_module/src/lib.rs b/examples/rustapi_module/src/lib.rs index 588ffa72..ce63565a 100644 --- a/examples/rustapi_module/src/lib.rs +++ b/examples/rustapi_module/src/lib.rs @@ -3,4 +3,5 @@ pub mod datetime; pub mod dict_iter; pub mod objstore; pub mod othermod; +pub mod pyclass_iter; pub mod subclassing; diff --git a/examples/rustapi_module/src/pyclass_iter.rs b/examples/rustapi_module/src/pyclass_iter.rs new file mode 100644 index 00000000..bb09e260 --- /dev/null +++ b/examples/rustapi_module/src/pyclass_iter.rs @@ -0,0 +1,34 @@ +use pyo3::class::iter::{IterNextOutput, PyIterProtocol}; +use pyo3::prelude::*; + +/// This is for demonstrating how to return a value from __next__ +#[pyclass] +struct PyClassIter { + count: usize, +} + +#[pymethods] +impl PyClassIter { + #[new] + pub fn new() -> Self { + PyClassIter { count: 0 } + } +} + +#[pyproto] +impl PyIterProtocol for PyClassIter { + fn __next__(mut slf: PyRefMut) -> IterNextOutput { + if slf.count < 5 { + slf.count += 1; + IterNextOutput::Yield(slf.count) + } else { + IterNextOutput::Return("Ended") + } + } +} + +#[pymodule] +pub fn pyclass_iter(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/examples/rustapi_module/tests/test_pyclass_iter.py b/examples/rustapi_module/tests/test_pyclass_iter.py new file mode 100644 index 00000000..f69eab36 --- /dev/null +++ b/examples/rustapi_module/tests/test_pyclass_iter.py @@ -0,0 +1,15 @@ +import pytest +from rustapi_module import pyclass_iter + + +def test_iter(): + i = pyclass_iter.PyClassIter() + assert next(i) == 1 + assert next(i) == 2 + assert next(i) == 3 + assert next(i) == 4 + assert next(i) == 5 + + with pytest.raises(StopIteration) as excinfo: + next(i) + assert excinfo.value.value == "Ended" diff --git a/guide/src/class.md b/guide/src/class.md index bd61aef4..61f19f05 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -808,11 +808,10 @@ It includes two methods `__iter__` and `__next__`: * `fn __iter__(slf: PyRefMut) -> PyResult>` * `fn __next__(slf: PyRefMut) -> PyResult>>` -Returning `Ok(None)` from `__next__` indicates that that there are no further items. +Returning `None` from `__next__` indicates that that there are no further items. These two methods can be take either `PyRef` or `PyRefMut` as their first argument, so that mutable borrow can be avoided if needed. - Example: ```rust @@ -891,6 +890,14 @@ impl PyIterProtocol for Container { For more details on Python's iteration protocols, check out [the "Iterator Types" section of the library documentation](https://docs.python.org/3/library/stdtypes.html#iterator-types). +#### Returning a value from iteration + +This guide has so far shown how to use `Option` to implement yielding values during iteration. +In Python a generator can also return a value. To express this in Rust, PyO3 provides the +[`IterNextOutput`](https://docs.rs/pyo3/latest/pyo3/class/iter/enum.IterNextOutput.html) enum to +both `Yield` values and `Return` a final value - see its docs for further details and an example. + + ## How methods are implemented Users should be able to define a `#[pyclass]` with or without `#[pymethods]`, while PyO3 needs a diff --git a/src/class/iter.rs b/src/class/iter.rs index bb3e3a9d..ccf84044 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -11,6 +11,39 @@ use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python}; /// /// Check [CPython doc](https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_iter) /// for more. +/// +/// # Example +/// The following example shows how to implement a simple Python iterator in Rust which yields +/// the integers 1 to 5, before raising `StopIteration("Ended")`. +/// +/// ```rust +/// use pyo3::prelude::*; +/// use pyo3::PyIterProtocol; +/// use pyo3::class::iter::IterNextOutput; +/// +/// #[pyclass] +/// struct Iter { +/// count: usize +/// } +/// +/// #[pyproto] +/// impl PyIterProtocol for Iter { +/// fn __next__(mut slf: PyRefMut) -> IterNextOutput { +/// if slf.count < 5 { +/// slf.count += 1; +/// IterNextOutput::Yield(slf.count) +/// } else { +/// IterNextOutput::Return("Ended") +/// } +/// } +/// } +/// +/// # let gil = Python::acquire_gil(); +/// # let py = gil.python(); +/// # let inst = Py::new(py, Iter { count: 0 }).unwrap(); +/// # pyo3::py_run!(py, inst, "assert next(inst) == 1"); +/// # // test of StopIteration is done in examples/rustapi_module/pyclass_iter.rs +/// ``` #[allow(unused_variables)] pub trait PyIterProtocol<'p>: PyClass { fn __iter__(slf: Self::Receiver) -> Self::Result @@ -35,7 +68,7 @@ pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> { pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> { type Receiver: TryFromPyCell<'p, Self>; - type Result: IntoPyCallbackOutput; + type Result: IntoPyCallbackOutput; } #[derive(Default)] @@ -64,22 +97,47 @@ impl PyIterMethods { } } -pub struct IterNextOutput(Option); +/// Output of `__next__` which can either `yield` the next value in the iteration, or +/// `return` a value to raise `StopIteration` in Python. +/// +/// See [`PyIterProtocol`](trait.PyIterProtocol.html) for an example. +pub enum IterNextOutput { + Yield(T), + Return(U), +} -impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput { +pub type PyIterNextOutput = IterNextOutput; + +impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopIteration::py_err(())), + match self { + IterNextOutput::Yield(o) => Ok(o.into_ptr()), + IterNextOutput::Return(opt) => Err(crate::exceptions::StopIteration::py_err((opt,))), } } } -impl IntoPyCallbackOutput for Option +impl IntoPyCallbackOutput for IterNextOutput +where + T: IntoPy, + U: IntoPy, +{ + fn convert(self, py: Python) -> PyResult { + match self { + IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))), + IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))), + } + } +} + +impl IntoPyCallbackOutput for Option where T: IntoPy, { - fn convert(self, py: Python) -> PyResult { - Ok(IterNextOutput(self.map(|o| o.into_py(py)))) + fn convert(self, py: Python) -> PyResult { + match self { + Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))), + None => Ok(PyIterNextOutput::Return(py.None())), + } } } diff --git a/src/class/pyasync.rs b/src/class/pyasync.rs index 83df1410..7986adf6 100644 --- a/src/class/pyasync.rs +++ b/src/class/pyasync.rs @@ -71,7 +71,7 @@ pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> { pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> { type Receiver: TryFromPyCell<'p, Self>; - type Result: IntoPyCallbackOutput; + type Result: IntoPyCallbackOutput; } pub trait PyAsyncAenterProtocol<'p>: PyAsyncProtocol<'p> { @@ -107,23 +107,46 @@ impl ffi::PyAsyncMethods { } } -pub struct IterANextOutput(Option); +pub enum IterANextOutput { + Yield(T), + Return(U), +} -impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput { +pub type PyIterANextOutput = IterANextOutput; + +impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterANextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopAsyncIteration::py_err(())), + match self { + IterANextOutput::Yield(o) => Ok(o.into_ptr()), + IterANextOutput::Return(opt) => { + Err(crate::exceptions::StopAsyncIteration::py_err((opt,))) + } } } } -impl IntoPyCallbackOutput for Option +impl IntoPyCallbackOutput for IterANextOutput +where + T: IntoPy, + U: IntoPy, +{ + fn convert(self, py: Python) -> PyResult { + match self { + IterANextOutput::Yield(o) => Ok(IterANextOutput::Yield(o.into_py(py))), + IterANextOutput::Return(o) => Ok(IterANextOutput::Return(o.into_py(py))), + } + } +} + +impl IntoPyCallbackOutput for Option where T: IntoPy, { - fn convert(self, py: Python) -> PyResult { - Ok(IterANextOutput(self.map(|o| o.into_py(py)))) + fn convert(self, py: Python) -> PyResult { + match self { + Some(o) => Ok(PyIterANextOutput::Yield(o.into_py(py))), + None => Ok(PyIterANextOutput::Return(py.None())), + } } }