From f4953224d8ee71e4ee02570e8627306a607ba1af Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 25 Jan 2023 22:15:43 +0000 Subject: [PATCH] correct ffi definition of PyIter_Check --- newsfragments/2914.fixed.md | 1 + pyo3-ffi/src/abstract_.rs | 11 ++++-- src/types/iterator.rs | 74 +++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 newsfragments/2914.fixed.md diff --git a/newsfragments/2914.fixed.md b/newsfragments/2914.fixed.md new file mode 100644 index 00000000..9ea78ee8 --- /dev/null +++ b/newsfragments/2914.fixed.md @@ -0,0 +1 @@ +Fix downcast to `PyIterator` succeeding for Python classes which did not implement `__next__`. diff --git a/pyo3-ffi/src/abstract_.rs b/pyo3-ffi/src/abstract_.rs index b8306ca9..6f954bc1 100644 --- a/pyo3-ffi/src/abstract_.rs +++ b/pyo3-ffi/src/abstract_.rs @@ -91,9 +91,12 @@ extern "C" { pub fn PyObject_GetIter(arg1: *mut PyObject) -> *mut PyObject; } -// Defined as this macro in Python limited API, but relies on -// non-limited PyTypeObject. Don't expose this since it cannot be used. -#[cfg(not(any(Py_LIMITED_API, PyPy)))] +// Before 3.8 PyIter_Check was defined in CPython as a macro, +// which uses Py_TYPE so cannot work on the limited ABI. +// +// From 3.10 onwards CPython removed the macro completely, +// so PyO3 only uses this on 3.7 unlimited API. +#[cfg(not(any(Py_3_8, Py_LIMITED_API, PyPy)))] #[inline] pub unsafe fn PyIter_Check(o: *mut PyObject) -> c_int { (match (*crate::Py_TYPE(o)).tp_iternext { @@ -105,7 +108,7 @@ pub unsafe fn PyIter_Check(o: *mut PyObject) -> c_int { } extern "C" { - #[cfg(any(all(Py_3_8, Py_LIMITED_API), PyPy))] + #[cfg(any(Py_3_8, PyPy))] #[cfg_attr(PyPy, link_name = "PyPyIter_Check")] pub fn PyIter_Check(obj: *mut PyObject) -> c_int; diff --git a/src/types/iterator.rs b/src/types/iterator.rs index 5b51c067..9d68ca4f 100644 --- a/src/types/iterator.rs +++ b/src/types/iterator.rs @@ -248,4 +248,78 @@ def fibonacci(target): assert_eq!(iter_ref.get_refcnt(), 2); }) } + + #[test] + #[cfg(any(not(Py_LIMITED_API), Py_3_8))] + #[cfg(feature = "macros")] + fn python_class_not_iterator() { + use crate::PyErr; + + #[crate::pyclass(crate = "crate")] + struct Downcaster { + failed: Option, + } + + #[crate::pymethods(crate = "crate")] + impl Downcaster { + fn downcast_iterator(&mut self, obj: &PyAny) { + self.failed = Some(obj.downcast::().unwrap_err().into()); + } + } + + // Regression test for 2913 + Python::with_gil(|py| { + let downcaster = Py::new(py, Downcaster { failed: None }).unwrap(); + crate::py_run!( + py, + downcaster, + r#" + from collections.abc import Sequence + + class MySequence(Sequence): + def __init__(self): + self._data = [1, 2, 3] + + def __getitem__(self, index): + return self._data[index] + + def __len__(self): + return len(self._data) + + downcaster.downcast_iterator(MySequence()) + "# + ); + + assert_eq!( + downcaster.borrow_mut(py).failed.take().unwrap().to_string(), + "TypeError: 'MySequence' object cannot be converted to 'Iterator'" + ); + }); + } + + #[test] + #[cfg(any(not(Py_LIMITED_API), Py_3_8))] + #[cfg(feature = "macros")] + fn python_class_iterator() { + #[crate::pyfunction(crate = "crate")] + fn assert_iterator(obj: &PyAny) { + assert!(obj.downcast::().is_ok()) + } + + // Regression test for 2913 + Python::with_gil(|py| { + let assert_iterator = crate::wrap_pyfunction!(assert_iterator, py).unwrap(); + crate::py_run!( + py, + assert_iterator, + r#" + class MyIter: + def __next__(self): + raise StopIteration + + assert_iterator(MyIter()) + "# + ); + }); + } }