Merge pull request #3506 from davidhewitt/default-ne

Fix bug in default implementation of `__ne__`
This commit is contained in:
David Hewitt 2023-10-11 10:04:55 +00:00 committed by GitHub
commit b03c4cb33c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 22 deletions

View file

@ -73,7 +73,7 @@ impl Number {
In the `__repr__`, we used a hard-coded class name. This is sometimes not ideal,
because if the class is subclassed in Python, we would like the repr to reflect
the subclass name. This is typically done in Python code by accessing
the subclass name. This is typically done in Python code by accessing
`self.__class__.__name__`. In order to be able to access the Python type information
*and* the Rust struct, we need to use a `PyCell` as the `self` argument.
@ -149,8 +149,8 @@ impl Number {
### Comparisons
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__`
and so on. It is also possible to support all six operations at once with `__richcmp__`.
This method will be called with a value of `CompareOp` depending on the operation.
```rust
@ -198,13 +198,10 @@ impl Number {
It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
the given `CompareOp`.
Alternatively, if you want to leave some operations unimplemented, you can
return `py.NotImplemented()` for some of the operations:
Alternatively, you can implement just equality using `__eq__`:
```rust
use pyo3::class::basic::CompareOp;
# use pyo3::prelude::*;
#
# #[pyclass]
@ -212,14 +209,20 @@ use pyo3::class::basic::CompareOp;
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}
}
# fn main() -> PyResult<()> {
# Python::with_gil(|py| {
# let x = PyCell::new(py, Number(4))?;
# let y = PyCell::new(py, Number(4))?;
# assert!(x.eq(y)?);
# assert!(!x.ne(y)?);
# Ok(())
# })
# }
```
### Truthyness

View file

@ -76,7 +76,8 @@ given signatures should be interpreted as follows:
- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
The `CompareOp` argument indicates the comparison operation being performed.
The `CompareOp` argument indicates the comparison operation being performed. You can use
[`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison.
_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._
@ -84,11 +85,32 @@ given signatures should be interpreted as follows:
<details>
<summary>Return type</summary>
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
If you want to leave some operations unimplemented, you can return `py.NotImplemented()`
for some of the operations:
```rust
use pyo3::class::basic::CompareOp;
# use pyo3::prelude::*;
#
# #[pyclass]
# struct Number(i32);
#
#[pymethods]
impl Number {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
}
}
```
If the second argument `object` is not of the type specified in the
signature, the generated code will automatically `return NotImplemented`.
You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
to the requested comparison.
</details>
- `__getattr__(<self>, object) -> object`

View file

@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
c = ty(1)
assert a == b
assert not (a != b)
assert a != c
assert not (a == c)
assert b == a
assert not (a != b)
assert b != c
assert not (b == c)
with pytest.raises(TypeError):
assert a <= b
@ -49,17 +53,21 @@ class PyEqDefaultNe:
return self.x == other.x
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert not (a != b)
assert a != c
assert not (a == c)
assert b == a
assert not (a != b)
assert b != c
assert not (b == c)
with pytest.raises(TypeError):
assert a <= b
@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]
c = ty(1)
assert a == b
assert not (a != b)
assert a <= b
assert a >= b
assert a != c
assert not (a == c)
assert a <= c
assert b == a
assert not (b != a)
assert b <= a
assert b >= a
assert b != c
assert not (b == c)
assert b <= c
assert c != a
assert not (c == a)
assert c != b
assert not (c == b)
assert c > a
assert c >= a
assert c > b

View file

@ -6,6 +6,7 @@ use crate::{
internal_tricks::extract_c_string,
pycell::PyCellLayout,
pyclass_init::PyObjectInit,
types::PyBool,
Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python,
};
use std::{
@ -805,11 +806,14 @@ slot_fragment_trait! {
#[inline]
unsafe fn __ne__(
self,
_py: Python<'_>,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
py: Python<'_>,
slf: *mut ffi::PyObject,
other: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented()))
// By default `__ne__` will try `__eq__` and invert the result
let slf: &PyAny = py.from_borrowed_ptr(slf);
let other: &PyAny = py.from_borrowed_ptr(other);
slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr())
}
}