From 1834ec4a33d468485e5ba50a3426c20e8d440a2f Mon Sep 17 00:00:00 2001 From: messense Date: Sun, 23 Jul 2017 22:43:56 +0800 Subject: [PATCH] Add PyDict::iter --- src/objects/dict.rs | 54 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/src/objects/dict.rs b/src/objects/dict.rs index 98db8483..33058f63 100644 --- a/src/objects/dict.rs +++ b/src/objects/dict.rs @@ -120,9 +120,6 @@ impl PyDict { /// Returns the list of (key, value) pairs in this dictionary. pub fn items_vec(&self) -> Vec<(PyObject, PyObject)> { - // Note that we don't provide an iterator because - // PyDict_Next() is unsafe to use when the dictionary might be changed - // by other python code. let mut vec = Vec::with_capacity(self.len()); unsafe { let mut pos = 0; @@ -135,6 +132,37 @@ impl PyDict { } vec } + + /// Returns a iterator of (key, value) pairs in this dictionary + /// Note that it's unsafe to use when the dictionary might be changed + /// by other python code. + #[inline] + pub fn iter(&self) -> PyDictIterator { + PyDictIterator { dict: self, pos: 0 } + } +} + +pub struct PyDictIterator<'a> { + dict: &'a PyDict, + pos: isize +} + +impl<'a> Iterator for PyDictIterator<'a> { + type Item = (&'a PyObjectRef, &'a PyObjectRef); + + #[inline] + fn next(&mut self) -> Option { + unsafe { + let mut key: *mut ffi::PyObject = mem::uninitialized(); + let mut value: *mut ffi::PyObject = mem::uninitialized(); + if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.pos, &mut key, &mut value) != 0 { + let py = self.dict.py(); + Some((py.cast_from_borrowed_ptr(key), py.cast_from_borrowed_ptr(value))) + } else { + None + } + } + } } impl ToPyObject for collections::HashMap @@ -382,6 +410,26 @@ mod test { assert_eq!(32 + 42 + 123, value_sum); } + #[test] + fn test_dict_iter() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let dict = PyDict::downcast_from(ob.as_ref(py)).unwrap(); + let mut key_sum = 0; + let mut value_sum = 0; + for (key, value) in dict.iter() { + key_sum += key.extract::().unwrap(); + value_sum += value.extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + assert_eq!(32 + 42 + 123, value_sum); + } + #[test] fn test_hashmap_to_python() { let gil = Python::acquire_gil();