Apply __bool__ conversion only to numpy.bool_ to avoid false positives.
This commit is contained in:
parent
57002d2389
commit
4177dfcc81
|
@ -1 +1 @@
|
||||||
Values of type `bool` can now be extracted from all Python values defining a `__bool__` magic method.
|
Values of type `bool` can now be extracted from NumPy's `bool_`.
|
||||||
|
|
|
@ -12,9 +12,15 @@ fn get_type_full_name(obj: &PyAny) -> PyResult<Cow<'_, str>> {
|
||||||
obj.get_type().name()
|
obj.get_type().name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn accepts_bool(val: bool) -> bool {
|
||||||
|
val
|
||||||
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(issue_219, m)?)?;
|
m.add_function(wrap_pyfunction!(issue_219, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(get_type_full_name, m)?)?;
|
m.add_function(wrap_pyfunction!(get_type_full_name, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(accepts_bool, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,3 +54,13 @@ def test_type_full_name_includes_module():
|
||||||
numpy = pytest.importorskip("numpy")
|
numpy = pytest.importorskip("numpy")
|
||||||
|
|
||||||
assert pyo3_pytests.misc.get_type_full_name(numpy.bool_(True)) == "numpy.bool_"
|
assert pyo3_pytests.misc.get_type_full_name(numpy.bool_(True)) == "numpy.bool_"
|
||||||
|
|
||||||
|
|
||||||
|
def test_accepts_numpy_bool():
|
||||||
|
# binary numpy wheel not available on all platforms
|
||||||
|
numpy = pytest.importorskip("numpy")
|
||||||
|
|
||||||
|
assert pyo3_pytests.misc.accepts_bool(True) is True
|
||||||
|
assert pyo3_pytests.misc.accepts_bool(False) is False
|
||||||
|
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(True)) is True
|
||||||
|
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(False)) is False
|
||||||
|
|
|
@ -77,43 +77,52 @@ impl IntoPy<PyObject> for bool {
|
||||||
/// Fails with `TypeError` if the input is not a Python `bool`.
|
/// Fails with `TypeError` if the input is not a Python `bool`.
|
||||||
impl<'source> FromPyObject<'source> for bool {
|
impl<'source> FromPyObject<'source> for bool {
|
||||||
fn extract(obj: &'source PyAny) -> PyResult<Self> {
|
fn extract(obj: &'source PyAny) -> PyResult<Self> {
|
||||||
if let Ok(obj) = obj.downcast::<PyBool>() {
|
let err = match obj.downcast::<PyBool>() {
|
||||||
return Ok(obj.is_true());
|
Ok(obj) => return Ok(obj.is_true()),
|
||||||
}
|
Err(err) => err,
|
||||||
|
|
||||||
let missing_conversion = |obj: &PyAny| {
|
|
||||||
PyTypeError::new_err(format!(
|
|
||||||
"object of type '{}' does not define a '__bool__' conversion",
|
|
||||||
obj.get_type()
|
|
||||||
))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(not(any(Py_LIMITED_API, PyPy)))]
|
if obj
|
||||||
unsafe {
|
.get_type()
|
||||||
let ptr = obj.as_ptr();
|
.name()
|
||||||
|
.map_or(false, |name| name == "numpy.bool_")
|
||||||
|
{
|
||||||
|
let missing_conversion = |obj: &PyAny| {
|
||||||
|
PyTypeError::new_err(format!(
|
||||||
|
"object of type '{}' does not define a '__bool__' conversion",
|
||||||
|
obj.get_type()
|
||||||
|
))
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
|
#[cfg(not(any(Py_LIMITED_API, PyPy)))]
|
||||||
if let Some(nb_bool) = tp_as_number.nb_bool {
|
unsafe {
|
||||||
match (nb_bool)(ptr) {
|
let ptr = obj.as_ptr();
|
||||||
0 => return Ok(false),
|
|
||||||
1 => return Ok(true),
|
if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
|
||||||
_ => return Err(crate::PyErr::fetch(obj.py())),
|
if let Some(nb_bool) = tp_as_number.nb_bool {
|
||||||
|
match (nb_bool)(ptr) {
|
||||||
|
0 => return Ok(false),
|
||||||
|
1 => return Ok(true),
|
||||||
|
_ => return Err(crate::PyErr::fetch(obj.py())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Err(missing_conversion(obj));
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(missing_conversion(obj))
|
#[cfg(any(Py_LIMITED_API, PyPy))]
|
||||||
|
{
|
||||||
|
let meth = obj
|
||||||
|
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
|
||||||
|
.ok_or_else(|| missing_conversion(obj))?;
|
||||||
|
|
||||||
|
let obj = meth.call0()?.downcast::<PyBool>()?;
|
||||||
|
return Ok(obj.is_true());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(Py_LIMITED_API, PyPy))]
|
Err(err.into())
|
||||||
{
|
|
||||||
let meth = obj
|
|
||||||
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
|
|
||||||
.ok_or_else(|| missing_conversion(obj))?;
|
|
||||||
|
|
||||||
let obj = meth.call0()?.downcast::<PyBool>()?;
|
|
||||||
Ok(obj.is_true())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "experimental-inspect")]
|
#[cfg(feature = "experimental-inspect")]
|
||||||
|
@ -124,7 +133,7 @@ impl<'source> FromPyObject<'source> for bool {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::types::{PyAny, PyBool, PyModule};
|
use crate::types::{PyAny, PyBool};
|
||||||
use crate::Python;
|
use crate::Python;
|
||||||
use crate::ToPyObject;
|
use crate::ToPyObject;
|
||||||
|
|
||||||
|
@ -147,48 +156,4 @@ mod tests {
|
||||||
assert!(false.to_object(py).is(PyBool::new(py, false)));
|
assert!(false.to_object(py).is(PyBool::new(py, false)));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_magic_method() {
|
|
||||||
Python::with_gil(|py| {
|
|
||||||
let module = PyModule::from_code(
|
|
||||||
py,
|
|
||||||
r#"
|
|
||||||
class A:
|
|
||||||
def __bool__(self): return True
|
|
||||||
class B:
|
|
||||||
def __bool__(self): return "not a bool"
|
|
||||||
class C:
|
|
||||||
def __len__(self): return 23
|
|
||||||
class D:
|
|
||||||
pass
|
|
||||||
"#,
|
|
||||||
"test.py",
|
|
||||||
"test",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let a = module.getattr("A").unwrap().call0().unwrap();
|
|
||||||
assert!(a.extract::<bool>().unwrap());
|
|
||||||
|
|
||||||
let b = module.getattr("B").unwrap().call0().unwrap();
|
|
||||||
assert!(matches!(
|
|
||||||
&*b.extract::<bool>().unwrap_err().to_string(),
|
|
||||||
"TypeError: 'str' object cannot be converted to 'PyBool'"
|
|
||||||
| "TypeError: __bool__ should return bool, returned str"
|
|
||||||
));
|
|
||||||
|
|
||||||
let c = module.getattr("C").unwrap().call0().unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
c.extract::<bool>().unwrap_err().to_string(),
|
|
||||||
"TypeError: object of type '<class 'test.C'>' does not define a '__bool__' conversion",
|
|
||||||
);
|
|
||||||
|
|
||||||
let d = module.getattr("D").unwrap().call0().unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
d.extract::<bool>().unwrap_err().to_string(),
|
|
||||||
"TypeError: object of type '<class 'test.D'>' does not define a '__bool__' conversion",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue