pyo3/pytests/tests/test_comparisons.py

172 lines
3.4 KiB
Python

from typing import Type, Union
import pytest
from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe
from typing_extensions import Self
class PyEq:
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __ne__(self, other: Self) -> bool:
return self.x != other.x
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
def test_eq(ty: Type[Union[Eq, PyEq]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert a != c
assert b == a
assert b != c
with pytest.raises(TypeError):
assert a <= b
with pytest.raises(TypeError):
assert a >= b
with pytest.raises(TypeError):
assert a < c
with pytest.raises(TypeError):
assert c > a
class PyEqDefaultNe:
def __init__(self, x: int) -> None:
self.x = x
def __eq__(self, other: Self) -> bool:
return self.x == other.x
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert a != c
assert b == a
assert b != c
with pytest.raises(TypeError):
assert a <= b
with pytest.raises(TypeError):
assert a >= b
with pytest.raises(TypeError):
assert a < c
with pytest.raises(TypeError):
assert c > a
class PyOrdered:
def __init__(self, x: int) -> None:
self.x = x
def __lt__(self, other: Self) -> bool:
return self.x < other.x
def __le__(self, other: Self) -> bool:
return self.x <= other.x
def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __ne__(self, other: Self) -> bool:
return self.x != other.x
def __gt__(self, other: Self) -> bool:
return self.x >= other.x
def __ge__(self, other: Self) -> bool:
return self.x >= other.x
@pytest.mark.parametrize("ty", (Ordered, PyOrdered), ids=("rust", "python"))
def test_ordered(ty: Type[Union[Ordered, PyOrdered]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert a <= b
assert a >= b
assert a != c
assert a <= c
assert b == a
assert b <= a
assert b >= a
assert b != c
assert b <= c
assert c != a
assert c != b
assert c > a
assert c >= a
assert c > b
assert c >= b
class PyOrderedDefaultNe:
def __init__(self, x: int) -> None:
self.x = x
def __lt__(self, other: Self) -> bool:
return self.x < other.x
def __le__(self, other: Self) -> bool:
return self.x <= other.x
def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __gt__(self, other: Self) -> bool:
return self.x >= other.x
def __ge__(self, other: Self) -> bool:
return self.x >= other.x
@pytest.mark.parametrize(
"ty", (OrderedDefaultNe, PyOrderedDefaultNe), ids=("rust", "python")
)
def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]]):
a = ty(0)
b = ty(0)
c = ty(1)
assert a == b
assert a <= b
assert a >= b
assert a != c
assert a <= c
assert b == a
assert b <= a
assert b >= a
assert b != c
assert b <= c
assert c != a
assert c != b
assert c > a
assert c >= a
assert c > b
assert c >= b