From 7349513f5b2da3baae736156baa3ab60a22b7467 Mon Sep 17 00:00:00 2001 From: Azat Ibrakov Date: Wed, 20 Oct 2021 01:14:26 +0300 Subject: [PATCH] Add fallback for `__mod__` magic method (#1934) * Add fallback for `__mod__` magic method * Add 'CHANGELOG' entry * Complete tests --- CHANGELOG.md | 1 + pyo3-macros-backend/src/defs.rs | 2 ++ src/class/number.rs | 7 +++++++ tests/test_arithmetics_protos.rs | 22 ++++++++++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b3b5691..8755fcb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix incorrect linking to version-specific DLL instead of `python3.dll` when cross-compiling to Windows with `abi3`. [#1880](https://github.com/PyO3/pyo3/pull/1880) - Fix panic in generated `#[derive(FromPyObject)]` for enums. [#1888](https://github.com/PyO3/pyo3/pull/1888) - Fix cross-compiling to Python 3.7 builds with the "m" abi flag. [#1908](https://github.com/PyO3/pyo3/pull/1908) +- Fix `__mod__` magic method fallback to `__rmod__`. [#1934](https://github.com/PyO3/pyo3/pull/1934). ## [0.14.5] - 2021-09-05 diff --git a/pyo3-macros-backend/src/defs.rs b/pyo3-macros-backend/src/defs.rs index 0c92906d..bdda83b6 100644 --- a/pyo3-macros-backend/src/defs.rs +++ b/pyo3-macros-backend/src/defs.rs @@ -528,7 +528,9 @@ pub const NUM: Proto = Proto { SlotDef::new(&["__mul__", "__rmul__"], "Py_nb_multiply", "mul_rmul"), SlotDef::new(&["__mul__"], "Py_nb_multiply", "mul"), SlotDef::new(&["__rmul__"], "Py_nb_multiply", "rmul"), + SlotDef::new(&["__mod__", "__rmod__"], "Py_nb_remainder", "mod_rmod"), SlotDef::new(&["__mod__"], "Py_nb_remainder", "mod_"), + SlotDef::new(&["__rmod__"], "Py_nb_remainder", "rmod"), SlotDef::new( &["__divmod__", "__rdivmod__"], "Py_nb_divmod", diff --git a/src/class/number.rs b/src/class/number.rs index 20dc8c3d..903744eb 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -614,7 +614,14 @@ py_binary_fallback_num_func!( ); py_binary_num_func!(mul, PyNumberMulProtocol, T::__mul__); py_binary_reversed_num_func!(rmul, PyNumberRMulProtocol, T::__rmul__); +py_binary_fallback_num_func!( + mod_rmod, + T, + PyNumberModProtocol::__mod__, + PyNumberRModProtocol::__rmod__ +); py_binary_num_func!(mod_, PyNumberModProtocol, T::__mod__); +py_binary_reversed_num_func!(rmod, PyNumberRModProtocol, T::__rmod__); py_binary_fallback_num_func!( divmod_rdivmod, T, diff --git a/tests/test_arithmetics_protos.rs b/tests/test_arithmetics_protos.rs index dc64155e..9036d737 100644 --- a/tests/test_arithmetics_protos.rs +++ b/tests/test_arithmetics_protos.rs @@ -152,6 +152,10 @@ impl PyNumberProtocol for BinaryArithmetic { format!("{:?} - {:?}", lhs, rhs) } + fn __mod__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} % {:?}", lhs, rhs) + } + fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String { format!("{:?} * {:?}", lhs, rhs) } @@ -195,6 +199,8 @@ fn binary_arithmetic() { py_run!(py, c, "assert 1 - c == '1 - BA'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); py_run!(py, c, "assert 1 * c == '1 * BA'"); + py_run!(py, c, "assert c % 1 == 'BA % 1'"); + py_run!(py, c, "assert 1 % c == '1 % BA'"); py_run!(py, c, "assert c << 1 == 'BA << 1'"); py_run!(py, c, "assert 1 << c == '1 << BA'"); @@ -225,6 +231,10 @@ impl PyNumberProtocol for RhsArithmetic { format!("{:?} - RA", other) } + fn __rmod__(&self, other: &PyAny) -> String { + format!("{:?} % RA", other) + } + fn __rmul__(&self, other: &PyAny) -> String { format!("{:?} * RA", other) } @@ -264,6 +274,8 @@ fn rhs_arithmetic() { py_run!(py, c, "assert 1 + c == '1 + RA'"); py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert c.__rmod__(1) == '1 % RA'"); + py_run!(py, c, "assert 1 % c == '1 % RA'"); py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'"); py_run!(py, c, "assert 1 * c == '1 * RA'"); py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'"); @@ -299,6 +311,10 @@ impl PyNumberProtocol for LhsAndRhs { format!("{:?} - {:?}", lhs, rhs) } + fn __mod__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} % {:?}", lhs, rhs) + } + fn __mul__(lhs: PyRef, rhs: &PyAny) -> String { format!("{:?} * {:?}", lhs, rhs) } @@ -339,6 +355,10 @@ impl PyNumberProtocol for LhsAndRhs { format!("{:?} - RA", other) } + fn __rmod__(&self, other: &PyAny) -> String { + format!("{:?} % RA", other) + } + fn __rmul__(&self, other: &PyAny) -> String { format!("{:?} * RA", other) } @@ -388,6 +408,7 @@ fn lhs_fellback_to_rhs() { // If the light hand value is `LhsAndRhs`, LHS is used. py_run!(py, c, "assert c + 1 == 'LR + 1'"); py_run!(py, c, "assert c - 1 == 'LR - 1'"); + py_run!(py, c, "assert c % 1 == 'LR % 1'"); py_run!(py, c, "assert c * 1 == 'LR * 1'"); py_run!(py, c, "assert c << 1 == 'LR << 1'"); py_run!(py, c, "assert c >> 1 == 'LR >> 1'"); @@ -399,6 +420,7 @@ fn lhs_fellback_to_rhs() { // Fellback to RHS because of type mismatching py_run!(py, c, "assert 1 + c == '1 + RA'"); py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert 1 % c == '1 % RA'"); py_run!(py, c, "assert 1 * c == '1 * RA'"); py_run!(py, c, "assert 1 << c == '1 << RA'"); py_run!(py, c, "assert 1 >> c == '1 >> RA'");