Merge pull request #1107 from kngwyu/radd-fallback

Left-hand operands are fellback to right-hand ones for type mismatching
This commit is contained in:
Yuji Kanagawa 2020-08-21 17:10:37 +09:00 committed by GitHub
commit 9d73e0b1a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 399 additions and 228 deletions

View file

@ -44,6 +44,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Improve lifetime elision in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093) - Improve lifetime elision in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093)
- Fix python configuration detection when cross-compiling. [#1095](https://github.com/PyO3/pyo3/pull/1095) - Fix python configuration detection when cross-compiling. [#1095](https://github.com/PyO3/pyo3/pull/1095)
- Link against libpython on android with `extension-module` set. [#1095](https://github.com/PyO3/pyo3/pull/1095) - Link against libpython on android with `extension-module` set. [#1095](https://github.com/PyO3/pyo3/pull/1095)
- Fix support for both `__add__` and `__radd__` in the `+` operator when both are defined in `PyNumberProtocol`
(and similar for all other reversible operators). [#1107](https://github.com/PyO3/pyo3/pull/1107)
## [0.11.1] - 2020-06-30 ## [0.11.1] - 2020-06-30
### Added ### Added

View file

@ -1,5 +1,6 @@
// Copyright (c) 2017-present PyO3 Project and Contributors // Copyright (c) 2017-present PyO3 Project and Contributors
use crate::proto_method::MethodProto; use crate::proto_method::MethodProto;
use std::collections::HashSet;
/// Predicates for `#[pyproto]`. /// Predicates for `#[pyproto]`.
pub struct Proto { pub struct Proto {
@ -14,7 +15,7 @@ pub struct Proto {
/// All methods registered as normal methods like `#[pymethods]`. /// All methods registered as normal methods like `#[pymethods]`.
pub py_methods: &'static [PyMethod], pub py_methods: &'static [PyMethod],
/// All methods registered to the slot table. /// All methods registered to the slot table.
pub slot_setters: &'static [SlotSetter], slot_setters: &'static [SlotSetter],
} }
impl Proto { impl Proto {
@ -30,6 +31,28 @@ impl Proto {
{ {
self.py_methods.iter().find(|m| query == m.name) self.py_methods.iter().find(|m| query == m.name)
} }
// Since the order matters, we expose only the iterator instead of the slice.
pub(crate) fn setters(
&self,
mut implemented_protocols: HashSet<String>,
) -> impl Iterator<Item = &'static str> {
self.slot_setters.iter().filter_map(move |setter| {
// If any required method is not implemented, we skip this setter.
if setter
.proto_names
.iter()
.any(|name| !implemented_protocols.contains(*name))
{
return None;
}
// To use 'paired' setter in priority, we remove used protocols.
// For example, if set_add_radd is already used, we shouldn't use set_add and set_radd.
for name in setter.proto_names {
implemented_protocols.remove(*name);
}
Some(setter.set_function)
})
}
} }
/// Represents a method registered as a normal method like `#[pymethods]`. /// Represents a method registered as a normal method like `#[pymethods]`.
@ -59,24 +82,19 @@ impl PyMethod {
} }
/// Represents a setter used to register a method to the method table. /// Represents a setter used to register a method to the method table.
pub struct SlotSetter { struct SlotSetter {
/// Protocols necessary for invoking this setter. /// Protocols necessary for invoking this setter.
/// E.g., we need `__setattr__` and `__delattr__` for invoking `set_setdelitem`. /// E.g., we need `__setattr__` and `__delattr__` for invoking `set_setdelitem`.
pub proto_names: &'static [&'static str], pub proto_names: &'static [&'static str],
/// The name of the setter called to the method table. /// The name of the setter called to the method table.
pub set_function: &'static str, pub set_function: &'static str,
/// Represents a set of setters disabled by this setter.
/// E.g., `set_setdelitem` have to disable `set_setitem` and `set_delitem`.
pub skipped_setters: &'static [&'static str],
} }
impl SlotSetter { impl SlotSetter {
const EMPTY_SETTERS: &'static [&'static str] = &[];
const fn new(names: &'static [&'static str], set_function: &'static str) -> Self { const fn new(names: &'static [&'static str], set_function: &'static str) -> Self {
SlotSetter { SlotSetter {
proto_names: names, proto_names: names,
set_function, set_function,
skipped_setters: Self::EMPTY_SETTERS,
} }
} }
} }
@ -144,11 +162,7 @@ pub const OBJECT: Proto = Proto {
SlotSetter::new(&["__hash__"], "set_hash"), SlotSetter::new(&["__hash__"], "set_hash"),
SlotSetter::new(&["__getattr__"], "set_getattr"), SlotSetter::new(&["__getattr__"], "set_getattr"),
SlotSetter::new(&["__richcmp__"], "set_richcompare"), SlotSetter::new(&["__richcmp__"], "set_richcompare"),
SlotSetter { SlotSetter::new(&["__setattr__", "__delattr__"], "set_setdelattr"),
proto_names: &["__setattr__", "__delattr__"],
set_function: "set_setdelattr",
skipped_setters: &["set_setattr", "set_delattr"],
},
SlotSetter::new(&["__setattr__"], "set_setattr"), SlotSetter::new(&["__setattr__"], "set_setattr"),
SlotSetter::new(&["__delattr__"], "set_delattr"), SlotSetter::new(&["__delattr__"], "set_delattr"),
SlotSetter::new(&["__bool__"], "set_bool"), SlotSetter::new(&["__bool__"], "set_bool"),
@ -379,11 +393,7 @@ pub const MAPPING: Proto = Proto {
slot_setters: &[ slot_setters: &[
SlotSetter::new(&["__len__"], "set_length"), SlotSetter::new(&["__len__"], "set_length"),
SlotSetter::new(&["__getitem__"], "set_getitem"), SlotSetter::new(&["__getitem__"], "set_getitem"),
SlotSetter { SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"),
proto_names: &["__setitem__", "__delitem__"],
set_function: "set_setdelitem",
skipped_setters: &["set_setitem", "set_delitem"],
},
SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__setitem__"], "set_setitem"),
SlotSetter::new(&["__delitem__"], "set_delitem"), SlotSetter::new(&["__delitem__"], "set_delitem"),
], ],
@ -446,11 +456,7 @@ pub const SEQ: Proto = Proto {
SlotSetter::new(&["__concat__"], "set_concat"), SlotSetter::new(&["__concat__"], "set_concat"),
SlotSetter::new(&["__repeat__"], "set_repeat"), SlotSetter::new(&["__repeat__"], "set_repeat"),
SlotSetter::new(&["__getitem__"], "set_getitem"), SlotSetter::new(&["__getitem__"], "set_getitem"),
SlotSetter { SlotSetter::new(&["__setitem__", "__delitem__"], "set_setdelitem"),
proto_names: &["__setitem__", "__delitem__"],
set_function: "set_setdelitem",
skipped_setters: &["set_setitem", "set_delitem"],
},
SlotSetter::new(&["__setitem__"], "set_setitem"), SlotSetter::new(&["__setitem__"], "set_setitem"),
SlotSetter::new(&["__delitem__"], "set_delitem"), SlotSetter::new(&["__delitem__"], "set_delitem"),
SlotSetter::new(&["__contains__"], "set_contains"), SlotSetter::new(&["__contains__"], "set_contains"),
@ -766,71 +772,40 @@ pub const NUM: Proto = Proto {
), ),
], ],
slot_setters: &[ slot_setters: &[
SlotSetter { SlotSetter::new(&["__add__", "__radd__"], "set_add_radd"),
proto_names: &["__add__"], SlotSetter::new(&["__add__"], "set_add"),
set_function: "set_add",
skipped_setters: &["set_radd"],
},
SlotSetter::new(&["__radd__"], "set_radd"), SlotSetter::new(&["__radd__"], "set_radd"),
SlotSetter { SlotSetter::new(&["__sub__", "__rsub__"], "set_sub_rsub"),
proto_names: &["__sub__"], SlotSetter::new(&["__sub__"], "set_sub"),
set_function: "set_sub",
skipped_setters: &["set_rsub"],
},
SlotSetter::new(&["__rsub__"], "set_rsub"), SlotSetter::new(&["__rsub__"], "set_rsub"),
SlotSetter { SlotSetter::new(&["__mul__", "__rmul__"], "set_mul_rmul"),
proto_names: &["__mul__"], SlotSetter::new(&["__mul__"], "set_mul"),
set_function: "set_mul",
skipped_setters: &["set_rmul"],
},
SlotSetter::new(&["__rmul__"], "set_rmul"), SlotSetter::new(&["__rmul__"], "set_rmul"),
SlotSetter::new(&["__mod__"], "set_mod"), SlotSetter::new(&["__mod__"], "set_mod"),
SlotSetter { SlotSetter::new(&["__divmod__", "__rdivmod__"], "set_divmod_rdivmod"),
proto_names: &["__divmod__"], SlotSetter::new(&["__divmod__"], "set_divmod"),
set_function: "set_divmod",
skipped_setters: &["set_rdivmod"],
},
SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), SlotSetter::new(&["__rdivmod__"], "set_rdivmod"),
SlotSetter { SlotSetter::new(&["__pow__", "__rpow__"], "set_pow_rpow"),
proto_names: &["__pow__"], SlotSetter::new(&["__pow__"], "set_pow"),
set_function: "set_pow",
skipped_setters: &["set_rpow"],
},
SlotSetter::new(&["__rpow__"], "set_rpow"), SlotSetter::new(&["__rpow__"], "set_rpow"),
SlotSetter::new(&["__neg__"], "set_neg"), SlotSetter::new(&["__neg__"], "set_neg"),
SlotSetter::new(&["__pos__"], "set_pos"), SlotSetter::new(&["__pos__"], "set_pos"),
SlotSetter::new(&["__abs__"], "set_abs"), SlotSetter::new(&["__abs__"], "set_abs"),
SlotSetter::new(&["__invert__"], "set_invert"), SlotSetter::new(&["__invert__"], "set_invert"),
SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), SlotSetter::new(&["__lshift__", "__rlshift__"], "set_lshift_rlshift"),
SlotSetter { SlotSetter::new(&["__lshift__"], "set_lshift"),
proto_names: &["__lshift__"],
set_function: "set_lshift",
skipped_setters: &["set_rlshift"],
},
SlotSetter::new(&["__rlshift__"], "set_rlshift"), SlotSetter::new(&["__rlshift__"], "set_rlshift"),
SlotSetter { SlotSetter::new(&["__rshift__", "__rrshift__"], "set_rshift_rrshift"),
proto_names: &["__rshift__"], SlotSetter::new(&["__rshift__"], "set_rshift"),
set_function: "set_rshift",
skipped_setters: &["set_rrshift"],
},
SlotSetter::new(&["__rrshift__"], "set_rrshift"), SlotSetter::new(&["__rrshift__"], "set_rrshift"),
SlotSetter { SlotSetter::new(&["__and__", "__rand__"], "set_and_rand"),
proto_names: &["__and__"], SlotSetter::new(&["__and__"], "set_and"),
set_function: "set_and",
skipped_setters: &["set_rand"],
},
SlotSetter::new(&["__rand__"], "set_rand"), SlotSetter::new(&["__rand__"], "set_rand"),
SlotSetter { SlotSetter::new(&["__xor__", "__rxor__"], "set_xor_rxor"),
proto_names: &["__xor__"], SlotSetter::new(&["__xor__"], "set_xor"),
set_function: "set_xor",
skipped_setters: &["set_rxor"],
},
SlotSetter::new(&["__rxor__"], "set_rxor"), SlotSetter::new(&["__rxor__"], "set_rxor"),
SlotSetter { SlotSetter::new(&["__or__", "__ror__"], "set_or_ror"),
proto_names: &["__or__"], SlotSetter::new(&["__or__"], "set_or"),
set_function: "set_or",
skipped_setters: &["set_ror"],
},
SlotSetter::new(&["__ror__"], "set_ror"), SlotSetter::new(&["__ror__"], "set_ror"),
SlotSetter::new(&["__int__"], "set_int"), SlotSetter::new(&["__int__"], "set_int"),
SlotSetter::new(&["__float__"], "set_float"), SlotSetter::new(&["__float__"], "set_float"),
@ -844,26 +819,17 @@ pub const NUM: Proto = Proto {
SlotSetter::new(&["__iand__"], "set_iand"), SlotSetter::new(&["__iand__"], "set_iand"),
SlotSetter::new(&["__ixor__"], "set_ixor"), SlotSetter::new(&["__ixor__"], "set_ixor"),
SlotSetter::new(&["__ior__"], "set_ior"), SlotSetter::new(&["__ior__"], "set_ior"),
SlotSetter { SlotSetter::new(&["__floordiv__", "__rfloordiv__"], "set_floordiv_rfloordiv"),
proto_names: &["__floordiv__"], SlotSetter::new(&["__floordiv__"], "set_floordiv"),
set_function: "set_floordiv",
skipped_setters: &["set_rfloordiv"],
},
SlotSetter::new(&["__rfloordiv__"], "set_rfloordiv"), SlotSetter::new(&["__rfloordiv__"], "set_rfloordiv"),
SlotSetter { SlotSetter::new(&["__truediv__", "__rtruediv__"], "set_truediv_rtruediv"),
proto_names: &["__truediv__"], SlotSetter::new(&["__truediv__"], "set_truediv"),
set_function: "set_truediv",
skipped_setters: &["set_rtruediv"],
},
SlotSetter::new(&["__rtruediv__"], "set_rtruediv"), SlotSetter::new(&["__rtruediv__"], "set_rtruediv"),
SlotSetter::new(&["__ifloordiv__"], "set_ifloordiv"), SlotSetter::new(&["__ifloordiv__"], "set_ifloordiv"),
SlotSetter::new(&["__itruediv__"], "set_itruediv"), SlotSetter::new(&["__itruediv__"], "set_itruediv"),
SlotSetter::new(&["__index__"], "set_index"), SlotSetter::new(&["__index__"], "set_index"),
SlotSetter { SlotSetter::new(&["__matmul__", "__rmatmul__"], "set_matmul_rmatmul"),
proto_names: &["__matmul__"], SlotSetter::new(&["__matmul__"], "set_matmul"),
set_function: "set_matmul",
skipped_setters: &["set_rmatmul"],
},
SlotSetter::new(&["__rmatmul__"], "set_rmatmul"), SlotSetter::new(&["__rmatmul__"], "set_rmatmul"),
SlotSetter::new(&["__imatmul__"], "set_imatmul"), SlotSetter::new(&["__imatmul__"], "set_imatmul"),
], ],

