From c21a84d9996f94c037cce78fd696caf45c56bdac Mon Sep 17 00:00:00 2001 From: messense Date: Thu, 7 Dec 2023 20:19:51 +0800 Subject: [PATCH] Add support for extracting Rust set types from `frozenset` --- newsfragments/3632.added.md | 1 + src/conversions/hashbrown.rs | 18 ++++++++++++++--- src/conversions/std/set.rs | 39 +++++++++++++++++++++++++++++------- 3 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 newsfragments/3632.added.md diff --git a/newsfragments/3632.added.md b/newsfragments/3632.added.md new file mode 100644 index 00000000..d9c954fa --- /dev/null +++ b/newsfragments/3632.added.md @@ -0,0 +1 @@ +Add support for extracting Rust set types from `frozenset`. diff --git a/src/conversions/hashbrown.rs b/src/conversions/hashbrown.rs index 6e20db39..62a7e87b 100644 --- a/src/conversions/hashbrown.rs +++ b/src/conversions/hashbrown.rs @@ -18,7 +18,7 @@ //! The required hashbrown version may vary based on the version of PyO3. use crate::{ types::set::new_from_iter, - types::{IntoPyDict, PyDict, PySet}, + types::{IntoPyDict, PyDict, PyFrozenSet, PySet}, FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject, }; use std::{cmp, hash}; @@ -93,8 +93,16 @@ where S: hash::BuildHasher + Default, { fn extract(ob: &'source PyAny) -> PyResult { - let set: &PySet = ob.downcast()?; - set.iter().map(K::extract).collect() + match ob.downcast::() { + Ok(set) => set.iter().map(K::extract).collect(), + Err(err) => { + if let Ok(frozen_set) = ob.downcast::() { + frozen_set.iter().map(K::extract).collect() + } else { + Err(PyErr::from(err)) + } + } + } } } @@ -173,6 +181,10 @@ mod tests { let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap(); let hash_set: hashbrown::HashSet = set.extract().unwrap(); assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); + + let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap(); + let hash_set: hashbrown::HashSet = set.extract().unwrap(); + assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); }); } diff --git a/src/conversions/std/set.rs b/src/conversions/std/set.rs index 6329f9d1..4221c1ff 100644 --- a/src/conversions/std/set.rs +++ b/src/conversions/std/set.rs @@ -3,8 +3,9 @@ use std::{cmp, collections, hash}; #[cfg(feature = "experimental-inspect")] use crate::inspect::types::TypeInfo; use crate::{ - types::set::new_from_iter, types::PySet, FromPyObject, IntoPy, PyAny, PyObject, PyResult, - Python, ToPyObject, + types::set::new_from_iter, + types::{PyFrozenSet, PySet}, + FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject, }; impl ToPyObject for collections::HashSet @@ -53,8 +54,16 @@ where S: hash::BuildHasher + Default, { fn extract(ob: &'source PyAny) -> PyResult { - let set: &PySet = ob.downcast()?; - set.iter().map(K::extract).collect() + match ob.downcast::() { + Ok(set) => set.iter().map(K::extract).collect(), + Err(err) => { + if let Ok(frozen_set) = ob.downcast::() { + frozen_set.iter().map(K::extract).collect() + } else { + Err(PyErr::from(err)) + } + } + } } #[cfg(feature = "experimental-inspect")] @@ -84,8 +93,16 @@ where K: FromPyObject<'source> + cmp::Ord, { fn extract(ob: &'source PyAny) -> PyResult { - let set: &PySet = ob.downcast()?; - set.iter().map(K::extract).collect() + match ob.downcast::() { + Ok(set) => set.iter().map(K::extract).collect(), + Err(err) => { + if let Ok(frozen_set) = ob.downcast::() { + frozen_set.iter().map(K::extract).collect() + } else { + Err(PyErr::from(err)) + } + } + } } #[cfg(feature = "experimental-inspect")] @@ -96,7 +113,7 @@ where #[cfg(test)] mod tests { - use super::PySet; + use super::{PyFrozenSet, PySet}; use crate::{IntoPy, PyObject, Python, ToPyObject}; use std::collections::{BTreeSet, HashSet}; @@ -106,6 +123,10 @@ mod tests { let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap(); let hash_set: HashSet = set.extract().unwrap(); assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); + + let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap(); + let hash_set: HashSet = set.extract().unwrap(); + assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); }); } @@ -115,6 +136,10 @@ mod tests { let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap(); let hash_set: BTreeSet = set.extract().unwrap(); assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); + + let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap(); + let hash_set: BTreeSet = set.extract().unwrap(); + assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect()); }); }