Left-hand operands are fellback to RH ones for type mismatching

This commit is contained in:
kngwyu 2020-08-15 21:57:00 +09:00
parent 629efd94e2
commit f086f48499
5 changed files with 316 additions and 153 deletions

View File

@ -65,18 +65,13 @@ pub struct SlotSetter {
pub proto_names: &'static [&'static str], pub proto_names: &'static [&'static str],
/// The name of the setter called to the method table. /// The name of the setter called to the method table.
pub set_function: &'static str, pub set_function: &'static str,
/// Represents a set of setters disabled by this setter.
/// E.g., `set_setdelitem` have to disable `set_setitem` and `set_delitem`.
pub skipped_setters: &'static [&'static str],
} }
impl SlotSetter { impl SlotSetter {
const EMPTY_SETTERS: &'static [&'static str] = &[];
const fn new(names: &'static [&'static str], set_function: &'static str) -> Self { const fn new(names: &'static [&'static str], set_function: &'static str) -> Self {
SlotSetter { SlotSetter {
proto_names: names, proto_names: names,
set_function, set_function,
skipped_setters: Self::EMPTY_SETTERS,
} }
} }
} }
@ -144,11 +139,7 @@ pub const OBJECT: Proto = Proto {
SlotSetter::new(&["__hash__"], "set_hash"), SlotSetter::new(&["__hash__"], "set_hash"),
SlotSetter::new(&["__getattr__"], "set_getattr"), SlotSetter::new(&["__getattr__"], "set_getattr"),
SlotSetter::new(&["__richcmp__"], "set_richcompare"), SlotSetter::new(&["__richcmp__"], "set_richcompare"),
SlotSetter { SlotSetter::new(&["__setattr__", "__delattr__"], "set_setdelattr"),
proto_names: &["__setattr__", "__delattr__"],
set_function: "set_setdelattr",
skipped_setters: &["set_setattr", "set_delattr"],
},
SlotSetter::new(&["__setattr__"], "set_setattr"), SlotSetter::new(&["__setattr__"], "set_setattr"),
SlotSetter::new(&["__delattr__"], "set_delattr"), SlotSetter::new(&["__delattr__"], "set_delattr"),
SlotSetter::new(&["__bool__"], "set_bool"), SlotSetter::new(&["__bool__"], "set_bool"),
@ -379,11 +370,7 @@ pub const MAPPING: Proto = Proto {
slot_setters: &[ slot_setters: &[
SlotSetter::new(&["__len__"], "set_length"), SlotSetter::new(&["__len__"], "set_length"),
SlotSetter::new(&["__getitem__"], "set_getitem"), SlotSetter::new(&["__getitem__"], "set_getitem"),
SlotSetter { SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"),
proto_names: &["__setitem__", "__delitem__"],
set_function: "set_setdelitem",
skipped_setters: &["set_setitem", "set_delitem"],
},
SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__setitem__"], "set_setitem"),
SlotSetter::new(&["__delitem__"], "set_delitem"), SlotSetter::new(&["__delitem__"], "set_delitem"),
], ],
@ -446,11 +433,7 @@ pub const SEQ: Proto = Proto {
SlotSetter::new(&["__concat__"], "set_concat"), SlotSetter::new(&["__concat__"], "set_concat"),
SlotSetter::new(&["__repeat__"], "set_repeat"), SlotSetter::new(&["__repeat__"], "set_repeat"),
SlotSetter::new(&["__getitem__"], "set_getitem"), SlotSetter::new(&["__getitem__"], "set_getitem"),
SlotSetter { SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"),
proto_names: &["__setitem__", "__delitem__"],
set_function: "set_setdelitem",
skipped_setters: &["set_setitem", "set_delitem"],
},
SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__setitem__"], "set_setitem"),
SlotSetter::new(&["__delitem__"], "set_delitem"), SlotSetter::new(&["__delitem__"], "set_delitem"),
SlotSetter::new(&["__contains__"], "set_contains"), SlotSetter::new(&["__contains__"], "set_contains"),
@ -766,71 +749,40 @@ pub const NUM: Proto = Proto {
), ),
], ],
slot_setters: &[ slot_setters: &[
SlotSetter { SlotSetter::new(&["__add__", "__radd__"], "set_add_radd"),
proto_names: &["__add__"], SlotSetter::new(&["__add__"], "set_add"),
set_function: "set_add",
skipped_setters: &["set_radd"],
},
SlotSetter::new(&["__radd__"], "set_radd"), SlotSetter::new(&["__radd__"], "set_radd"),
SlotSetter { SlotSetter::new(&["__sub__", "__rsub__"], "set_sub_rsub"),
proto_names: &["__sub__"], SlotSetter::new(&["__sub__"], "set_sub"),
set_function: "set_sub",
skipped_setters: &["set_rsub"],
},
SlotSetter::new(&["__rsub__"], "set_rsub"), SlotSetter::new(&["__rsub__"], "set_rsub"),
SlotSetter { SlotSetter::new(&["__mul__", "__rmul__"], "set_mul_rmul"),
proto_names: &["__mul__"], SlotSetter::new(&["__mul__"], "set_mul"),
set_function: "set_mul",
skipped_setters: &["set_rmul"],
},
SlotSetter::new(&["__rmul__"], "set_rmul"), SlotSetter::new(&["__rmul__"], "set_rmul"),
SlotSetter::new(&["__mod__"], "set_mod"), SlotSetter::new(&["__mod__"], "set_mod"),
SlotSetter { SlotSetter::new(&["__divmod__", "__rdivmod__"], "set_divmod_rdivmod"),
proto_names: &["__divmod__"], SlotSetter::new(&["__divmod__"], "set_divmod"),
set_function: "set_divmod",
skipped_setters: &["set_rdivmod"],
},
SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), SlotSetter::new(&["__rdivmod__"], "set_rdivmod"),
SlotSetter { SlotSetter::new(&["__pow__", "__rpow__"], "set_pow_rpow"),
proto_names: &["__pow__"], SlotSetter::new(&["__pow__"], "set_pow"),
set_function: "set_pow",
skipped_setters: &["set_rpow"],
},
SlotSetter::new(&["__rpow__"], "set_rpow"), SlotSetter::new(&["__rpow__"], "set_rpow"),
SlotSetter::new(&["__neg__"], "set_neg"), SlotSetter::new(&["__neg__"], "set_neg"),
SlotSetter::new(&["__pos__"], "set_pos"), SlotSetter::new(&["__pos__"], "set_pos"),
SlotSetter::new(&["__abs__"], "set_abs"), SlotSetter::new(&["__abs__"], "set_abs"),
SlotSetter::new(&["__invert__"], "set_invert"), SlotSetter::new(&["__invert__"], "set_invert"),
SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), SlotSetter::new(&["__lshift__", "__rlshift__"], "set_lshift_rlshift"),
SlotSetter { SlotSetter::new(&["__lshift__"], "set_lshift"),
proto_names: &["__lshift__"],
set_function: "set_lshift",
skipped_setters: &["set_rlshift"],
},
SlotSetter::new(&["__rlshift__"], "set_rlshift"), SlotSetter::new(&["__rlshift__"], "set_rlshift"),
SlotSetter { SlotSetter::new(&["__rshift__", "__rrshift__"], "set_rshift_rrshift"),
proto_names: &["__rshift__"], SlotSetter::new(&["__rshift__"], "set_rshift"),
set_function: "set_rshift",
skipped_setters: &["set_rrshift"],
},
SlotSetter::new(&["__rrshift__"], "set_rrshift"), SlotSetter::new(&["__rrshift__"], "set_rrshift"),
SlotSetter { SlotSetter::new(&["__and__", "__rand__"], "set_and_rand"),
proto_names: &["__and__"], SlotSetter::new(&["__and__"], "set_and"),
set_function: "set_and",
skipped_setters: &["set_rand"],
},
SlotSetter::new(&["__rand__"], "set_rand"), SlotSetter::new(&["__rand__"], "set_rand"),
SlotSetter { SlotSetter::new(&["__xor__", "__rxor__"], "set_xor_rxor"),
proto_names: &["__xor__"], SlotSetter::new(&["__xor__"], "set_xor"),
set_function: "set_xor",
skipped_setters: &["set_rxor"],
},
SlotSetter::new(&["__rxor__"], "set_rxor"), SlotSetter::new(&["__rxor__"], "set_rxor"),
SlotSetter { SlotSetter::new(&["__or__", "__ror__"], "set_or_ror"),
proto_names: &["__or__"], SlotSetter::new(&["__or__"], "set_or"),
set_function: "set_or",
skipped_setters: &["set_ror"],
},
SlotSetter::new(&["__ror__"], "set_ror"), SlotSetter::new(&["__ror__"], "set_ror"),
SlotSetter::new(&["__int__"], "set_int"), SlotSetter::new(&["__int__"], "set_int"),
SlotSetter::new(&["__float__"], "set_float"), SlotSetter::new(&["__float__"], "set_float"),
@ -844,26 +796,17 @@ pub const NUM: Proto = Proto {
SlotSetter::new(&["__iand__"], "set_iand"), SlotSetter::new(&["__iand__"], "set_iand"),
SlotSetter::new(&["__ixor__"], "set_ixor"), SlotSetter::new(&["__ixor__"], "set_ixor"),
SlotSetter::new(&["__ior__"], "set_ior"), SlotSetter::new(&["__ior__"], "set_ior"),
SlotSetter { SlotSetter::new(&["__floordiv__", "__rfloordiv__"], "set_floordiv_rfloordiv"),
proto_names: &["__floordiv__"], SlotSetter::new(&["__floordiv__"], "set_floordiv"),
set_function: "set_floordiv",
skipped_setters: &["set_rfloordiv"],
},
SlotSetter::new(&["__rfloordiv__"], "set_rfloordiv"), SlotSetter::new(&["__rfloordiv__"], "set_rfloordiv"),
SlotSetter { SlotSetter::new(&["__truediv__", "__rtruediv__"], "set_truediv_rtruediv"),
proto_names: &["__truediv__"], SlotSetter::new(&["__truediv__"], "set_truediv"),
set_function: "set_truediv",
skipped_setters: &["set_rtruediv"],
},
SlotSetter::new(&["__rtruediv__"], "set_rtruediv"), SlotSetter::new(&["__rtruediv__"], "set_rtruediv"),
SlotSetter::new(&["__ifloordiv__"], "set_ifloordiv"), SlotSetter::new(&["__ifloordiv__"], "set_ifloordiv"),
SlotSetter::new(&["__itruediv__"], "set_itruediv"), SlotSetter::new(&["__itruediv__"], "set_itruediv"),
SlotSetter::new(&["__index__"], "set_index"), SlotSetter::new(&["__index__"], "set_index"),
SlotSetter { SlotSetter::new(&["__matmul__", "__rmatmul__"], "set_matmul_rmatmul"),
proto_names: &["__matmul__"], SlotSetter::new(&["__matmul__"], "set_matmul"),
set_function: "set_matmul",
skipped_setters: &["set_rmatmul"],
},
SlotSetter::new(&["__rmatmul__"], "set_rmatmul"), SlotSetter::new(&["__rmatmul__"], "set_rmatmul"),
SlotSetter::new(&["__imatmul__"], "set_imatmul"), SlotSetter::new(&["__imatmul__"], "set_imatmul"),
], ],