View file

@ -134,25 +134,11 @@ fn slot_initialization(
ty: &syn::Type, ty: &syn::Type,
proto: &defs::Proto, proto: &defs::Proto,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
// Some setters cannot coexist.
// E.g., if we have `__add__`, we need to skip `set_radd`.
let mut skipped_setters = Vec::new();
// Collect initializers // Collect initializers
let mut initializers: Vec<TokenStream> = vec![]; let mut initializers: Vec<TokenStream> = vec![];
'outer_loop: for m in proto.slot_setters { for setter in proto.setters(method_names) {
if skipped_setters.contains(&m.set_function) {
continue;
}
for name in m.proto_names {
// If this `#[pyproto]` block doesn't provide all required methods,
// let's skip implementing this method.
if !method_names.contains(*name) {
continue 'outer_loop;
}
}
skipped_setters.extend_from_slice(m.skipped_setters);
// Add slot methods to PyProtoRegistry // Add slot methods to PyProtoRegistry
let set = syn::Ident::new(m.set_function, Span::call_site()); let set = syn::Ident::new(setter, Span::call_site());
initializers.push(quote! { table.#set::<#ty>(); }); initializers.push(quote! { table.#set::<#ty>(); });
} }
if initializers.is_empty() { if initializers.is_empty() {

View file

@ -1,7 +1,5 @@
// Copyright (c) 2017-present PyO3 Project and Contributors // Copyright (c) 2017-present PyO3 Project and Contributors
#[macro_export]
#[doc(hidden)]
macro_rules! py_unary_func { macro_rules! py_unary_func {
($trait: ident, $class:ident :: $f:ident, $call:ident, $ret_type: ty) => {{ ($trait: ident, $class:ident :: $f:ident, $call:ident, $ret_type: ty) => {{
unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> $ret_type unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> $ret_type
@ -24,8 +22,6 @@ macro_rules! py_unary_func {
}; };
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_unarys_func { macro_rules! py_unarys_func {
($trait:ident, $class:ident :: $f:ident) => {{ ($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject
@ -45,16 +41,12 @@ macro_rules! py_unarys_func {
}}; }};
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_len_func { macro_rules! py_len_func {
($trait:ident, $class:ident :: $f:ident) => { ($trait:ident, $class:ident :: $f:ident) => {
py_unary_func!($trait, $class::$f, $crate::ffi::Py_ssize_t) py_unary_func!($trait, $class::$f, $crate::ffi::Py_ssize_t)
}; };
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_binary_func { macro_rules! py_binary_func {
// Use call_ref! by default // Use call_ref! by default
($trait:ident, $class:ident :: $f:ident, $return:ty, $call:ident) => {{ ($trait:ident, $class:ident :: $f:ident, $return:ty, $call:ident) => {{
@ -78,8 +70,6 @@ macro_rules! py_binary_func {
}; };
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_binary_num_func { macro_rules! py_binary_num_func {
($trait:ident, $class:ident :: $f:ident) => {{ ($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>( unsafe extern "C" fn wrap<T>(
@ -99,8 +89,6 @@ macro_rules! py_binary_num_func {
}}; }};
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_binary_reversed_num_func { macro_rules! py_binary_reversed_num_func {
($trait:ident, $class:ident :: $f:ident) => {{ ($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>( unsafe extern "C" fn wrap<T>(
@ -112,10 +100,37 @@ macro_rules! py_binary_reversed_num_func {
{ {
$crate::callback_body!(py, { $crate::callback_body!(py, {
// Swap lhs <-> rhs // Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(rhs); let slf: &$crate::PyCell<T> = extract_or_return_not_implemented!(py, rhs);
let arg = py.from_borrowed_ptr::<$crate::PyAny>(lhs); let arg = extract_or_return_not_implemented!(py, lhs);
$class::$f(&*slf.try_borrow()?, arg).convert(py)
})
}
Some(wrap::<$class>)
}};
}
$class::$f(&*slf.try_borrow()?, arg.extract()?).convert(py) macro_rules! py_binary_fallback_num_func {
($class:ident, $lop_trait: ident :: $lop: ident, $rop_trait: ident :: $rop: ident) => {{
unsafe extern "C" fn wrap<T>(
lhs: *mut ffi::PyObject,
rhs: *mut ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $lop_trait<'p> + for<'p> $rop_trait<'p>,
{
$crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<$crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<$crate::PyAny>(rhs);
// First, try the left hand method (e.g., __add__)
match (lhs.extract(), rhs.extract()) {
(Ok(l), Ok(r)) => $class::$lop(l, r).convert(py),
_ => {
// Next, try the right hand method (e.g., __radd__)
let slf: &$crate::PyCell<T> = extract_or_return_not_implemented!(rhs);
let arg = extract_or_return_not_implemented!(lhs);
$class::$rop(&*slf.try_borrow()?, arg).convert(py)
}
}
}) })
} }
Some(wrap::<$class>) Some(wrap::<$class>)
@ -123,8 +138,6 @@ macro_rules! py_binary_reversed_num_func {
} }
// NOTE(kngwyu): This macro is used only for inplace operations, so I used call_mut here. // NOTE(kngwyu): This macro is used only for inplace operations, so I used call_mut here.
#[macro_export]
#[doc(hidden)]
macro_rules! py_binary_self_func { macro_rules! py_binary_self_func {
($trait:ident, $class:ident :: $f:ident) => {{ ($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>( unsafe extern "C" fn wrap<T>(
@ -146,8 +159,6 @@ macro_rules! py_binary_self_func {
}}; }};
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_ssizearg_func { macro_rules! py_ssizearg_func {
// Use call_ref! by default // Use call_ref! by default
($trait:ident, $class:ident :: $f:ident) => { ($trait:ident, $class:ident :: $f:ident) => {
@ -170,8 +181,6 @@ macro_rules! py_ssizearg_func {
}}; }};
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternarys_func { macro_rules! py_ternarys_func {
($trait:ident, $class:ident :: $f:ident, $return_type:ty) => {{ ($trait:ident, $class:ident :: $f:ident, $return_type:ty) => {{
unsafe extern "C" fn wrap<T>( unsafe extern "C" fn wrap<T>(
@ -205,83 +214,6 @@ macro_rules! py_ternarys_func {
}; };
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_num_func {
($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(
arg1: *mut $crate::ffi::PyObject,
arg2: *mut $crate::ffi::PyObject,
arg3: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
$crate::callback_body!(py, {
let arg1 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg1)
.extract()?;
let arg2 = extract_or_return_not_implemented!(py, arg2);
let arg3 = extract_or_return_not_implemented!(py, arg3);
$class::$f(arg1, arg2, arg3).convert(py)
})
}
Some(wrap::<T>)
}};
}
#[macro_export]
#[doc(hidden)]
macro_rules! py_ternary_reversed_num_func {
($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(
arg1: *mut $crate::ffi::PyObject,
arg2: *mut $crate::ffi::PyObject,
arg3: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
$crate::callback_body!(py, {
// Swap lhs <-> rhs
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(arg2);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
let arg2 = py.from_borrowed_ptr::<$crate::PyAny>(arg3);
$class::$f(&*slf.try_borrow()?, arg1.extract()?, arg2.extract()?).convert(py)
})
}
Some(wrap::<$class>)
}};
}
// NOTE(kngwyu): Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract arg2,
// so we ignore it. It's the same as what CPython does.
#[macro_export]
#[doc(hidden)]
macro_rules! py_dummy_ternary_self_func {
($trait:ident, $class:ident :: $f:ident) => {{
unsafe extern "C" fn wrap<T>(
slf: *mut $crate::ffi::PyObject,
arg1: *mut $crate::ffi::PyObject,
_arg2: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject
where
T: for<'p> $trait<'p>,
{
$crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
call_operator_mut!(py, slf_cell, $f, arg1).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
}
Some(wrap::<$class>)
}};
}
macro_rules! py_func_set { macro_rules! py_func_set {
($trait_name:ident, $generic:ident, $fn_set:ident) => {{ ($trait_name:ident, $generic:ident, $fn_set:ident) => {{
unsafe extern "C" fn wrap<$generic>( unsafe extern "C" fn wrap<$generic>(
@ -370,6 +302,16 @@ macro_rules! py_func_set_del {
} }
macro_rules! extract_or_return_not_implemented { macro_rules! extract_or_return_not_implemented {
($arg: ident) => {
match $arg.extract() {
Ok(value) => value,
Err(_) => {
let res = $crate::ffi::Py_NotImplemented();
ffi::Py_INCREF(res);
return Ok(res);
}
}
};
($py: ident, $arg: ident) => { ($py: ident, $arg: ident) => {
match $py match $py
.from_borrowed_ptr::<$crate::types::PyAny>($arg) .from_borrowed_ptr::<$crate::types::PyAny>($arg)

View file

@ -585,6 +585,16 @@ impl ffi::PyNumberMethods {
nm.nb_bool = Some(nb_bool); nm.nb_bool = Some(nb_bool);
Box::into_raw(Box::new(nm)) Box::into_raw(Box::new(nm))
} }
pub fn set_add_radd<T>(&mut self)
where
T: for<'p> PyNumberAddProtocol<'p> + for<'p> PyNumberRAddProtocol<'p>,
{
self.nb_add = py_binary_fallback_num_func!(
T,
PyNumberAddProtocol::__add__,
PyNumberRAddProtocol::__radd__
);
}
pub fn set_add<T>(&mut self) pub fn set_add<T>(&mut self)
where where
T: for<'p> PyNumberAddProtocol<'p>, T: for<'p> PyNumberAddProtocol<'p>,
@ -597,6 +607,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_add = py_binary_reversed_num_func!(PyNumberRAddProtocol, T::__radd__); self.nb_add = py_binary_reversed_num_func!(PyNumberRAddProtocol, T::__radd__);
} }
pub fn set_sub_rsub<T>(&mut self)
where
T: for<'p> PyNumberSubProtocol<'p> + for<'p> PyNumberRSubProtocol<'p>,
{
self.nb_subtract = py_binary_fallback_num_func!(
T,
PyNumberSubProtocol::__sub__,
PyNumberRSubProtocol::__rsub__
);
}
pub fn set_sub<T>(&mut self) pub fn set_sub<T>(&mut self)
where where
T: for<'p> PyNumberSubProtocol<'p>, T: for<'p> PyNumberSubProtocol<'p>,
@ -609,6 +629,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_subtract = py_binary_reversed_num_func!(PyNumberRSubProtocol, T::__rsub__); self.nb_subtract = py_binary_reversed_num_func!(PyNumberRSubProtocol, T::__rsub__);
} }
pub fn set_mul_rmul<T>(&mut self)
where
T: for<'p> PyNumberMulProtocol<'p> + for<'p> PyNumberRMulProtocol<'p>,
{
self.nb_multiply = py_binary_fallback_num_func!(
T,
PyNumberMulProtocol::__mul__,
PyNumberRMulProtocol::__rmul__
);
}
pub fn set_mul<T>(&mut self) pub fn set_mul<T>(&mut self)
where where
T: for<'p> PyNumberMulProtocol<'p>, T: for<'p> PyNumberMulProtocol<'p>,
@ -627,6 +657,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_remainder = py_binary_num_func!(PyNumberModProtocol, T::__mod__); self.nb_remainder = py_binary_num_func!(PyNumberModProtocol, T::__mod__);
} }
pub fn set_divmod_rdivmod<T>(&mut self)
where
T: for<'p> PyNumberDivmodProtocol<'p> + for<'p> PyNumberRDivmodProtocol<'p>,
{
self.nb_divmod = py_binary_fallback_num_func!(
T,
PyNumberDivmodProtocol::__divmod__,
PyNumberRDivmodProtocol::__rdivmod__
);
}
pub fn set_divmod<T>(&mut self) pub fn set_divmod<T>(&mut self)
where where
T: for<'p> PyNumberDivmodProtocol<'p>, T: for<'p> PyNumberDivmodProtocol<'p>,
@ -639,17 +679,78 @@ impl ffi::PyNumberMethods {
{ {
self.nb_divmod = py_binary_reversed_num_func!(PyNumberRDivmodProtocol, T::__rdivmod__); self.nb_divmod = py_binary_reversed_num_func!(PyNumberRDivmodProtocol, T::__rdivmod__);
} }
pub fn set_pow_rpow<T>(&mut self)
where
T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>,
{
unsafe extern "C" fn wrap_pow_and_rpow<T>(
lhs: *mut crate::ffi::PyObject,
rhs: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberPowProtocol<'p> + for<'p> PyNumberRPowProtocol<'p>,
{
crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<crate::PyAny>(rhs);
let modulo = py.from_borrowed_ptr::<crate::PyAny>(modulo);
// First, try __pow__
match (lhs.extract(), rhs.extract(), modulo.extract()) {
(Ok(l), Ok(r), Ok(m)) => T::__pow__(l, r, m).convert(py),
_ => {
// Then try __rpow__
let slf: &crate::PyCell<T> = extract_or_return_not_implemented!(rhs);
let arg = extract_or_return_not_implemented!(lhs);
let modulo = extract_or_return_not_implemented!(modulo);
slf.try_borrow()?.__rpow__(arg, modulo).convert(py)
}
}
})
}
self.nb_power = Some(wrap_pow_and_rpow::<T>);
}
pub fn set_pow<T>(&mut self) pub fn set_pow<T>(&mut self)
where where
T: for<'p> PyNumberPowProtocol<'p>, T: for<'p> PyNumberPowProtocol<'p>,
{ {
self.nb_power = py_ternary_num_func!(PyNumberPowProtocol, T::__pow__); unsafe extern "C" fn wrap_pow<T>(
lhs: *mut crate::ffi::PyObject,
rhs: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberPowProtocol<'p>,
{
crate::callback_body!(py, {
let lhs = extract_or_return_not_implemented!(py, lhs);
let rhs = extract_or_return_not_implemented!(py, rhs);
let modulo = extract_or_return_not_implemented!(py, modulo);
T::__pow__(lhs, rhs, modulo).convert(py)
})
}
self.nb_power = Some(wrap_pow::<T>);
} }
pub fn set_rpow<T>(&mut self) pub fn set_rpow<T>(&mut self)
where where
T: for<'p> PyNumberRPowProtocol<'p>, T: for<'p> PyNumberRPowProtocol<'p>,
{ {
self.nb_power = py_ternary_reversed_num_func!(PyNumberRPowProtocol, T::__rpow__); unsafe extern "C" fn wrap_rpow<T>(
arg: *mut crate::ffi::PyObject,
slf: *mut crate::ffi::PyObject,
modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberRPowProtocol<'p>,
{
crate::callback_body!(py, {
let slf: &crate::PyCell<T> = extract_or_return_not_implemented!(py, slf);
let arg = extract_or_return_not_implemented!(py, arg);
let modulo = extract_or_return_not_implemented!(py, modulo);
slf.try_borrow()?.__rpow__(arg, modulo).convert(py)
})
}
self.nb_power = Some(wrap_rpow::<T>);
} }
pub fn set_neg<T>(&mut self) pub fn set_neg<T>(&mut self)
where where
@ -675,6 +776,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_invert = py_unary_func!(PyNumberInvertProtocol, T::__invert__); self.nb_invert = py_unary_func!(PyNumberInvertProtocol, T::__invert__);
} }
pub fn set_lshift_rlshift<T>(&mut self)
where
T: for<'p> PyNumberLShiftProtocol<'p> + for<'p> PyNumberRLShiftProtocol<'p>,
{
self.nb_lshift = py_binary_fallback_num_func!(
T,
PyNumberLShiftProtocol::__lshift__,
PyNumberRLShiftProtocol::__rlshift__
);
}
pub fn set_lshift<T>(&mut self) pub fn set_lshift<T>(&mut self)
where where
T: for<'p> PyNumberLShiftProtocol<'p>, T: for<'p> PyNumberLShiftProtocol<'p>,
@ -687,6 +798,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_lshift = py_binary_reversed_num_func!(PyNumberRLShiftProtocol, T::__rlshift__); self.nb_lshift = py_binary_reversed_num_func!(PyNumberRLShiftProtocol, T::__rlshift__);
} }
pub fn set_rshift_rrshift<T>(&mut self)
where
T: for<'p> PyNumberRShiftProtocol<'p> + for<'p> PyNumberRRShiftProtocol<'p>,
{
self.nb_rshift = py_binary_fallback_num_func!(
T,
PyNumberRShiftProtocol::__rshift__,
PyNumberRRShiftProtocol::__rrshift__
);
}
pub fn set_rshift<T>(&mut self) pub fn set_rshift<T>(&mut self)
where where
T: for<'p> PyNumberRShiftProtocol<'p>, T: for<'p> PyNumberRShiftProtocol<'p>,
@ -699,6 +820,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_rshift = py_binary_reversed_num_func!(PyNumberRRShiftProtocol, T::__rrshift__); self.nb_rshift = py_binary_reversed_num_func!(PyNumberRRShiftProtocol, T::__rrshift__);
} }
pub fn set_and_rand<T>(&mut self)
where
T: for<'p> PyNumberAndProtocol<'p> + for<'p> PyNumberRAndProtocol<'p>,
{
self.nb_and = py_binary_fallback_num_func!(
T,
PyNumberAndProtocol::__and__,
PyNumberRAndProtocol::__rand__
);
}
pub fn set_and<T>(&mut self) pub fn set_and<T>(&mut self)
where where
T: for<'p> PyNumberAndProtocol<'p>, T: for<'p> PyNumberAndProtocol<'p>,
@ -711,6 +842,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_and = py_binary_reversed_num_func!(PyNumberRAndProtocol, T::__rand__); self.nb_and = py_binary_reversed_num_func!(PyNumberRAndProtocol, T::__rand__);
} }
pub fn set_xor_rxor<T>(&mut self)
where
T: for<'p> PyNumberXorProtocol<'p> + for<'p> PyNumberRXorProtocol<'p>,
{
self.nb_xor = py_binary_fallback_num_func!(
T,
PyNumberXorProtocol::__xor__,
PyNumberRXorProtocol::__rxor__
);
}
pub fn set_xor<T>(&mut self) pub fn set_xor<T>(&mut self)
where where
T: for<'p> PyNumberXorProtocol<'p>, T: for<'p> PyNumberXorProtocol<'p>,
@ -723,6 +864,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_xor = py_binary_reversed_num_func!(PyNumberRXorProtocol, T::__rxor__); self.nb_xor = py_binary_reversed_num_func!(PyNumberRXorProtocol, T::__rxor__);
} }
pub fn set_or_ror<T>(&mut self)
where
T: for<'p> PyNumberOrProtocol<'p> + for<'p> PyNumberROrProtocol<'p>,
{
self.nb_or = py_binary_fallback_num_func!(
T,
PyNumberOrProtocol::__or__,
PyNumberROrProtocol::__ror__
);
}
pub fn set_or<T>(&mut self) pub fn set_or<T>(&mut self)
where where
T: for<'p> PyNumberOrProtocol<'p>, T: for<'p> PyNumberOrProtocol<'p>,
@ -775,7 +926,25 @@ impl ffi::PyNumberMethods {
where where
T: for<'p> PyNumberIPowProtocol<'p>, T: for<'p> PyNumberIPowProtocol<'p>,
{ {
self.nb_inplace_power = py_dummy_ternary_self_func!(PyNumberIPowProtocol, T::__ipow__) // NOTE: Somehow __ipow__ causes SIGSEGV in Python < 3.8 when we extract,
// so we ignore it. It's the same as what CPython does.
unsafe extern "C" fn wrap_ipow<T>(
slf: *mut crate::ffi::PyObject,
other: *mut crate::ffi::PyObject,
_modulo: *mut crate::ffi::PyObject,
) -> *mut crate::ffi::PyObject
where
T: for<'p> PyNumberIPowProtocol<'p>,
{
crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let other = py.from_borrowed_ptr::<crate::PyAny>(other);
call_operator_mut!(py, slf_cell, __ipow__, other).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
}
self.nb_inplace_power = Some(wrap_ipow::<T>);
} }
pub fn set_ilshift<T>(&mut self) pub fn set_ilshift<T>(&mut self)
where where
@ -807,6 +976,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_inplace_or = py_binary_self_func!(PyNumberIOrProtocol, T::__ior__); self.nb_inplace_or = py_binary_self_func!(PyNumberIOrProtocol, T::__ior__);
} }
pub fn set_floordiv_rfloordiv<T>(&mut self)
where
T: for<'p> PyNumberFloordivProtocol<'p> + for<'p> PyNumberRFloordivProtocol<'p>,
{
self.nb_floor_divide = py_binary_fallback_num_func!(
T,
PyNumberFloordivProtocol::__floordiv__,
PyNumberRFloordivProtocol::__rfloordiv__
);
}
pub fn set_floordiv<T>(&mut self) pub fn set_floordiv<T>(&mut self)
where where
T: for<'p> PyNumberFloordivProtocol<'p>, T: for<'p> PyNumberFloordivProtocol<'p>,
@ -820,6 +999,16 @@ impl ffi::PyNumberMethods {
self.nb_floor_divide = self.nb_floor_divide =
py_binary_reversed_num_func!(PyNumberRFloordivProtocol, T::__rfloordiv__); py_binary_reversed_num_func!(PyNumberRFloordivProtocol, T::__rfloordiv__);
} }
pub fn set_truediv_rtruediv<T>(&mut self)
where
T: for<'p> PyNumberTruedivProtocol<'p> + for<'p> PyNumberRTruedivProtocol<'p>,
{
self.nb_true_divide = py_binary_fallback_num_func!(
T,
PyNumberTruedivProtocol::__truediv__,
PyNumberRTruedivProtocol::__rtruediv__
);
}
pub fn set_truediv<T>(&mut self) pub fn set_truediv<T>(&mut self)
where where
T: for<'p> PyNumberTruedivProtocol<'p>, T: for<'p> PyNumberTruedivProtocol<'p>,
@ -853,6 +1042,16 @@ impl ffi::PyNumberMethods {
{ {
self.nb_index = py_unary_func!(PyNumberIndexProtocol, T::__index__); self.nb_index = py_unary_func!(PyNumberIndexProtocol, T::__index__);
} }
pub fn set_matmul_rmatmul<T>(&mut self)
where
T: for<'p> PyNumberMatmulProtocol<'p> + for<'p> PyNumberRMatmulProtocol<'p>,
{
self.nb_matrix_multiply = py_binary_fallback_num_func!(
T,
PyNumberMatmulProtocol::__matmul__,
PyNumberRMatmulProtocol::__rmatmul__
);
}
pub fn set_matmul<T>(&mut self) pub fn set_matmul<T>(&mut self)
where where
T: for<'p> PyNumberMatmulProtocol<'p>, T: for<'p> PyNumberMatmulProtocol<'p>,

View file

@ -280,10 +280,56 @@ fn rhs_arithmetic() {
} }
#[pyclass] #[pyclass]
struct LhsAndRhsArithmetic {} struct LhsAndRhs {}
impl std::fmt::Debug for LhsAndRhs {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "LR")
}
}
#[pyproto] #[pyproto]
impl PyNumberProtocol for LhsAndRhsArithmetic { impl PyNumberProtocol for LhsAndRhs {
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs)
}
fn __sub__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}
fn __mul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}
fn __lshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} << {:?}", lhs, rhs)
}
fn __rshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} >> {:?}", lhs, rhs)
}
fn __and__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} & {:?}", lhs, rhs)
}
fn __xor__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} ^ {:?}", lhs, rhs)
}
fn __or__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} | {:?}", lhs, rhs)
}
fn __pow__(lhs: PyRef<Self>, rhs: &PyAny, _mod: Option<usize>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
}
fn __matmul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} @ {:?}", lhs, rhs)
}
fn __radd__(&self, other: &PyAny) -> String { fn __radd__(&self, other: &PyAny) -> String {
format!("{:?} + RA", other) format!("{:?} + RA", other)
} }
@ -292,44 +338,74 @@ impl PyNumberProtocol for LhsAndRhsArithmetic {
format!("{:?} - RA", other) format!("{:?} - RA", other)
} }
fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
fn __rlshift__(&self, other: &PyAny) -> String {
format!("{:?} << RA", other)
}
fn __rrshift__(&self, other: &PyAny) -> String {
format!("{:?} >> RA", other)
}
fn __rand__(&self, other: &PyAny) -> String {
format!("{:?} & RA", other)
}
fn __rxor__(&self, other: &PyAny) -> String {
format!("{:?} ^ RA", other)
}
fn __ror__(&self, other: &PyAny) -> String {
format!("{:?} | RA", other)
}
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
format!("{:?} ** RA", other) format!("{:?} ** RA", other)
} }
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { fn __rmatmul__(&self, other: &PyAny) -> String {
format!("{:?} + {:?}", lhs, rhs) format!("{:?} @ RA", other)
}
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} - {:?}", lhs, rhs)
}
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> String {
format!("{:?} ** {:?}", lhs, rhs)
} }
} }
#[pyproto] #[pyproto]
impl PyObjectProtocol for LhsAndRhsArithmetic { impl PyObjectProtocol for LhsAndRhs {
fn __repr__(&self) -> &'static str { fn __repr__(&self) -> &'static str {
"BA" "BA"
} }
} }
#[test] #[test]
fn lhs_override_rhs() { fn lhs_fellback_to_rhs() {
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap(); let c = PyCell::new(py, LhsAndRhs {}).unwrap();
// Not overrided // If the light hand value is `LhsAndRhs`, LHS is used.
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'"); py_run!(py, c, "assert c + 1 == 'LR + 1'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); py_run!(py, c, "assert c - 1 == 'LR - 1'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'"); py_run!(py, c, "assert c * 1 == 'LR * 1'");
// Overrided py_run!(py, c, "assert c << 1 == 'LR << 1'");
py_run!(py, c, "assert 1 + c == '1 + BA'"); py_run!(py, c, "assert c >> 1 == 'LR >> 1'");
py_run!(py, c, "assert 1 - c == '1 - BA'"); py_run!(py, c, "assert c & 1 == 'LR & 1'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'"); py_run!(py, c, "assert c ^ 1 == 'LR ^ 1'");
py_run!(py, c, "assert c | 1 == 'LR | 1'");
py_run!(py, c, "assert c ** 1 == 'LR ** 1'");
py_run!(py, c, "assert c @ 1 == 'LR @ 1'");
// Fellback to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert 1 << c == '1 << RA'");
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
py_run!(py, c, "assert 1 & c == '1 & RA'");
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
py_run!(py, c, "assert 1 | c == '1 | RA'");
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
py_run!(py, c, "assert 1 @ c == '1 @ RA'");
} }
#[pyclass] #[pyclass]