Fix bug in default implementation of `__ne__`

This commit is contained in:
David Hewitt 2023-10-11 09:23:25 +01:00
parent b73c06948c
commit e1d4173827
4 changed files with 65 additions and 22 deletions

View File

@ -149,8 +149,8 @@ impl Number {
### Comparisons
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__`
and so on. It is also possible to support all six operations at once with `__richcmp__`.
This method will be called with a value of `CompareOp` depending on the operation.
```rust
@ -198,13 +198,10 @@ impl Number {
It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
the given `CompareOp`.
Alternatively, if you want to leave some operations unimplemented, you can
return `py.NotImplemented()` for some of the operations:
Alternatively, you can implement just equality using `__eq__`:
```rust
use pyo3::class::basic::CompareOp;
# use pyo3::prelude::*;
#
# #[pyclass]
@ -212,14 +209,20 @@ use pyo3::class::basic::CompareOp;
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}
}
# fn main() -> PyResult<()> {
# Python::with_gil(|py| {
# let x = PyCell::new(py, Number(4))?;
# let y = PyCell::new(py, Number(4))?;
# assert!(x.eq(y)?);
# assert!(!x.ne(y)?);
# Ok(())
# })
# }
```
### Truthyness

View File

@ -76,7 +76,8 @@ given signatures should be interpreted as follows:
- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
The `CompareOp` argument indicates the comparison operation being performed.
The `CompareOp` argument indicates the comparison operation being performed. You can use
[`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison.
_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._
@ -84,11 +85,32 @@ given signatures should be interpreted as follows:
<details>
<summary>Return type</summary>
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
If you want to leave some operations unimplemented, you can return `py.NotImplemented()`
for some of the operations:
```rust
use pyo3::class::basic::CompareOp;
# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
}
}
```
If the second argument `object` is not of the type specified in the
signature, the generated code will automatically `return NotImplemented`.
You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
to the requested comparison.
</details>
- `__getattr__(<self>, object) -> object`

View File

@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
c = ty(1)
assert a == b
assert not (a != b)
assert a != c
assert not (a == c)
assert b == a
assert not (a != b)
assert b != c
assert not (b == c)
with pytest.raises(TypeError):
assert a <= b
@ -49,17 +53,21 @@ class PyEqDefaultNe:
return self.x == other.x
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert not (a != b)
assert a != c
assert not (a == c)
assert b == a
assert not (a != b)
assert b != c
assert not (b == c)
with pytest.raises(TypeError):
assert a <= b
@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]
c = ty(1)
assert a == b
assert not (a != b)
assert a <= b
assert a >= b
assert a != c
assert not (a == c)
assert a <= c
assert b == a
assert not (b != a)
assert b <= a
assert b >= a
assert b != c
assert not (b == c)
assert b <= c
assert c != a
assert not (c == a)
assert c != b
assert not (c == b)
assert c > a
assert c >= a
assert c > b

View File

@ -6,6 +6,7 @@ use crate::{
internal_tricks::extract_c_string,
pycell::PyCellLayout,
pyclass_init::PyObjectInit,
types::PyBool,
Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python,
};
use std::{
@ -805,11 +806,14 @@ slot_fragment_trait! {
#[inline]
unsafe fn __ne__(
self,
_py: Python<'_>,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
py: Python<'_>,
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented()))
// By default `__ne__` will try `__eq__` and invert the result
let slf: &PyAny = py.from_borrowed_ptr(slf);
let other: &PyAny = py.from_borrowed_ptr(other);
slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr())
}
}