diff --git a/newsfragments/3638.changed.md b/newsfragments/3638.changed.md new file mode 100644 index 00000000..6bdafde8 --- /dev/null +++ b/newsfragments/3638.changed.md @@ -0,0 +1 @@ +Values of type `bool` can now be extracted from NumPy's `bool_`. diff --git a/pytests/src/misc.rs b/pytests/src/misc.rs index 69f3b75e..029e8b16 100644 --- a/pytests/src/misc.rs +++ b/pytests/src/misc.rs @@ -12,9 +12,15 @@ fn get_type_full_name(obj: &PyAny) -> PyResult> { obj.get_type().name() } +#[pyfunction] +fn accepts_bool(val: bool) -> bool { + val +} + #[pymodule] pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(issue_219, m)?)?; m.add_function(wrap_pyfunction!(get_type_full_name, m)?)?; + m.add_function(wrap_pyfunction!(accepts_bool, m)?)?; Ok(()) } diff --git a/pytests/tests/test_misc.py b/pytests/tests/test_misc.py index 537ee119..06b2ce73 100644 --- a/pytests/tests/test_misc.py +++ b/pytests/tests/test_misc.py @@ -54,3 +54,13 @@ def test_type_full_name_includes_module(): numpy = pytest.importorskip("numpy") assert pyo3_pytests.misc.get_type_full_name(numpy.bool_(True)) == "numpy.bool_" + + +def test_accepts_numpy_bool(): + # binary numpy wheel not available on all platforms + numpy = pytest.importorskip("numpy") + + assert pyo3_pytests.misc.accepts_bool(True) is True + assert pyo3_pytests.misc.accepts_bool(False) is False + assert pyo3_pytests.misc.accepts_bool(numpy.bool_(True)) is True + assert pyo3_pytests.misc.accepts_bool(numpy.bool_(False)) is False diff --git a/src/types/boolobject.rs b/src/types/boolobject.rs index 7e75c424..71c91c8e 100644 --- a/src/types/boolobject.rs +++ b/src/types/boolobject.rs @@ -1,7 +1,8 @@ #[cfg(feature = "experimental-inspect")] use crate::inspect::types::TypeInfo; use crate::{ - ffi, instance::Py2, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, + exceptions::PyTypeError, ffi, instance::Py2, FromPyObject, IntoPy, PyAny, PyObject, PyResult, + Python, ToPyObject, }; /// Represents a Python `bool`. @@ -76,7 +77,52 @@ impl IntoPy for bool { /// Fails with `TypeError` if the input is not a Python `bool`. impl<'source> FromPyObject<'source> for bool { fn extract(obj: &'source PyAny) -> PyResult { - Ok(obj.downcast::()?.is_true()) + let err = match obj.downcast::() { + Ok(obj) => return Ok(obj.is_true()), + Err(err) => err, + }; + + if obj + .get_type() + .name() + .map_or(false, |name| name == "numpy.bool_") + { + let missing_conversion = |obj: &PyAny| { + PyTypeError::new_err(format!( + "object of type '{}' does not define a '__bool__' conversion", + obj.get_type() + )) + }; + + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + unsafe { + let ptr = obj.as_ptr(); + + if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() { + if let Some(nb_bool) = tp_as_number.nb_bool { + match (nb_bool)(ptr) { + 0 => return Ok(false), + 1 => return Ok(true), + _ => return Err(crate::PyErr::fetch(obj.py())), + } + } + } + + return Err(missing_conversion(obj)); + } + + #[cfg(any(Py_LIMITED_API, PyPy))] + { + let meth = obj + .lookup_special(crate::intern!(obj.py(), "__bool__"))? + .ok_or_else(|| missing_conversion(obj))?; + + let obj = meth.call0()?.downcast::()?; + return Ok(obj.is_true()); + } + } + + Err(err.into()) } #[cfg(feature = "experimental-inspect")]