More tests for RHS

This commit is contained in:
kngwyu 2020-08-16 18:55:36 +09:00
parent 554ccb9bee
commit 71a7a76227
3 changed files with 166 additions and 46 deletions

View File

@ -109,7 +109,7 @@ macro_rules! py_binary_reversed_num_func {
}};
}
macro_rules! py_binary_fallbacked_num_func {
macro_rules! py_binary_fallback_num_func {
($class:ident, $lop_trait: ident :: $lop: ident, $rop_trait: ident :: $rop: ident) => {{
unsafe extern "C" fn wrap<T>(
lhs: *mut ffi::PyObject,

View File

@ -589,7 +589,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberAddProtocol<'p> + for<'p> PyNumberRAddProtocol<'p>,
{
self.nb_add = py_binary_fallbacked_num_func!(
self.nb_add = py_binary_fallback_num_func!(
T,
PyNumberAddProtocol::__add__,
PyNumberRAddProtocol::__radd__
@ -611,7 +611,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberSubProtocol<'p> + for<'p> PyNumberRSubProtocol<'p>,
{
self.nb_subtract = py_binary_fallbacked_num_func!(
self.nb_subtract = py_binary_fallback_num_func!(
T,
PyNumberSubProtocol::__sub__,
PyNumberRSubProtocol::__rsub__
@ -633,7 +633,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberMulProtocol<'p> + for<'p> PyNumberRMulProtocol<'p>,
{
self.nb_multiply = py_binary_fallbacked_num_func!(
self.nb_multiply = py_binary_fallback_num_func!(
T,
PyNumberMulProtocol::__mul__,
PyNumberRMulProtocol::__rmul__
@ -661,7 +661,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberDivmodProtocol<'p> + for<'p> PyNumberRDivmodProtocol<'p>,
{
self.nb_divmod = py_binary_fallbacked_num_func!(
self.nb_divmod = py_binary_fallback_num_func!(
T,
PyNumberDivmodProtocol::__divmod__,
PyNumberRDivmodProtocol::__rdivmod__
@ -780,7 +780,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberLShiftProtocol<'p> + for<'p> PyNumberRLShiftProtocol<'p>,
{
self.nb_lshift = py_binary_fallbacked_num_func!(
self.nb_lshift = py_binary_fallback_num_func!(
T,
PyNumberLShiftProtocol::__lshift__,
PyNumberRLShiftProtocol::__rlshift__
@ -802,7 +802,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberRShiftProtocol<'p> + for<'p> PyNumberRRShiftProtocol<'p>,
{
self.nb_rshift = py_binary_fallbacked_num_func!(
self.nb_rshift = py_binary_fallback_num_func!(
T,
PyNumberRShiftProtocol::__rshift__,
PyNumberRRShiftProtocol::__rrshift__
@ -824,7 +824,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberAndProtocol<'p> + for<'p> PyNumberRAndProtocol<'p>,
{
self.nb_and = py_binary_fallbacked_num_func!(
self.nb_and = py_binary_fallback_num_func!(
T,
PyNumberAndProtocol::__and__,
PyNumberRAndProtocol::__rand__
@ -846,7 +846,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberXorProtocol<'p> + for<'p> PyNumberRXorProtocol<'p>,
{
self.nb_xor = py_binary_fallbacked_num_func!(
self.nb_xor = py_binary_fallback_num_func!(
T,
PyNumberXorProtocol::__xor__,
PyNumberRXorProtocol::__rxor__
@ -868,7 +868,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberOrProtocol<'p> + for<'p> PyNumberROrProtocol<'p>,
{
self.nb_or = py_binary_fallbacked_num_func!(
self.nb_or = py_binary_fallback_num_func!(
T,
PyNumberOrProtocol::__or__,
PyNumberROrProtocol::__ror__
@ -980,7 +980,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberFloordivProtocol<'p> + for<'p> PyNumberRFloordivProtocol<'p>,
{
self.nb_floor_divide = py_binary_fallbacked_num_func!(
self.nb_floor_divide = py_binary_fallback_num_func!(
T,
PyNumberFloordivProtocol::__floordiv__,
PyNumberRFloordivProtocol::__rfloordiv__
@ -1003,7 +1003,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberTruedivProtocol<'p> + for<'p> PyNumberRTruedivProtocol<'p>,
{
self.nb_true_divide = py_binary_fallbacked_num_func!(
self.nb_true_divide = py_binary_fallback_num_func!(
T,
PyNumberTruedivProtocol::__truediv__,
PyNumberRTruedivProtocol::__rtruediv__
@ -1046,7 +1046,7 @@ impl ffi::PyNumberMethods {
where
T: for<'p> PyNumberMatmulProtocol<'p> + for<'p> PyNumberRMatmulProtocol<'p>,
{
self.nb_matrix_multiply = py_binary_fallbacked_num_func!(
self.nb_matrix_multiply = py_binary_fallback_num_func!(
T,
PyNumberMatmulProtocol::__matmul__,
PyNumberRMatmulProtocol::__rmatmul__

View File

@ -280,22 +280,10 @@ fn rhs_arithmetic() {
}
#[pyclass]
struct LhsAndRhsArithmetic {}
struct LhsOverridesRhs {}
#[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic {
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
impl PyNumberProtocol for LhsOverridesRhs {
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
@ -304,83 +292,215 @@ impl PyNumberProtocol for LhsAndRhsArithmetic {
format!("{:?} - {:?}", lhs, rhs)
}
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> String {
fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}
fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}
fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}
fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}
fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}
fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<&PyAny>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}
fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}
fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
}
#[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic {
impl PyObjectProtocol for LhsOverridesRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}
#[test]
fn lhs_override_rhs() {
fn lhs_overrides_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap();
let c = PyCell::new(py, LhsOverridesRhs {}).unwrap();
// Not overrided
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'");
py_run!(py, c, "assert c.__rand__(1) == '1 & RA'");
py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'");
py_run!(py, c, "assert c.__ror__(1) == '1 | RA'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
// Overrided
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert 1 * c == '1 * BA'");
py_run!(py, c, "assert 1 << c == '1 << BA'");
py_run!(py, c, "assert 1 >> c == '1 >> BA'");
py_run!(py, c, "assert 1 & c == '1 & BA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'");
}
#[pyclass]
#[derive(Debug)]
struct Lhs2Rhs {}
struct LhsFellbackToRhs {}
#[pyproto]
impl PyNumberProtocol for Lhs2Rhs {
fn __add__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
impl PyNumberProtocol for LhsFellbackToRhs {
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
fn __sub__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
fn __sub__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}
fn __pow__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny, _mod: Option<usize>) -> String {
fn __mul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}
fn __lshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}
fn __rshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}
fn __and__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}
fn __xor__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}
fn __or__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}
fn __pow__(lhs: PyRef<Self>, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
fn __matmul__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
fn __matmul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} @ {:?}", lhs, rhs)
}
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
fn __rsub__(&self, rhs: &PyAny) -> String {
format!("{:?} - RA", rhs)
fn __rsub__(&self, other: &PyAny) -> String {
format!("{:?} - RA", other)
}
fn __rpow__(&self, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** RA", rhs)
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
fn __rmatmul__(&self, rhs: &PyAny) -> String {
format!("{:?} @ RA", rhs)
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}
fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}
fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other)
}
fn __rmatmul__(&self, other: &PyAny) -> String {
format!("{:?} @ RA", other)
}
}
#[pyproto]
impl PyObjectProtocol for Lhs2Rhs {
impl PyObjectProtocol for LhsFellbackToRhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}
#[test]
fn lhs_fallbacked_to_rhs() {
fn lhs_fellback_to_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = PyCell::new(py, Lhs2Rhs {}).unwrap();
let c = PyCell::new(py, LhsFellbackToRhs {}).unwrap();
// Fallbacked to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert 1 << c == '1 << RA'");
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
py_run!(py, c, "assert 1 & c == '1 & RA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
py_run!(py, c, "assert 1 | c == '1 | RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
}