diff --git a/CHANGELOG.md b/CHANGELOG.md index 661fcdb2..442a2365 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Improve lifetime elision in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093) - Fix python configuration detection when cross-compiling. [#1095](https://github.com/PyO3/pyo3/pull/1095) - Link against libpython on android with `extension-module` set. [#1095](https://github.com/PyO3/pyo3/pull/1095) +- Fix support for both `__add__` and `__radd__` in the `+` operator when both are defined in `PyNumberProtocol` + (and similar for all other reversible operators). [#1107](https://github.com/PyO3/pyo3/pull/1107) ## [0.11.1] - 2020-06-30 ### Added diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 6004a368..0bdfedf9 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -1,5 +1,6 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::proto_method::MethodProto; +use std::collections::HashSet; /// Predicates for `#[pyproto]`. pub struct Proto { @@ -14,7 +15,7 @@ pub struct Proto { /// All methods registered as normal methods like `#[pymethods]`. pub py_methods: &'static [PyMethod], /// All methods registered to the slot table. - pub slot_setters: &'static [SlotSetter], + slot_setters: &'static [SlotSetter], } impl Proto { @@ -30,6 +31,28 @@ impl Proto { { self.py_methods.iter().find(|m| query == m.name) } + // Since the order matters, we expose only the iterator instead of the slice. + pub(crate) fn setters( + &self, + mut implemented_protocols: HashSet, + ) -> impl Iterator { + self.slot_setters.iter().filter_map(move |setter| { + // If any required method is not implemented, we skip this setter. + if setter + .proto_names + .iter() + .any(|name| !implemented_protocols.contains(*name)) + { + return None; + } + // To use 'paired' setter in priority, we remove used protocols. + // For example, if set_add_radd is already used, we shouldn't use set_add and set_radd. + for name in setter.proto_names { + implemented_protocols.remove(*name); + } + Some(setter.set_function) + }) + } } /// Represents a method registered as a normal method like `#[pymethods]`. @@ -59,24 +82,19 @@ impl PyMethod { } /// Represents a setter used to register a method to the method table. -pub struct SlotSetter { +struct SlotSetter { /// Protocols necessary for invoking this setter. /// E.g., we need `__setattr__` and `__delattr__` for invoking `set_setdelitem`. pub proto_names: &'static [&'static str], /// The name of the setter called to the method table. pub set_function: &'static str, - /// Represents a set of setters disabled by this setter. - /// E.g., `set_setdelitem` have to disable `set_setitem` and `set_delitem`. - pub skipped_setters: &'static [&'static str], } impl SlotSetter { - const EMPTY_SETTERS: &'static [&'static str] = &[]; const fn new(names: &'static [&'static str], set_function: &'static str) -> Self { SlotSetter { proto_names: names, set_function, - skipped_setters: Self::EMPTY_SETTERS, } } } @@ -144,11 +162,7 @@ pub const OBJECT: Proto = Proto { SlotSetter::new(&["__hash__"], "set_hash"), SlotSetter::new(&["__getattr__"], "set_getattr"), SlotSetter::new(&["__richcmp__"], "set_richcompare"), - SlotSetter { - proto_names: &["__setattr__", "__delattr__"], - set_function: "set_setdelattr", - skipped_setters: &["set_setattr", "set_delattr"], - }, + SlotSetter::new(&["__setattr__", "__delattr__"], "set_setdelattr"), SlotSetter::new(&["__setattr__"], "set_setattr"), SlotSetter::new(&["__delattr__"], "set_delattr"), SlotSetter::new(&["__bool__"], "set_bool"), @@ -379,11 +393,7 @@ pub const MAPPING: Proto = Proto { slot_setters: &[ SlotSetter::new(&["__len__"], "set_length"), SlotSetter::new(&["__getitem__"], "set_getitem"), - SlotSetter { - proto_names: &["__setitem__", "__delitem__"], - set_function: "set_setdelitem", - skipped_setters: &["set_setitem", "set_delitem"], - }, + SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"), SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__delitem__"], "set_delitem"), ], @@ -446,11 +456,7 @@ pub const SEQ: Proto = Proto { SlotSetter::new(&["__concat__"], "set_concat"), SlotSetter::new(&["__repeat__"], "set_repeat"), SlotSetter::new(&["__getitem__"], "set_getitem"), - SlotSetter { - proto_names: &["__setitem__", "__delitem__"], - set_function: "set_setdelitem", - skipped_setters: &["set_setitem", "set_delitem"], - }, + SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"), SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__delitem__"], "set_delitem"), SlotSetter::new(&["__contains__"], "set_contains"), @@ -766,71 +772,40 @@ pub const NUM: Proto = Proto { ), ], slot_setters: &[ - SlotSetter { - proto_names: &["__add__"], - set_function: "set_add", - skipped_setters: &["set_radd"], - }, + SlotSetter::new(&["__add__", "__radd__"], "set_add_radd"), + SlotSetter::new(&["__add__"], "set_add"), SlotSetter::new(&["__radd__"], "set_radd"), - SlotSetter { - proto_names: &["__sub__"], - set_function: "set_sub", - skipped_setters: &["set_rsub"], - }, + SlotSetter::new(&["__sub__", "__rsub__"], "set_sub_rsub"), + SlotSetter::new(&["__sub__"], "set_sub"), SlotSetter::new(&["__rsub__"], "set_rsub"), - SlotSetter { - proto_names: &["__mul__"], - set_function: "set_mul", - skipped_setters: &["set_rmul"], - }, + SlotSetter::new(&["__mul__", "__rmul__"], "set_mul_rmul"), + SlotSetter::new(&["__mul__"], "set_mul"), SlotSetter::new(&["__rmul__"], "set_rmul"), SlotSetter::new(&["__mod__"], "set_mod"), - SlotSetter { - proto_names: &["__divmod__"], - set_function: "set_divmod", - skipped_setters: &["set_rdivmod"], - }, + SlotSetter::new(&["__divmod__", "__rdivmod__"], "set_divmod_rdivmod"), + SlotSetter::new(&["__divmod__"], "set_divmod"), SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), - SlotSetter { - proto_names: &["__pow__"], - set_function: "set_pow", - skipped_setters: &["set_rpow"], - }, + SlotSetter::new(&["__pow__", "__rpow__"], "set_pow_rpow"), + SlotSetter::new(&["__pow__"], "set_pow"), SlotSetter::new(&["__rpow__"], "set_rpow"), SlotSetter::new(&["__neg__"], "set_neg"), SlotSetter::new(&["__pos__"], "set_pos"), SlotSetter::new(&["__abs__"], "set_abs"), SlotSetter::new(&["__invert__"], "set_invert"), - SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), - SlotSetter { - proto_names: &["__lshift__"], - set_function: "set_lshift", - skipped_setters: &["set_rlshift"], - }, + SlotSetter::new(&["__lshift__", "__rlshift__"], "set_lshift_rlshift"), + SlotSetter::new(&["__lshift__"], "set_lshift"), SlotSetter::new(&["__rlshift__"], "set_rlshift"), - SlotSetter { - proto_names: &["__rshift__"], - set_function: "set_rshift", - skipped_setters: &["set_rrshift"], - }, + SlotSetter::new(&["__rshift__", "__rrshift__"], "set_rshift_rrshift"), + SlotSetter::new(&["__rshift__"], "set_rshift"), SlotSetter::new(&["__rrshift__"], "set_rrshift"), - SlotSetter { - proto_names: &["__and__"], - set_function: "set_and", - skipped_setters: &["set_rand"], - }, + SlotSetter::new(&["__and__", "__rand__"], "set_and_rand"), + SlotSetter::new(&["__and__"], "set_and"), SlotSetter::new(&["__rand__"], "set_rand"), - SlotSetter { - proto_names: &["__xor__"], - set_function: "set_xor", - skipped_setters: &["set_rxor"], - }, + SlotSetter::new(&["__xor__", "__rxor__"], "set_xor_rxor"), + SlotSetter::new(&["__xor__"], "set_xor"), SlotSetter::new(&["__rxor__"], "set_rxor"), - SlotSetter { - proto_names: &["__or__"], - set_function: "set_or", - skipped_setters: &["set_ror"], - }, + SlotSetter::new(&["__or__", "__ror__"], "set_or_ror"), + SlotSetter::new(&["__or__"], "set_or"), SlotSetter::new(&["__ror__"], "set_ror"), SlotSetter::new(&["__int__"], "set_int"), SlotSetter::new(&["__float__"], "set_float"), @@ -844,26 +819,17 @@ pub const NUM: Proto = Proto { SlotSetter::new(&["__iand__"], "set_iand"), SlotSetter::new(&["__ixor__"], "set_ixor"), SlotSetter::new(&["__ior__"], "set_ior"), - SlotSetter { - proto_names: &["__floordiv__"], - set_function: "set_floordiv", - skipped_setters: &["set_rfloordiv"], - }, + SlotSetter::new(&["__floordiv__", "__rfloordiv__"], "set_floordiv_rfloordiv"), + SlotSetter::new(&["__floordiv__"], "set_floordiv"), SlotSetter::new(&["__rfloordiv__"], "set_rfloordiv"), - SlotSetter { - proto_names: &["__truediv__"], - set_function: "set_truediv", - skipped_setters: &["set_rtruediv"], - }, + SlotSetter::new(&["__truediv__", "__rtruediv__"], "set_truediv_rtruediv"), + SlotSetter::new(&["__truediv__"], "set_truediv"), SlotSetter::new(&["__rtruediv__"], "set_rtruediv"), SlotSetter::new(&["__ifloordiv__"], "set_ifloordiv"), SlotSetter::new(&["__itruediv__"], "set_itruediv"), SlotSetter::new(&["__index__"], "set_index"), - SlotSetter { - proto_names: &["__matmul__"], - set_function: "set_matmul", - skipped_setters: &["set_rmatmul"], - }, + SlotSetter::new(&["__matmul__", "__rmatmul__"], "set_matmul_rmatmul"), + SlotSetter::new(&["__matmul__"], "set_matmul"), SlotSetter::new(&["__rmatmul__"], "set_rmatmul"), SlotSetter::new(&["__imatmul__"], "set_imatmul"), ], diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index 44fc2065..b1a19281 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -134,25 +134,11 @@ fn slot_initialization( ty: &syn::Type, proto: &defs::Proto, ) -> syn::Result { - // Some setters cannot coexist. - // E.g., if we have `__add__`, we need to skip `set_radd`. - let mut skipped_setters = Vec::new(); // Collect initializers let mut initializers: Vec = vec![]; - 'outer_loop: for m in proto.slot_setters { - if skipped_setters.contains(&m.set_function) { - continue; - } - for name in m.proto_names { - // If this `#[pyproto]` block doesn't provide all required methods, - // let's skip implementing this method. - if !method_names.contains(*name) { - continue 'outer_loop; - } - } - skipped_setters.extend_from_slice(m.skipped_setters); + for setter in proto.setters(method_names) { // Add slot methods to PyProtoRegistry - let set = syn::Ident::new(m.set_function, Span::call_site()); + let set = syn::Ident::new(setter, Span::call_site()); initializers.push(quote! { table.#set::<#ty>(); }); } if initializers.is_empty() { diff --git a/src/class/macros.rs b/src/class/macros.rs index 9349ed06..b72efdbd 100644 --- a/src/class/macros.rs +++ b/src/class/macros.rs @@ -1,7 +1,5 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -#[macro_export] -#[doc(hidden)] macro_rules! py_unary_func { ($trait: ident, $class:ident :: $f:ident, $call:ident, $ret_type: ty) => {{ unsafe extern "C" fn wrap(slf: *mut $crate::ffi::PyObject) -> $ret_type @@ -24,8 +22,6 @@ macro_rules! py_unary_func { }; } -#[macro_export] -#[doc(hidden)] macro_rules! py_unarys_func { ($trait:ident, $class:ident :: $f:ident) => {{ unsafe extern "C" fn wrap(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject @@ -45,16 +41,12 @@ macro_rules! py_unarys_func { }}; } -#[macro_export] -#[doc(hidden)] macro_rules! py_len_func { ($trait:ident, $class:ident :: $f:ident) => { py_unary_func!($trait, $class::$f, $crate::ffi::Py_ssize_t) }; } -#[macro_export] -#[doc(hidden)] macro_rules! py_binary_func { // Use call_ref! by default ($trait:ident, $class:ident :: $f:ident, $return:ty, $call:ident) => {{ @@ -78,8 +70,6 @@ macro_rules! py_binary_func { }; } -#[macro_export] -#[doc(hidden)] macro_rules! py_binary_num_func { ($trait:ident, $class:ident :: $f:ident) => {{ unsafe extern "C" fn wrap( @@ -99,8 +89,6 @@ macro_rules! py_binary_num_func { }}; } -#[macro_export] -#[doc(hidden)] macro_rules! py_binary_reversed_num_func { ($trait:ident, $class:ident :: $f:ident) => {{ unsafe extern "C" fn wrap( @@ -112,10 +100,37 @@ macro_rules! py_binary_reversed_num_func { { $crate::callback_body!(py, { // Swap lhs <-> rhs - let slf = py.from_borrowed_ptr::<$crate::PyCell>(rhs); - let arg = py.from_borrowed_ptr::<$crate::PyAny>(lhs); + let slf: &$crate::PyCell = extract_or_return_not_implemented!(py, rhs); + let arg = extract_or_return_not_implemented!(py, lhs); + $class::$f(&*slf.try_borrow()?, arg).convert(py) + }) + } + Some(wrap::<$class>) + }}; +} - $class::$f(&*slf.try_borrow()?, arg.extract()?).convert(py) +macro_rules! py_binary_fallback_num_func { + ($class:ident, $lop_trait: ident :: $lop: ident, $rop_trait: ident :: $rop: ident) => {{ + unsafe extern "C" fn wrap( + lhs: *mut ffi::PyObject, + rhs: *mut ffi::PyObject, + ) -> *mut $crate::ffi::PyObject + where + T: for<'p> $lop_trait<'p> + for<'p> $rop_trait<'p>, + { + $crate::callback_body!(py, { + let lhs = py.from_borrowed_ptr::<$crate::PyAny>(lhs); + let rhs = py.from_borrowed_ptr::<$crate::PyAny>(rhs); + // First, try the left hand method (e.g., __add__) + match (lhs.extract(), rhs.extract()) { + (Ok(l), Ok(r)) => $class::$lop(l, r).convert(py), + _ => { + // Next, try the right hand method (e.g., __radd__) + let slf: &$crate::PyCell = extract_or_return_not_implemented!(rhs); + let arg = extract_or_return_not_implemented!(lhs); + $class::$rop(&*slf.try_borrow()?, arg).convert(py) + } + } }) } Some(wrap::<$class>) @@ -123,8 +138,6 @@ macro_rules! py_binary_reversed_num_func { } // NOTE(kngwyu): This macro is used only for inplace operations, so I used call_mut here. -#[macro_export] -#[doc(hidden)] macro_rules! py_binary_self_func { ($trait:ident, $class:ident :: $f:ident) => {{ unsafe extern "C" fn wrap( @@ -146,8 +159,6 @@ macro_rules! py_binary_self_func { }}; } -#[macro_export] -#[doc(hidden)] macro_rules! py_ssizearg_func { // Use call_ref! by default ($trait:ident, $class:ident :: $f:ident) => { @@ -170,8 +181,6 @@ macro_rules! py_ssizearg_func { }}; } -#[macro_export] -#[doc(hidden)] macro_rules! py_ternarys_func { ($trait:ident, $class:ident :: $f:ident, $return_type:ty) => {{ unsafe extern "C" fn wrap( @@ -205,83 +214,6 @@ macro_rules! py_ternarys_func { }; } -#[macro_export] -#[doc(hidden)] -macro_rules! py_ternary_num_func { - ($trait:ident, $class:ident :: $f:ident) => {{ - unsafe extern "C" fn wrap( - arg1: *mut $crate::ffi::PyObject, - arg2: *mut $crate::ffi::PyObject, - arg3: *mut $crate::ffi::PyObject, - ) -> *mut $crate::ffi::PyObject - where - T: for<'p> $trait<'p>, - { - $crate::callback_body!(py, { - let arg1 = py - .from_borrowed_ptr::<$crate::types::PyAny>(arg1) - .extract()?; - let arg2 = extract_or_return_not_implemented!(py, arg2); - let arg3 = extract_or_return_not_implemented!(py, arg3); - $class::$f(arg1, arg2, arg3).convert(py) - }) - } - - Some(wrap::) - }}; -} - -#[macro_export] -#[doc(hidden)] -macro_rules! py_ternary_reversed_num_func { - ($trait:ident, $class:ident :: $f:ident) => {{ - unsafe extern "C" fn wrap( - arg1: *mut $crate::ffi::PyObject, - arg2: *mut $crate::ffi::PyObject, - arg3: *mut $crate::ffi::PyObject, - ) -> *mut $crate::ffi::PyObject - where - T: for<'p> $trait<'p>, - { - $crate::callback_body!(py, { - // Swap lhs <-> rhs - let slf = py.from_borrowed_ptr::<$crate::PyCell>(arg2); - let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1); - let arg2 = py.from_borrowed_ptr::<$crate::PyAny>(arg3); - - $class::$f(&*slf.try_borrow()?, arg1.extract()?, arg2.extract()?).convert(py) - }) - } - Some(wrap::<$class>) - }}; -} - -// NOTE(kngwyu): Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract arg2, -// so we ignore it. It's the same as what CPython does. -#[macro_export] -#[doc(hidden)] -macro_rules! py_dummy_ternary_self_func { - ($trait:ident, $class:ident :: $f:ident) => {{ - unsafe extern "C" fn wrap( - slf: *mut $crate::ffi::PyObject, - arg1: *mut $crate::ffi::PyObject, - _arg2: *mut $crate::ffi::PyObject, - ) -> *mut $crate::ffi::PyObject - where - T: for<'p> $trait<'p>, - { - $crate::callback_body!(py, { - let slf_cell = py.from_borrowed_ptr::<$crate::PyCell>(slf); - let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1); - call_operator_mut!(py, slf_cell, $f, arg1).convert(py)?; - ffi::Py_INCREF(slf); - Ok(slf) - }) - } - Some(wrap::<$class>) - }}; -} - macro_rules! py_func_set { ($trait_name:ident, $generic:ident, $fn_set:ident) => {{ unsafe extern "C" fn wrap<$generic>( @@ -370,6 +302,16 @@ macro_rules! py_func_set_del { } macro_rules! extract_or_return_not_implemented { + ($arg: ident) => { + match $arg.extract() { + Ok(value) => value, + Err(_) => { + let res = $crate::ffi::Py_NotImplemented(); + ffi::Py_INCREF(res); + return Ok(res); + } + } + }; ($py: ident, $arg: ident) => { match $py .from_borrowed_ptr::<$crate::types::PyAny>($arg) diff --git a/src/class/number.rs b/src/class/number.rs index 67ec984d..149a8b98 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -585,6 +585,16 @@ impl ffi::PyNumberMethods { nm.nb_bool = Some(nb_bool); Box::into_raw(Box::new(nm)) } + pub fn set_add_radd(&mut self) + where + T: for<'p> PyNumberAddProtocol<'p> + for<'p> PyNumberRAddProtocol<'p>, + { + self.nb_add = py_binary_fallback_num_func!( + T, + PyNumberAddProtocol::__add__, + PyNumberRAddProtocol::__radd__ + ); + } pub fn set_add(&mut self) where T: for<'p> PyNumberAddProtocol<'p>, @@ -597,6 +607,16 @@ impl ffi::PyNumberMethods { { self.nb_add = py_binary_reversed_num_func!(PyNumberRAddProtocol, T::__radd__); } + pub fn set_sub_rsub(&mut self) + where + T: for<'p> PyNumberSubProtocol<'p> + for<'p> PyNumberRSubProtocol<'p>, + { + self.nb_subtract = py_binary_fallback_num_func!( + T, + PyNumberSubProtocol::__sub__, + PyNumberRSubProtocol::__rsub__ + ); + } pub fn set_sub(&mut self) where T: for<'p> PyNumberSubProtocol<'p>, @@ -609,6 +629,16 @@ impl ffi::PyNumberMethods { { self.nb_subtract = py_binary_reversed_num_func!(PyNumberRSubProtocol, T::__rsub__); } + pub fn set_mul_rmul(&mut self) + where + T: for<'p> PyNumberMulProtocol<'p> + for<'p> PyNumberRMulProtocol<'p>, + { + self.nb_multiply = py_binary_fallback_num_func!( + T, + PyNumberMulProtocol::__mul__, + PyNumberRMulProtocol::__rmul__ + ); + } pub fn set_mul(&mut self) where T: for<'p> PyNumberMulProtocol<'p>, @@ -627,6 +657,16 @@ impl ffi::PyNumberMethods { { self.nb_remainder = py_binary_num_func!(PyNumberModProtocol, T::__mod__); } + pub fn set_divmod_rdivmod(&mut self) + where + T: for<'p> PyNumberDivmodProtocol<'p> + for<'p> PyNumberRDivmodProtocol<'p>, + { + self.nb_divmod = py_binary_fallback_num_func!( + T, + PyNumberDivmodProtocol::__divmod__, + PyNumberRDivmodProtocol::__rdivmod__ + ); + } pub fn set_divmod(&mut self) where T: for<'p> PyNumberDivmodProtocol<'p>, @@ -639,17 +679,78 @@ impl ffi::PyNumberMethods { { self.nb_divmod = py_binary_reversed_num_func!(PyNumberRDivmodProtocol, T::__rdivmod__); } + pub fn set_pow_rpow(&mut self) + where + T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>, + { + unsafe extern "C" fn wrap_pow_and_rpow( + lhs: *mut crate::ffi::PyObject, + rhs: *mut crate::ffi::PyObject, + modulo: *mut crate::ffi::PyObject, + ) -> *mut crate::ffi::PyObject + where + T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>, + { + crate::callback_body!(py, { + let lhs = py.from_borrowed_ptr::(lhs); + let rhs = py.from_borrowed_ptr::(rhs); + let modulo = py.from_borrowed_ptr::(modulo); + // First, try __pow__ + match (lhs.extract(), rhs.extract(), modulo.extract()) { + (Ok(l), Ok(r), Ok(m)) => T::__pow__(l, r, m).convert(py), + _ => { + // Then try __rpow__ + let slf: &crate::PyCell = extract_or_return_not_implemented!(rhs); + let arg = extract_or_return_not_implemented!(lhs); + let modulo = extract_or_return_not_implemented!(modulo); + slf.try_borrow()?.__rpow__(arg, modulo).convert(py) + } + } + }) + } + self.nb_power = Some(wrap_pow_and_rpow::); + } pub fn set_pow(&mut self) where T: for<'p> PyNumberPowProtocol<'p>, { - self.nb_power = py_ternary_num_func!(PyNumberPowProtocol, T::__pow__); + unsafe extern "C" fn wrap_pow( + lhs: *mut crate::ffi::PyObject, + rhs: *mut crate::ffi::PyObject, + modulo: *mut crate::ffi::PyObject, + ) -> *mut crate::ffi::PyObject + where + T: for<'p> PyNumberPowProtocol<'p>, + { + crate::callback_body!(py, { + let lhs = extract_or_return_not_implemented!(py, lhs); + let rhs = extract_or_return_not_implemented!(py, rhs); + let modulo = extract_or_return_not_implemented!(py, modulo); + T::__pow__(lhs, rhs, modulo).convert(py) + }) + } + self.nb_power = Some(wrap_pow::); } pub fn set_rpow(&mut self) where T: for<'p> PyNumberRPowProtocol<'p>, { - self.nb_power = py_ternary_reversed_num_func!(PyNumberRPowProtocol, T::__rpow__); + unsafe extern "C" fn wrap_rpow( + arg: *mut crate::ffi::PyObject, + slf: *mut crate::ffi::PyObject, + modulo: *mut crate::ffi::PyObject, + ) -> *mut crate::ffi::PyObject + where + T: for<'p> PyNumberRPowProtocol<'p>, + { + crate::callback_body!(py, { + let slf: &crate::PyCell = extract_or_return_not_implemented!(py, slf); + let arg = extract_or_return_not_implemented!(py, arg); + let modulo = extract_or_return_not_implemented!(py, modulo); + slf.try_borrow()?.__rpow__(arg, modulo).convert(py) + }) + } + self.nb_power = Some(wrap_rpow::); } pub fn set_neg(&mut self) where @@ -675,6 +776,16 @@ impl ffi::PyNumberMethods { { self.nb_invert = py_unary_func!(PyNumberInvertProtocol, T::__invert__); } + pub fn set_lshift_rlshift(&mut self) + where + T: for<'p> PyNumberLShiftProtocol<'p> + for<'p> PyNumberRLShiftProtocol<'p>, + { + self.nb_lshift = py_binary_fallback_num_func!( + T, + PyNumberLShiftProtocol::__lshift__, + PyNumberRLShiftProtocol::__rlshift__ + ); + } pub fn set_lshift(&mut self) where T: for<'p> PyNumberLShiftProtocol<'p>, @@ -687,6 +798,16 @@ impl ffi::PyNumberMethods { { self.nb_lshift = py_binary_reversed_num_func!(PyNumberRLShiftProtocol, T::__rlshift__); } + pub fn set_rshift_rrshift(&mut self) + where + T: for<'p> PyNumberRShiftProtocol<'p> + for<'p> PyNumberRRShiftProtocol<'p>, + { + self.nb_rshift = py_binary_fallback_num_func!( + T, + PyNumberRShiftProtocol::__rshift__, + PyNumberRRShiftProtocol::__rrshift__ + ); + } pub fn set_rshift(&mut self) where T: for<'p> PyNumberRShiftProtocol<'p>, @@ -699,6 +820,16 @@ impl ffi::PyNumberMethods { { self.nb_rshift = py_binary_reversed_num_func!(PyNumberRRShiftProtocol, T::__rrshift__); } + pub fn set_and_rand(&mut self) + where + T: for<'p> PyNumberAndProtocol<'p> + for<'p> PyNumberRAndProtocol<'p>, + { + self.nb_and = py_binary_fallback_num_func!( + T, + PyNumberAndProtocol::__and__, + PyNumberRAndProtocol::__rand__ + ); + } pub fn set_and(&mut self) where T: for<'p> PyNumberAndProtocol<'p>, @@ -711,6 +842,16 @@ impl ffi::PyNumberMethods { { self.nb_and = py_binary_reversed_num_func!(PyNumberRAndProtocol, T::__rand__); } + pub fn set_xor_rxor(&mut self) + where + T: for<'p> PyNumberXorProtocol<'p> + for<'p> PyNumberRXorProtocol<'p>, + { + self.nb_xor = py_binary_fallback_num_func!( + T, + PyNumberXorProtocol::__xor__, + PyNumberRXorProtocol::__rxor__ + ); + } pub fn set_xor(&mut self) where T: for<'p> PyNumberXorProtocol<'p>, @@ -723,6 +864,16 @@ impl ffi::PyNumberMethods { { self.nb_xor = py_binary_reversed_num_func!(PyNumberRXorProtocol, T::__rxor__); } + pub fn set_or_ror(&mut self) + where + T: for<'p> PyNumberOrProtocol<'p> + for<'p> PyNumberROrProtocol<'p>, + { + self.nb_or = py_binary_fallback_num_func!( + T, + PyNumberOrProtocol::__or__, + PyNumberROrProtocol::__ror__ + ); + } pub fn set_or(&mut self) where T: for<'p> PyNumberOrProtocol<'p>, @@ -775,7 +926,25 @@ impl ffi::PyNumberMethods { where T: for<'p> PyNumberIPowProtocol<'p>, { - self.nb_inplace_power = py_dummy_ternary_self_func!(PyNumberIPowProtocol, T::__ipow__) + // NOTE: Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract, + // so we ignore it. It's the same as what CPython does. + unsafe extern "C" fn wrap_ipow( + slf: *mut crate::ffi::PyObject, + other: *mut crate::ffi::PyObject, + _modulo: *mut crate::ffi::PyObject, + ) -> *mut crate::ffi::PyObject + where + T: for<'p> PyNumberIPowProtocol<'p>, + { + crate::callback_body!(py, { + let slf_cell = py.from_borrowed_ptr::>(slf); + let other = py.from_borrowed_ptr::(other); + call_operator_mut!(py, slf_cell, __ipow__, other).convert(py)?; + ffi::Py_INCREF(slf); + Ok(slf) + }) + } + self.nb_inplace_power = Some(wrap_ipow::); } pub fn set_ilshift(&mut self) where @@ -807,6 +976,16 @@ impl ffi::PyNumberMethods { { self.nb_inplace_or = py_binary_self_func!(PyNumberIOrProtocol, T::__ior__); } + pub fn set_floordiv_rfloordiv(&mut self) + where + T: for<'p> PyNumberFloordivProtocol<'p> + for<'p> PyNumberRFloordivProtocol<'p>, + { + self.nb_floor_divide = py_binary_fallback_num_func!( + T, + PyNumberFloordivProtocol::__floordiv__, + PyNumberRFloordivProtocol::__rfloordiv__ + ); + } pub fn set_floordiv(&mut self) where T: for<'p> PyNumberFloordivProtocol<'p>, @@ -820,6 +999,16 @@ impl ffi::PyNumberMethods { self.nb_floor_divide = py_binary_reversed_num_func!(PyNumberRFloordivProtocol, T::__rfloordiv__); } + pub fn set_truediv_rtruediv(&mut self) + where + T: for<'p> PyNumberTruedivProtocol<'p> + for<'p> PyNumberRTruedivProtocol<'p>, + { + self.nb_true_divide = py_binary_fallback_num_func!( + T, + PyNumberTruedivProtocol::__truediv__, + PyNumberRTruedivProtocol::__rtruediv__ + ); + } pub fn set_truediv(&mut self) where T: for<'p> PyNumberTruedivProtocol<'p>, @@ -853,6 +1042,16 @@ impl ffi::PyNumberMethods { { self.nb_index = py_unary_func!(PyNumberIndexProtocol, T::__index__); } + pub fn set_matmul_rmatmul(&mut self) + where + T: for<'p> PyNumberMatmulProtocol<'p> + for<'p> PyNumberRMatmulProtocol<'p>, + { + self.nb_matrix_multiply = py_binary_fallback_num_func!( + T, + PyNumberMatmulProtocol::__matmul__, + PyNumberRMatmulProtocol::__rmatmul__ + ); + } pub fn set_matmul(&mut self) where T: for<'p> PyNumberMatmulProtocol<'p>, diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 98b931eb..f0c38724 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -280,10 +280,56 @@ fn rhs_arithmetic() { } #[pyclass] -struct LhsAndRhsArithmetic {} +struct LhsAndRhs {} + +impl std::fmt::Debug for LhsAndRhs { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "LR") + } +} #[pyproto] -impl PyNumberProtocol for LhsAndRhsArithmetic { +impl PyNumberProtocol for LhsAndRhs { + fn __add__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} + {:?}", lhs, rhs) + } + + fn __sub__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} - {:?}", lhs, rhs) + } + + fn __mul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} * {:?}", lhs, rhs) + } + + fn __lshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} << {:?}", lhs, rhs) + } + + fn __rshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} >> {:?}", lhs, rhs) + } + + fn __and__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} & {:?}", lhs, rhs) + } + + fn __xor__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} ^ {:?}", lhs, rhs) + } + + fn __or__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} | {:?}", lhs, rhs) + } + + fn __pow__(lhs: PyRef, rhs: &PyAny, _mod: Option) -> String { + format!("{:?} ** {:?}", lhs, rhs) + } + + fn __matmul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} @ {:?}", lhs, rhs) + } + fn __radd__(&self, other: &PyAny) -> String { format!("{:?} + RA", other) } @@ -292,44 +338,74 @@ impl PyNumberProtocol for LhsAndRhsArithmetic { format!("{:?} - RA", other) } + fn __rmul__(&self, other: &PyAny) -> String { + format!("{:?} * RA", other) + } + + fn __rlshift__(&self, other: &PyAny) -> String { + format!("{:?} << RA", other) + } + + fn __rrshift__(&self, other: &PyAny) -> String { + format!("{:?} >> RA", other) + } + + fn __rand__(&self, other: &PyAny) -> String { + format!("{:?} & RA", other) + } + + fn __rxor__(&self, other: &PyAny) -> String { + format!("{:?} ^ RA", other) + } + + fn __ror__(&self, other: &PyAny) -> String { + format!("{:?} | RA", other) + } + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { format!("{:?} ** RA", other) } - fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} + {:?}", lhs, rhs) - } - - fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} - {:?}", lhs, rhs) - } - - fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option) -> String { - format!("{:?} ** {:?}", lhs, rhs) + fn __rmatmul__(&self, other: &PyAny) -> String { + format!("{:?} @ RA", other) } } #[pyproto] -impl PyObjectProtocol for LhsAndRhsArithmetic { +impl PyObjectProtocol for LhsAndRhs { fn __repr__(&self) -> &'static str { "BA" } } #[test] -fn lhs_override_rhs() { +fn lhs_fellback_to_rhs() { let gil = Python::acquire_gil(); let py = gil.python(); - let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap(); - // Not overrided - py_run!(py, c, "assert c.__radd__(1) == '1 + RA'"); - py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); - py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'"); - // Overrided - py_run!(py, c, "assert 1 + c == '1 + BA'"); - py_run!(py, c, "assert 1 - c == '1 - BA'"); - py_run!(py, c, "assert 1 ** c == '1 ** BA'"); + let c = PyCell::new(py, LhsAndRhs {}).unwrap(); + // If the light hand value is `LhsAndRhs`, LHS is used. + py_run!(py, c, "assert c + 1 == 'LR + 1'"); + py_run!(py, c, "assert c - 1 == 'LR - 1'"); + py_run!(py, c, "assert c * 1 == 'LR * 1'"); + py_run!(py, c, "assert c << 1 == 'LR << 1'"); + py_run!(py, c, "assert c >> 1 == 'LR >> 1'"); + py_run!(py, c, "assert c & 1 == 'LR & 1'"); + py_run!(py, c, "assert c ^ 1 == 'LR ^ 1'"); + py_run!(py, c, "assert c | 1 == 'LR | 1'"); + py_run!(py, c, "assert c ** 1 == 'LR ** 1'"); + py_run!(py, c, "assert c @ 1 == 'LR @ 1'"); + // Fellback to RHS because of type mismatching + py_run!(py, c, "assert 1 + c == '1 + RA'"); + py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert 1 * c == '1 * RA'"); + py_run!(py, c, "assert 1 << c == '1 << RA'"); + py_run!(py, c, "assert 1 >> c == '1 >> RA'"); + py_run!(py, c, "assert 1 & c == '1 & RA'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ RA'"); + py_run!(py, c, "assert 1 | c == '1 | RA'"); + py_run!(py, c, "assert 1 ** c == '1 ** RA'"); + py_run!(py, c, "assert 1 @ c == '1 @ RA'"); } #[pyclass]