View File

@ -134,23 +134,22 @@ fn slot_initialization(
ty: &syn::Type, ty: &syn::Type,
proto: &defs::Proto, proto: &defs::Proto,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
// Some setters cannot coexist. // To skip used protocols.
// E.g., if we have `__add__`, we need to skip `set_radd`. // E.g., if __set__ and __del__ exist, we use set_set_del and skip set_set.
let mut skipped_setters = Vec::new(); let mut used_protocols = HashSet::new();
// Collect initializers // Collect initializers
let mut initializers: Vec<TokenStream> = vec![]; let mut initializers: Vec<TokenStream> = vec![];
'outer_loop: for m in proto.slot_setters { for m in proto.slot_setters {
if skipped_setters.contains(&m.set_function) { // Skip if any required protocol are not implemented or already used.
if m.proto_names
.iter()
.any(|name| !method_names.contains(*name) || used_protocols.contains(name))
{
continue; continue;
} }
for name in m.proto_names { for name in m.proto_names {
// If this `#[pyproto]` block doesn't provide all required methods, used_protocols.insert(name);
// let's skip implementing this method.
if !method_names.contains(*name) {
continue 'outer_loop;
}
} }
skipped_setters.extend_from_slice(m.skipped_setters);
// Add slot methods to PyProtoRegistry // Add slot methods to PyProtoRegistry
let set = syn::Ident::new(m.set_function, Span::call_site()); let set = syn::Ident::new(m.set_function, Span::call_site());
initializers.push(quote! { table.#set::<#ty>(); }); initializers.push(quote! { table.#set::<#ty>(); });

View File

@ -112,10 +112,39 @@ macro_rules! py_binary_reversed_num_func {
{ {
$crate::callback_body!(py, { $crate::callback_body!(py, {
// Swap lhs <-> rhs // Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(rhs); let slf: &$crate::PyCell<T> = extract_or_return_not_implemented!(py, rhs);
let arg = py.from_borrowed_ptr::<$crate::PyAny>(lhs); let arg = extract_or_return_not_implemented!(py, lhs);
$class::$f(&*slf.try_borrow()?, arg).convert(py)
})
}
Some(wrap::<$class>)
}};
}
$class::$f(&*slf.try_borrow()?, arg.extract()?).convert(py) #[macro_export]
#[doc(hidden)]
macro_rules! py_binary_fallbacked_num_func {
($class:ident, $lop_trait: ident :: $lop: ident, $rop_trait: ident :: $rop: ident) => {{
unsafe extern "C" fn wrap<T>(
lhs: *mut ffi::PyObject,
rhs: *mut ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $lop_trait<'p> + for<'p> $rop_trait<'p>,
{
$crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<$crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<$crate::PyAny>(rhs);
// First, try the left hand method (e.g., __add__)
match (lhs.extract(), rhs.extract()) {
(Ok(l), Ok(r)) => $class::$lop(l, r).convert(py),
_ => {
// Next, try the right hand method (e.g., __add__)
let slf: &$crate::PyCell<T> = extract_or_return_not_implemented!(rhs);
let arg = extract_or_return_not_implemented!(lhs);
$class::$rop(&*slf.try_borrow()?, arg).convert(py)
}
}
}) })
} }
Some(wrap::<$class>) Some(wrap::<$class>)
@ -205,57 +234,6 @@ macro_rules! py_ternarys_func {
}; };
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_num_func {
($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(
arg1: *mut $crate::ffi::PyObject,
arg2: *mut $crate::ffi::PyObject,
arg3: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
$crate::callback_body!(py, {
let arg1 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg1)
.extract()?;
let arg2 = extract_or_return_not_implemented!(py, arg2);
let arg3 = extract_or_return_not_implemented!(py, arg3);
$class::$f(arg1, arg2, arg3).convert(py)
})
}
Some(wrap::<T>)
}};
}
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_reversed_num_func {
($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(
arg1: *mut $crate::ffi::PyObject,
arg2: *mut $crate::ffi::PyObject,
arg3: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
$crate::callback_body!(py, {
// Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(arg2);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
let arg2 = py.from_borrowed_ptr::<$crate::PyAny>(arg3);
$class::$f(&*slf.try_borrow()?, arg1.extract()?, arg2.extract()?).convert(py)
})
}
Some(wrap::<$class>)
}};
}
// NOTE(kngwyu): Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract arg2, // NOTE(kngwyu): Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract arg2,
// so we ignore it. It's the same as what CPython does. // so we ignore it. It's the same as what CPython does.
#[macro_export] #[macro_export]
@ -370,6 +348,16 @@ macro_rules! py_func_set_del {
} }
macro_rules! extract_or_return_not_implemented { macro_rules! extract_or_return_not_implemented {
($arg: ident) => {
match $arg.extract() {
Ok(value) => value,
Err(_) => {
let res = $crate::ffi::Py_NotImplemented();
ffi::Py_INCREF(res);
return Ok(res);
}
}
};
($py: ident, $arg: ident) => { ($py: ident, $arg: ident) => {
match $py match $py
.from_borrowed_ptr::<$crate::types::PyAny>($arg) .from_borrowed_ptr::<$crate::types::PyAny>($arg)

View File

@ -585,6 +585,16 @@ impl ffi::PyNumberMethods {
nm.nb_bool = Some(nb_bool); nm.nb_bool = Some(nb_bool);
Box::into_raw(Box::new(nm)) Box::into_raw(Box::new(nm))
} }
pub fn set_add_radd<T>(&mut self)
where
T: for<'p> PyNumberAddProtocol<'p> + for<'p> PyNumberRAddProtocol<'p>,
{
self.nb_add = py_binary_fallbacked_num_func!(
T,
PyNumberAddProtocol::__add__,
PyNumberRAddProtocol::__radd__
);
}
pub fn set_add<T>(&mut self) pub fn set_add<T>(&mut self)
where where
T: for<'p> PyNumberAddProtocol<'p>, T: for<'p> PyNumberAddProtocol<'p>,
@ -597,6 +607,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_add = py_binary_reversed_num_func!(PyNumberRAddProtocol, T::__radd__); self.nb_add = py_binary_reversed_num_func!(PyNumberRAddProtocol, T::__radd__);
} }
pub fn set_sub_rsub<T>(&mut self)
where
T: for<'p> PyNumberSubProtocol<'p> + for<'p> PyNumberRSubProtocol<'p>,
{
self.nb_subtract = py_binary_fallbacked_num_func!(
T,
PyNumberSubProtocol::__sub__,
PyNumberRSubProtocol::__rsub__
);
}
pub fn set_sub<T>(&mut self) pub fn set_sub<T>(&mut self)
where where
T: for<'p> PyNumberSubProtocol<'p>, T: for<'p> PyNumberSubProtocol<'p>,
@ -609,6 +629,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_subtract = py_binary_reversed_num_func!(PyNumberRSubProtocol, T::__rsub__); self.nb_subtract = py_binary_reversed_num_func!(PyNumberRSubProtocol, T::__rsub__);
} }
pub fn set_mul_rmul<T>(&mut self)
where
T: for<'p> PyNumberMulProtocol<'p> + for<'p> PyNumberRMulProtocol<'p>,
{
self.nb_multiply = py_binary_fallbacked_num_func!(
T,
PyNumberMulProtocol::__mul__,
PyNumberRMulProtocol::__rmul__
);
}
pub fn set_mul<T>(&mut self) pub fn set_mul<T>(&mut self)
where where
T: for<'p> PyNumberMulProtocol<'p>, T: for<'p> PyNumberMulProtocol<'p>,
@ -627,6 +657,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_remainder = py_binary_num_func!(PyNumberModProtocol, T::__mod__); self.nb_remainder = py_binary_num_func!(PyNumberModProtocol, T::__mod__);
} }
pub fn set_divmod_rdivmod<T>(&mut self)
where
T: for<'p> PyNumberDivmodProtocol<'p> + for<'p> PyNumberRDivmodProtocol<'p>,
{
self.nb_divmod = py_binary_fallbacked_num_func!(
T,
PyNumberDivmodProtocol::__divmod__,
PyNumberRDivmodProtocol::__rdivmod__
);
}
pub fn set_divmod<T>(&mut self) pub fn set_divmod<T>(&mut self)
where where
T: for<'p> PyNumberDivmodProtocol<'p>, T: for<'p> PyNumberDivmodProtocol<'p>,
@ -639,17 +679,78 @@ impl ffi::PyNumberMethods {
{ {
self.nb_divmod = py_binary_reversed_num_func!(PyNumberRDivmodProtocol, T::__rdivmod__); self.nb_divmod = py_binary_reversed_num_func!(PyNumberRDivmodProtocol, T::__rdivmod__);
} }
pub fn set_pow_rpow<T>(&mut self)
where
T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>,
{
unsafe extern "C" fn wrap_pow_and_rpow<T>(
lhs: *mut crate::ffi::PyObject,
rhs: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>,
{
crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<crate::PyAny>(rhs);
let modulo = py.from_borrowed_ptr::<crate::PyAny>(modulo);
// First, try __pow__
match (lhs.extract(), rhs.extract(), modulo.extract()) {
(Ok(l), Ok(r), Ok(m)) => T::__pow__(l, r, m).convert(py),
_ => {
// Then try __rpow__
let slf: &crate::PyCell<T> = extract_or_return_not_implemented!(rhs);
let arg = extract_or_return_not_implemented!(lhs);
let modulo = extract_or_return_not_implemented!(modulo);
slf.try_borrow()?.__rpow__(arg, modulo).convert(py)
}
}
})
}
self.nb_power = Some(wrap_pow_and_rpow::<T>);
}
pub fn set_pow<T>(&mut self) pub fn set_pow<T>(&mut self)
where where
T: for<'p> PyNumberPowProtocol<'p>, T: for<'p> PyNumberPowProtocol<'p>,
{ {
self.nb_power = py_ternary_num_func!(PyNumberPowProtocol, T::__pow__); unsafe extern "C" fn wrap_pow<T>(
lhs: *mut crate::ffi::PyObject,
rhs: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberPowProtocol<'p>,
{
crate::callback_body!(py, {
let lhs = extract_or_return_not_implemented!(py, lhs);
let rhs = extract_or_return_not_implemented!(py, rhs);
let modulo = extract_or_return_not_implemented!(py, modulo);
T::__pow__(lhs, rhs, modulo).convert(py)
})
}
self.nb_power = Some(wrap_pow::<T>);
} }
pub fn set_rpow<T>(&mut self) pub fn set_rpow<T>(&mut self)
where where
T: for<'p> PyNumberRPowProtocol<'p>, T: for<'p> PyNumberRPowProtocol<'p>,
{ {
self.nb_power = py_ternary_reversed_num_func!(PyNumberRPowProtocol, T::__rpow__); unsafe extern "C" fn wrap_rpow<T>(
arg: *mut crate::ffi::PyObject,
slf: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberRPowProtocol<'p>,
{
crate::callback_body!(py, {
let slf: &crate::PyCell<T> = extract_or_return_not_implemented!(py, slf);
let arg = extract_or_return_not_implemented!(py, arg);
let modulo = extract_or_return_not_implemented!(py, modulo);
slf.try_borrow()?.__rpow__(arg, modulo).convert(py)
})
}
self.nb_power = Some(wrap_rpow::<T>);
} }
pub fn set_neg<T>(&mut self) pub fn set_neg<T>(&mut self)
where where
@ -675,6 +776,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_invert = py_unary_func!(PyNumberInvertProtocol, T::__invert__); self.nb_invert = py_unary_func!(PyNumberInvertProtocol, T::__invert__);
} }
pub fn set_lshift_rlshift<T>(&mut self)
where
T: for<'p> PyNumberLShiftProtocol<'p> + for<'p> PyNumberRLShiftProtocol<'p>,
{
self.nb_lshift = py_binary_fallbacked_num_func!(
T,
PyNumberLShiftProtocol::__lshift__,
PyNumberRLShiftProtocol::__rlshift__
);
}
pub fn set_lshift<T>(&mut self) pub fn set_lshift<T>(&mut self)
where where
T: for<'p> PyNumberLShiftProtocol<'p>, T: for<'p> PyNumberLShiftProtocol<'p>,
@ -687,6 +798,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_lshift = py_binary_reversed_num_func!(PyNumberRLShiftProtocol, T::__rlshift__); self.nb_lshift = py_binary_reversed_num_func!(PyNumberRLShiftProtocol, T::__rlshift__);
} }
pub fn set_rshift_rrshift<T>(&mut self)
where
T: for<'p> PyNumberRShiftProtocol<'p> + for<'p> PyNumberRRShiftProtocol<'p>,
{
self.nb_rshift = py_binary_fallbacked_num_func!(
T,
PyNumberRShiftProtocol::__rshift__,
PyNumberRRShiftProtocol::__rrshift__
);
}
pub fn set_rshift<T>(&mut self) pub fn set_rshift<T>(&mut self)
where where
T: for<'p> PyNumberRShiftProtocol<'p>, T: for<'p> PyNumberRShiftProtocol<'p>,
@ -699,6 +820,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_rshift = py_binary_reversed_num_func!(PyNumberRRShiftProtocol, T::__rrshift__); self.nb_rshift = py_binary_reversed_num_func!(PyNumberRRShiftProtocol, T::__rrshift__);
} }
pub fn set_and_rand<T>(&mut self)
where
T: for<'p> PyNumberAndProtocol<'p> + for<'p> PyNumberRAndProtocol<'p>,
{
self.nb_and = py_binary_fallbacked_num_func!(
T,
PyNumberAndProtocol::__and__,
PyNumberRAndProtocol::__rand__
);
}
pub fn set_and<T>(&mut self) pub fn set_and<T>(&mut self)
where where
T: for<'p> PyNumberAndProtocol<'p>, T: for<'p> PyNumberAndProtocol<'p>,
@ -711,6 +842,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_and = py_binary_reversed_num_func!(PyNumberRAndProtocol, T::__rand__); self.nb_and = py_binary_reversed_num_func!(PyNumberRAndProtocol, T::__rand__);
} }
pub fn set_xor_rxor<T>(&mut self)
where
T: for<'p> PyNumberXorProtocol<'p> + for<'p> PyNumberRXorProtocol<'p>,
{
self.nb_xor = py_binary_fallbacked_num_func!(
T,
PyNumberXorProtocol::__xor__,
PyNumberRXorProtocol::__rxor__
);
}
pub fn set_xor<T>(&mut self) pub fn set_xor<T>(&mut self)
where where
T: for<'p> PyNumberXorProtocol<'p>, T: for<'p> PyNumberXorProtocol<'p>,
@ -723,6 +864,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_xor = py_binary_reversed_num_func!(PyNumberRXorProtocol, T::__rxor__); self.nb_xor = py_binary_reversed_num_func!(PyNumberRXorProtocol, T::__rxor__);
} }
pub fn set_or_ror<T>(&mut self)
where
T: for<'p> PyNumberOrProtocol<'p> + for<'p> PyNumberROrProtocol<'p>,
{
self.nb_or = py_binary_fallbacked_num_func!(
T,
PyNumberOrProtocol::__or__,
PyNumberROrProtocol::__ror__
);
}
pub fn set_or<T>(&mut self) pub fn set_or<T>(&mut self)
where where
T: for<'p> PyNumberOrProtocol<'p>, T: for<'p> PyNumberOrProtocol<'p>,
@ -807,6 +958,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_inplace_or = py_binary_self_func!(PyNumberIOrProtocol, T::__ior__); self.nb_inplace_or = py_binary_self_func!(PyNumberIOrProtocol, T::__ior__);
} }
pub fn set_floordiv_rfloordiv<T>(&mut self)
where
T: for<'p> PyNumberFloordivProtocol<'p> + for<'p> PyNumberRFloordivProtocol<'p>,
{
self.nb_floor_divide = py_binary_fallbacked_num_func!(
T,
PyNumberFloordivProtocol::__floordiv__,
PyNumberRFloordivProtocol::__rfloordiv__
);
}
pub fn set_floordiv<T>(&mut self) pub fn set_floordiv<T>(&mut self)
where where
T: for<'p> PyNumberFloordivProtocol<'p>, T: for<'p> PyNumberFloordivProtocol<'p>,
@ -820,6 +981,16 @@ impl ffi::PyNumberMethods {
self.nb_floor_divide = self.nb_floor_divide =
py_binary_reversed_num_func!(PyNumberRFloordivProtocol, T::__rfloordiv__); py_binary_reversed_num_func!(PyNumberRFloordivProtocol, T::__rfloordiv__);
} }
pub fn set_truediv_rtruediv<T>(&mut self)
where
T: for<'p> PyNumberTruedivProtocol<'p> + for<'p> PyNumberRTruedivProtocol<'p>,
{
self.nb_true_divide = py_binary_fallbacked_num_func!(
T,
PyNumberTruedivProtocol::__truediv__,
PyNumberRTruedivProtocol::__rtruediv__
);
}
pub fn set_truediv<T>(&mut self) pub fn set_truediv<T>(&mut self)
where where
T: for<'p> PyNumberTruedivProtocol<'p>, T: for<'p> PyNumberTruedivProtocol<'p>,
@ -853,6 +1024,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_index = py_unary_func!(PyNumberIndexProtocol, T::__index__); self.nb_index = py_unary_func!(PyNumberIndexProtocol, T::__index__);
} }
pub fn set_matmul_rmatmul<T>(&mut self)
where
T: for<'p> PyNumberMatmulProtocol<'p> + for<'p> PyNumberRMatmulProtocol<'p>,
{
self.nb_matrix_multiply = py_binary_fallbacked_num_func!(
T,
PyNumberMatmulProtocol::__matmul__,
PyNumberRMatmulProtocol::__rmatmul__
);
}
pub fn set_matmul<T>(&mut self) pub fn set_matmul<T>(&mut self)
where where
T: for<'p> PyNumberMatmulProtocol<'p>, T: for<'p> PyNumberMatmulProtocol<'p>,

