Don't raise TypeError from generated equality method (#4287)

* Don't raise TypeError in derived equality method

* Add newsfragment
This commit is contained in:
jatoben 2024-06-25 22:41:42 -07:00 committed by GitHub
parent 2e2d4404a6
commit 7c2f5e80de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 48 additions and 9 deletions

View file

@ -0,0 +1 @@
Return `NotImplemented` from generated equality method when comparing different types.

View file

@ -1844,9 +1844,13 @@ fn pyclass_richcmp(
op: #pyo3_path::pyclass::CompareOp
) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
let self_val = self;
let other = &*#pyo3_path::types::PyAnyMethods::downcast::<Self>(other)?.borrow();
match op {
#arms
if let Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::<Self>(other) {
let other = &*other.borrow();
match op {
#arms
}
} else {
::std::result::Result::Ok(py.NotImplemented())
}
}
};

View file

@ -34,6 +34,18 @@ impl EqDefaultNe {
}
}
#[pyclass(eq)]
#[derive(PartialEq, Eq)]
struct EqDerived(i64);
#[pymethods]
impl EqDerived {
#[new]
fn new(value: i64) -> Self {
Self(value)
}
}
#[pyclass]
struct Ordered(i64);
@ -104,6 +116,7 @@ impl OrderedDefaultNe {
pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Eq>()?;
m.add_class::<EqDefaultNe>()?;
m.add_class::<EqDerived>()?;
m.add_class::<Ordered>()?;
m.add_class::<OrderedDefaultNe>()?;
Ok(())

View file

@ -1,7 +1,13 @@
from typing import Type, Union
import pytest
from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe
from pyo3_pytests.comparisons import (
Eq,
EqDefaultNe,
EqDerived,
Ordered,
OrderedDefaultNe,
)
from typing_extensions import Self
@ -9,15 +15,23 @@ class PyEq:
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.x == other.x
else:
return NotImplemented
def __ne__(self, other: Self) -> bool:
return self.x != other.x
if isinstance(other, self.__class__):
return self.x != other.x
else:
return NotImplemented
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
def test_eq(ty: Type[Union[Eq, PyEq]]):
@pytest.mark.parametrize(
"ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python")
)
def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
a = ty(0)
b = ty(0)
c = ty(1)
@ -32,6 +46,13 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
assert b != c
assert not (b == c)
assert not a == 0
assert a != 0
assert not b == 0
assert b != 1
assert not c == 1
assert c != 1
with pytest.raises(TypeError):
assert a <= b