Do not use PyObject_RichCompareBool to detect invalid comparison
This commit is contained in:
parent
7075827a03
commit
767bf8901c
|
@ -5,7 +5,7 @@ use crate::conversion::{
|
|||
use crate::err::{PyDowncastError, PyErr, PyResult};
|
||||
use crate::exceptions::TypeError;
|
||||
use crate::types::{PyDict, PyIterator, PyList, PyString, PyTuple, PyType};
|
||||
use crate::{err, ffi, Py, PyNativeType, Python};
|
||||
use crate::{err, ffi, Py, PyNativeType, PyObject, Python};
|
||||
use libc::c_int;
|
||||
use std::cell::UnsafeCell;
|
||||
use std::cmp::Ordering;
|
||||
|
@ -146,36 +146,25 @@ impl PyAny {
|
|||
where
|
||||
O: ToPyObject,
|
||||
{
|
||||
unsafe fn do_compare(
|
||||
py: Python,
|
||||
a: *mut ffi::PyObject,
|
||||
b: *mut ffi::PyObject,
|
||||
) -> PyResult<Ordering> {
|
||||
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_EQ);
|
||||
if result == 1 {
|
||||
return Ok(Ordering::Equal);
|
||||
} else if result < 0 {
|
||||
return Err(PyErr::fetch(py));
|
||||
let py = self.py();
|
||||
// Almost the same as ffi::PyObject_RichCompareBool, but this one doesn't try self == other.
|
||||
// See https://github.com/PyO3/pyo3/issues/985 for more.
|
||||
let do_compare = |other, op| unsafe {
|
||||
PyObject::from_owned_ptr_or_err(py, ffi::PyObject_RichCompare(self.as_ptr(), other, op))
|
||||
.and_then(|obj| obj.is_true(py))
|
||||
};
|
||||
other.with_borrowed_ptr(py, |other| {
|
||||
if do_compare(other, ffi::Py_EQ)? {
|
||||
Ok(Ordering::Equal)
|
||||
} else if do_compare(other, ffi::Py_LT)? {
|
||||
Ok(Ordering::Less)
|
||||
} else if do_compare(other, ffi::Py_GT)? {
|
||||
Ok(Ordering::Greater)
|
||||
} else {
|
||||
Err(TypeError::py_err(
|
||||
"PyAny::compare(): All comparisons returned false",
|
||||
))
|
||||
}
|
||||
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_LT);
|
||||
if result == 1 {
|
||||
return Ok(Ordering::Less);
|
||||
} else if result < 0 {
|
||||
return Err(PyErr::fetch(py));
|
||||
}
|
||||
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_GT);
|
||||
if result == 1 {
|
||||
return Ok(Ordering::Greater);
|
||||
} else if result < 0 {
|
||||
return Err(PyErr::fetch(py));
|
||||
}
|
||||
Err(TypeError::py_err(
|
||||
"PyAny::compare(): All comparisons returned false",
|
||||
))
|
||||
}
|
||||
|
||||
other.with_borrowed_ptr(self.py(), |other| unsafe {
|
||||
do_compare(self.py(), self.as_ptr(), other)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -515,4 +504,12 @@ mod test {
|
|||
let b = dir.into_iter().map(|x| x.extract::<String>().unwrap());
|
||||
assert!(a.eq(b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nan_eq() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let nan = py.eval("float('nan')", None, None).unwrap();
|
||||
assert!(nan.compare(nan).is_err());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue