Add a test that shows __add__ overrides __radd__

This commit is contained in:
kngwyu 2020-03-29 00:26:11 +09:00
parent 970e393bb9
commit a76bd7c4e3

View file

@ -34,7 +34,7 @@ fn unary_arithmetic() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = Py::new(py, UnaryArithmetic {}).unwrap();
let c = PyCell::new(py, UnaryArithmetic {}).unwrap();
py_run!(py, c, "assert -c == 'neg'");
py_run!(py, c, "assert +c == 'pos'");
py_run!(py, c, "assert abs(c) == 'abs'");
@ -112,7 +112,7 @@ fn inplace_operations() {
let py = gil.python();
let init = |value, code| {
let c = Py::new(py, InPlaceOperations { value }).unwrap();
let c = PyCell::new(py, InPlaceOperations { value }).unwrap();
py_run!(py, c, code);
};
@ -166,7 +166,7 @@ fn binary_arithmetic() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = Py::new(py, BinaryArithmetic {}).unwrap();
let c = PyCell::new(py, BinaryArithmetic {}).unwrap();
py_run!(py, c, "assert c + c == 'BA + BA'");
py_run!(py, c, "assert c.__add__(c) == 'BA + BA'");
py_run!(py, c, "assert c + 1 == 'BA + 1'");
@ -235,7 +235,7 @@ fn rhs_arithmetic() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = Py::new(py, RhsArithmetic {}).unwrap();
let c = PyCell::new(py, RhsArithmetic {}).unwrap();
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
@ -256,6 +256,47 @@ fn rhs_arithmetic() {
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
}
#[pyclass]
struct LhsAndRhsArithmetic {}
#[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
}
fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
}
fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
}
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
}
}
#[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic {
fn __repr__(&self) -> PyResult<&'static str> {
Ok("BA")
}
}
#[test]
fn lhs_override_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap();
py_run!(py, c, "assert c.__radd__(1) == '1 + BA'");
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
}
#[pyclass]
struct RichComparisons {}
@ -301,7 +342,7 @@ fn rich_comparisons() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = Py::new(py, RichComparisons {}).unwrap();
let c = PyCell::new(py, RichComparisons {}).unwrap();
py_run!(py, c, "assert (c < c) == 'RC < RC'");
py_run!(py, c, "assert (c < 1) == 'RC < 1'");
py_run!(py, c, "assert (1 < c) == 'RC > 1'");
@ -327,7 +368,7 @@ fn rich_comparisons_python_3_type_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let c2 = Py::new(py, RichComparisons2 {}).unwrap();
let c2 = PyCell::new(py, RichComparisons2 {}).unwrap();
py_expect_exception!(py, c2, "c2 < c2", TypeError);
py_expect_exception!(py, c2, "c2 < 1", TypeError);
py_expect_exception!(py, c2, "1 < c2", TypeError);