From 4177dfcc81bd5052c3d4c01088e9da00b2e5f32e Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 17 Dec 2023 14:45:05 +0100 Subject: [PATCH] Apply __bool__ conversion only to numpy.bool_ to avoid false positives. --- newsfragments/3638.changed.md | 2 +- pytests/src/misc.rs | 6 ++ pytests/tests/test_misc.py | 10 +++ src/types/boolobject.rs | 111 ++++++++++++---------------------- 4 files changed, 55 insertions(+), 74 deletions(-) diff --git a/newsfragments/3638.changed.md b/newsfragments/3638.changed.md index 83f0bd74..6bdafde8 100644 --- a/newsfragments/3638.changed.md +++ b/newsfragments/3638.changed.md @@ -1 +1 @@ -Values of type `bool` can now be extracted from all Python values defining a `__bool__` magic method. +Values of type `bool` can now be extracted from NumPy's `bool_`. diff --git a/pytests/src/misc.rs b/pytests/src/misc.rs index 69f3b75e..029e8b16 100644 --- a/pytests/src/misc.rs +++ b/pytests/src/misc.rs @@ -12,9 +12,15 @@ fn get_type_full_name(obj: &PyAny) -> PyResult> { obj.get_type().name() } +#[pyfunction] +fn accepts_bool(val: bool) -> bool { + val +} + #[pymodule] pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(issue_219, m)?)?; m.add_function(wrap_pyfunction!(get_type_full_name, m)?)?; + m.add_function(wrap_pyfunction!(accepts_bool, m)?)?; Ok(()) } diff --git a/pytests/tests/test_misc.py b/pytests/tests/test_misc.py index 537ee119..06b2ce73 100644 --- a/pytests/tests/test_misc.py +++ b/pytests/tests/test_misc.py @@ -54,3 +54,13 @@ def test_type_full_name_includes_module(): numpy = pytest.importorskip("numpy") assert pyo3_pytests.misc.get_type_full_name(numpy.bool_(True)) == "numpy.bool_" + + +def test_accepts_numpy_bool(): + # binary numpy wheel not available on all platforms + numpy = pytest.importorskip("numpy") + + assert pyo3_pytests.misc.accepts_bool(True) is True + assert pyo3_pytests.misc.accepts_bool(False) is False + assert pyo3_pytests.misc.accepts_bool(numpy.bool_(True)) is True + assert pyo3_pytests.misc.accepts_bool(numpy.bool_(False)) is False diff --git a/src/types/boolobject.rs b/src/types/boolobject.rs index 8609872b..71c91c8e 100644 --- a/src/types/boolobject.rs +++ b/src/types/boolobject.rs @@ -77,43 +77,52 @@ impl IntoPy for bool { /// Fails with `TypeError` if the input is not a Python `bool`. impl<'source> FromPyObject<'source> for bool { fn extract(obj: &'source PyAny) -> PyResult { - if let Ok(obj) = obj.downcast::() { - return Ok(obj.is_true()); - } - - let missing_conversion = |obj: &PyAny| { - PyTypeError::new_err(format!( - "object of type '{}' does not define a '__bool__' conversion", - obj.get_type() - )) + let err = match obj.downcast::() { + Ok(obj) => return Ok(obj.is_true()), + Err(err) => err, }; - #[cfg(not(any(Py_LIMITED_API, PyPy)))] - unsafe { - let ptr = obj.as_ptr(); + if obj + .get_type() + .name() + .map_or(false, |name| name == "numpy.bool_") + { + let missing_conversion = |obj: &PyAny| { + PyTypeError::new_err(format!( + "object of type '{}' does not define a '__bool__' conversion", + obj.get_type() + )) + }; - if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() { - if let Some(nb_bool) = tp_as_number.nb_bool { - match (nb_bool)(ptr) { - 0 => return Ok(false), - 1 => return Ok(true), - _ => return Err(crate::PyErr::fetch(obj.py())), + #[cfg(not(any(Py_LIMITED_API, PyPy)))] + unsafe { + let ptr = obj.as_ptr(); + + if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() { + if let Some(nb_bool) = tp_as_number.nb_bool { + match (nb_bool)(ptr) { + 0 => return Ok(false), + 1 => return Ok(true), + _ => return Err(crate::PyErr::fetch(obj.py())), + } } } + + return Err(missing_conversion(obj)); } - Err(missing_conversion(obj)) + #[cfg(any(Py_LIMITED_API, PyPy))] + { + let meth = obj + .lookup_special(crate::intern!(obj.py(), "__bool__"))? + .ok_or_else(|| missing_conversion(obj))?; + + let obj = meth.call0()?.downcast::()?; + return Ok(obj.is_true()); + } } - #[cfg(any(Py_LIMITED_API, PyPy))] - { - let meth = obj - .lookup_special(crate::intern!(obj.py(), "__bool__"))? - .ok_or_else(|| missing_conversion(obj))?; - - let obj = meth.call0()?.downcast::()?; - Ok(obj.is_true()) - } + Err(err.into()) } #[cfg(feature = "experimental-inspect")] @@ -124,7 +133,7 @@ impl<'source> FromPyObject<'source> for bool { #[cfg(test)] mod tests { - use crate::types::{PyAny, PyBool, PyModule}; + use crate::types::{PyAny, PyBool}; use crate::Python; use crate::ToPyObject; @@ -147,48 +156,4 @@ mod tests { assert!(false.to_object(py).is(PyBool::new(py, false))); }); } - - #[test] - fn test_magic_method() { - Python::with_gil(|py| { - let module = PyModule::from_code( - py, - r#" -class A: - def __bool__(self): return True -class B: - def __bool__(self): return "not a bool" -class C: - def __len__(self): return 23 -class D: - pass - "#, - "test.py", - "test", - ) - .unwrap(); - - let a = module.getattr("A").unwrap().call0().unwrap(); - assert!(a.extract::().unwrap()); - - let b = module.getattr("B").unwrap().call0().unwrap(); - assert!(matches!( - &*b.extract::().unwrap_err().to_string(), - "TypeError: 'str' object cannot be converted to 'PyBool'" - | "TypeError: __bool__ should return bool, returned str" - )); - - let c = module.getattr("C").unwrap().call0().unwrap(); - assert_eq!( - c.extract::().unwrap_err().to_string(), - "TypeError: object of type '' does not define a '__bool__' conversion", - ); - - let d = module.getattr("D").unwrap().call0().unwrap(); - assert_eq!( - d.extract::().unwrap_err().to_string(), - "TypeError: object of type '' does not define a '__bool__' conversion", - ); - }); - } }