diff --git a/guide/src/class.md b/guide/src/class.md index 25eef825..8b1a4241 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -865,6 +865,7 @@ impl pyo3::class::impl_::PyClassImpl for MyClass { visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + visitor(collector.methods_protocol_slots()); } fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 4fc1c7da..18bdcfff 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -176,6 +176,7 @@ fn impl_protos( try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot); try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot); try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot); + try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot); quote! { impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index e0698262..f0fced41 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -439,6 +439,13 @@ const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc"); const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc"); const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); +const __TRUEDIV__: SlotDef = SlotDef::new("Py_nb_true_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); +const __FLOORDIV__: SlotDef = SlotDef::new("Py_nb_floor_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); + const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) .extract_error_mode(ExtractErrorMode::NotImplemented) @@ -516,6 +523,8 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { "__int__" => Some(&__INT__), "__float__" => Some(&__FLOAT__), "__bool__" => Some(&__BOOL__), + "__truediv__" => Some(&__TRUEDIV__), + "__floordiv__" => Some(&__FLOORDIV__), "__iadd__" => Some(&__IADD__), "__isub__" => Some(&__ISUB__), "__imul__" => Some(&__IMUL__), @@ -898,6 +907,19 @@ binary_num_slot_fragment_def!(__RXOR__, "__rxor__"); binary_num_slot_fragment_def!(__OR__, "__or__"); binary_num_slot_fragment_def!(__ROR__, "__ror__"); +const __POW__: SlotFragmentDef = SlotFragmentDef::new( + "__pow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); +const __RPOW__: SlotFragmentDef = SlotFragmentDef::new( + "__rpow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); + fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { match method_name { "__setattr__" => Some(&__SETATTR__), @@ -928,6 +950,8 @@ fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { "__rxor__" => Some(&__RXOR__), "__or__" => Some(&__OR__), "__ror__" => Some(&__ROR__), + "__pow__" => Some(&__POW__), + "__rpow__" => Some(&__RPOW__), _ => None, } } diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 5a914868..dc38c038 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -249,7 +249,7 @@ macro_rules! define_pyclass_binary_operator_slot { slot_fragment_trait! { $lhs_trait, - /// # Safety: _slf and _attr must be valid non-null Python objects + /// # Safety: _slf and _other must be valid non-null Python objects #[inline] unsafe fn $lhs( self, @@ -265,7 +265,7 @@ macro_rules! define_pyclass_binary_operator_slot { slot_fragment_trait! { $rhs_trait, - /// # Safety: _slf and _attr must be valid non-null Python objects + /// # Safety: _slf and _other must be valid non-null Python objects #[inline] unsafe fn $rhs( self, @@ -417,6 +417,68 @@ define_pyclass_binary_operator_slot! { binaryfunc, } +slot_fragment_trait! { + PyClass__pow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __pow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +slot_fragment_trait! { + PyClass__rpow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __rpow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_pow_slot { + ($cls:ty) => {{ + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + _mod: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.__pow__(py, _slf, _other, _mod)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.__rpow__(py, _other, _slf, _mod) + } else { + ::std::result::Result::Ok(lhs_result) + } + }) + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_nb_power, + pfunc: __wrap as $crate::ffi::ternaryfunc as _, + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 711cbfb9..c045c26b 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -177,24 +177,27 @@ fn binary_arithmetic() { py_run!(py, c, "assert c + c == 'BA + BA'"); py_run!(py, c, "assert c.__add__(c) == 'BA + BA'"); py_run!(py, c, "assert c + 1 == 'BA + 1'"); - py_run!(py, c, "assert 1 + c == '1 + BA'"); py_run!(py, c, "assert c - 1 == 'BA - 1'"); - py_run!(py, c, "assert 1 - c == '1 - BA'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); - py_run!(py, c, "assert 1 * c == '1 * BA'"); - py_run!(py, c, "assert c << 1 == 'BA << 1'"); - py_run!(py, c, "assert 1 << c == '1 << BA'"); py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); - py_run!(py, c, "assert 1 >> c == '1 >> BA'"); py_run!(py, c, "assert c & 1 == 'BA & 1'"); - py_run!(py, c, "assert 1 & c == '1 & BA'"); py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'"); - py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); py_run!(py, c, "assert c | 1 == 'BA | 1'"); - py_run!(py, c, "assert 1 | c == '1 | BA'"); py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'"); - py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'"); + + // Class with __add__ only should not allow the reverse op; + // this is consistent with Python classes. + + py_expect_exception!(py, c, "1 + c", PyTypeError); + py_expect_exception!(py, c, "1 - c", PyTypeError); + py_expect_exception!(py, c, "1 * c", PyTypeError); + py_expect_exception!(py, c, "1 << c", PyTypeError); + py_expect_exception!(py, c, "1 >> c", PyTypeError); + py_expect_exception!(py, c, "1 & c", PyTypeError); + py_expect_exception!(py, c, "1 ^ c", PyTypeError); + py_expect_exception!(py, c, "1 | c", PyTypeError); + py_expect_exception!(py, c, "1 ** c", PyTypeError); py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); } @@ -629,15 +632,13 @@ mod return_not_implemented { } #[test] - #[ignore] fn reverse_arith() { _test_binary_dunder("radd"); _test_binary_dunder("rsub"); _test_binary_dunder("rmul"); _test_binary_dunder("rmatmul"); - _test_binary_dunder("rtruediv"); - _test_binary_dunder("rfloordiv"); _test_binary_dunder("rmod"); + _test_binary_dunder("rdivmod"); _test_binary_dunder("rpow"); } diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 8a12d2df..6d070b47 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -544,10 +544,3 @@ assert c.counter.count == 3 .map_err(|e| e.print(py)) .unwrap(); } - -// TODO: test __delete__ -// TODO: test __anext__, __aiter__ -// TODO: test __index__, __int__, __float__, __invert__ -// TODO: __floordiv__, __truediv__ -// TODO: __pow__, __rpow__ -// TODO: better argument casting errors