Make __r*__ methods work by slot fallback

This commit is contained in:
kngwyu 2020-03-28 17:19:36 +09:00
parent 77b4b9e67d
commit 970e393bb9
3 changed files with 486 additions and 21 deletions

View File

@ -127,6 +127,29 @@ macro_rules! py_binary_num_func {
}};
}
#[macro_export]
#[doc(hidden)]
macro_rules! py_binary_reverse_num_func {
($trait:ident, $class:ident :: $f:ident, $conv:expr) => {{
unsafe extern "C" fn wrap<T>(
lhs: *mut ffi::PyObject,
rhs: *mut ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
use $crate::ObjectProtocol;
let py = $crate::Python::assume_gil_acquired();
let _pool = $crate::GILPool::new(py);
// Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(rhs);
let arg = py.from_borrowed_ptr::<$crate::PyAny>(lhs);
call_ref_with_converter!(slf, $conv, py, $f, arg)
}
Some(wrap::<$class>)
}};
}
// NOTE(kngwyu): This macro is used only for inplace operations, so I used call_mut here.
#[macro_export]
#[doc(hidden)]
@ -254,6 +277,31 @@ macro_rules! py_ternary_num_func {
}};
}
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_reverse_num_func {
($trait:ident, $class:ident :: $f:ident, $conv:expr) => {{
unsafe extern "C" fn wrap<T>(
arg1: *mut $crate::ffi::PyObject,
arg2: *mut $crate::ffi::PyObject,
arg3: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
use $crate::ObjectProtocol;
let py = $crate::Python::assume_gil_acquired();
let _pool = $crate::GILPool::new(py);
// Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(arg2);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
let arg2 = py.from_borrowed_ptr::<$crate::PyAny>(arg3);
call_ref_with_converter!(slf, $conv, py, $f, arg1, arg2)
}
Some(wrap::<$class>)
}};
}
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_self_func {

View File

@ -145,7 +145,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __rpow__(&'p self, other: Self::Other) -> Self::Result
fn __rpow__(&'p self, other: Self::Other, module: Self::Modulo) -> Self::Result
where
Self: PyNumberRPowProtocol<'p>,
{
@ -647,22 +647,22 @@ where
{
fn tp_as_number() -> Option<ffi::PyNumberMethods> {
Some(ffi::PyNumberMethods {
nb_add: Self::nb_add(),
nb_subtract: Self::nb_subtract(),
nb_multiply: Self::nb_multiply(),
nb_add: Self::nb_add().or_else(Self::nb_add_fallback),
nb_subtract: Self::nb_subtract().or_else(Self::nb_sub_fallback),
nb_multiply: Self::nb_multiply().or_else(Self::nb_mul_fallback),
nb_remainder: Self::nb_remainder(),
nb_divmod: Self::nb_divmod(),
nb_power: Self::nb_power(),
nb_divmod: Self::nb_divmod().or_else(Self::nb_divmod_fallback),
nb_power: Self::nb_power().or_else(Self::nb_pow_fallback),
nb_negative: Self::nb_negative(),
nb_positive: Self::nb_positive(),
nb_absolute: Self::nb_absolute(),
nb_bool: <Self as PyObjectProtocolImpl>::nb_bool_fn(),
nb_invert: Self::nb_invert(),
nb_lshift: Self::nb_lshift(),
nb_rshift: Self::nb_rshift(),
nb_and: Self::nb_and(),
nb_xor: Self::nb_xor(),
nb_or: Self::nb_or(),
nb_lshift: Self::nb_lshift().or_else(Self::nb_lshift_fallback),
nb_rshift: Self::nb_rshift().or_else(Self::nb_rshift_fallback),
nb_and: Self::nb_and().or_else(Self::nb_and_fallback),
nb_xor: Self::nb_xor().or_else(Self::nb_xor_fallback),
nb_or: Self::nb_or().or_else(Self::nb_or_fallback),
nb_int: Self::nb_int(),
nb_reserved: ::std::ptr::null_mut(),
nb_float: Self::nb_float(),
@ -676,12 +676,12 @@ where
nb_inplace_and: Self::nb_inplace_and(),
nb_inplace_xor: Self::nb_inplace_xor(),
nb_inplace_or: Self::nb_inplace_or(),
nb_floor_divide: Self::nb_floor_divide(),
nb_true_divide: Self::nb_true_divide(),
nb_floor_divide: Self::nb_floor_divide().or_else(Self::nb_floordiv_fallback),
nb_true_divide: Self::nb_true_divide().or_else(Self::nb_truediv_fallback),
nb_inplace_floor_divide: Self::nb_inplace_floor_divide(),
nb_inplace_true_divide: Self::nb_inplace_true_divide(),
nb_index: Self::nb_index(),
nb_matrix_multiply: Self::nb_matrix_multiply(),
nb_matrix_multiply: Self::nb_matrix_multiply().or_else(Self::nb_matmul_fallback),
nb_inplace_matrix_multiply: Self::nb_inplace_matrix_multiply(),
})
}
@ -1407,14 +1407,72 @@ where
}
}
#[doc(hidden)]
pub trait PyNumberRSubProtocolImpl {
fn __rsub__() -> Option<PyMethodDef> {
// Fallback trait for nb_add
trait PyNumberAddFallback {
fn nb_add_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberAddFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_add_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<'p, T> PyNumberRSubProtocolImpl for T where T: PyNumberProtocol<'p> {}
impl<T> PyNumberAddFallback for T
where
T: for<'p> PyNumberRAddProtocol<'p>,
{
fn nb_add_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRAddProtocol,
T::__radd__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRSubProtocolImpl {
fn __rsub__() -> Option<PyMethodDef>;
}
impl<'p, T> PyNumberRSubProtocolImpl for T
where
T: PyNumberProtocol<'p>,
{
default fn __rsub__() -> Option<PyMethodDef> {
None
}
}
trait PyNumberSubFallback {
fn nb_sub_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberSubFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_sub_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberSubFallback for T
where
T: for<'p> PyNumberRSubProtocol<'p>,
{
fn nb_sub_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRSubProtocol,
T::__rsub__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRMulProtocolImpl {
@ -1430,6 +1488,32 @@ where
}
}
trait PyNumberMulFallback {
fn nb_mul_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberMulFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_mul_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberMulFallback for T
where
T: for<'p> PyNumberRMulProtocol<'p>,
{
fn nb_mul_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRMulProtocol,
T::__rmul__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRMatmulProtocolImpl {
fn __rmatmul__() -> Option<PyMethodDef>;
@ -1444,6 +1528,32 @@ where
}
}
trait PyNumberMatmulFallback {
fn nb_matmul_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberMatmulFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_matmul_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberMatmulFallback for T
where
T: for<'p> PyNumberRMatmulProtocol<'p>,
{
fn nb_matmul_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRMatmulProtocol,
T::__rmatmul__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRTruedivProtocolImpl {
fn __rtruediv__() -> Option<PyMethodDef>;
@ -1458,6 +1568,32 @@ where
}
}
trait PyNumberTruedivFallback {
fn nb_truediv_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberTruedivFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_truediv_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberTruedivFallback for T
where
T: for<'p> PyNumberRTruedivProtocol<'p>,
{
fn nb_truediv_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRTruedivProtocol,
T::__rtruediv__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRFloordivProtocolImpl {
fn __rfloordiv__() -> Option<PyMethodDef>;
@ -1472,6 +1608,32 @@ where
}
}
trait PyNumberFloordivFallback {
fn nb_floordiv_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberFloordivFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_floordiv_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberFloordivFallback for T
where
T: for<'p> PyNumberRFloordivProtocol<'p>,
{
fn nb_floordiv_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRFloordivProtocol,
T::__rfloordiv__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRModProtocolImpl {
fn __rmod__() -> Option<PyMethodDef>;
@ -1486,6 +1648,32 @@ where
}
}
trait PyNumberModFallback {
fn nb_mod_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberModFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_mod_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberModFallback for T
where
T: for<'p> PyNumberRModProtocol<'p>,
{
fn nb_mod_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRModProtocol,
T::__rmod__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRDivmodProtocolImpl {
fn __rdivmod__() -> Option<PyMethodDef>;
@ -1500,6 +1688,32 @@ where
}
}
trait PyNumberDivmodFallback {
fn nb_divmod_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberDivmodFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_divmod_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberDivmodFallback for T
where
T: for<'p> PyNumberRDivmodProtocol<'p>,
{
fn nb_divmod_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRDivmodProtocol,
T::__rdivmod__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRPowProtocolImpl {
fn __rpow__() -> Option<PyMethodDef>;
@ -1514,6 +1728,32 @@ where
}
}
trait PyNumberPowFallback {
fn nb_pow_fallback() -> Option<ffi::ternaryfunc>;
}
impl<'p, T> PyNumberPowFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_pow_fallback() -> Option<ffi::ternaryfunc> {
None
}
}
impl<T> PyNumberPowFallback for T
where
T: for<'p> PyNumberRPowProtocol<'p>,
{
fn nb_pow_fallback() -> Option<ffi::ternaryfunc> {
py_ternary_reverse_num_func!(
PyNumberRPowProtocol,
T::__rpow__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRLShiftProtocolImpl {
fn __rlshift__() -> Option<PyMethodDef>;
@ -1528,6 +1768,32 @@ where
}
}
trait PyNumberLShiftFallback {
fn nb_lshift_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberLShiftFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_lshift_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberLShiftFallback for T
where
T: for<'p> PyNumberRLShiftProtocol<'p>,
{
fn nb_lshift_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRLShiftProtocol,
T::__rlshift__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRRShiftProtocolImpl {
fn __rrshift__() -> Option<PyMethodDef>;
@ -1542,6 +1808,32 @@ where
}
}
trait PyNumberRRshiftFallback {
fn nb_rshift_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberRRshiftFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_rshift_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberRRshiftFallback for T
where
T: for<'p> PyNumberRRShiftProtocol<'p>,
{
fn nb_rshift_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRRShiftProtocol,
T::__rrshift__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRAndProtocolImpl {
fn __rand__() -> Option<PyMethodDef>;
@ -1556,6 +1848,32 @@ where
}
}
trait PyNumberAndFallback {
fn nb_and_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberAndFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_and_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberAndFallback for T
where
T: for<'p> PyNumberRAndProtocol<'p>,
{
fn nb_and_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRAndProtocol,
T::__rand__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberRXorProtocolImpl {
fn __rxor__() -> Option<PyMethodDef>;
@ -1570,6 +1888,32 @@ where
}
}
trait PyNumberXorFallback {
fn nb_xor_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberXorFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_xor_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberXorFallback for T
where
T: for<'p> PyNumberRXorProtocol<'p>,
{
fn nb_xor_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberRXorProtocol,
T::__rxor__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
#[doc(hidden)]
pub trait PyNumberROrProtocolImpl {
fn __ror__() -> Option<PyMethodDef>;
@ -1584,6 +1928,32 @@ where
}
}
trait PyNumberOrFallback {
fn nb_or_fallback() -> Option<ffi::binaryfunc>;
}
impl<'p, T> PyNumberOrFallback for T
where
T: PyNumberProtocol<'p>,
{
default fn nb_or_fallback() -> Option<ffi::binaryfunc> {
None
}
}
impl<T> PyNumberOrFallback for T
where
T: for<'p> PyNumberROrProtocol<'p>,
{
fn nb_or_fallback() -> Option<ffi::binaryfunc> {
py_binary_reverse_num_func!(
PyNumberROrProtocol,
T::__ror__,
PyObjectCallbackConverter::<T::Success>(std::marker::PhantomData)
)
}
}
trait PyNumberNegProtocolImpl {
fn nb_negative() -> Option<ffi::unaryfunc>;
}

53
tests/test_arithmetics.rs Normal file → Executable file
View File

@ -168,6 +168,7 @@ fn binary_arithmetic() {
let c = Py::new(py, BinaryArithmetic {}).unwrap();
py_run!(py, c, "assert c + c == 'BA + BA'");
py_run!(py, c, "assert c.__add__(c) == 'BA + BA'");
py_run!(py, c, "assert c + 1 == 'BA + 1'");
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert c - 1 == 'BA - 1'");
@ -195,6 +196,38 @@ impl PyNumberProtocol for RhsArithmetic {
fn __radd__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + RA", other))
}
fn __rsub__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - RA", other))
}
fn __rmul__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} * RA", other))
}
fn __rlshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} << RA", other))
}
fn __rrshift__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} >> RA", other))
}
fn __rand__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} & RA", other))
}
fn __rxor__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ^ RA", other))
}
fn __ror__(&self, other: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | RA", other))
}
fn __rpow__(&self, other: &PyAny, _module: &PyAny) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
}
}
#[test]
@ -204,9 +237,23 @@ fn rhs_arithmetic() {
let c = Py::new(py, RhsArithmetic {}).unwrap();
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
// TODO: commented out for now until reflected arithemtics gets fixed.
// see discussion here: https://github.com/PyO3/pyo3/pull/550
// py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
py_run!(py, c, "assert 1 << c == '1 << RA'");
py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'");
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
py_run!(py, c, "assert c.__rand__(1) == '1 & RA'");
py_run!(py, c, "assert 1 & c == '1 & RA'");
py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
py_run!(py, c, "assert c.__ror__(1) == '1 | RA'");
py_run!(py, c, "assert 1 | c == '1 | RA'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
}
#[pyclass]