From fcffcdfae5eba6c21a65c845a0212eefeec7713d Mon Sep 17 00:00:00 2001 From: Ohad Ravid Date: Sun, 27 Jun 2021 09:45:15 +0300 Subject: [PATCH] Added `size_hint` impls for `{PyDict,PyList,PySet,PyTuple}Iterator`s --- CHANGELOG.md | 1 + src/types/dict.rs | 31 +++++++++++++++++++++++++++++++ src/types/list.rs | 29 +++++++++++++++++++++++++++++ src/types/set.rs | 27 +++++++++++++++++++++++++++ src/types/tuple.rs | 16 ++++++++++++++++ 5 files changed, 104 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56da0781..3683f8da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -115,6 +115,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `serde` feature which provides implementations of `Serialize` and `Deserialize` for `Py`. [#1366](https://github.com/PyO3/pyo3/pull/1366) - Add FFI definition `_PyCFunctionFastWithKeywords` on Python 3.7 and up. [#1384](https://github.com/PyO3/pyo3/pull/1384) - Add `PyDateTime::new_with_fold()` method. [#1398](https://github.com/PyO3/pyo3/pull/1398) +- Add `size_hint` impls for `{PyDict,PyList,PySet,PyTuple}Iterator`s. [#1699](https://github.com/PyO3/pyo3/pull/1699) ### Changed diff --git a/src/types/dict.rs b/src/types/dict.rs index e58afccb..40cb1917 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -204,6 +204,15 @@ impl<'py> Iterator for PyDictIterator<'py> { } } } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.dict.len().unwrap_or_default(); + ( + len.saturating_sub(self.pos as usize), + Some(len.saturating_sub(self.pos as usize)), + ) + } } impl<'a> std::iter::IntoIterator for &'a PyDict { @@ -698,6 +707,28 @@ mod test { assert_eq!(32 + 42 + 123, value_sum); } + #[test] + fn test_iter_size_hint() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let dict = ::try_from(ob.as_ref(py)).unwrap(); + + let mut iter = dict.iter(); + assert_eq!(iter.size_hint(), (v.len(), Some(v.len()))); + iter.next(); + assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1))); + + // Exhust iterator. + while let Some(_) = iter.next() {} + + assert_eq!(iter.size_hint(), (0, Some(0))); + } + #[test] fn test_into_iter() { let gil = Python::acquire_gil(); diff --git a/src/types/list.rs b/src/types/list.rs index 5d981c1a..a078dd0b 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -158,6 +158,16 @@ impl<'a> Iterator for PyListIterator<'a> { None } } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.list.len(); + + ( + len.saturating_sub(self.index as usize), + Some(len.saturating_sub(self.index as usize)), + ) + } } impl<'a> std::iter::IntoIterator for &'a PyList { @@ -342,6 +352,25 @@ mod test { assert_eq!(idx, v.len()); } + #[test] + fn test_iter_size_hint() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let v = vec![2, 3, 5, 7]; + let ob = v.to_object(py); + let list = ::try_from(ob.as_ref(py)).unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.size_hint(), (v.len(), Some(v.len()))); + iter.next(); + assert_eq!(iter.size_hint(), (v.len() - 1, Some(v.len() - 1))); + + // Exhust iterator. + while let Some(_) = iter.next() {} + + assert_eq!(iter.size_hint(), (0, Some(0))); + } + #[test] fn test_into_iter() { let gil = Python::acquire_gil(); diff --git a/src/types/set.rs b/src/types/set.rs index dc524641..e96672cd 100644 --- a/src/types/set.rs +++ b/src/types/set.rs @@ -172,6 +172,15 @@ impl<'py> Iterator for PySetIterator<'py> { } } } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.set.len().unwrap_or_default(); + ( + len.saturating_sub(self.pos as usize), + Some(len.saturating_sub(self.pos as usize)), + ) + } } impl<'a> std::iter::IntoIterator for &'a PySet { @@ -505,6 +514,24 @@ mod test { } } + #[test] + fn test_set_iter_size_hint() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let set = PySet::new(py, &[1]).unwrap(); + + let mut iter = set.iter(); + + if cfg!(Py_LIMITED_API) { + assert_eq!(iter.size_hint(), (0, None)); + } else { + assert_eq!(iter.size_hint(), (1, Some(1))); + iter.next(); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + } + #[test] fn test_frozenset_new_and_len() { let gil = Python::acquire_gil(); diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 5aed29e8..f48b7ce0 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -131,6 +131,14 @@ impl<'a> Iterator for PyTupleIterator<'a> { None } } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + ( + self.length.saturating_sub(self.index as usize), + Some(self.length.saturating_sub(self.index as usize)), + ) + } } impl<'a> ExactSizeIterator for PyTupleIterator<'a> { @@ -346,9 +354,17 @@ mod test { let tuple = ::try_from(ob.as_ref(py)).unwrap(); assert_eq!(3, tuple.len()); let mut iter = tuple.iter(); + + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(1, iter.next().unwrap().extract().unwrap()); + assert_eq!(iter.size_hint(), (2, Some(2))); + assert_eq!(2, iter.next().unwrap().extract().unwrap()); + assert_eq!(iter.size_hint(), (1, Some(1))); + assert_eq!(3, iter.next().unwrap().extract().unwrap()); + assert_eq!(iter.size_hint(), (0, Some(0))); } #[test]