Don't raise TypeError
from generated equality method (#4287)
* Don't raise TypeError in derived equality method * Add newsfragment
This commit is contained in:
parent
2e2d4404a6
commit
7c2f5e80de
1
newsfragments/4287.changed.md
Normal file
1
newsfragments/4287.changed.md
Normal file
|
@ -0,0 +1 @@
|
|||
Return `NotImplemented` from generated equality method when comparing different types.
|
|
@ -1844,9 +1844,13 @@ fn pyclass_richcmp(
|
|||
op: #pyo3_path::pyclass::CompareOp
|
||||
) -> #pyo3_path::PyResult<#pyo3_path::PyObject> {
|
||||
let self_val = self;
|
||||
let other = &*#pyo3_path::types::PyAnyMethods::downcast::<Self>(other)?.borrow();
|
||||
match op {
|
||||
#arms
|
||||
if let Ok(other) = #pyo3_path::types::PyAnyMethods::downcast::<Self>(other) {
|
||||
let other = &*other.borrow();
|
||||
match op {
|
||||
#arms
|
||||
}
|
||||
} else {
|
||||
::std::result::Result::Ok(py.NotImplemented())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -34,6 +34,18 @@ impl EqDefaultNe {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyclass(eq)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct EqDerived(i64);
|
||||
|
||||
#[pymethods]
|
||||
impl EqDerived {
|
||||
#[new]
|
||||
fn new(value: i64) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Ordered(i64);
|
||||
|
||||
|
@ -104,6 +116,7 @@ impl OrderedDefaultNe {
|
|||
pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<Eq>()?;
|
||||
m.add_class::<EqDefaultNe>()?;
|
||||
m.add_class::<EqDerived>()?;
|
||||
m.add_class::<Ordered>()?;
|
||||
m.add_class::<OrderedDefaultNe>()?;
|
||||
Ok(())
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe
|
||||
from pyo3_pytests.comparisons import (
|
||||
Eq,
|
||||
EqDefaultNe,
|
||||
EqDerived,
|
||||
Ordered,
|
||||
OrderedDefaultNe,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
|
@ -9,15 +15,23 @@ class PyEq:
|
|||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
def __eq__(self, other: Self) -> bool:
|
||||
return self.x == other.x
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self.x == other.x
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Self) -> bool:
|
||||
return self.x != other.x
|
||||
if isinstance(other, self.__class__):
|
||||
return self.x != other.x
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
|
||||
def test_eq(ty: Type[Union[Eq, PyEq]]):
|
||||
@pytest.mark.parametrize(
|
||||
"ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python")
|
||||
)
|
||||
def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
|
||||
a = ty(0)
|
||||
b = ty(0)
|
||||
c = ty(1)
|
||||
|
@ -32,6 +46,13 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
|
|||
assert b != c
|
||||
assert not (b == c)
|
||||
|
||||
assert not a == 0
|
||||
assert a != 0
|
||||
assert not b == 0
|
||||
assert b != 1
|
||||
assert not c == 1
|
||||
assert c != 1
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert a <= b
|
||||
|
||||
|
|
Loading…
Reference in a new issue