Merge pull request #3506 from davidhewitt/default-ne
Fix bug in default implementation of `__ne__`
This commit is contained in:
commit
b03c4cb33c
|
@ -73,7 +73,7 @@ impl Number {
|
|||
|
||||
In the `__repr__`, we used a hard-coded class name. This is sometimes not ideal,
|
||||
because if the class is subclassed in Python, we would like the repr to reflect
|
||||
the subclass name. This is typically done in Python code by accessing
|
||||
the subclass name. This is typically done in Python code by accessing
|
||||
`self.__class__.__name__`. In order to be able to access the Python type information
|
||||
*and* the Rust struct, we need to use a `PyCell` as the `self` argument.
|
||||
|
||||
|
@ -149,8 +149,8 @@ impl Number {
|
|||
|
||||
### Comparisons
|
||||
|
||||
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
|
||||
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
|
||||
PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__`
|
||||
and so on. It is also possible to support all six operations at once with `__richcmp__`.
|
||||
This method will be called with a value of `CompareOp` depending on the operation.
|
||||
|
||||
```rust
|
||||
|
@ -198,13 +198,10 @@ impl Number {
|
|||
It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
|
||||
the given `CompareOp`.
|
||||
|
||||
Alternatively, if you want to leave some operations unimplemented, you can
|
||||
return `py.NotImplemented()` for some of the operations:
|
||||
Alternatively, you can implement just equality using `__eq__`:
|
||||
|
||||
|
||||
```rust
|
||||
use pyo3::class::basic::CompareOp;
|
||||
|
||||
# use pyo3::prelude::*;
|
||||
#
|
||||
# #[pyclass]
|
||||
|
@ -212,14 +209,20 @@ use pyo3::class::basic::CompareOp;
|
|||
#
|
||||
#[pymethods]
|
||||
impl Number {
|
||||
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
|
||||
match op {
|
||||
CompareOp::Eq => (self.0 == other.0).into_py(py),
|
||||
CompareOp::Ne => (self.0 != other.0).into_py(py),
|
||||
_ => py.NotImplemented(),
|
||||
}
|
||||
fn __eq__(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
# fn main() -> PyResult<()> {
|
||||
# Python::with_gil(|py| {
|
||||
# let x = PyCell::new(py, Number(4))?;
|
||||
# let y = PyCell::new(py, Number(4))?;
|
||||
# assert!(x.eq(y)?);
|
||||
# assert!(!x.ne(y)?);
|
||||
# Ok(())
|
||||
# })
|
||||
# }
|
||||
```
|
||||
|
||||
### Truthyness
|
||||
|
|
|
@ -76,7 +76,8 @@ given signatures should be interpreted as follows:
|
|||
- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`
|
||||
|
||||
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
|
||||
The `CompareOp` argument indicates the comparison operation being performed.
|
||||
The `CompareOp` argument indicates the comparison operation being performed. You can use
|
||||
[`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison.
|
||||
|
||||
_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._
|
||||
|
||||
|
@ -84,11 +85,32 @@ given signatures should be interpreted as follows:
|
|||
<details>
|
||||
<summary>Return type</summary>
|
||||
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
|
||||
|
||||
If you want to leave some operations unimplemented, you can return `py.NotImplemented()`
|
||||
for some of the operations:
|
||||
|
||||
```rust
|
||||
use pyo3::class::basic::CompareOp;
|
||||
|
||||
# use pyo3::prelude::*;
|
||||
#
|
||||
# #[pyclass]
|
||||
# struct Number(i32);
|
||||
#
|
||||
#[pymethods]
|
||||
impl Number {
|
||||
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
|
||||
match op {
|
||||
CompareOp::Eq => (self.0 == other.0).into_py(py),
|
||||
CompareOp::Ne => (self.0 != other.0).into_py(py),
|
||||
_ => py.NotImplemented(),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If the second argument `object` is not of the type specified in the
|
||||
signature, the generated code will automatically `return NotImplemented`.
|
||||
|
||||
You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
|
||||
to the requested comparison.
|
||||
</details>
|
||||
|
||||
- `__getattr__(<self>, object) -> object`
|
||||
|
|
|
@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
|
|||
c = ty(1)
|
||||
|
||||
assert a == b
|
||||
assert not (a != b)
|
||||
assert a != c
|
||||
assert not (a == c)
|
||||
|
||||
assert b == a
|
||||
assert not (a != b)
|
||||
assert b != c
|
||||
assert not (b == c)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert a <= b
|
||||
|
@ -49,17 +53,21 @@ class PyEqDefaultNe:
|
|||
return self.x == other.x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
|
||||
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
|
||||
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
|
||||
a = ty(0)
|
||||
b = ty(0)
|
||||
c = ty(1)
|
||||
|
||||
assert a == b
|
||||
assert not (a != b)
|
||||
assert a != c
|
||||
assert not (a == c)
|
||||
|
||||
assert b == a
|
||||
assert not (a != b)
|
||||
assert b != c
|
||||
assert not (b == c)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert a <= b
|
||||
|
@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]
|
|||
c = ty(1)
|
||||
|
||||
assert a == b
|
||||
assert not (a != b)
|
||||
assert a <= b
|
||||
assert a >= b
|
||||
assert a != c
|
||||
assert not (a == c)
|
||||
assert a <= c
|
||||
|
||||
assert b == a
|
||||
assert not (b != a)
|
||||
assert b <= a
|
||||
assert b >= a
|
||||
assert b != c
|
||||
assert not (b == c)
|
||||
assert b <= c
|
||||
|
||||
assert c != a
|
||||
assert not (c == a)
|
||||
assert c != b
|
||||
assert not (c == b)
|
||||
assert c > a
|
||||
assert c >= a
|
||||
assert c > b
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{
|
|||
internal_tricks::extract_c_string,
|
||||
pycell::PyCellLayout,
|
||||
pyclass_init::PyObjectInit,
|
||||
types::PyBool,
|
||||
Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python,
|
||||
};
|
||||
use std::{
|
||||
|
@ -805,11 +806,14 @@ slot_fragment_trait! {
|
|||
#[inline]
|
||||
unsafe fn __ne__(
|
||||
self,
|
||||
_py: Python<'_>,
|
||||
_slf: *mut ffi::PyObject,
|
||||
_other: *mut ffi::PyObject,
|
||||
py: Python<'_>,
|
||||
slf: *mut ffi::PyObject,
|
||||
other: *mut ffi::PyObject,
|
||||
) -> PyResult<*mut ffi::PyObject> {
|
||||
Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented()))
|
||||
// By default `__ne__` will try `__eq__` and invert the result
|
||||
let slf: &PyAny = py.from_borrowed_ptr(slf);
|
||||
let other: &PyAny = py.from_borrowed_ptr(other);
|
||||
slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue