Do not use PyObject_RichCompareBool to detect invalid comparison

This commit is contained in:
kngwyu 2020-06-20 13:34:52 +09:00
parent 7075827a03
commit 767bf8901c

View file

@ -5,7 +5,7 @@ use crate::conversion::{
use crate::err::{PyDowncastError, PyErr, PyResult};
use crate::exceptions::TypeError;
use crate::types::{PyDict, PyIterator, PyList, PyString, PyTuple, PyType};
use crate::{err, ffi, Py, PyNativeType, Python};
use crate::{err, ffi, Py, PyNativeType, PyObject, Python};
use libc::c_int;
use std::cell::UnsafeCell;
use std::cmp::Ordering;
@ -146,36 +146,25 @@ impl PyAny {
where
O: ToPyObject,
{
unsafe fn do_compare(
py: Python,
a: *mut ffi::PyObject,
b: *mut ffi::PyObject,
) -> PyResult<Ordering> {
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_EQ);
if result == 1 {
return Ok(Ordering::Equal);
} else if result < 0 {
return Err(PyErr::fetch(py));
let py = self.py();
// Almost the same as ffi::PyObject_RichCompareBool, but this one doesn't try self == other.
// See https://github.com/PyO3/pyo3/issues/985 for more.
let do_compare = |other, op| unsafe {
PyObject::from_owned_ptr_or_err(py, ffi::PyObject_RichCompare(self.as_ptr(), other, op))
.and_then(|obj| obj.is_true(py))
};
other.with_borrowed_ptr(py, |other| {
if do_compare(other, ffi::Py_EQ)? {
Ok(Ordering::Equal)
} else if do_compare(other, ffi::Py_LT)? {
Ok(Ordering::Less)
} else if do_compare(other, ffi::Py_GT)? {
Ok(Ordering::Greater)
} else {
Err(TypeError::py_err(
"PyAny::compare(): All comparisons returned false",
))
}
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_LT);
if result == 1 {
return Ok(Ordering::Less);
} else if result < 0 {
return Err(PyErr::fetch(py));
}
let result = ffi::PyObject_RichCompareBool(a, b, ffi::Py_GT);
if result == 1 {
return Ok(Ordering::Greater);
} else if result < 0 {
return Err(PyErr::fetch(py));
}
Err(TypeError::py_err(
"PyAny::compare(): All comparisons returned false",
))
}
other.with_borrowed_ptr(self.py(), |other| unsafe {
do_compare(self.py(), self.as_ptr(), other)
})
}
@ -515,4 +504,12 @@ mod test {
let b = dir.into_iter().map(|x| x.extract::<String>().unwrap());
assert!(a.eq(b));
}
#[test]
fn test_nan_eq() {
let gil = Python::acquire_gil();
let py = gil.python();
let nan = py.eval("float('nan')", None, None).unwrap();
assert!(nan.compare(nan).is_err());
}
}