pymethods: finish support for number protocol

This commit is contained in:
David Hewitt 2021-09-18 12:59:25 +01:00
parent c2d78ca76e
commit a551b005b4
6 changed files with 104 additions and 22 deletions

View file

@ -865,6 +865,7 @@ impl pyo3::class::impl_::PyClassImpl for MyClass {
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
visitor(collector.methods_protocol_slots());
}
fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> {

View file

@ -176,6 +176,7 @@ fn impl_protos(
try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot);
try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot);
try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot);
try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot);
quote! {
impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty>

View file

@ -439,6 +439,13 @@ const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc");
const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc");
const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int);
const __TRUEDIV__: SlotDef = SlotDef::new("Py_nb_true_divide", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.extract_error_mode(ExtractErrorMode::NotImplemented);
const __FLOORDIV__: SlotDef = SlotDef::new("Py_nb_floor_divide", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.extract_error_mode(ExtractErrorMode::NotImplemented);
const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.extract_error_mode(ExtractErrorMode::NotImplemented)
@ -516,6 +523,8 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
"__int__" => Some(&__INT__),
"__float__" => Some(&__FLOAT__),
"__bool__" => Some(&__BOOL__),
"__truediv__" => Some(&__TRUEDIV__),
"__floordiv__" => Some(&__FLOORDIV__),
"__iadd__" => Some(&__IADD__),
"__isub__" => Some(&__ISUB__),
"__imul__" => Some(&__IMUL__),
@ -898,6 +907,19 @@ binary_num_slot_fragment_def!(__RXOR__, "__rxor__");
binary_num_slot_fragment_def!(__OR__, "__or__");
binary_num_slot_fragment_def!(__ROR__, "__ror__");
const __POW__: SlotFragmentDef = SlotFragmentDef::new(
"__pow__",
&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented],
)
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __RPOW__: SlotFragmentDef = SlotFragmentDef::new(
"__rpow__",
&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented],
)
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
match method_name {
"__setattr__" => Some(&__SETATTR__),
@ -928,6 +950,8 @@ fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
"__rxor__" => Some(&__RXOR__),
"__or__" => Some(&__OR__),
"__ror__" => Some(&__ROR__),
"__pow__" => Some(&__POW__),
"__rpow__" => Some(&__RPOW__),
_ => None,
}
}

View file

@ -249,7 +249,7 @@ macro_rules! define_pyclass_binary_operator_slot {
slot_fragment_trait! {
$lhs_trait,
/// # Safety: _slf and _attr must be valid non-null Python objects
/// # Safety: _slf and _other must be valid non-null Python objects
#[inline]
unsafe fn $lhs(
self,
@ -265,7 +265,7 @@ macro_rules! define_pyclass_binary_operator_slot {
slot_fragment_trait! {
$rhs_trait,
/// # Safety: _slf and _attr must be valid non-null Python objects
/// # Safety: _slf and _other must be valid non-null Python objects
#[inline]
unsafe fn $rhs(
self,
@ -417,6 +417,68 @@ define_pyclass_binary_operator_slot! {
binaryfunc,
}
slot_fragment_trait! {
PyClass__pow__SlotFragment,
/// # Safety: _slf and _other must be valid non-null Python objects
#[inline]
unsafe fn __pow__(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
_mod: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
ffi::Py_INCREF(ffi::Py_NotImplemented());
Ok(ffi::Py_NotImplemented())
}
}
slot_fragment_trait! {
PyClass__rpow__SlotFragment,
/// # Safety: _slf and _other must be valid non-null Python objects
#[inline]
unsafe fn __rpow__(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
_mod: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
ffi::Py_INCREF(ffi::Py_NotImplemented());
Ok(ffi::Py_NotImplemented())
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! generate_pyclass_pow_slot {
($cls:ty) => {{
unsafe extern "C" fn __wrap(
_slf: *mut $crate::ffi::PyObject,
_other: *mut $crate::ffi::PyObject,
_mod: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject {
$crate::callback::handle_panic(|py| {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<$cls>::new();
let lhs_result = collector.__pow__(py, _slf, _other, _mod)?;
if lhs_result == $crate::ffi::Py_NotImplemented() {
$crate::ffi::Py_DECREF(lhs_result);
collector.__rpow__(py, _other, _slf, _mod)
} else {
::std::result::Result::Ok(lhs_result)
}
})
}
$crate::ffi::PyType_Slot {
slot: $crate::ffi::Py_nb_power,
pfunc: __wrap as $crate::ffi::ternaryfunc as _,
}
}};
}
pub trait PyClassAllocImpl<T> {
fn alloc_impl(self) -> Option<ffi::allocfunc>;
}

View file

@ -177,24 +177,27 @@ fn binary_arithmetic() {
py_run!(py, c, "assert c + c == 'BA + BA'");
py_run!(py, c, "assert c.__add__(c) == 'BA + BA'");
py_run!(py, c, "assert c + 1 == 'BA + 1'");
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert c - 1 == 'BA - 1'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert c * 1 == 'BA * 1'");
py_run!(py, c, "assert 1 * c == '1 * BA'");
py_run!(py, c, "assert c << 1 == 'BA << 1'");
py_run!(py, c, "assert 1 << c == '1 << BA'");
py_run!(py, c, "assert c >> 1 == 'BA >> 1'");
py_run!(py, c, "assert 1 >> c == '1 >> BA'");
py_run!(py, c, "assert c & 1 == 'BA & 1'");
py_run!(py, c, "assert 1 & c == '1 & BA'");
py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'");
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert c | 1 == 'BA | 1'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'");
py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'");
// Class with __add__ only should not allow the reverse op;
// this is consistent with Python classes.
py_expect_exception!(py, c, "1 + c", PyTypeError);
py_expect_exception!(py, c, "1 - c", PyTypeError);
py_expect_exception!(py, c, "1 * c", PyTypeError);
py_expect_exception!(py, c, "1 << c", PyTypeError);
py_expect_exception!(py, c, "1 >> c", PyTypeError);
py_expect_exception!(py, c, "1 & c", PyTypeError);
py_expect_exception!(py, c, "1 ^ c", PyTypeError);
py_expect_exception!(py, c, "1 | c", PyTypeError);
py_expect_exception!(py, c, "1 ** c", PyTypeError);
py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");
}
@ -629,15 +632,13 @@ mod return_not_implemented {
}
#[test]
#[ignore]
fn reverse_arith() {
_test_binary_dunder("radd");
_test_binary_dunder("rsub");
_test_binary_dunder("rmul");
_test_binary_dunder("rmatmul");
_test_binary_dunder("rtruediv");
_test_binary_dunder("rfloordiv");
_test_binary_dunder("rmod");
_test_binary_dunder("rdivmod");
_test_binary_dunder("rpow");
}

View file

@ -544,10 +544,3 @@ assert c.counter.count == 3
.map_err(|e| e.print(py))
.unwrap();
}
// TODO: test __delete__
// TODO: test __anext__, __aiter__
// TODO: test __index__, __int__, __float__, __invert__
// TODO: __floordiv__, __truediv__
// TODO: __pow__, __rpow__
// TODO: better argument casting errors