View File

@ -332,6 +332,58 @@ fn lhs_override_rhs() {
py_run!(py, c, "assert 1 ** c == '1 ** BA'"); py_run!(py, c, "assert 1 ** c == '1 ** BA'");
} }
#[pyclass]
#[derive(Debug)]
struct Lhs2Rhs {}
#[pyproto]
impl PyNumberProtocol for Lhs2Rhs {
fn __add__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
fn __sub__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}
fn __pow__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
fn __matmul__(lhs: PyRef<Lhs2Rhs>, rhs: &PyAny) -> String {
format!("{:?} @ {:?}", lhs, rhs)
}
fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other)
}
fn __rsub__(&self, rhs: &PyAny) -> String {
format!("{:?} - RA", rhs)
}
fn __rpow__(&self, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** RA", rhs)
}
fn __rmatmul__(&self, rhs: &PyAny) -> String {
format!("{:?} @ RA", rhs)
}
}
#[pyproto]
impl PyObjectProtocol for Lhs2Rhs {
fn __repr__(&self) -> &'static str {
"BA"
}
}
#[test]
fn lhs_fallbacked_to_rhs() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = PyCell::new(py, Lhs2Rhs {}).unwrap();
// Fallbacked to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
}
#[pyclass] #[pyclass]
struct RichComparisons {} struct RichComparisons {}