From e1d4173827f1881411baf20cd379b1ce77b88acb Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 11 Oct 2023 09:23:25 +0100 Subject: [PATCH] Fix bug in default implementation of `__ne__` --- guide/src/class/object.md | 29 ++++++++++++++++------------- guide/src/class/protocols.md | 30 ++++++++++++++++++++++++++---- pytests/tests/test_comparisons.py | 16 +++++++++++++++- src/impl_/pyclass.rs | 12 ++++++++---- 4 files changed, 65 insertions(+), 22 deletions(-) diff --git a/guide/src/class/object.md b/guide/src/class/object.md index 7613e296..c6bf0483 100644 --- a/guide/src/class/object.md +++ b/guide/src/class/object.md @@ -73,7 +73,7 @@ impl Number { In the `__repr__`, we used a hard-coded class name. This is sometimes not ideal, because if the class is subclassed in Python, we would like the repr to reflect -the subclass name. This is typically done in Python code by accessing +the subclass name. This is typically done in Python code by accessing `self.__class__.__name__`. In order to be able to access the Python type information *and* the Rust struct, we need to use a `PyCell` as the `self` argument. @@ -149,8 +149,8 @@ impl Number { ### Comparisons -Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`, - `__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`. +PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__` +and so on. It is also possible to support all six operations at once with `__richcmp__`. This method will be called with a value of `CompareOp` depending on the operation. ```rust @@ -198,13 +198,10 @@ impl Number { It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches the given `CompareOp`. -Alternatively, if you want to leave some operations unimplemented, you can -return `py.NotImplemented()` for some of the operations: +Alternatively, you can implement just equality using `__eq__`: ```rust -use pyo3::class::basic::CompareOp; - # use pyo3::prelude::*; # # #[pyclass] @@ -212,14 +209,20 @@ use pyo3::class::basic::CompareOp; # #[pymethods] impl Number { - fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject { - match op { - CompareOp::Eq => (self.0 == other.0).into_py(py), - CompareOp::Ne => (self.0 != other.0).into_py(py), - _ => py.NotImplemented(), - } + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 } } + +# fn main() -> PyResult<()> { +# Python::with_gil(|py| { +# let x = PyCell::new(py, Number(4))?; +# let y = PyCell::new(py, Number(4))?; +# assert!(x.eq(y)?); +# assert!(!x.ne(y)?); +# Ok(()) +# }) +# } ``` ### Truthyness diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 891635c0..411978f0 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -76,7 +76,8 @@ given signatures should be interpreted as follows: - `__richcmp__(, object, pyo3::basic::CompareOp) -> object` Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method. - The `CompareOp` argument indicates the comparison operation being performed. + The `CompareOp` argument indicates the comparison operation being performed. You can use + [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison. _This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._ @@ -84,11 +85,32 @@ given signatures should be interpreted as follows:
Return type The return type will normally be `PyResult`, but any Python object can be returned. + + If you want to leave some operations unimplemented, you can return `py.NotImplemented()` + for some of the operations: + + ```rust + use pyo3::class::basic::CompareOp; + + # use pyo3::prelude::*; + # + # #[pyclass] + # struct Number(i32); + # + #[pymethods] + impl Number { + fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject { + match op { + CompareOp::Eq => (self.0 == other.0).into_py(py), + CompareOp::Ne => (self.0 != other.0).into_py(py), + _ => py.NotImplemented(), + } + } + } + ``` + If the second argument `object` is not of the type specified in the signature, the generated code will automatically `return NotImplemented`. - - You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result - to the requested comparison.
- `__getattr__(, object) -> object` diff --git a/pytests/tests/test_comparisons.py b/pytests/tests/test_comparisons.py index 54bb7aaf..508cdeb2 100644 --- a/pytests/tests/test_comparisons.py +++ b/pytests/tests/test_comparisons.py @@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]): c = ty(1) assert a == b + assert not (a != b) assert a != c + assert not (a == c) assert b == a + assert not (a != b) assert b != c + assert not (b == c) with pytest.raises(TypeError): assert a <= b @@ -49,17 +53,21 @@ class PyEqDefaultNe: return self.x == other.x -@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python")) +@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python")) def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]): a = ty(0) b = ty(0) c = ty(1) assert a == b + assert not (a != b) assert a != c + assert not (a == c) assert b == a + assert not (a != b) assert b != c + assert not (b == c) with pytest.raises(TypeError): assert a <= b @@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe] c = ty(1) assert a == b + assert not (a != b) assert a <= b assert a >= b assert a != c + assert not (a == c) assert a <= c assert b == a + assert not (b != a) assert b <= a assert b >= a assert b != c + assert not (b == c) assert b <= c assert c != a + assert not (c == a) assert c != b + assert not (c == b) assert c > a assert c >= a assert c > b diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 25e15fb3..3941dfcb 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -6,6 +6,7 @@ use crate::{ internal_tricks::extract_c_string, pycell::PyCellLayout, pyclass_init::PyObjectInit, + types::PyBool, Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python, }; use std::{ @@ -805,11 +806,14 @@ slot_fragment_trait! { #[inline] unsafe fn __ne__( self, - _py: Python<'_>, - _slf: *mut ffi::PyObject, - _other: *mut ffi::PyObject, + py: Python<'_>, + slf: *mut ffi::PyObject, + other: *mut ffi::PyObject, ) -> PyResult<*mut ffi::PyObject> { - Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + // By default `__ne__` will try `__eq__` and invert the result + let slf: &PyAny = py.from_borrowed_ptr(slf); + let other: &PyAny = py.from_borrowed_ptr(other); + slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr()) } }