From 8408328cb3905a132bddbd26c3ea7c69d8af8bd9 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 9 Sep 2021 09:11:41 +0100 Subject: [PATCH 01/12] pymethods: add support for protocol methods --- pyo3-macros-backend/src/pyclass.rs | 9 + pyo3-macros-backend/src/pyimpl.rs | 43 +- pyo3-macros-backend/src/pymethod.rs | 366 ++++++++++++- src/class/basic.rs | 27 +- src/class/impl_.rs | 183 ++++++- tests/test_proto_methods.rs | 765 ++++++++++++++++++++++++++++ 6 files changed, 1376 insertions(+), 17 deletions(-) create mode 100644 tests/test_proto_methods.rs diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 034d1784..cf8b43ba 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -591,6 +591,15 @@ fn impl_class( visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + visitor(collector.methods_protocol_slots()); + let mut generated_slots = Vec::new(); + if let ::std::option::Option::Some(setattr) = ::pyo3::generate_pyclass_setattr_slot!(#cls) { + generated_slots.push(setattr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setdescr_slot!(#cls) { + generated_slots.push(setdescr); + } + visitor(&generated_slots); } fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 61d20267..5d2dd881 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -37,8 +37,8 @@ pub fn impl_methods( impls: &mut Vec, methods_type: PyClassMethodsType, ) -> syn::Result { - let mut new_impls = Vec::new(); - let mut call_impls = Vec::new(); + let mut trait_impls = Vec::new(); + let mut proto_impls = Vec::new(); let mut methods = Vec::new(); for iimpl in impls.iter_mut() { match iimpl { @@ -49,13 +49,13 @@ pub fn impl_methods( let attrs = get_cfg_attributes(&meth.attrs); methods.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::New(token_stream) => { + GeneratedPyMethod::TraitImpl(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); - new_impls.push(quote!(#(#attrs)* #token_stream)); + trait_impls.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::Call(token_stream) => { + GeneratedPyMethod::Proto(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); - call_impls.push(quote!(#(#attrs)* #token_stream)); + proto_impls.push(quote!(#(#attrs)* #token_stream)) } } } @@ -80,10 +80,23 @@ pub fn impl_methods( PyClassMethodsType::Inventory => submit_methods_inventory(ty, methods), }; - Ok(quote! { - #(#new_impls)* + let protos_registration = match methods_type { + PyClassMethodsType::Specialization => Some(impl_protos(ty, proto_impls)), + PyClassMethodsType::Inventory => { + if proto_impls.is_empty() { + None + } else { + panic!( + "cannot implement protos in #[pymethods] using `multiple-pymethods` feature" + ); + } + } + }; - #(#call_impls)* + Ok(quote! { + #(#trait_impls)* + + #protos_registration #methods_registration }) @@ -122,6 +135,18 @@ fn impl_py_methods(ty: &syn::Type, methods: Vec) -> TokenStream { } } +fn impl_protos(ty: &syn::Type, proto_impls: Vec) -> TokenStream { + quote! { + impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> + for ::pyo3::class::impl_::PyClassImplCollector<#ty> + { + fn methods_protocol_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { + &[#(#proto_impls),*] + } + } + } +} + fn submit_methods_inventory(ty: &syn::Type, methods: Vec) -> TokenStream { if methods.is_empty() { return TokenStream::default(); diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 93f771ab..28de756a 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use crate::attributes::NameAttribute; -use crate::utils::{ensure_not_async_fn, PythonDoc}; +use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc}; use crate::{deprecations::Deprecations, utils}; use crate::{ method::{FnArg, FnSpec, FnType, SelfType}, @@ -11,12 +11,13 @@ use crate::{ }; use proc_macro2::{Span, TokenStream}; use quote::quote; +use syn::Ident; use syn::{ext::IdentExt, spanned::Spanned, Result}; pub enum GeneratedPyMethod { Method(TokenStream), - New(TokenStream), - Call(TokenStream), + Proto(TokenStream), + TraitImpl(TokenStream), } pub fn gen_py_method( @@ -30,6 +31,14 @@ pub fn gen_py_method( ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; + if let Some(proto) = pyproto(cls, &spec) { + return Ok(GeneratedPyMethod::Proto(proto)); + } + + if let Some(proto) = pyproto_fragment(cls, &spec)? { + return Ok(GeneratedPyMethod::TraitImpl(proto)); + } + Ok(match &spec.tp { // ordinary functions (with some specialties) FnType::Fn(_) => GeneratedPyMethod::Method(impl_py_method_def(cls, &spec, None)?), @@ -44,8 +53,8 @@ pub fn gen_py_method( Some(quote!(::pyo3::ffi::METH_STATIC)), )?), // special prototypes - FnType::FnNew => GeneratedPyMethod::New(impl_py_method_def_new(cls, &spec)?), - FnType::FnCall(_) => GeneratedPyMethod::Call(impl_py_method_def_call(cls, &spec)?), + FnType::FnNew => GeneratedPyMethod::TraitImpl(impl_py_method_def_new(cls, &spec)?), + FnType::FnCall(_) => GeneratedPyMethod::TraitImpl(impl_py_method_def_call(cls, &spec)?), FnType::ClassAttribute => GeneratedPyMethod::Method(impl_py_class_attribute(cls, &spec)), FnType::Getter(self_type) => GeneratedPyMethod::Method(impl_py_getter_def( cls, @@ -364,3 +373,350 @@ impl PropertyType<'_> { } } } + +fn pyproto(cls: &syn::Type, spec: &FnSpec) -> Option { + match spec.python_name.to_string().as_str() { + "__getattr__" => Some( + SlotDef::new("Py_tp_getattro", "getattrofunc") + .arguments(&[Ty::Object]) + .before_call_method(quote! { + // Behave like python's __getattr__ (as opposed to __getattribute__) and check + // for existing fields and methods first + let existing = ::pyo3::ffi::PyObject_GenericGetAttr(_slf, arg0); + if existing.is_null() { + // PyObject_HasAttr also tries to get an object and clears the error if it fails + ::pyo3::ffi::PyErr_Clear(); + } else { + return existing; + } + }) + .generate_type_slot(cls, spec), + ), + "__str__" => Some(SlotDef::new("Py_tp_str", "reprfunc").generate_type_slot(cls, spec)), + "__repr__" => Some(SlotDef::new("Py_tp_repr", "reprfunc").generate_type_slot(cls, spec)), + "__hash__" => Some( + SlotDef::new("Py_tp_hash", "hashfunc") + .ret_ty(Ty::PyHashT) + .return_conversion(quote! { ::pyo3::callback::HashCallbackOutput }) + .generate_type_slot(cls, spec), + ), + "__richcmp__" => Some( + SlotDef::new("Py_tp_richcompare", "richcmpfunc") + .arguments(&[Ty::Object, Ty::CompareOp]) + .generate_type_slot(cls, spec), + ), + "__bool__" => Some( + SlotDef::new("Py_nb_bool", "inquiry") + .ret_ty(Ty::Int) + .generate_type_slot(cls, spec), + ), + "__get__" => Some( + SlotDef::new("Py_tp_descr_get", "descrgetfunc") + .arguments(&[Ty::Object, Ty::Object]) + .generate_type_slot(cls, spec), + ), + _ => None, + } +} + +#[derive(Clone, Copy)] +enum Ty { + Object, + NonNullObject, + CompareOp, + Int, + PyHashT, +} + +impl Ty { + fn ffi_type(self) -> TokenStream { + match self { + Ty::Object => quote! { *mut ::pyo3::ffi::PyObject }, + Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, + Ty::Int => quote! { ::std::os::raw::c_int }, + Ty::CompareOp => quote! { ::std::os::raw::c_int }, + Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, + } + } + + fn extract( + self, + cls: &syn::Type, + py: &syn::Ident, + ident: &syn::Ident, + target: &syn::Type, + ) -> TokenStream { + match self { + Ty::Object => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); + #extract + } + } + Ty::NonNullObject => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident.as_ptr()); + #extract + } + } + Ty::Int => todo!(), + Ty::PyHashT => todo!(), + Ty::CompareOp => quote! { + let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) + .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; + }, + } + } +} + +fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -> TokenStream { + return if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let (tref, mut_) = preprocess_tref(tref, self_); + quote! { + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #ident.extract()?; + let #ident = &#mut_ *#ident; + } + } else { + quote! { + let #ident = #ident.extract()?; + } + }; + + /// Replace `Self`, remove lifetime and get mutability from the type + fn preprocess_tref( + tref: &syn::TypeReference, + self_: &syn::Type, + ) -> (syn::TypeReference, Option) { + let mut tref = tref.to_owned(); + if let syn::Type::Path(tpath) = self_ { + replace_self(&mut tref, &tpath.path); + } + tref.lifetime = None; + let mut_ = tref.mutability; + (tref, mut_) + } + + /// Replace `Self` with the exact type name since it is used out of the impl block + fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { + match &mut *tref.elem { + syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), + syn::Type::Path(tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + tpath.path = self_path.to_owned(); + } + } + } + _ => {} + } + } +} + +struct SlotDef { + slot: syn::Ident, + func_ty: syn::Ident, + arguments: &'static [Ty], + ret_ty: Ty, + before_call_method: Option, + return_conversion: Option, +} + +impl SlotDef { + fn new(slot: &str, func_ty: &str) -> Self { + SlotDef { + slot: syn::Ident::new(slot, Span::call_site()), + func_ty: syn::Ident::new(func_ty, Span::call_site()), + arguments: &[], + ret_ty: Ty::Object, + before_call_method: None, + return_conversion: None, + } + } + + fn arguments(mut self, arguments: &'static [Ty]) -> Self { + self.arguments = arguments; + self + } + + fn ret_ty(mut self, ret_ty: Ty) -> Self { + self.ret_ty = ret_ty; + self + } + + fn before_call_method(mut self, before_call_method: TokenStream) -> Self { + self.before_call_method = Some(before_call_method); + self + } + + fn return_conversion(mut self, return_conversion: TokenStream) -> Self { + self.return_conversion = Some(return_conversion); + self + } + + fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> TokenStream { + let SlotDef { + slot, + func_ty, + before_call_method, + arguments, + ret_ty, + return_conversion, + } = self; + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let arguments = arguments.into_iter().enumerate().map(|(i, arg)| { + let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); + let ffi_type = arg.ffi_type(); + quote! { + #ident: #ffi_type + } + }); + let ret_ty = ret_ty.ffi_type(); + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &self.arguments); + let call = + quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; + let body = if let Some(return_conversion) = return_conversion { + quote! { + let _result: PyResult<#return_conversion> = #call; + ::pyo3::callback::convert(#py, _result) + } + } else { + call + }; + quote!({ + unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#arguments),*) -> #ret_ty { + #before_call_method + ::pyo3::callback::handle_panic(|#py| { + #self_conversion + #conversions + #body + }) + } + ::pyo3::ffi::PyType_Slot { + slot: ::pyo3::ffi::#slot, + pfunc: __wrap as ::pyo3::ffi::#func_ty as _ + } + }) + } +} + +fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result> { + Ok(match spec.python_name.to_string().as_str() { + "__setattr__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassSetattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn setattr_implemented(self) -> bool { true } + + #[inline] + unsafe fn setattr( + self, + _slf: *mut ::pyo3::ffi::PyObject, + arg0: *mut ::pyo3::ffi::PyObject, + arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject> + ) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + } + }) + } + "__delattr__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassDelattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn delattr_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + "__set__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassSetSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn set_impl(self) -> ::std::option::Option) -> ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject, arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject>) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + "__delete__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassDeleteSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn delete_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + _ => None, + }) +} + +fn extract_proto_arguments( + cls: &syn::Type, + py: &syn::Ident, + method_args: &[FnArg], + proto_args: &[Ty], +) -> (Vec, TokenStream) { + let mut arg_idents = Vec::with_capacity(method_args.len()); + let mut non_python_args = 0; + + let args_conversion = method_args.into_iter().filter_map(|arg| { + if arg.py { + arg_idents.push(py.clone()); + None + } else { + let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); + let conversions = proto_args[non_python_args].extract(cls, py, &ident, arg.ty); + non_python_args += 1; + arg_idents.push(ident); + Some(conversions) + } + }); + let conversions = quote!(#(#args_conversion)*); + (arg_idents, conversions) +} diff --git a/src/class/basic.rs b/src/class/basic.rs index affa7ee6..602096c5 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -13,7 +13,7 @@ use crate::{exceptions, ffi, FromPyObject, PyAny, PyCell, PyClass, PyObject}; use std::os::raw::c_int; /// Operators for the `__richcmp__` method -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum CompareOp { /// The *less than* operator. Lt = ffi::Py_LT as isize, @@ -29,6 +29,31 @@ pub enum CompareOp { Ge = ffi::Py_GE as isize, } +impl CompareOp { + pub fn from_raw(op: c_int) -> Option { + match op { + ffi::Py_LT => Some(CompareOp::Lt), + ffi::Py_LE => Some(CompareOp::Le), + ffi::Py_EQ => Some(CompareOp::Eq), + ffi::Py_NE => Some(CompareOp::Ne), + ffi::Py_GT => Some(CompareOp::Gt), + ffi::Py_GE => Some(CompareOp::Ge), + _ => None, + } + } + + pub fn matches_ordering(self, ordering: std::cmp::Ordering) -> bool { + match self { + CompareOp::Lt => ordering.is_lt(), + CompareOp::Le => ordering.is_le(), + CompareOp::Eq => ordering.is_eq(), + CompareOp::Ne => ordering.is_ne(), + CompareOp::Gt => ordering.is_gt(), + CompareOp::Ge => ordering.is_ge(), + } + } +} + /// Basic Python class customization #[allow(unused_variables)] pub trait PyObjectProtocol<'p>: PyClass { diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 86bf8496..9015ad18 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -3,12 +3,13 @@ use crate::{ ffi, impl_::freelist::FreeList, + exceptions::PyAttributeError, pycell::PyCellLayout, pyclass_init::PyObjectInit, type_object::{PyLayout, PyTypeObject}, - PyClass, PyMethodDefType, PyNativeType, PyTypeInfo, Python, + PyClass, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python, }; -use std::{marker::PhantomData, os::raw::c_void, thread}; +use std::{marker::PhantomData, os::raw::c_void, ptr::NonNull, thread}; /// This type is used as a "dummy" type on which dtolnay specializations are /// applied to apply implementations from `#[pymethods]` & `#[pyproto]` @@ -107,6 +108,181 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } } +pub trait PyClassSetattrSlotFragment: Sized { + #[inline] + fn setattr_implemented(self) -> bool { + false + } + + unsafe fn setattr( + self, + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>; +} + +impl PyClassSetattrSlotFragment for &'_ PyClassImplCollector { + #[inline] + unsafe fn setattr( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + Err(PyAttributeError::new_err("can't set attribute")) + } +} + +pub trait PyClassDelattrSlotFragment { + fn delattr_impl( + self, + ) -> Option PyResult<()>>; +} + +impl PyClassDelattrSlotFragment for &'_ PyClassImplCollector { + fn delattr_impl( + self, + ) -> Option PyResult<()>> { + None + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_setattr_slot { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let delattr = collector.delattr_impl(); + if collector.setattr_implemented() || delattr.is_some() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.setattr(_slf, attr, value) + } else { + if let Some(del) = collector.delattr_impl() { + del(_slf, attr) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyAttributeError::new_err( + "can't delete attribute", + ), + ) + } + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_tp_setattro, + pfunc: __wrap as $crate::ffi::setattrofunc as _, + }) + } else { + None + } + }}; +} + +pub trait PyClassSetSlotFragment { + fn set_impl( + self, + ) -> Option< + unsafe fn( + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>, + >; +} + +impl PyClassSetSlotFragment for &'_ PyClassImplCollector { + fn set_impl( + self, + ) -> Option< + unsafe fn( + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>, + > { + None + } +} + +pub trait PyClassDeleteSlotFragment { + fn delete_impl( + self, + ) -> Option PyResult<()>>; +} + +impl PyClassDeleteSlotFragment for &'_ PyClassImplCollector { + fn delete_impl( + self, + ) -> Option PyResult<()>> { + None + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_setdescr_slot { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let set = collector.set_impl(); + let delete = collector.delete_impl(); + if set.is_some() || delete.is_some() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + if let Some(set) = collector.set_impl() { + set(_slf, attr, value) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyTypeError::new_err( + "can't set descriptor", + ), + ) + } + } else { + if let Some(del) = collector.delete_impl() { + del(_slf, attr) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyTypeError::new_err( + "can't delete descriptor", + ), + ) + } + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_tp_descr_set, + pfunc: __wrap as $crate::ffi::descrsetfunc as _, + }) + } else { + None + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } @@ -288,6 +464,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots); slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots); slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots); +#[cfg(not(feature = "multiple-pymethods"))] +slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots); + methods_trait!(PyObjectProtocolMethods, object_protocol_methods); methods_trait!(PyAsyncProtocolMethods, async_protocol_methods); methods_trait!(PyContextProtocolMethods, context_protocol_methods); diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs new file mode 100644 index 00000000..5de12f0e --- /dev/null +++ b/tests/test_proto_methods.rs @@ -0,0 +1,765 @@ +use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; +use pyo3::exceptions::{PyIndexError, PyValueError}; +use pyo3::types::{PySlice, PyType}; +use pyo3::{ffi, py_run, AsPyPointer, PyCell}; +use std::convert::TryFrom; +use std::{isize, iter}; + +mod common; + +#[pyclass] +struct ExampleClass { + #[pyo3(get, set)] + value: i32, + _custom_attr: Option, +} + +#[pymethods] +impl ExampleClass { + fn __getattr__(&self, py: Python, attr: &str) -> PyResult { + if attr == "special_custom_attr" { + Ok(self._custom_attr.into_py(py)) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __setattr__(&mut self, attr: &str, value: &PyAny) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = Some(value.extract()?); + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __delattr__(&mut self, attr: &str) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = None; + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __str__(&self) -> String { + self.value.to_string() + } + + fn __repr__(&self) -> String { + format!("ExampleClass(value={})", self.value) + } + + fn __hash__(&self) -> u64 { + let i64_value: i64 = self.value.into(); + i64_value as u64 + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches_ordering(self.value.cmp(&other.value)) + } + + fn __bool__(&self) -> bool { + self.value != 0 + } +} + +fn make_example(py: Python) -> &PyCell { + Py::new( + py, + ExampleClass { + value: 5, + _custom_attr: Some(20), + }, + ) + .unwrap() + .into_ref(py) +} + +#[test] +fn test_getattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py + .getattr("value") + .unwrap() + .extract::() + .unwrap(), + 5, + ); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 20, + ); + assert!(example_py + .getattr("other_attr") + .unwrap_err() + .is_instance::(py)); + }) +} + +#[test] +fn test_setattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.setattr("special_custom_attr", 15).unwrap(); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 15, + ); + }) +} + +#[test] +fn test_delattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.delattr("special_custom_attr").unwrap(); + assert!(example_py.getattr("special_custom_attr").unwrap().is_none()); + }) +} + +#[test] +fn test_str() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.str().unwrap().to_str().unwrap(), "5"); + }) +} + +#[test] +fn test_repr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py.repr().unwrap().to_str().unwrap(), + "ExampleClass(value=5)" + ); + }) +} + +#[test] +fn test_hash() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.hash().unwrap(), 5); + }) +} + +#[test] +fn test_richcmp() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py + .rich_compare(example_py, CompareOp::Eq) + .unwrap() + .is_true() + .unwrap(), + true + ); + }) +} + +#[test] +fn test_bool() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert!(example_py.is_true().unwrap()); + example_py.borrow_mut().value = 0; + assert!(!example_py.is_true().unwrap()); + }) +} + +#[pyclass] +pub struct Len { + l: usize, +} + +#[pymethods] +impl Len { + fn __len__(&self) -> usize { + self.l + } +} + +#[test] +fn len() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new(py, Len { l: 10 }).unwrap(); + py_assert!(py, inst, "len(inst) == 10"); + unsafe { + assert_eq!(ffi::PyObject_Size(inst.as_ptr()), 10); + assert_eq!(ffi::PyMapping_Size(inst.as_ptr()), 10); + } + + let inst = Py::new( + py, + Len { + l: (isize::MAX as usize) + 1, + }, + ) + .unwrap(); + py_expect_exception!(py, inst, "len(inst)", PyOverflowError); +} + +#[pyclass] +struct Iterator { + iter: Box + Send>, +} + +#[pymethods] +impl Iterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> Option { + slf.iter.next() + } +} + +#[test] +fn iterator() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new( + py, + Iterator { + iter: Box::new(5..8), + }, + ) + .unwrap(); + py_assert!(py, inst, "iter(inst) is inst"); + py_assert!(py, inst, "list(inst) == [5, 6, 7]"); +} + +#[pyclass] +struct StringMethods {} + +#[pymethods] +impl StringMethods { + fn __str__(&self) -> &'static str { + "str" + } + + fn __repr__(&self) -> &'static str { + "repr" + } + + fn __format__(&self, format_spec: String) -> String { + format!("format({})", format_spec) + } + + fn __bytes__(&self) -> &'static [u8] { + b"bytes" + } +} + +#[test] +fn string_methods() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let obj = Py::new(py, StringMethods {}).unwrap(); + py_assert!(py, obj, "str(obj) == 'str'"); + py_assert!(py, obj, "repr(obj) == 'repr'"); + py_assert!(py, obj, "'{0:x}'.format(obj) == 'format(x)'"); + py_assert!(py, obj, "bytes(obj) == b'bytes'"); + + // Test that `__bytes__` takes no arguments (should be METH_NOARGS) + py_assert!(py, obj, "obj.__bytes__() == b'bytes'"); + py_expect_exception!(py, obj, "obj.__bytes__('unexpected argument')", PyTypeError); +} + +#[pyclass] +struct Comparisons { + val: i32, +} + +#[pymethods] +impl Comparisons { + fn __hash__(&self) -> isize { + self.val as isize + } + fn __bool__(&self) -> bool { + self.val != 0 + } +} + +#[test] +fn comparisons() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let zero = Py::new(py, Comparisons { val: 0 }).unwrap(); + let one = Py::new(py, Comparisons { val: 1 }).unwrap(); + let ten = Py::new(py, Comparisons { val: 10 }).unwrap(); + let minus_one = Py::new(py, Comparisons { val: -1 }).unwrap(); + py_assert!(py, one, "hash(one) == 1"); + py_assert!(py, ten, "hash(ten) == 10"); + py_assert!(py, minus_one, "hash(minus_one) == -2"); + + py_assert!(py, one, "bool(one) is True"); + py_assert!(py, zero, "not zero"); +} + +#[pyclass] +#[derive(Debug)] +struct Sequence { + fields: Vec, +} + +impl Default for Sequence { + fn default() -> Sequence { + let mut fields = vec![]; + for &s in &["A", "B", "C", "D", "E", "F", "G"] { + fields.push(s.to_string()); + } + Sequence { fields } + } +} + +#[pymethods] +impl Sequence { + fn __len__(&self) -> usize { + self.fields.len() + } + + fn __getitem__(&self, key: isize) -> PyResult { + let idx = usize::try_from(key)?; + if let Some(s) = self.fields.get(idx) { + Ok(s.clone()) + } else { + Err(PyIndexError::new_err(())) + } + } + + fn __setitem__(&mut self, idx: isize, value: String) -> PyResult<()> { + let idx = usize::try_from(idx)?; + if let Some(elem) = self.fields.get_mut(idx) { + *elem = value; + Ok(()) + } else { + Err(PyIndexError::new_err(())) + } + } +} + +#[test] +fn sequence() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Sequence::default()).unwrap(); + py_assert!(py, c, "list(c) == ['A', 'B', 'C', 'D', 'E', 'F', 'G']"); + py_assert!(py, c, "c[-1] == 'G'"); + py_run!( + py, + c, + r#" + c[0] = 'H' + assert c[0] == 'H' +"# + ); + py_expect_exception!(py, c, "c['abc']", PyTypeError); +} + +#[pyclass] +struct Callable {} + +#[pymethods] +impl Callable { + #[__call__] + fn __call__(&self, arg: i32) -> i32 { + arg * 6 + } +} + +#[test] +fn callable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Callable {}).unwrap(); + py_assert!(py, c, "callable(c)"); + py_assert!(py, c, "c(7) == 42"); + + let nc = Py::new(py, Comparisons { val: 0 }).unwrap(); + py_assert!(py, nc, "not callable(nc)"); +} + +#[pyclass] +#[derive(Debug)] +struct SetItem { + key: i32, + val: i32, +} + +#[pymethods] +impl SetItem { + fn __setitem__(&mut self, key: i32, val: i32) { + self.key = key; + self.val = val; + } +} + +#[test] +fn setitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetItem { key: 0, val: 0 }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + assert_eq!(c.val, 2); + } + py_expect_exception!(py, c, "del c[1]", PyNotImplementedError); +} + +#[pyclass] +struct DelItem { + key: i32, +} + +#[pymethods] +impl DelItem { + fn __delitem__(&mut self, key: i32) { + self.key = key; + } +} + +#[test] +fn delitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, DelItem { key: 0 }).unwrap(); + py_run!(py, c, "del c[1]"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + } + py_expect_exception!(py, c, "c[1] = 2", PyNotImplementedError); +} + +#[pyclass] +struct SetDelItem { + val: Option, +} + +#[pymethods] +impl SetDelItem { + fn __setitem__(&mut self, _key: i32, val: i32) { + self.val = Some(val); + } + + fn __delitem__(&mut self, _key: i32) { + self.val = None; + } +} + +#[test] +fn setdelitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetDelItem { val: None }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.val, Some(2)); + } + py_run!(py, c, "del c[1]"); + let c = c.borrow(); + assert_eq!(c.val, None); +} + +#[pyclass] +struct Reversed {} + +#[pymethods] +impl Reversed { + fn __reversed__(&self) -> &'static str { + "I am reversed" + } +} + +#[test] +fn reversed() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Reversed {}).unwrap(); + py_run!(py, c, "assert reversed(c) == 'I am reversed'"); +} + +#[pyclass] +struct Contains {} + +#[pymethods] +impl Contains { + fn __contains__(&self, item: i32) -> bool { + item >= 0 + } +} + +#[test] +fn contains() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Contains {}).unwrap(); + py_run!(py, c, "assert 1 in c"); + py_run!(py, c, "assert -1 not in c"); + py_expect_exception!(py, c, "assert 'wrong type' not in c", PyTypeError); +} + +#[pyclass] +struct ContextManager { + exit_called: bool, +} + +#[pymethods] +impl ContextManager { + fn __enter__(&mut self) -> i32 { + 42 + } + + fn __exit__( + &mut self, + ty: Option<&PyType>, + _value: Option<&PyAny>, + _traceback: Option<&PyAny>, + ) -> bool { + let gil = Python::acquire_gil(); + self.exit_called = true; + ty == Some(gil.python().get_type::()) + } +} + +#[test] +fn context_manager() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, ContextManager { exit_called: false }).unwrap(); + py_run!(py, c, "with c as x: assert x == 42"); + { + let mut c = c.borrow_mut(); + assert!(c.exit_called); + c.exit_called = false; + } + py_run!(py, c, "with c as x: raise ValueError"); + { + let mut c = c.borrow_mut(); + assert!(c.exit_called); + c.exit_called = false; + } + py_expect_exception!( + py, + c, + "with c as x: raise NotImplementedError", + PyNotImplementedError + ); + let c = c.borrow(); + assert!(c.exit_called); +} + +#[test] +fn test_basics() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let v = PySlice::new(py, 1, 10, 2); + let indices = v.indices(100).unwrap(); + assert_eq!(1, indices.start); + assert_eq!(10, indices.stop); + assert_eq!(2, indices.step); + assert_eq!(5, indices.slicelength); +} + +#[pyclass] +struct Test {} + +#[pymethods] +impl Test { + fn __getitem__(&self, idx: &PyAny) -> PyResult<&'static str> { + if let Ok(slice) = idx.cast_as::() { + let indices = slice.indices(1000)?; + if indices.start == 100 && indices.stop == 200 && indices.step == 1 { + return Ok("slice"); + } + } else if let Ok(idx) = idx.extract::() { + if idx == 1 { + return Ok("int"); + } + } + Err(PyValueError::new_err("error")) + } +} + +#[test] +fn test_cls_impl() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let ob = Py::new(py, Test {}).unwrap(); + + py_assert!(py, ob, "ob[1] == 'int'"); + py_assert!(py, ob, "ob[100:200:1] == 'slice'"); +} + +#[pyclass] +struct ClassWithGetAttr { + #[pyo3(get, set)] + data: u32, +} + +#[pymethods] +impl ClassWithGetAttr { + fn __getattr__(&self, _name: &str) -> u32 { + self.data * 2 + } +} + +#[test] +fn getattr_doesnt_override_member() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let inst = PyCell::new(py, ClassWithGetAttr { data: 4 }).unwrap(); + py_assert!(py, inst, "inst.data == 4"); + py_assert!(py, inst, "inst.a == 8"); +} + +/// Wraps a Python future and yield it once. +#[pyclass] +struct OnceFuture { + future: PyObject, + polled: bool, +} + +#[pymethods] +impl OnceFuture { + #[new] + fn new(future: PyObject) -> Self { + OnceFuture { + future, + polled: false, + } + } + + fn __await__(slf: PyRef) -> PyRef { + slf + } + + fn __iter__(slf: PyRef) -> PyRef { + slf + } + fn __next__(mut slf: PyRefMut) -> Option { + if !slf.polled { + slf.polled = true; + Some(slf.future.clone()) + } else { + None + } + } +} + +#[test] +fn test_await() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let once = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +import asyncio +import sys + +async def main(): + res = await Once(await asyncio.sleep(0.1)) + return res +# For an odd error similar to https://bugs.python.org/issue38563 +if sys.platform == "win32" and sys.version_info >= (3, 8, 0): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +# get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579 +loop = asyncio.new_event_loop() +asyncio.set_event_loop(loop) +assert loop.run_until_complete(main()) is None +loop.close() +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Once", once).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + +/// Increment the count when `__get__` is called. +#[pyclass] +struct DescrCounter { + #[pyo3(get)] + count: usize, +} + +#[pymethods] +impl DescrCounter { + #[new] + fn new() -> Self { + DescrCounter { count: 0 } + } + + fn __get__<'a>( + mut slf: PyRefMut<'a, Self>, + _instance: &PyAny, + _owner: Option<&PyType>, + ) -> PyRefMut<'a, Self> { + slf.count += 1; + slf + } + fn __set__(_slf: PyRef, _instance: &PyAny, mut new_value: PyRefMut) { + new_value.count = _slf.count; + } +} + +#[test] +fn descr_getset() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let counter = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +class Class: + counter = Counter() +c = Class() +c.counter # count += 1 +assert c.counter.count == 2 +c.counter = Counter() +assert c.counter.count == 3 +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Counter", counter).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + + +// TODO: test __delete__ +// TODO: better argument casting errors From b544b5a6d74d76629b95eef7d7f1f17475ff67df Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 9 Sep 2021 17:00:35 +0100 Subject: [PATCH 02/12] pymethods: support iter and async protocols --- pyo3-macros-backend/src/pymethod.rs | 16 +- src/class/impl_.rs | 2 +- tests/test_proto_methods.rs | 234 ++-------------------------- 3 files changed, 26 insertions(+), 226 deletions(-) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 28de756a..30707922 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -415,6 +415,19 @@ fn pyproto(cls: &syn::Type, spec: &FnSpec) -> Option { .arguments(&[Ty::Object, Ty::Object]) .generate_type_slot(cls, spec), ), + "__iter__" => Some(SlotDef::new("Py_tp_iter", "getiterfunc").generate_type_slot(cls, spec)), + "__next__" => Some( + SlotDef::new("Py_tp_iternext", "iternextfunc") + .return_conversion(quote! { ::pyo3::class::iter::IterNextOutput::<_, _> }) + .generate_type_slot(cls, spec), + ), + "__await__" => Some(SlotDef::new("Py_am_await", "unaryfunc").generate_type_slot(cls, spec)), + "__aiter__" => Some(SlotDef::new("Py_am_aiter", "unaryfunc").generate_type_slot(cls, spec)), + "__anext__" => Some( + SlotDef::new("Py_am_anext", "unaryfunc") + .return_conversion(quote! { ::pyo3::class::pyasync::IterANextOutput::<_, _> }) + .generate_type_slot(cls, spec), + ), _ => None, } } @@ -433,8 +446,7 @@ impl Ty { match self { Ty::Object => quote! { *mut ::pyo3::ffi::PyObject }, Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, - Ty::Int => quote! { ::std::os::raw::c_int }, - Ty::CompareOp => quote! { ::std::os::raw::c_int }, + Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, } } diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 9015ad18..c8271a5b 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -1,9 +1,9 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::{ + exceptions::PyAttributeError, ffi, impl_::freelist::FreeList, - exceptions::PyAttributeError, pycell::PyCellLayout, pyclass_init::PyObjectInit, type_object::{PyLayout, PyTypeObject}, diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 5de12f0e..d6731d7b 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -1,8 +1,7 @@ -use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; -use pyo3::exceptions::{PyIndexError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::types::{PySlice, PyType}; +use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; use pyo3::{ffi, py_run, AsPyPointer, PyCell}; -use std::convert::TryFrom; use std::{isize, iter}; mod common; @@ -246,137 +245,6 @@ fn iterator() { py_assert!(py, inst, "list(inst) == [5, 6, 7]"); } -#[pyclass] -struct StringMethods {} - -#[pymethods] -impl StringMethods { - fn __str__(&self) -> &'static str { - "str" - } - - fn __repr__(&self) -> &'static str { - "repr" - } - - fn __format__(&self, format_spec: String) -> String { - format!("format({})", format_spec) - } - - fn __bytes__(&self) -> &'static [u8] { - b"bytes" - } -} - -#[test] -fn string_methods() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let obj = Py::new(py, StringMethods {}).unwrap(); - py_assert!(py, obj, "str(obj) == 'str'"); - py_assert!(py, obj, "repr(obj) == 'repr'"); - py_assert!(py, obj, "'{0:x}'.format(obj) == 'format(x)'"); - py_assert!(py, obj, "bytes(obj) == b'bytes'"); - - // Test that `__bytes__` takes no arguments (should be METH_NOARGS) - py_assert!(py, obj, "obj.__bytes__() == b'bytes'"); - py_expect_exception!(py, obj, "obj.__bytes__('unexpected argument')", PyTypeError); -} - -#[pyclass] -struct Comparisons { - val: i32, -} - -#[pymethods] -impl Comparisons { - fn __hash__(&self) -> isize { - self.val as isize - } - fn __bool__(&self) -> bool { - self.val != 0 - } -} - -#[test] -fn comparisons() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let zero = Py::new(py, Comparisons { val: 0 }).unwrap(); - let one = Py::new(py, Comparisons { val: 1 }).unwrap(); - let ten = Py::new(py, Comparisons { val: 10 }).unwrap(); - let minus_one = Py::new(py, Comparisons { val: -1 }).unwrap(); - py_assert!(py, one, "hash(one) == 1"); - py_assert!(py, ten, "hash(ten) == 10"); - py_assert!(py, minus_one, "hash(minus_one) == -2"); - - py_assert!(py, one, "bool(one) is True"); - py_assert!(py, zero, "not zero"); -} - -#[pyclass] -#[derive(Debug)] -struct Sequence { - fields: Vec, -} - -impl Default for Sequence { - fn default() -> Sequence { - let mut fields = vec![]; - for &s in &["A", "B", "C", "D", "E", "F", "G"] { - fields.push(s.to_string()); - } - Sequence { fields } - } -} - -#[pymethods] -impl Sequence { - fn __len__(&self) -> usize { - self.fields.len() - } - - fn __getitem__(&self, key: isize) -> PyResult { - let idx = usize::try_from(key)?; - if let Some(s) = self.fields.get(idx) { - Ok(s.clone()) - } else { - Err(PyIndexError::new_err(())) - } - } - - fn __setitem__(&mut self, idx: isize, value: String) -> PyResult<()> { - let idx = usize::try_from(idx)?; - if let Some(elem) = self.fields.get_mut(idx) { - *elem = value; - Ok(()) - } else { - Err(PyIndexError::new_err(())) - } - } -} - -#[test] -fn sequence() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let c = Py::new(py, Sequence::default()).unwrap(); - py_assert!(py, c, "list(c) == ['A', 'B', 'C', 'D', 'E', 'F', 'G']"); - py_assert!(py, c, "c[-1] == 'G'"); - py_run!( - py, - c, - r#" - c[0] = 'H' - assert c[0] == 'H' -"# - ); - py_expect_exception!(py, c, "c['abc']", PyTypeError); -} - #[pyclass] struct Callable {} @@ -388,6 +256,9 @@ impl Callable { } } +#[pyclass] +struct EmptyClass; + #[test] fn callable() { let gil = Python::acquire_gil(); @@ -397,7 +268,7 @@ fn callable() { py_assert!(py, c, "callable(c)"); py_assert!(py, c, "c(7) == 42"); - let nc = Py::new(py, Comparisons { val: 0 }).unwrap(); + let nc = Py::new(py, EmptyClass).unwrap(); py_assert!(py, nc, "not callable(nc)"); } @@ -489,25 +360,6 @@ fn setdelitem() { assert_eq!(c.val, None); } -#[pyclass] -struct Reversed {} - -#[pymethods] -impl Reversed { - fn __reversed__(&self) -> &'static str { - "I am reversed" - } -} - -#[test] -fn reversed() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let c = Py::new(py, Reversed {}).unwrap(); - py_run!(py, c, "assert reversed(c) == 'I am reversed'"); -} - #[pyclass] struct Contains {} @@ -530,74 +382,10 @@ fn contains() { } #[pyclass] -struct ContextManager { - exit_called: bool, -} +struct GetItem {} #[pymethods] -impl ContextManager { - fn __enter__(&mut self) -> i32 { - 42 - } - - fn __exit__( - &mut self, - ty: Option<&PyType>, - _value: Option<&PyAny>, - _traceback: Option<&PyAny>, - ) -> bool { - let gil = Python::acquire_gil(); - self.exit_called = true; - ty == Some(gil.python().get_type::()) - } -} - -#[test] -fn context_manager() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let c = PyCell::new(py, ContextManager { exit_called: false }).unwrap(); - py_run!(py, c, "with c as x: assert x == 42"); - { - let mut c = c.borrow_mut(); - assert!(c.exit_called); - c.exit_called = false; - } - py_run!(py, c, "with c as x: raise ValueError"); - { - let mut c = c.borrow_mut(); - assert!(c.exit_called); - c.exit_called = false; - } - py_expect_exception!( - py, - c, - "with c as x: raise NotImplementedError", - PyNotImplementedError - ); - let c = c.borrow(); - assert!(c.exit_called); -} - -#[test] -fn test_basics() { - let gil = Python::acquire_gil(); - let py = gil.python(); - - let v = PySlice::new(py, 1, 10, 2); - let indices = v.indices(100).unwrap(); - assert_eq!(1, indices.start); - assert_eq!(10, indices.stop); - assert_eq!(2, indices.step); - assert_eq!(5, indices.slicelength); -} - -#[pyclass] -struct Test {} - -#[pymethods] -impl Test { +impl GetItem { fn __getitem__(&self, idx: &PyAny) -> PyResult<&'static str> { if let Ok(slice) = idx.cast_as::() { let indices = slice.indices(1000)?; @@ -614,11 +402,11 @@ impl Test { } #[test] -fn test_cls_impl() { +fn test_getitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let ob = Py::new(py, Test {}).unwrap(); + let ob = Py::new(py, GetItem {}).unwrap(); py_assert!(py, ob, "ob[1] == 'int'"); py_assert!(py, ob, "ob[100:200:1] == 'slice'"); @@ -760,6 +548,6 @@ assert c.counter.count == 3 .unwrap(); } - // TODO: test __delete__ +// TODO: test __anext__, __aiter__ // TODO: better argument casting errors From fda18b07d7999ac2e200f65ffc553288d2b34c69 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 9 Sep 2021 17:23:46 +0100 Subject: [PATCH 03/12] pymethods: implement some mapping methods --- pyo3-macros-backend/src/pymethod.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 30707922..f2d2422b 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -428,6 +428,22 @@ fn pyproto(cls: &syn::Type, spec: &FnSpec) -> Option { .return_conversion(quote! { ::pyo3::class::pyasync::IterANextOutput::<_, _> }) .generate_type_slot(cls, spec), ), + "__len__" => Some( + SlotDef::new("Py_mp_length", "lenfunc") + .ret_ty(Ty::PySsizeT) + .generate_type_slot(cls, spec), + ), + "__contains__" => Some( + SlotDef::new("Py_sq_contains", "objobjproc") + .arguments(&[Ty::Object]) + .ret_ty(Ty::Int) + .generate_type_slot(cls, spec), + ), + "__getitem__" => Some( + SlotDef::new("Py_mp_subscript", "binaryfunc") + .arguments(&[Ty::Object]) + .generate_type_slot(cls, spec), + ), _ => None, } } @@ -439,6 +455,7 @@ enum Ty { CompareOp, Int, PyHashT, + PySsizeT, } impl Ty { @@ -448,6 +465,7 @@ impl Ty { Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, + Ty::PySsizeT => quote! { ::pyo3::ffi::Py_ssize_t }, } } @@ -473,12 +491,11 @@ impl Ty { #extract } } - Ty::Int => todo!(), - Ty::PyHashT => todo!(), Ty::CompareOp => quote! { let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; }, + Ty::Int | Ty::PyHashT | Ty::PySsizeT => todo!(), } } } From 52105176952424e280b2ef5adf225467a4d6beb0 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Fri, 10 Sep 2021 12:01:25 +0100 Subject: [PATCH 04/12] pymethods: implement more mapping methods --- pyo3-macros-backend/src/pyclass.rs | 3 + pyo3-macros-backend/src/pymethod.rs | 389 ++++++++++++++-------------- src/class/impl_.rs | 224 +++++++++------- 3 files changed, 335 insertions(+), 281 deletions(-) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index cf8b43ba..9579a9cd 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -599,6 +599,9 @@ fn impl_class( if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setdescr_slot!(#cls) { generated_slots.push(setdescr); } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setitem_slot!(#cls) { + generated_slots.push(setdescr); + } visitor(&generated_slots); } diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index f2d2422b..545c777b 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -10,7 +10,7 @@ use crate::{ pyfunction::PyFunctionOptions, }; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote, ToTokens}; use syn::Ident; use syn::{ext::IdentExt, spanned::Spanned, Result}; @@ -31,8 +31,9 @@ pub fn gen_py_method( ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; - if let Some(proto) = pyproto(cls, &spec) { - return Ok(GeneratedPyMethod::Proto(proto)); + if let Some(slot_def) = pyproto(&spec.python_name.to_string()) { + let slot = slot_def.generate_type_slot(cls, &spec); + return Ok(GeneratedPyMethod::Proto(slot)); } if let Some(proto) = pyproto_fragment(cls, &spec)? { @@ -374,76 +375,65 @@ impl PropertyType<'_> { } } -fn pyproto(cls: &syn::Type, spec: &FnSpec) -> Option { - match spec.python_name.to_string().as_str() { - "__getattr__" => Some( - SlotDef::new("Py_tp_getattro", "getattrofunc") - .arguments(&[Ty::Object]) - .before_call_method(quote! { - // Behave like python's __getattr__ (as opposed to __getattribute__) and check - // for existing fields and methods first - let existing = ::pyo3::ffi::PyObject_GenericGetAttr(_slf, arg0); - if existing.is_null() { - // PyObject_HasAttr also tries to get an object and clears the error if it fails - ::pyo3::ffi::PyErr_Clear(); - } else { - return existing; - } - }) - .generate_type_slot(cls, spec), - ), - "__str__" => Some(SlotDef::new("Py_tp_str", "reprfunc").generate_type_slot(cls, spec)), - "__repr__" => Some(SlotDef::new("Py_tp_repr", "reprfunc").generate_type_slot(cls, spec)), - "__hash__" => Some( - SlotDef::new("Py_tp_hash", "hashfunc") - .ret_ty(Ty::PyHashT) - .return_conversion(quote! { ::pyo3::callback::HashCallbackOutput }) - .generate_type_slot(cls, spec), - ), - "__richcmp__" => Some( - SlotDef::new("Py_tp_richcompare", "richcmpfunc") - .arguments(&[Ty::Object, Ty::CompareOp]) - .generate_type_slot(cls, spec), - ), - "__bool__" => Some( - SlotDef::new("Py_nb_bool", "inquiry") - .ret_ty(Ty::Int) - .generate_type_slot(cls, spec), - ), - "__get__" => Some( - SlotDef::new("Py_tp_descr_get", "descrgetfunc") - .arguments(&[Ty::Object, Ty::Object]) - .generate_type_slot(cls, spec), - ), - "__iter__" => Some(SlotDef::new("Py_tp_iter", "getiterfunc").generate_type_slot(cls, spec)), - "__next__" => Some( - SlotDef::new("Py_tp_iternext", "iternextfunc") - .return_conversion(quote! { ::pyo3::class::iter::IterNextOutput::<_, _> }) - .generate_type_slot(cls, spec), - ), - "__await__" => Some(SlotDef::new("Py_am_await", "unaryfunc").generate_type_slot(cls, spec)), - "__aiter__" => Some(SlotDef::new("Py_am_aiter", "unaryfunc").generate_type_slot(cls, spec)), - "__anext__" => Some( - SlotDef::new("Py_am_anext", "unaryfunc") - .return_conversion(quote! { ::pyo3::class::pyasync::IterANextOutput::<_, _> }) - .generate_type_slot(cls, spec), - ), - "__len__" => Some( - SlotDef::new("Py_mp_length", "lenfunc") - .ret_ty(Ty::PySsizeT) - .generate_type_slot(cls, spec), - ), - "__contains__" => Some( - SlotDef::new("Py_sq_contains", "objobjproc") - .arguments(&[Ty::Object]) - .ret_ty(Ty::Int) - .generate_type_slot(cls, spec), - ), - "__getitem__" => Some( - SlotDef::new("Py_mp_subscript", "binaryfunc") - .arguments(&[Ty::Object]) - .generate_type_slot(cls, spec), - ), +const __GETATTR__: SlotDef = SlotDef::new("Py_tp_getattro", "getattrofunc") + .arguments(&[Ty::Object]) + .before_call_method(TokenGenerator(|| { + quote! { + // Behave like python's __getattr__ (as opposed to __getattribute__) and check + // for existing fields and methods first + let existing = ::pyo3::ffi::PyObject_GenericGetAttr(_slf, arg0); + if existing.is_null() { + // PyObject_HasAttr also tries to get an object and clears the error if it fails + ::pyo3::ffi::PyErr_Clear(); + } else { + return existing; + } + } + })); +const __STR__: SlotDef = SlotDef::new("Py_tp_str", "reprfunc"); +const __REPR__: SlotDef = SlotDef::new("Py_tp_repr", "reprfunc"); +const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc") + .ret_ty(Ty::PyHashT) + .return_conversion(TokenGenerator( + || quote! { ::pyo3::callback::HashCallbackOutput }, + )); +const __RICHCMP__: SlotDef = + SlotDef::new("Py_tp_richcompare", "richcmpfunc").arguments(&[Ty::Object, Ty::CompareOp]); +const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); +const __GET__: SlotDef = + SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]); +const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); +const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion( + TokenGenerator(|| quote! { ::pyo3::class::iter::IterNextOutput::<_, _> }), +); +const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc"); +const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc"); +const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion( + TokenGenerator(|| quote! { ::pyo3::class::pyasync::IterANextOutput::<_, _> }), +); +const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT); +const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc") + .arguments(&[Ty::Object]) + .ret_ty(Ty::Int); +const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); + +fn pyproto(method_name: &str) -> Option<&'static SlotDef> { + match method_name { + "__getattr__" => Some(&__GETATTR__), + "__str__" => Some(&__STR__), + "__repr__" => Some(&__REPR__), + "__hash__" => Some(&__HASH__), + "__richcmp__" => Some(&__RICHCMP__), + "__bool__" => Some(&__BOOL__), + "__get__" => Some(&__GET__), + "__iter__" => Some(&__ITER__), + "__next__" => Some(&__NEXT__), + "__await__" => Some(&__AWAIT__), + "__aiter__" => Some(&__AITER__), + "__anext__" => Some(&__ANEXT__), + "__len__" => Some(&__LEN__), + "__contains__" => Some(&__CONTAINS__), + "__getitem__" => Some(&__GETITEM__), _ => None, } } @@ -544,19 +534,19 @@ fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) - } struct SlotDef { - slot: syn::Ident, - func_ty: syn::Ident, + slot: StaticIdent, + func_ty: StaticIdent, arguments: &'static [Ty], ret_ty: Ty, - before_call_method: Option, - return_conversion: Option, + before_call_method: Option, + return_conversion: Option, } impl SlotDef { - fn new(slot: &str, func_ty: &str) -> Self { + const fn new(slot: &'static str, func_ty: &'static str) -> Self { SlotDef { - slot: syn::Ident::new(slot, Span::call_site()), - func_ty: syn::Ident::new(func_ty, Span::call_site()), + slot: StaticIdent(slot), + func_ty: StaticIdent(func_ty), arguments: &[], ret_ty: Ty::Object, before_call_method: None, @@ -564,22 +554,22 @@ impl SlotDef { } } - fn arguments(mut self, arguments: &'static [Ty]) -> Self { + const fn arguments(mut self, arguments: &'static [Ty]) -> Self { self.arguments = arguments; self } - fn ret_ty(mut self, ret_ty: Ty) -> Self { + const fn ret_ty(mut self, ret_ty: Ty) -> Self { self.ret_ty = ret_ty; self } - fn before_call_method(mut self, before_call_method: TokenStream) -> Self { + const fn before_call_method(mut self, before_call_method: TokenGenerator) -> Self { self.before_call_method = Some(before_call_method); self } - fn return_conversion(mut self, return_conversion: TokenStream) -> Self { + const fn return_conversion(mut self, return_conversion: TokenGenerator) -> Self { self.return_conversion = Some(return_conversion); self } @@ -594,34 +584,13 @@ impl SlotDef { return_conversion, } = self; let py = syn::Ident::new("_py", Span::call_site()); - let self_conversion = spec.tp.self_conversion(Some(cls)); - let rust_name = spec.name; - let arguments = arguments.into_iter().enumerate().map(|(i, arg)| { - let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); - let ffi_type = arg.ffi_type(); - quote! { - #ident: #ffi_type - } - }); + let method_arguments = generate_method_arguments(arguments); let ret_ty = ret_ty.ffi_type(); - let (arg_idents, conversions) = - extract_proto_arguments(cls, &py, &spec.args, &self.arguments); - let call = - quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; - let body = if let Some(return_conversion) = return_conversion { - quote! { - let _result: PyResult<#return_conversion> = #call; - ::pyo3::callback::convert(#py, _result) - } - } else { - call - }; + let body = generate_method_body(cls, spec, &py, arguments, return_conversion.as_ref()); quote!({ - unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#arguments),*) -> #ret_ty { + unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { #before_call_method ::pyo3::callback::handle_panic(|#py| { - #self_conversion - #conversions #body }) } @@ -633,94 +602,110 @@ impl SlotDef { } } +fn generate_method_arguments(arguments: &[Ty]) -> impl Iterator + '_ { + arguments.into_iter().enumerate().map(|(i, arg)| { + let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); + let ffi_type = arg.ffi_type(); + quote! { + #ident: #ffi_type + } + }) +} + +fn generate_method_body( + cls: &syn::Type, + spec: &FnSpec, + py: &syn::Ident, + arguments: &[Ty], + return_conversion: Option<&TokenGenerator>, +) -> TokenStream { + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = extract_proto_arguments(cls, &py, &spec.args, arguments); + let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; + let body = if let Some(return_conversion) = return_conversion { + quote! { + let _result: PyResult<#return_conversion> = #call; + ::pyo3::callback::convert(#py, _result) + } + } else { + call + }; + quote! { + #self_conversion + #conversions + #body + } +} + +fn generate_pyproto_fragment( + cls: &syn::Type, + spec: &FnSpec, + fragment: &str, + arguments: &[Ty], +) -> TokenStream { + let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); + let implemented = format_ident!("{}implemented", fragment); + let method = syn::Ident::new(fragment, Span::call_site()); + let py = syn::Ident::new("_py", Span::call_site()); + let method_arguments = generate_method_arguments(arguments); + let body = generate_method_body(cls, spec, &py, arguments, None); + quote! { + impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn #implemented(self) -> bool { true } + + #[inline] + unsafe fn #method( + self, + _slf: *mut ::pyo3::ffi::PyObject, + #(#method_arguments),* + ) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #body + } + } + } +} + fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result> { Ok(match spec.python_name.to_string().as_str() { - "__setattr__" => { - let py = syn::Ident::new("_py", Span::call_site()); - let self_conversion = spec.tp.self_conversion(Some(cls)); - let rust_name = spec.name; - let (arg_idents, conversions) = - extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); - Some(quote! { - impl ::pyo3::class::impl_::PyClassSetattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - #[inline] - fn setattr_implemented(self) -> bool { true } - - #[inline] - unsafe fn setattr( - self, - _slf: *mut ::pyo3::ffi::PyObject, - arg0: *mut ::pyo3::ffi::PyObject, - arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject> - ) -> ::pyo3::PyResult<()> { - let #py = ::pyo3::Python::assume_gil_acquired(); - #self_conversion - #conversions - ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) - } - } - }) - } - "__delattr__" => { - let py = syn::Ident::new("_py", Span::call_site()); - let self_conversion = spec.tp.self_conversion(Some(cls)); - let rust_name = spec.name; - let (arg_idents, conversions) = - extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); - Some(quote! { - impl ::pyo3::class::impl_::PyClassDelattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - fn delattr_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { - unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { - let #py = ::pyo3::Python::assume_gil_acquired(); - #self_conversion - #conversions - ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) - } - Some(__wrap) - } - } - }) - } - "__set__" => { - let py = syn::Ident::new("_py", Span::call_site()); - let self_conversion = spec.tp.self_conversion(Some(cls)); - let rust_name = spec.name; - let (arg_idents, conversions) = - extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); - Some(quote! { - impl ::pyo3::class::impl_::PyClassSetSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - fn set_impl(self) -> ::std::option::Option) -> ::pyo3::PyResult<()>> { - unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject, arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject>) -> ::pyo3::PyResult<()> { - let #py = ::pyo3::Python::assume_gil_acquired(); - #self_conversion - #conversions - ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) - } - Some(__wrap) - } - } - }) - } - "__delete__" => { - let py = syn::Ident::new("_py", Span::call_site()); - let self_conversion = spec.tp.self_conversion(Some(cls)); - let rust_name = spec.name; - let (arg_idents, conversions) = - extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); - Some(quote! { - impl ::pyo3::class::impl_::PyClassDeleteSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - fn delete_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { - unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { - let #py = ::pyo3::Python::assume_gil_acquired(); - #self_conversion - #conversions - ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) - } - Some(__wrap) - } - } - }) - } + "__setattr__" => Some(generate_pyproto_fragment( + cls, + spec, + "__setattr__", + &[Ty::Object, Ty::NonNullObject], + )), + "__delattr__" => Some(generate_pyproto_fragment( + cls, + spec, + "__delattr__", + &[Ty::Object], + )), + "__set__" => Some(generate_pyproto_fragment( + cls, + spec, + "__set__", + &[Ty::Object, Ty::NonNullObject], + )), + "__delete__" => Some(generate_pyproto_fragment( + cls, + spec, + "__delete__", + &[Ty::Object], + )), + "__setitem__" => Some(generate_pyproto_fragment( + cls, + spec, + "__setitem__", + &[Ty::Object, Ty::NonNullObject], + )), + "__delitem__" => Some(generate_pyproto_fragment( + cls, + spec, + "__delitem__", + &[Ty::Object], + )), _ => None, }) } @@ -749,3 +734,19 @@ fn extract_proto_arguments( let conversions = quote!(#(#args_conversion)*); (arg_idents, conversions) } + +struct StaticIdent(&'static str); + +impl ToTokens for StaticIdent { + fn to_tokens(&self, tokens: &mut TokenStream) { + syn::Ident::new(self.0, Span::call_site()).to_tokens(tokens) + } +} + +struct TokenGenerator(fn() -> TokenStream); + +impl ToTokens for TokenGenerator { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.0().to_tokens(tokens) + } +} diff --git a/src/class/impl_.rs b/src/class/impl_.rs index c8271a5b..57c86203 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -1,7 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::{ - exceptions::PyAttributeError, + exceptions::{PyAttributeError, PyNotImplementedError}, ffi, impl_::freelist::FreeList, pycell::PyCellLayout, @@ -108,23 +108,16 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } } -pub trait PyClassSetattrSlotFragment: Sized { +#[allow(non_camel_case_types)] +pub trait PyClass__setattr__SlotFragment: Sized { #[inline] - fn setattr_implemented(self) -> bool { + #[allow(non_snake_case)] + fn __setattr__implemented(self) -> bool { false } - unsafe fn setattr( - self, - _slf: *mut ffi::PyObject, - attr: *mut ffi::PyObject, - value: NonNull, - ) -> PyResult<()>; -} - -impl PyClassSetattrSlotFragment for &'_ PyClassImplCollector { #[inline] - unsafe fn setattr( + unsafe fn __setattr__( self, _slf: *mut ffi::PyObject, _attr: *mut ffi::PyObject, @@ -134,20 +127,28 @@ impl PyClassSetattrSlotFragment for &'_ PyClassImplCollector { } } -pub trait PyClassDelattrSlotFragment { - fn delattr_impl( - self, - ) -> Option PyResult<()>>; -} +impl PyClass__setattr__SlotFragment for &'_ PyClassImplCollector {} -impl PyClassDelattrSlotFragment for &'_ PyClassImplCollector { - fn delattr_impl( +#[allow(non_camel_case_types)] +pub trait PyClass__delattr__SlotFragment: Sized { + #[inline] + #[allow(non_snake_case)] + fn __delattr__implemented(self) -> bool { + false + } + + #[inline] + unsafe fn __delattr__( self, - ) -> Option PyResult<()>> { - None + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + ) -> PyResult<()> { + Err(PyAttributeError::new_err("can't delete attribute")) } } +impl PyClass__delattr__SlotFragment for &'_ PyClassImplCollector {} + #[doc(hidden)] #[macro_export] macro_rules! generate_pyclass_setattr_slot { @@ -155,28 +156,19 @@ macro_rules! generate_pyclass_setattr_slot { use ::std::option::Option::*; use $crate::class::impl_::*; let collector = PyClassImplCollector::<$cls>::new(); - let delattr = collector.delattr_impl(); - if collector.setattr_implemented() || delattr.is_some() { + if collector.__setattr__implemented() || collector.__delattr__implemented() { unsafe extern "C" fn __wrap( _slf: *mut $crate::ffi::PyObject, attr: *mut $crate::ffi::PyObject, value: *mut $crate::ffi::PyObject, ) -> ::std::os::raw::c_int { - $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + $crate::callback::handle_panic(|py| { let collector = PyClassImplCollector::<$cls>::new(); $crate::callback::convert(py, { if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.setattr(_slf, attr, value) + collector.__setattr__(_slf, attr, value) } else { - if let Some(del) = collector.delattr_impl() { - del(_slf, attr) - } else { - ::std::result::Result::Err( - $crate::exceptions::PyAttributeError::new_err( - "can't delete attribute", - ), - ) - } + collector.__delattr__(_slf, attr) } }) }) @@ -191,46 +183,47 @@ macro_rules! generate_pyclass_setattr_slot { }}; } -pub trait PyClassSetSlotFragment { - fn set_impl( - self, - ) -> Option< - unsafe fn( - _slf: *mut ffi::PyObject, - attr: *mut ffi::PyObject, - value: NonNull, - ) -> PyResult<()>, - >; -} +#[allow(non_camel_case_types)] +pub trait PyClass__set__SlotFragment: Sized { + #[inline] + #[allow(non_snake_case)] + fn __set__implemented(self) -> bool { + false + } -impl PyClassSetSlotFragment for &'_ PyClassImplCollector { - fn set_impl( + #[inline] + unsafe fn __set__( self, - ) -> Option< - unsafe fn( - _slf: *mut ffi::PyObject, - attr: *mut ffi::PyObject, - value: NonNull, - ) -> PyResult<()>, - > { - None + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err("can't set descriptor")) } } -pub trait PyClassDeleteSlotFragment { - fn delete_impl( - self, - ) -> Option PyResult<()>>; -} +impl PyClass__set__SlotFragment for &'_ PyClassImplCollector {} -impl PyClassDeleteSlotFragment for &'_ PyClassImplCollector { - fn delete_impl( +#[allow(non_camel_case_types)] +pub trait PyClass__delete__SlotFragment: Sized { + #[allow(non_snake_case)] + #[inline] + fn __delete__implemented(self) -> bool { + false + } + + #[inline] + unsafe fn __delete__( self, - ) -> Option PyResult<()>> { - None + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err("can't delete descriptor")) } } +impl PyClass__delete__SlotFragment for &'_ PyClassImplCollector {} + #[doc(hidden)] #[macro_export] macro_rules! generate_pyclass_setdescr_slot { @@ -238,37 +231,19 @@ macro_rules! generate_pyclass_setdescr_slot { use ::std::option::Option::*; use $crate::class::impl_::*; let collector = PyClassImplCollector::<$cls>::new(); - let set = collector.set_impl(); - let delete = collector.delete_impl(); - if set.is_some() || delete.is_some() { + if collector.__set__implemented() || collector.__delete__implemented() { unsafe extern "C" fn __wrap( _slf: *mut $crate::ffi::PyObject, attr: *mut $crate::ffi::PyObject, value: *mut $crate::ffi::PyObject, ) -> ::std::os::raw::c_int { - $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + $crate::callback::handle_panic(|py| { let collector = PyClassImplCollector::<$cls>::new(); $crate::callback::convert(py, { if let Some(value) = ::std::ptr::NonNull::new(value) { - if let Some(set) = collector.set_impl() { - set(_slf, attr, value) - } else { - ::std::result::Result::Err( - $crate::exceptions::PyTypeError::new_err( - "can't set descriptor", - ), - ) - } + collector.__set__(_slf, attr, value) } else { - if let Some(del) = collector.delete_impl() { - del(_slf, attr) - } else { - ::std::result::Result::Err( - $crate::exceptions::PyTypeError::new_err( - "can't delete descriptor", - ), - ) - } + collector.__delete__(_slf, attr) } }) }) @@ -283,6 +258,81 @@ macro_rules! generate_pyclass_setdescr_slot { }}; } +#[allow(non_camel_case_types)] +pub trait PyClass__setitem__SlotFragment: Sized { + #[inline] + #[allow(non_snake_case)] + fn __setitem__implemented(self) -> bool { + false + } + + #[inline] + unsafe fn __setitem__( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err("can't set item")) + } +} + +impl PyClass__setitem__SlotFragment for &'_ PyClassImplCollector {} + +#[allow(non_camel_case_types)] +pub trait PyClass__delitem__SlotFragment: Sized { + #[allow(non_snake_case)] + #[inline] + fn __delitem__implemented(self) -> bool { + false + } + + #[inline] + unsafe fn __delitem__( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err("can't delete item")) + } +} + +impl PyClass__delitem__SlotFragment for &'_ PyClassImplCollector {} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_setitem_slot { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + if collector.__setitem__implemented() || collector.__delitem__implemented() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.__setitem__(_slf, attr, value) + } else { + collector.__delitem__(_slf, attr) + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_mp_ass_subscript, + pfunc: __wrap as $crate::ffi::objobjargproc as _, + }) + } else { + None + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } From c090b6581db23b96308c592026e276da7a461176 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sun, 12 Sep 2021 08:36:47 +0100 Subject: [PATCH 05/12] pymethods: fix clippy errors --- pyo3-macros-backend/src/pyclass.rs | 2 +- pyo3-macros-backend/src/pymethod.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 9579a9cd..98ec071a 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -592,7 +592,7 @@ fn impl_class( visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); visitor(collector.methods_protocol_slots()); - let mut generated_slots = Vec::new(); + let mut generated_slots = ::std::vec::Vec::new(); if let ::std::option::Option::Some(setattr) = ::pyo3::generate_pyclass_setattr_slot!(#cls) { generated_slots.push(setattr); } diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 545c777b..4d0ae404 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -603,7 +603,7 @@ impl SlotDef { } fn generate_method_arguments(arguments: &[Ty]) -> impl Iterator + '_ { - arguments.into_iter().enumerate().map(|(i, arg)| { + arguments.iter().enumerate().map(|(i, arg)| { let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); let ffi_type = arg.ffi_type(); quote! { @@ -621,7 +621,7 @@ fn generate_method_body( ) -> TokenStream { let self_conversion = spec.tp.self_conversion(Some(cls)); let rust_name = spec.name; - let (arg_idents, conversions) = extract_proto_arguments(cls, &py, &spec.args, arguments); + let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments); let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; let body = if let Some(return_conversion) = return_conversion { quote! { @@ -719,7 +719,7 @@ fn extract_proto_arguments( let mut arg_idents = Vec::with_capacity(method_args.len()); let mut non_python_args = 0; - let args_conversion = method_args.into_iter().filter_map(|arg| { + let args_conversion = method_args.iter().filter_map(|arg| { if arg.py { arg_idents.push(py.clone()); None From 92e2156161581e8c76c49dab7c945f6d46893af1 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sun, 12 Sep 2021 10:20:21 +0100 Subject: [PATCH 06/12] pymethods: support inplace numerical operations --- pyo3-macros-backend/src/pymethod.rs | 243 +++++++--- tests/test_arithmetics.rs | 195 ++++---- tests/test_arithmetics_protos.rs | 683 ++++++++++++++++++++++++++++ tests/test_proto_methods.rs | 1 + 4 files changed, 954 insertions(+), 168 deletions(-) create mode 100644 tests/test_arithmetics_protos.rs diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 4d0ae404..e6099712 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -32,7 +32,7 @@ pub fn gen_py_method( let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; if let Some(slot_def) = pyproto(&spec.python_name.to_string()) { - let slot = slot_def.generate_type_slot(cls, &spec); + let slot = slot_def.generate_type_slot(cls, &spec)?; return Ok(GeneratedPyMethod::Proto(slot)); } @@ -399,7 +399,6 @@ const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc") )); const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc").arguments(&[Ty::Object, Ty::CompareOp]); -const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]); const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); @@ -417,6 +416,55 @@ const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc") .ret_ty(Ty::Int); const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); +const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc"); +const __NEG__: SlotDef = SlotDef::new("Py_nb_negative", "unaryfunc"); +const __ABS__: SlotDef = SlotDef::new("Py_nb_absolute", "unaryfunc"); +const __INVERT__: SlotDef = SlotDef::new("Py_nb_invert", "unaryfunc"); +const __INDEX__: SlotDef = SlotDef::new("Py_nb_index", "unaryfunc"); +const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc"); +const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc"); +const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); + +const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented]) + .return_self(); +const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); +const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .return_self(); + fn pyproto(method_name: &str) -> Option<&'static SlotDef> { match method_name { "__getattr__" => Some(&__GETATTR__), @@ -424,7 +472,6 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { "__repr__" => Some(&__REPR__), "__hash__" => Some(&__HASH__), "__richcmp__" => Some(&__RICHCMP__), - "__bool__" => Some(&__BOOL__), "__get__" => Some(&__GET__), "__iter__" => Some(&__ITER__), "__next__" => Some(&__NEXT__), @@ -434,6 +481,27 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { "__len__" => Some(&__LEN__), "__contains__" => Some(&__CONTAINS__), "__getitem__" => Some(&__GETITEM__), + "__pos__" => Some(&__POS__), + "__neg__" => Some(&__NEG__), + "__abs__" => Some(&__ABS__), + "__invert__" => Some(&__INVERT__), + "__index__" => Some(&__INDEX__), + "__int__" => Some(&__INT__), + "__float__" => Some(&__FLOAT__), + "__bool__" => Some(&__BOOL__), + "__iadd__" => Some(&__IADD__), + "__isub__" => Some(&__ISUB__), + "__imul__" => Some(&__IMUL__), + "__imatmul__" => Some(&__IMATMUL__), + "__itruediv__" => Some(&__ITRUEDIV__), + "__ifloordiv__" => Some(&__IFLOORDIV__), + "__imod__" => Some(&__IMOD__), + "__ipow__" => Some(&__IPOW__), + "__ilshift__" => Some(&__ILSHIFT__), + "__irshift__" => Some(&__IRSHIFT__), + "__iand__" => Some(&__IAND__), + "__ixor__" => Some(&__IXOR__), + "__ior__" => Some(&__IOR__), _ => None, } } @@ -441,6 +509,7 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { #[derive(Clone, Copy)] enum Ty { Object, + ObjectOrNotImplemented, NonNullObject, CompareOp, Int, @@ -451,7 +520,7 @@ enum Ty { impl Ty { fn ffi_type(self) -> TokenStream { match self { - Ty::Object => quote! { *mut ::pyo3::ffi::PyObject }, + Ty::Object | Ty::ObjectOrNotImplemented => quote! { *mut ::pyo3::ffi::PyObject }, Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, @@ -474,6 +543,29 @@ impl Ty { #extract } } + Ty::ObjectOrNotImplemented => { + let extract = if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let (tref, mut_) = preprocess_tref(tref, cls); + quote! { + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = match #ident.extract() { + Ok(#ident) => #ident, + Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), + }; + let #ident = &#mut_ *#ident; + } + } else { + quote! { + let #ident = match #ident.extract() { + Ok(#ident) => #ident, + Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), + }; + } + }; + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); + #extract + } + } Ty::NonNullObject => { let extract = extract_from_any(cls, target, ident); quote! { @@ -502,33 +594,55 @@ fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) - let #ident = #ident.extract()?; } }; +} - /// Replace `Self`, remove lifetime and get mutability from the type - fn preprocess_tref( - tref: &syn::TypeReference, - self_: &syn::Type, - ) -> (syn::TypeReference, Option) { - let mut tref = tref.to_owned(); - if let syn::Type::Path(tpath) = self_ { - replace_self(&mut tref, &tpath.path); - } - tref.lifetime = None; - let mut_ = tref.mutability; - (tref, mut_) +/// Replace `Self`, remove lifetime and get mutability from the type +fn preprocess_tref( + tref: &syn::TypeReference, + self_: &syn::Type, +) -> (syn::TypeReference, Option) { + let mut tref = tref.to_owned(); + if let syn::Type::Path(tpath) = self_ { + replace_self(&mut tref, &tpath.path); } + tref.lifetime = None; + let mut_ = tref.mutability; + (tref, mut_) +} - /// Replace `Self` with the exact type name since it is used out of the impl block - fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { - match &mut *tref.elem { - syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), - syn::Type::Path(tpath) => { - if let Some(ident) = tpath.path.get_ident() { - if ident == "Self" { - tpath.path = self_path.to_owned(); - } +/// Replace `Self` with the exact type name since it is used out of the impl block +fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { + match &mut *tref.elem { + syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), + syn::Type::Path(tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + tpath.path = self_path.to_owned(); } } - _ => {} + } + _ => {} + } +} + +enum ReturnMode { + ReturnSelf, + Conversion(TokenGenerator), +} + +impl ReturnMode { + fn return_call_output(&self, py: &syn::Ident, call: TokenStream) -> TokenStream { + match self { + ReturnMode::Conversion(conversion) => quote! { + let _result: PyResult<#conversion> = #call; + ::pyo3::callback::convert(#py, _result) + }, + ReturnMode::ReturnSelf => quote! { + let _result: PyResult<()> = #call; + _result?; + ::pyo3::ffi::Py_XINCREF(_raw_slf); + Ok(_raw_slf) + }, } } } @@ -539,7 +653,7 @@ struct SlotDef { arguments: &'static [Ty], ret_ty: Ty, before_call_method: Option, - return_conversion: Option, + return_mode: Option, } impl SlotDef { @@ -550,7 +664,7 @@ impl SlotDef { arguments: &[], ret_ty: Ty::Object, before_call_method: None, - return_conversion: None, + return_mode: None, } } @@ -570,25 +684,31 @@ impl SlotDef { } const fn return_conversion(mut self, return_conversion: TokenGenerator) -> Self { - self.return_conversion = Some(return_conversion); + self.return_mode = Some(ReturnMode::Conversion(return_conversion)); self } - fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> TokenStream { + const fn return_self(mut self) -> Self { + self.return_mode = Some(ReturnMode::ReturnSelf); + self + } + + fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result { let SlotDef { slot, func_ty, before_call_method, arguments, ret_ty, - return_conversion, + return_mode, } = self; let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); let ret_ty = ret_ty.ffi_type(); - let body = generate_method_body(cls, spec, &py, arguments, return_conversion.as_ref()); - quote!({ - unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { + let body = generate_method_body(cls, spec, &py, arguments, return_mode.as_ref())?; + Ok(quote!({ + unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { + let _slf = _raw_slf; #before_call_method ::pyo3::callback::handle_panic(|#py| { #body @@ -598,7 +718,7 @@ impl SlotDef { slot: ::pyo3::ffi::#slot, pfunc: __wrap as ::pyo3::ffi::#func_ty as _ } - }) + })) } } @@ -617,25 +737,22 @@ fn generate_method_body( spec: &FnSpec, py: &syn::Ident, arguments: &[Ty], - return_conversion: Option<&TokenGenerator>, -) -> TokenStream { + return_mode: Option<&ReturnMode>, +) -> Result { let self_conversion = spec.tp.self_conversion(Some(cls)); let rust_name = spec.name; - let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments); + let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?; let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; - let body = if let Some(return_conversion) = return_conversion { - quote! { - let _result: PyResult<#return_conversion> = #call; - ::pyo3::callback::convert(#py, _result) - } + let body = if let Some(return_mode) = return_mode { + return_mode.return_call_output(py, call) } else { call }; - quote! { + Ok(quote! { #self_conversion #conversions #body - } + }) } fn generate_pyproto_fragment( @@ -643,14 +760,14 @@ fn generate_pyproto_fragment( spec: &FnSpec, fragment: &str, arguments: &[Ty], -) -> TokenStream { +) -> Result { let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); let implemented = format_ident!("{}implemented", fragment); let method = syn::Ident::new(fragment, Span::call_site()); let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); - let body = generate_method_body(cls, spec, &py, arguments, None); - quote! { + let body = generate_method_body(cls, spec, &py, arguments, None)?; + Ok(quote! { impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { #[inline] fn #implemented(self) -> bool { true } @@ -658,18 +775,19 @@ fn generate_pyproto_fragment( #[inline] unsafe fn #method( self, - _slf: *mut ::pyo3::ffi::PyObject, + _raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),* ) -> ::pyo3::PyResult<()> { + let _slf = _raw_slf; let #py = ::pyo3::Python::assume_gil_acquired(); #body } } - } + }) } fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result> { - Ok(match spec.python_name.to_string().as_str() { + match spec.python_name.to_string().as_str() { "__setattr__" => Some(generate_pyproto_fragment( cls, spec, @@ -707,7 +825,8 @@ fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result None, - }) + } + .transpose() } fn extract_proto_arguments( @@ -715,24 +834,28 @@ fn extract_proto_arguments( py: &syn::Ident, method_args: &[FnArg], proto_args: &[Ty], -) -> (Vec, TokenStream) { +) -> Result<(Vec, TokenStream)> { let mut arg_idents = Vec::with_capacity(method_args.len()); let mut non_python_args = 0; - let args_conversion = method_args.iter().filter_map(|arg| { + let mut args_conversions = Vec::with_capacity(proto_args.len()); + + for arg in method_args { if arg.py { arg_idents.push(py.clone()); - None } else { let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); - let conversions = proto_args[non_python_args].extract(cls, py, &ident, arg.ty); + let conversions = proto_args.get(non_python_args) + .ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))? + .extract(cls, py, &ident, arg.ty); non_python_args += 1; + args_conversions.push(conversions); arg_idents.push(ident); - Some(conversions) } - }); - let conversions = quote!(#(#args_conversion)*); - (arg_idents, conversions) + } + + let conversions = quote!(#(#args_conversions)*); + Ok((arg_idents, conversions)) } struct StaticIdent(&'static str); diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index dc64155e..6f39920d 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -1,7 +1,4 @@ -#![allow(deprecated)] // for deprecated protocol methods - use pyo3::class::basic::CompareOp; -use pyo3::class::*; use pyo3::prelude::*; use pyo3::py_run; @@ -16,17 +13,11 @@ impl UnaryArithmetic { fn new(value: f64) -> Self { UnaryArithmetic { inner: value } } -} -#[pyproto] -impl PyObjectProtocol for UnaryArithmetic { fn __repr__(&self) -> String { format!("UA({})", self.inner) } -} -#[pyproto] -impl PyNumberProtocol for UnaryArithmetic { fn __neg__(&self) -> Self { Self::new(-self.inner) } @@ -57,30 +48,17 @@ fn unary_arithmetic() { py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); } -#[pyclass] -struct BinaryArithmetic {} - -#[pyproto] -impl PyObjectProtocol for BinaryArithmetic { - fn __repr__(&self) -> &'static str { - "BA" - } -} - #[pyclass] struct InPlaceOperations { value: u32, } -#[pyproto] -impl PyObjectProtocol for InPlaceOperations { +#[pymethods] +impl InPlaceOperations { fn __repr__(&self) -> String { format!("IPO({:?})", self.value) } -} -#[pyproto] -impl PyNumberProtocol for InPlaceOperations { fn __iadd__(&mut self, other: u32) { self.value += other; } @@ -142,42 +120,49 @@ fn inplace_operations() { ); } -#[pyproto] -impl PyNumberProtocol for BinaryArithmetic { - fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} + {:?}", lhs, rhs) +#[pyclass] +struct BinaryArithmetic {} + +#[pymethods] +impl BinaryArithmetic { + fn __repr__(&self) -> &'static str { + "BA" } - fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} - {:?}", lhs, rhs) + fn __add__(&self, rhs: &PyAny) -> String { + format!("BA + {:?}", rhs) } - fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} * {:?}", lhs, rhs) + fn __sub__(&self, rhs: &PyAny) -> String { + format!("BA - {:?}", rhs) } - fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} << {:?}", lhs, rhs) + fn __mul__(&self, rhs: &PyAny) -> String { + format!("BA * {:?}", rhs) } - fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} >> {:?}", lhs, rhs) + fn __lshift__(&self, rhs: &PyAny) -> String { + format!("BA << {:?}", rhs) } - fn __and__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} & {:?}", lhs, rhs) + fn __rshift__(&self, rhs: &PyAny) -> String { + format!("BA >> {:?}", rhs) } - fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} ^ {:?}", lhs, rhs) + fn __and__(&self, rhs: &PyAny) -> String { + format!("BA & {:?}", rhs) } - fn __or__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} | {:?}", lhs, rhs) + fn __xor__(&self, rhs: &PyAny) -> String { + format!("BA ^ {:?}", rhs) } - fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option) -> String { - format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_) + fn __or__(&self, rhs: &PyAny) -> String { + format!("BA | {:?}", rhs) + } + + fn __pow__(&self, rhs: &PyAny, mod_: Option) -> String { + format!("BA ** {:?} (mod: {:?})", rhs, mod_) } } @@ -215,8 +200,8 @@ fn binary_arithmetic() { #[pyclass] struct RhsArithmetic {} -#[pyproto] -impl PyNumberProtocol for RhsArithmetic { +#[pymethods] +impl RhsArithmetic { fn __radd__(&self, other: &PyAny) -> String { format!("{:?} + RA", other) } @@ -249,7 +234,7 @@ impl PyNumberProtocol for RhsArithmetic { format!("{:?} | RA", other) } - fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String { format!("{:?} ** RA", other) } } @@ -289,8 +274,12 @@ impl std::fmt::Debug for LhsAndRhs { } } -#[pyproto] -impl PyNumberProtocol for LhsAndRhs { +#[pymethods] +impl LhsAndRhs { + // fn __repr__(&self) -> &'static str { + // "BA" + // } + fn __add__(lhs: PyRef, rhs: &PyAny) -> String { format!("{:?} + {:?}", lhs, rhs) } @@ -363,7 +352,7 @@ impl PyNumberProtocol for LhsAndRhs { format!("{:?} | RA", other) } - fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String { format!("{:?} ** RA", other) } @@ -372,13 +361,6 @@ impl PyNumberProtocol for LhsAndRhs { } } -#[pyproto] -impl PyObjectProtocol for LhsAndRhs { - fn __repr__(&self) -> &'static str { - "BA" - } -} - #[test] fn lhs_fellback_to_rhs() { let gil = Python::acquire_gil(); @@ -412,8 +394,8 @@ fn lhs_fellback_to_rhs() { #[pyclass] struct RichComparisons {} -#[pyproto] -impl PyObjectProtocol for RichComparisons { +#[pymethods] +impl RichComparisons { fn __repr__(&self) -> &'static str { "RC" } @@ -433,8 +415,8 @@ impl PyObjectProtocol for RichComparisons { #[pyclass] struct RichComparisons2 {} -#[pyproto] -impl PyObjectProtocol for RichComparisons2 { +#[pymethods] +impl RichComparisons2 { fn __repr__(&self) -> &'static str { "RC2" } @@ -508,76 +490,73 @@ mod return_not_implemented { #[pyclass] struct RichComparisonToSelf {} - #[pyproto] - impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf { + #[pymethods] + impl RichComparisonToSelf { fn __repr__(&self) -> &'static str { "RC_Self" } - fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject { + fn __richcmp__(&self, other: PyRef, _op: CompareOp) -> PyObject { other.py().None() } - } - #[pyproto] - impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf { - fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __add__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __sub__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __mul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __matmul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __truediv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __floordiv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __mod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option) -> &'p PyAny { - lhs + fn __pow__<'p>(slf: PyRef<'p, Self>, _other: u8, _modulo: Option) -> PyRef<'p, Self> { + slf } - fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __lshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __rshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __divmod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __and__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __or__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __xor__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } // Inplace assignments - fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __iadd__(&mut self, _other: PyRef) {} + fn __isub__(&mut self, _other: PyRef) {} + fn __imul__(&mut self, _other: PyRef) {} + fn __imatmul__(&mut self, _other: PyRef) {} + fn __itruediv__(&mut self, _other: PyRef) {} + fn __ifloordiv__(&mut self, _other: PyRef) {} + fn __imod__(&mut self, _other: PyRef) {} + fn __ipow__(&mut self, _other: PyRef) {} + fn __ilshift__(&mut self, _other: PyRef) {} + fn __irshift__(&mut self, _other: PyRef) {} + fn __iand__(&mut self, _other: PyRef) {} + fn __ior__(&mut self, _other: PyRef) {} + fn __ixor__(&mut self, _other: PyRef) {} } fn _test_binary_dunder(dunder: &str) { diff --git a/tests/test_arithmetics_protos.rs b/tests/test_arithmetics_protos.rs new file mode 100644 index 00000000..dc64155e --- /dev/null +++ b/tests/test_arithmetics_protos.rs @@ -0,0 +1,683 @@ +#![allow(deprecated)] // for deprecated protocol methods + +use pyo3::class::basic::CompareOp; +use pyo3::class::*; +use pyo3::prelude::*; +use pyo3::py_run; + +mod common; + +#[pyclass] +struct UnaryArithmetic { + inner: f64, +} + +impl UnaryArithmetic { + fn new(value: f64) -> Self { + UnaryArithmetic { inner: value } + } +} + +#[pyproto] +impl PyObjectProtocol for UnaryArithmetic { + fn __repr__(&self) -> String { + format!("UA({})", self.inner) + } +} + +#[pyproto] +impl PyNumberProtocol for UnaryArithmetic { + fn __neg__(&self) -> Self { + Self::new(-self.inner) + } + + fn __pos__(&self) -> Self { + Self::new(self.inner) + } + + fn __abs__(&self) -> Self { + Self::new(self.inner.abs()) + } + + fn __round__(&self, _ndigits: Option) -> Self { + Self::new(self.inner.round()) + } +} + +#[test] +fn unary_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, UnaryArithmetic::new(2.7)).unwrap(); + py_run!(py, c, "assert repr(-c) == 'UA(-2.7)'"); + py_run!(py, c, "assert repr(+c) == 'UA(2.7)'"); + py_run!(py, c, "assert repr(abs(c)) == 'UA(2.7)'"); + py_run!(py, c, "assert repr(round(c)) == 'UA(3)'"); + py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); +} + +#[pyclass] +struct BinaryArithmetic {} + +#[pyproto] +impl PyObjectProtocol for BinaryArithmetic { + fn __repr__(&self) -> &'static str { + "BA" + } +} + +#[pyclass] +struct InPlaceOperations { + value: u32, +} + +#[pyproto] +impl PyObjectProtocol for InPlaceOperations { + fn __repr__(&self) -> String { + format!("IPO({:?})", self.value) + } +} + +#[pyproto] +impl PyNumberProtocol for InPlaceOperations { + fn __iadd__(&mut self, other: u32) { + self.value += other; + } + + fn __isub__(&mut self, other: u32) { + self.value -= other; + } + + fn __imul__(&mut self, other: u32) { + self.value *= other; + } + + fn __ilshift__(&mut self, other: u32) { + self.value <<= other; + } + + fn __irshift__(&mut self, other: u32) { + self.value >>= other; + } + + fn __iand__(&mut self, other: u32) { + self.value &= other; + } + + fn __ixor__(&mut self, other: u32) { + self.value ^= other; + } + + fn __ior__(&mut self, other: u32) { + self.value |= other; + } + + fn __ipow__(&mut self, other: u32) { + self.value = self.value.pow(other); + } +} + +#[test] +fn inplace_operations() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let init = |value, code| { + let c = PyCell::new(py, InPlaceOperations { value }).unwrap(); + py_run!(py, c, code); + }; + + init(0, "d = c; c += 1; assert repr(c) == repr(d) == 'IPO(1)'"); + init(10, "d = c; c -= 1; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c *= 3; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c <<= 2; assert repr(c) == repr(d) == 'IPO(12)'"); + init(12, "d = c; c >>= 2; assert repr(c) == repr(d) == 'IPO(3)'"); + init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'"); + init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'"); + init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'"); + init( + 3, + "d = c; c.__ipow__(4); assert repr(c) == repr(d) == 'IPO(81)'", + ); +} + +#[pyproto] +impl PyNumberProtocol for BinaryArithmetic { + fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} + {:?}", lhs, rhs) + } + + fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} - {:?}", lhs, rhs) + } + + fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} * {:?}", lhs, rhs) + } + + fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} << {:?}", lhs, rhs) + } + + fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} >> {:?}", lhs, rhs) + } + + fn __and__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} & {:?}", lhs, rhs) + } + + fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} ^ {:?}", lhs, rhs) + } + + fn __or__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} | {:?}", lhs, rhs) + } + + fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option) -> String { + format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_) + } +} + +#[test] +fn binary_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, BinaryArithmetic {}).unwrap(); + py_run!(py, c, "assert c + c == 'BA + BA'"); + py_run!(py, c, "assert c.__add__(c) == 'BA + BA'"); + py_run!(py, c, "assert c + 1 == 'BA + 1'"); + py_run!(py, c, "assert 1 + c == '1 + BA'"); + py_run!(py, c, "assert c - 1 == 'BA - 1'"); + py_run!(py, c, "assert 1 - c == '1 - BA'"); + py_run!(py, c, "assert c * 1 == 'BA * 1'"); + py_run!(py, c, "assert 1 * c == '1 * BA'"); + + py_run!(py, c, "assert c << 1 == 'BA << 1'"); + py_run!(py, c, "assert 1 << c == '1 << BA'"); + py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); + py_run!(py, c, "assert 1 >> c == '1 >> BA'"); + py_run!(py, c, "assert c & 1 == 'BA & 1'"); + py_run!(py, c, "assert 1 & c == '1 & BA'"); + py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); + py_run!(py, c, "assert c | 1 == 'BA | 1'"); + py_run!(py, c, "assert 1 | c == '1 | BA'"); + py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'"); + py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'"); + + py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); +} + +#[pyclass] +struct RhsArithmetic {} + +#[pyproto] +impl PyNumberProtocol for RhsArithmetic { + fn __radd__(&self, other: &PyAny) -> String { + format!("{:?} + RA", other) + } + + fn __rsub__(&self, other: &PyAny) -> String { + format!("{:?} - RA", other) + } + + fn __rmul__(&self, other: &PyAny) -> String { + format!("{:?} * RA", other) + } + + fn __rlshift__(&self, other: &PyAny) -> String { + format!("{:?} << RA", other) + } + + fn __rrshift__(&self, other: &PyAny) -> String { + format!("{:?} >> RA", other) + } + + fn __rand__(&self, other: &PyAny) -> String { + format!("{:?} & RA", other) + } + + fn __rxor__(&self, other: &PyAny) -> String { + format!("{:?} ^ RA", other) + } + + fn __ror__(&self, other: &PyAny) -> String { + format!("{:?} | RA", other) + } + + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + format!("{:?} ** RA", other) + } +} + +#[test] +fn rhs_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, RhsArithmetic {}).unwrap(); + py_run!(py, c, "assert c.__radd__(1) == '1 + RA'"); + py_run!(py, c, "assert 1 + c == '1 + RA'"); + py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); + py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'"); + py_run!(py, c, "assert 1 * c == '1 * RA'"); + py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'"); + py_run!(py, c, "assert 1 << c == '1 << RA'"); + py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'"); + py_run!(py, c, "assert 1 >> c == '1 >> RA'"); + py_run!(py, c, "assert c.__rand__(1) == '1 & RA'"); + py_run!(py, c, "assert 1 & c == '1 & RA'"); + py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ RA'"); + py_run!(py, c, "assert c.__ror__(1) == '1 | RA'"); + py_run!(py, c, "assert 1 | c == '1 | RA'"); + py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'"); + py_run!(py, c, "assert 1 ** c == '1 ** RA'"); +} + +#[pyclass] +struct LhsAndRhs {} + +impl std::fmt::Debug for LhsAndRhs { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "LR") + } +} + +#[pyproto] +impl PyNumberProtocol for LhsAndRhs { + fn __add__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} + {:?}", lhs, rhs) + } + + fn __sub__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} - {:?}", lhs, rhs) + } + + fn __mul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} * {:?}", lhs, rhs) + } + + fn __lshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} << {:?}", lhs, rhs) + } + + fn __rshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} >> {:?}", lhs, rhs) + } + + fn __and__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} & {:?}", lhs, rhs) + } + + fn __xor__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} ^ {:?}", lhs, rhs) + } + + fn __or__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} | {:?}", lhs, rhs) + } + + fn __pow__(lhs: PyRef, rhs: &PyAny, _mod: Option) -> String { + format!("{:?} ** {:?}", lhs, rhs) + } + + fn __matmul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} @ {:?}", lhs, rhs) + } + + fn __radd__(&self, other: &PyAny) -> String { + format!("{:?} + RA", other) + } + + fn __rsub__(&self, other: &PyAny) -> String { + format!("{:?} - RA", other) + } + + fn __rmul__(&self, other: &PyAny) -> String { + format!("{:?} * RA", other) + } + + fn __rlshift__(&self, other: &PyAny) -> String { + format!("{:?} << RA", other) + } + + fn __rrshift__(&self, other: &PyAny) -> String { + format!("{:?} >> RA", other) + } + + fn __rand__(&self, other: &PyAny) -> String { + format!("{:?} & RA", other) + } + + fn __rxor__(&self, other: &PyAny) -> String { + format!("{:?} ^ RA", other) + } + + fn __ror__(&self, other: &PyAny) -> String { + format!("{:?} | RA", other) + } + + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + format!("{:?} ** RA", other) + } + + fn __rmatmul__(&self, other: &PyAny) -> String { + format!("{:?} @ RA", other) + } +} + +#[pyproto] +impl PyObjectProtocol for LhsAndRhs { + fn __repr__(&self) -> &'static str { + "BA" + } +} + +#[test] +fn lhs_fellback_to_rhs() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, LhsAndRhs {}).unwrap(); + // If the light hand value is `LhsAndRhs`, LHS is used. + py_run!(py, c, "assert c + 1 == 'LR + 1'"); + py_run!(py, c, "assert c - 1 == 'LR - 1'"); + py_run!(py, c, "assert c * 1 == 'LR * 1'"); + py_run!(py, c, "assert c << 1 == 'LR << 1'"); + py_run!(py, c, "assert c >> 1 == 'LR >> 1'"); + py_run!(py, c, "assert c & 1 == 'LR & 1'"); + py_run!(py, c, "assert c ^ 1 == 'LR ^ 1'"); + py_run!(py, c, "assert c | 1 == 'LR | 1'"); + py_run!(py, c, "assert c ** 1 == 'LR ** 1'"); + py_run!(py, c, "assert c @ 1 == 'LR @ 1'"); + // Fellback to RHS because of type mismatching + py_run!(py, c, "assert 1 + c == '1 + RA'"); + py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert 1 * c == '1 * RA'"); + py_run!(py, c, "assert 1 << c == '1 << RA'"); + py_run!(py, c, "assert 1 >> c == '1 >> RA'"); + py_run!(py, c, "assert 1 & c == '1 & RA'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ RA'"); + py_run!(py, c, "assert 1 | c == '1 | RA'"); + py_run!(py, c, "assert 1 ** c == '1 ** RA'"); + py_run!(py, c, "assert 1 @ c == '1 @ RA'"); +} + +#[pyclass] +struct RichComparisons {} + +#[pyproto] +impl PyObjectProtocol for RichComparisons { + fn __repr__(&self) -> &'static str { + "RC" + } + + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> String { + match op { + CompareOp::Lt => format!("{} < {:?}", self.__repr__(), other), + CompareOp::Le => format!("{} <= {:?}", self.__repr__(), other), + CompareOp::Eq => format!("{} == {:?}", self.__repr__(), other), + CompareOp::Ne => format!("{} != {:?}", self.__repr__(), other), + CompareOp::Gt => format!("{} > {:?}", self.__repr__(), other), + CompareOp::Ge => format!("{} >= {:?}", self.__repr__(), other), + } + } +} + +#[pyclass] +struct RichComparisons2 {} + +#[pyproto] +impl PyObjectProtocol for RichComparisons2 { + fn __repr__(&self) -> &'static str { + "RC2" + } + + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyObject { + match op { + CompareOp::Eq => true.into_py(other.py()), + CompareOp::Ne => false.into_py(other.py()), + _ => other.py().NotImplemented(), + } + } +} + +#[test] +fn rich_comparisons() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, RichComparisons {}).unwrap(); + py_run!(py, c, "assert (c < c) == 'RC < RC'"); + py_run!(py, c, "assert (c < 1) == 'RC < 1'"); + py_run!(py, c, "assert (1 < c) == 'RC > 1'"); + py_run!(py, c, "assert (c <= c) == 'RC <= RC'"); + py_run!(py, c, "assert (c <= 1) == 'RC <= 1'"); + py_run!(py, c, "assert (1 <= c) == 'RC >= 1'"); + py_run!(py, c, "assert (c == c) == 'RC == RC'"); + py_run!(py, c, "assert (c == 1) == 'RC == 1'"); + py_run!(py, c, "assert (1 == c) == 'RC == 1'"); + py_run!(py, c, "assert (c != c) == 'RC != RC'"); + py_run!(py, c, "assert (c != 1) == 'RC != 1'"); + py_run!(py, c, "assert (1 != c) == 'RC != 1'"); + py_run!(py, c, "assert (c > c) == 'RC > RC'"); + py_run!(py, c, "assert (c > 1) == 'RC > 1'"); + py_run!(py, c, "assert (1 > c) == 'RC < 1'"); + py_run!(py, c, "assert (c >= c) == 'RC >= RC'"); + py_run!(py, c, "assert (c >= 1) == 'RC >= 1'"); + py_run!(py, c, "assert (1 >= c) == 'RC <= 1'"); +} + +#[test] +fn rich_comparisons_python_3_type_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c2 = PyCell::new(py, RichComparisons2 {}).unwrap(); + py_expect_exception!(py, c2, "c2 < c2", PyTypeError); + py_expect_exception!(py, c2, "c2 < 1", PyTypeError); + py_expect_exception!(py, c2, "1 < c2", PyTypeError); + py_expect_exception!(py, c2, "c2 <= c2", PyTypeError); + py_expect_exception!(py, c2, "c2 <= 1", PyTypeError); + py_expect_exception!(py, c2, "1 <= c2", PyTypeError); + py_run!(py, c2, "assert (c2 == c2) == True"); + py_run!(py, c2, "assert (c2 == 1) == True"); + py_run!(py, c2, "assert (1 == c2) == True"); + py_run!(py, c2, "assert (c2 != c2) == False"); + py_run!(py, c2, "assert (c2 != 1) == False"); + py_run!(py, c2, "assert (1 != c2) == False"); + py_expect_exception!(py, c2, "c2 > c2", PyTypeError); + py_expect_exception!(py, c2, "c2 > 1", PyTypeError); + py_expect_exception!(py, c2, "1 > c2", PyTypeError); + py_expect_exception!(py, c2, "c2 >= c2", PyTypeError); + py_expect_exception!(py, c2, "c2 >= 1", PyTypeError); + py_expect_exception!(py, c2, "1 >= c2", PyTypeError); +} + +// Checks that binary operations for which the arguments don't match the +// required type, return NotImplemented. +mod return_not_implemented { + use super::*; + + #[pyclass] + struct RichComparisonToSelf {} + + #[pyproto] + impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf { + fn __repr__(&self) -> &'static str { + "RC_Self" + } + + fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject { + other.py().None() + } + } + + #[pyproto] + impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf { + fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option) -> &'p PyAny { + lhs + } + fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + + // Inplace assignments + fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {} + } + + fn _test_binary_dunder(dunder: &str) { + let gil = Python::acquire_gil(); + let py = gil.python(); + let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap(); + py_run!( + py, + c2, + &format!( + "class Other: pass\nassert c2.__{}__(Other()) is NotImplemented", + dunder + ) + ); + } + + fn _test_binary_operator(operator: &str, dunder: &str) { + _test_binary_dunder(dunder); + + let gil = Python::acquire_gil(); + let py = gil.python(); + let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap(); + py_expect_exception!( + py, + c2, + &format!("class Other: pass\nc2 {} Other()", operator), + PyTypeError + ); + } + + fn _test_inplace_binary_operator(operator: &str, dunder: &str) { + _test_binary_operator(operator, dunder); + } + + #[test] + fn equality() { + _test_binary_dunder("eq"); + _test_binary_dunder("ne"); + } + + #[test] + fn ordering() { + _test_binary_operator("<", "lt"); + _test_binary_operator("<=", "le"); + _test_binary_operator(">", "gt"); + _test_binary_operator(">=", "ge"); + } + + #[test] + fn bitwise() { + _test_binary_operator("&", "and"); + _test_binary_operator("|", "or"); + _test_binary_operator("^", "xor"); + _test_binary_operator("<<", "lshift"); + _test_binary_operator(">>", "rshift"); + } + + #[test] + fn arith() { + _test_binary_operator("+", "add"); + _test_binary_operator("-", "sub"); + _test_binary_operator("*", "mul"); + _test_binary_operator("@", "matmul"); + _test_binary_operator("/", "truediv"); + _test_binary_operator("//", "floordiv"); + _test_binary_operator("%", "mod"); + _test_binary_operator("**", "pow"); + } + + #[test] + #[ignore] + fn reverse_arith() { + _test_binary_dunder("radd"); + _test_binary_dunder("rsub"); + _test_binary_dunder("rmul"); + _test_binary_dunder("rmatmul"); + _test_binary_dunder("rtruediv"); + _test_binary_dunder("rfloordiv"); + _test_binary_dunder("rmod"); + _test_binary_dunder("rpow"); + } + + #[test] + fn inplace_bitwise() { + _test_inplace_binary_operator("&=", "iand"); + _test_inplace_binary_operator("|=", "ior"); + _test_inplace_binary_operator("^=", "ixor"); + _test_inplace_binary_operator("<<=", "ilshift"); + _test_inplace_binary_operator(">>=", "irshift"); + } + + #[test] + fn inplace_arith() { + _test_inplace_binary_operator("+=", "iadd"); + _test_inplace_binary_operator("-=", "isub"); + _test_inplace_binary_operator("*=", "imul"); + _test_inplace_binary_operator("@=", "imatmul"); + _test_inplace_binary_operator("/=", "itruediv"); + _test_inplace_binary_operator("//=", "ifloordiv"); + _test_inplace_binary_operator("%=", "imod"); + _test_inplace_binary_operator("**=", "ipow"); + } +} diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index d6731d7b..afcbca0d 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -550,4 +550,5 @@ assert c.counter.count == 3 // TODO: test __delete__ // TODO: test __anext__, __aiter__ +// TODO: test __index__, __int__, __float__, __invert__ // TODO: better argument casting errors From 75c0116f6a41bbf95e630fc1ffaa1613bdda11f2 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Fri, 17 Sep 2021 08:43:52 +0100 Subject: [PATCH 07/12] pymethods: cleanup macros generating setdel slots --- src/class/impl_.rs | 339 +++++++++++++++++---------------------------- 1 file changed, 129 insertions(+), 210 deletions(-) diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 57c86203..ad2515c6 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -108,229 +108,148 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } } -#[allow(non_camel_case_types)] -pub trait PyClass__setattr__SlotFragment: Sized { - #[inline] - #[allow(non_snake_case)] - fn __setattr__implemented(self) -> bool { - false - } - - #[inline] - unsafe fn __setattr__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - _value: NonNull, - ) -> PyResult<()> { - Err(PyAttributeError::new_err("can't set attribute")) - } -} - -impl PyClass__setattr__SlotFragment for &'_ PyClassImplCollector {} - -#[allow(non_camel_case_types)] -pub trait PyClass__delattr__SlotFragment: Sized { - #[inline] - #[allow(non_snake_case)] - fn __delattr__implemented(self) -> bool { - false - } - - #[inline] - unsafe fn __delattr__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - ) -> PyResult<()> { - Err(PyAttributeError::new_err("can't delete attribute")) - } -} - -impl PyClass__delattr__SlotFragment for &'_ PyClassImplCollector {} - -#[doc(hidden)] -#[macro_export] -macro_rules! generate_pyclass_setattr_slot { - ($cls:ty) => {{ - use ::std::option::Option::*; - use $crate::class::impl_::*; - let collector = PyClassImplCollector::<$cls>::new(); - if collector.__setattr__implemented() || collector.__delattr__implemented() { - unsafe extern "C" fn __wrap( - _slf: *mut $crate::ffi::PyObject, - attr: *mut $crate::ffi::PyObject, - value: *mut $crate::ffi::PyObject, - ) -> ::std::os::raw::c_int { - $crate::callback::handle_panic(|py| { - let collector = PyClassImplCollector::<$cls>::new(); - $crate::callback::convert(py, { - if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.__setattr__(_slf, attr, value) - } else { - collector.__delattr__(_slf, attr) - } - }) - }) +macro_rules! slot_fragment_trait { + ($trait_name:ident, $implemented_name:ident, $($default_method:tt)*) => { + #[allow(non_camel_case_types)] + pub trait $trait_name: Sized { + #[inline] + #[allow(non_snake_case)] + fn $implemented_name(self) -> bool { + false } - Some($crate::ffi::PyType_Slot { - slot: $crate::ffi::Py_tp_setattro, - pfunc: __wrap as $crate::ffi::setattrofunc as _, - }) - } else { - None + + $($default_method)* } - }}; -} -#[allow(non_camel_case_types)] -pub trait PyClass__set__SlotFragment: Sized { - #[inline] - #[allow(non_snake_case)] - fn __set__implemented(self) -> bool { - false - } - - #[inline] - unsafe fn __set__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - _value: NonNull, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err("can't set descriptor")) + impl $trait_name for &'_ PyClassImplCollector {} } } -impl PyClass__set__SlotFragment for &'_ PyClassImplCollector {} +/// Macro which expands to three items +/// - Trait for a __setitem__ dunder +/// - Trait for the corresponding __delitem__ dunder +/// - A macro which will use dtolnay specialisation to generate the shared slot for the two dunders +macro_rules! define_pyclass_setattr_slot { + ( + $set_trait:ident, + $del_trait:ident, + $set_implemented:ident, + $del_implemented:ident, + $set:ident, + $del:ident, + $set_error:expr, + $del_error:expr, + $generate_macro:ident, + $slot:ident, + $func_ty:ident, + ) => { + slot_fragment_trait! { + $set_trait, + $set_implemented, -#[allow(non_camel_case_types)] -pub trait PyClass__delete__SlotFragment: Sized { - #[allow(non_snake_case)] - #[inline] - fn __delete__implemented(self) -> bool { - false - } - - #[inline] - unsafe fn __delete__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err("can't delete descriptor")) - } -} - -impl PyClass__delete__SlotFragment for &'_ PyClassImplCollector {} - -#[doc(hidden)] -#[macro_export] -macro_rules! generate_pyclass_setdescr_slot { - ($cls:ty) => {{ - use ::std::option::Option::*; - use $crate::class::impl_::*; - let collector = PyClassImplCollector::<$cls>::new(); - if collector.__set__implemented() || collector.__delete__implemented() { - unsafe extern "C" fn __wrap( - _slf: *mut $crate::ffi::PyObject, - attr: *mut $crate::ffi::PyObject, - value: *mut $crate::ffi::PyObject, - ) -> ::std::os::raw::c_int { - $crate::callback::handle_panic(|py| { - let collector = PyClassImplCollector::<$cls>::new(); - $crate::callback::convert(py, { - if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.__set__(_slf, attr, value) - } else { - collector.__delete__(_slf, attr) - } - }) - }) + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $set( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + $set_error } - Some($crate::ffi::PyType_Slot { - slot: $crate::ffi::Py_tp_descr_set, - pfunc: __wrap as $crate::ffi::descrsetfunc as _, - }) - } else { - None } - }}; -} -#[allow(non_camel_case_types)] -pub trait PyClass__setitem__SlotFragment: Sized { - #[inline] - #[allow(non_snake_case)] - fn __setitem__implemented(self) -> bool { - false - } + slot_fragment_trait! { + $del_trait, + $del_implemented, - #[inline] - unsafe fn __setitem__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - _value: NonNull, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err("can't set item")) - } -} - -impl PyClass__setitem__SlotFragment for &'_ PyClassImplCollector {} - -#[allow(non_camel_case_types)] -pub trait PyClass__delitem__SlotFragment: Sized { - #[allow(non_snake_case)] - #[inline] - fn __delitem__implemented(self) -> bool { - false - } - - #[inline] - unsafe fn __delitem__( - self, - _slf: *mut ffi::PyObject, - _attr: *mut ffi::PyObject, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err("can't delete item")) - } -} - -impl PyClass__delitem__SlotFragment for &'_ PyClassImplCollector {} - -#[doc(hidden)] -#[macro_export] -macro_rules! generate_pyclass_setitem_slot { - ($cls:ty) => {{ - use ::std::option::Option::*; - use $crate::class::impl_::*; - let collector = PyClassImplCollector::<$cls>::new(); - if collector.__setitem__implemented() || collector.__delitem__implemented() { - unsafe extern "C" fn __wrap( - _slf: *mut $crate::ffi::PyObject, - attr: *mut $crate::ffi::PyObject, - value: *mut $crate::ffi::PyObject, - ) -> ::std::os::raw::c_int { - $crate::callback::handle_panic(|py| { - let collector = PyClassImplCollector::<$cls>::new(); - $crate::callback::convert(py, { - if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.__setitem__(_slf, attr, value) - } else { - collector.__delitem__(_slf, attr) - } - }) - }) + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $del( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + ) -> PyResult<()> { + $del_error } - Some($crate::ffi::PyType_Slot { - slot: $crate::ffi::Py_mp_ass_subscript, - pfunc: __wrap as $crate::ffi::objobjargproc as _, - }) - } else { - None } - }}; + + #[doc(hidden)] + #[macro_export] + macro_rules! $generate_macro { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + if collector.$set_implemented() || collector.$del_implemented() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.$set(_slf, attr, value) + } else { + collector.$del(_slf, attr) + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, + }) + } else { + None + } + }}; + } + }; +} + +define_pyclass_setattr_slot! { + PyClass__setattr__SlotFragment, + PyClass__delattr__SlotFragment, + __setattr__implemented, + __delattr__implemented, + __setattr__, + __delattr__, + Err(PyAttributeError::new_err("can't set attribute")), + Err(PyAttributeError::new_err("can't delete attribute")), + generate_pyclass_setattr_slot, + Py_tp_setattro, + setattrofunc, +} + +define_pyclass_setattr_slot! { + PyClass__set__SlotFragment, + PyClass__delete__SlotFragment, + __set__implemented, + __delete__implemented, + __set__, + __delete__, + Err(PyNotImplementedError::new_err("can't set descriptor")), + Err(PyNotImplementedError::new_err("can't delete descriptor")), + generate_pyclass_setdescr_slot, + Py_tp_descr_set, + descrsetfunc, +} + +define_pyclass_setattr_slot! { + PyClass__setitem__SlotFragment, + PyClass__delitem__SlotFragment, + __setitem__implemented, + __delitem__implemented, + __setitem__, + __delitem__, + Err(PyNotImplementedError::new_err("can't set item")), + Err(PyNotImplementedError::new_err("can't delete item")), + generate_pyclass_setitem_slot, + Py_mp_ass_subscript, + objobjargproc, } pub trait PyClassAllocImpl { From 43eb762346896afeb10281061b7268a6632d5df8 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 18 Sep 2021 00:31:17 +0100 Subject: [PATCH 08/12] pymethods: support most numerical methods --- pyo3-macros-backend/src/method.rs | 32 +++- pyo3-macros-backend/src/pyclass.rs | 33 ++++ pyo3-macros-backend/src/pymethod.rs | 172 ++++++++++++--------- src/class/impl_.rs | 232 +++++++++++++++++++++++++++- tests/test_proto_methods.rs | 1 + 5 files changed, 379 insertions(+), 91 deletions(-) diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 71b4c5bc..3646e360 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -93,10 +93,10 @@ pub enum FnType { } impl FnType { - pub fn self_conversion(&self, cls: Option<&syn::Type>) -> TokenStream { + pub fn self_conversion(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream { match self { FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => { - st.receiver(cls.expect("no class given for Fn with a \"self\" receiver")) + st.receiver(cls.expect("no class given for Fn with a \"self\" receiver"), error_mode) } FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => { quote!() @@ -128,26 +128,44 @@ pub enum SelfType { TryFromPyCell(Span), } +pub enum ExtractErrorMode { + NotImplemented, + Raise, +} + impl SelfType { - pub fn receiver(&self, cls: &syn::Type) -> TokenStream { + pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream { + let cell = match error_mode { + ExtractErrorMode::Raise => { + quote! { _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>()? } + }, + ExtractErrorMode::NotImplemented => { + quote! { + match _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>() { + ::std::result::Result::Ok(cell) => cell, + ::std::result::Result::Err(_) => return ::pyo3::callback::convert(_py, _py.NotImplemented()), + } + } + }, + }; match self { SelfType::Receiver { mutable: false } => { quote! { - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; let _ref = _cell.try_borrow()?; let _slf = &_ref; } } SelfType::Receiver { mutable: true } => { quote! { - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; let mut _ref = _cell.try_borrow_mut()?; let _slf = &mut _ref; } } SelfType::TryFromPyCell(span) => { quote_spanned! { *span => - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; #[allow(clippy::useless_conversion)] // In case _slf is PyCell let _slf = std::convert::TryFrom::try_from(_cell)?; } @@ -442,7 +460,7 @@ impl<'a> FnSpec<'a> { cls: Option<&syn::Type>, ) -> Result { let deprecations = &self.deprecations; - let self_conversion = self.tp.self_conversion(cls); + let self_conversion = self.tp.self_conversion(cls, ExtractErrorMode::Raise); let self_arg = self.tp.self_arg(); let arg_names = (0..self.args.len()) .map(|pos| syn::Ident::new(&format!("arg{}", pos), Span::call_site())) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 98ec071a..adb87ab1 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -602,6 +602,39 @@ fn impl_class( if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setitem_slot!(#cls) { generated_slots.push(setdescr); } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_add_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_sub_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mul_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mod_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_divmod_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_lshift_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_rshift_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_and_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_or_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_xor_slot!(#cls) { + generated_slots.push(setdescr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_matmul_slot!(#cls) { + generated_slots.push(setdescr); + } visitor(&generated_slots); } diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index e6099712..b6508e58 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use crate::attributes::NameAttribute; +use crate::method::ExtractErrorMode; use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc}; use crate::{deprecations::Deprecations, utils}; use crate::{ @@ -31,12 +32,15 @@ pub fn gen_py_method( ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; - if let Some(slot_def) = pyproto(&spec.python_name.to_string()) { + let method_name = spec.python_name.to_string(); + + if let Some(slot_def) = pyproto(&method_name) { let slot = slot_def.generate_type_slot(cls, &spec)?; return Ok(GeneratedPyMethod::Proto(slot)); } - if let Some(proto) = pyproto_fragment(cls, &spec)? { + if let Some(slot_fragment_def) = pyproto_fragment(&method_name) { + let proto = slot_fragment_def.generate_pyproto_fragment(cls, &spec)?; return Ok(GeneratedPyMethod::TraitImpl(proto)); } @@ -212,8 +216,8 @@ pub fn impl_py_setter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls), - PropertyType::Function { self_type, .. } => self_type.receiver(cls), + PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise), + PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise), }; Ok(quote! { ::pyo3::class::PyMethodDefType::Setter({ @@ -288,8 +292,8 @@ pub fn impl_py_getter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls), - PropertyType::Function { self_type, .. } => self_type.receiver(cls), + PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise), + PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise), }; Ok(quote! { ::pyo3::class::PyMethodDefType::Getter({ @@ -515,6 +519,7 @@ enum Ty { Int, PyHashT, PySsizeT, + Void, } impl Ty { @@ -525,6 +530,7 @@ impl Ty { Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, Ty::PySsizeT => quote! { ::pyo3::ffi::Py_ssize_t }, + Ty::Void => quote! { () }, } } @@ -577,7 +583,7 @@ impl Ty { let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; }, - Ty::Int | Ty::PyHashT | Ty::PySsizeT => todo!(), + Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(), } } } @@ -705,7 +711,7 @@ impl SlotDef { let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); let ret_ty = ret_ty.ffi_type(); - let body = generate_method_body(cls, spec, &py, arguments, return_mode.as_ref())?; + let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::Raise, return_mode.as_ref())?; Ok(quote!({ unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { let _slf = _raw_slf; @@ -737,9 +743,10 @@ fn generate_method_body( spec: &FnSpec, py: &syn::Ident, arguments: &[Ty], + extract_error_mode: ExtractErrorMode, return_mode: Option<&ReturnMode>, ) -> Result { - let self_conversion = spec.tp.self_conversion(Some(cls)); + let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode); let rust_name = spec.name; let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?; let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; @@ -755,78 +762,89 @@ fn generate_method_body( }) } -fn generate_pyproto_fragment( - cls: &syn::Type, - spec: &FnSpec, - fragment: &str, - arguments: &[Ty], -) -> Result { - let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); - let implemented = format_ident!("{}implemented", fragment); - let method = syn::Ident::new(fragment, Span::call_site()); - let py = syn::Ident::new("_py", Span::call_site()); - let method_arguments = generate_method_arguments(arguments); - let body = generate_method_body(cls, spec, &py, arguments, None)?; - Ok(quote! { - impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - #[inline] - fn #implemented(self) -> bool { true } - - #[inline] - unsafe fn #method( - self, - _raw_slf: *mut ::pyo3::ffi::PyObject, - #(#method_arguments),* - ) -> ::pyo3::PyResult<()> { - let _slf = _raw_slf; - let #py = ::pyo3::Python::assume_gil_acquired(); - #body - } - } - }) +struct SlotFragmentDef { + fragment: &'static str, + arguments: &'static [Ty], + ret_ty: Ty, } -fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result> { - match spec.python_name.to_string().as_str() { - "__setattr__" => Some(generate_pyproto_fragment( - cls, - spec, - "__setattr__", - &[Ty::Object, Ty::NonNullObject], - )), - "__delattr__" => Some(generate_pyproto_fragment( - cls, - spec, - "__delattr__", - &[Ty::Object], - )), - "__set__" => Some(generate_pyproto_fragment( - cls, - spec, - "__set__", - &[Ty::Object, Ty::NonNullObject], - )), - "__delete__" => Some(generate_pyproto_fragment( - cls, - spec, - "__delete__", - &[Ty::Object], - )), - "__setitem__" => Some(generate_pyproto_fragment( - cls, - spec, - "__setitem__", - &[Ty::Object, Ty::NonNullObject], - )), - "__delitem__" => Some(generate_pyproto_fragment( - cls, - spec, - "__delitem__", - &[Ty::Object], - )), +impl SlotFragmentDef { + const fn new(fragment: &'static str, arguments: &'static [Ty]) -> Self { + SlotFragmentDef { + fragment, + arguments, + ret_ty: Ty::Void, + } + } + + const fn ret_ty(mut self, ret_ty: Ty) -> Self { + self.ret_ty = ret_ty; + self + } + + fn generate_pyproto_fragment(&self, cls: &syn::Type, spec: &FnSpec) -> Result { + let SlotFragmentDef { + fragment, + arguments, + ret_ty, + } = self; + let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); + let implemented = format_ident!("{}implemented", fragment); + let method = syn::Ident::new(fragment, Span::call_site()); + let py = syn::Ident::new("_py", Span::call_site()); + let method_arguments = generate_method_arguments(arguments); + let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::NotImplemented, None)?; + let ret_ty = ret_ty.ffi_type(); + Ok(quote! { + impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn #implemented(self) -> bool { true } + + #[inline] + unsafe fn #method( + self, + #py: ::pyo3::Python, + _raw_slf: *mut ::pyo3::ffi::PyObject, + #(#method_arguments),* + ) -> ::pyo3::PyResult<#ret_ty> { + let _slf = _raw_slf; + #body + } + } + }) + } +} + +const __SETATTR__: SlotFragmentDef = + SlotFragmentDef::new("__setattr__", &[Ty::Object, Ty::NonNullObject]); +const __DELATTR__: SlotFragmentDef = + SlotFragmentDef::new("__delattr__", &[Ty::Object]); +const __SET__: SlotFragmentDef = + SlotFragmentDef::new("__set__", &[Ty::Object, Ty::NonNullObject]); +const __DELETE__: SlotFragmentDef = + SlotFragmentDef::new("__delete__", &[Ty::Object]); +const __SETITEM__: SlotFragmentDef = + SlotFragmentDef::new("__setitem__", &[Ty::Object, Ty::NonNullObject]); +const __DELITEM__: SlotFragmentDef = + SlotFragmentDef::new("__delitem__", &[Ty::Object]); + +const __ADD__: SlotFragmentDef = + SlotFragmentDef::new("__add__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object); +const __RADD__: SlotFragmentDef = + SlotFragmentDef::new("__radd__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object); + +fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { + match method_name { + "__setattr__" => Some(&__SETATTR__), + "__delattr__" => Some(&__DELATTR__), + "__set__" => Some(&__SET__), + "__delete__" => Some(&__DELETE__), + "__setitem__" => Some(&__SETITEM__), + "__delitem__" => Some(&__DELITEM__), + "__add__" => Some(&__ADD__), + "__radd__" => Some(&__RADD__), _ => None, } - .transpose() } fn extract_proto_arguments( diff --git a/src/class/impl_.rs b/src/class/impl_.rs index ad2515c6..d2fe3814 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -151,6 +151,7 @@ macro_rules! define_pyclass_setattr_slot { #[inline] unsafe fn $set( self, + _py: Python, _slf: *mut ffi::PyObject, _attr: *mut ffi::PyObject, _value: NonNull, @@ -167,6 +168,7 @@ macro_rules! define_pyclass_setattr_slot { #[inline] unsafe fn $del( self, + _py: Python, _slf: *mut ffi::PyObject, _attr: *mut ffi::PyObject, ) -> PyResult<()> { @@ -187,15 +189,14 @@ macro_rules! define_pyclass_setattr_slot { attr: *mut $crate::ffi::PyObject, value: *mut $crate::ffi::PyObject, ) -> ::std::os::raw::c_int { + use $crate::callback::IntoPyCallbackOutput; $crate::callback::handle_panic(|py| { let collector = PyClassImplCollector::<$cls>::new(); - $crate::callback::convert(py, { - if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.$set(_slf, attr, value) - } else { - collector.$del(_slf, attr) - } - }) + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.$set(py, _slf, attr, value).convert(py) + } else { + collector.$del(py, _slf, attr).convert(py) + } }) } Some($crate::ffi::PyType_Slot { @@ -252,6 +253,223 @@ define_pyclass_setattr_slot! { objobjargproc, } +/// Macro which expands to three items +/// - Trait for a lhs dunder e.g. __add__ +/// - Trait for the corresponding rhs e.g. __radd__ +/// - A macro which will use dtolnay specialisation to generate the shared slot for the two dunders +macro_rules! define_pyclass_binary_operator_slot { + ( + $lhs_trait:ident, + $rhs_trait:ident, + $lhs_implemented:ident, + $rhs_implemented:ident, + $lhs:ident, + $rhs:ident, + $generate_macro:ident, + $slot:ident, + $func_ty:ident, + ) => { + slot_fragment_trait! { + $lhs_trait, + $lhs_implemented, + + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $lhs( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } + } + + slot_fragment_trait! { + $rhs_trait, + $rhs_implemented, + + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $rhs( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } + } + + #[doc(hidden)] + #[macro_export] + macro_rules! $generate_macro { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + if collector.$lhs_implemented() || collector.$rhs_implemented() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.$lhs(py, _slf, _other)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.$rhs(py, _other, _slf) + } else { + ::std::result::Result::Ok(lhs_result) + } + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, + }) + } else { + None + } + }}; + } + }; +} + +define_pyclass_binary_operator_slot! { + PyClass__add__SlotFragment, + PyClass__radd__SlotFragment, + __add__implemented, + __radd__implemented, + __add__, + __radd__, + generate_pyclass_add_slot, + Py_nb_add, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__sub__SlotFragment, + PyClass__rsub__SlotFragment, + __sub__implemented, + __rsub__implemented, + __sub__, + __rsub__, + generate_pyclass_sub_slot, + Py_nb_subtract, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__mul__SlotFragment, + PyClass__rmul__SlotFragment, + __mul__implemented, + __rmul__implemented, + __mul__, + __rmul__, + generate_pyclass_mul_slot, + Py_nb_multiply, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__mod__SlotFragment, + PyClass__rmod__SlotFragment, + __mod__implemented, + __rmod__implemented, + __mod__, + __rmod__, + generate_pyclass_mod_slot, + Py_nb_remainder, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__divmod__SlotFragment, + PyClass__rdivmod__SlotFragment, + __divmod__implemented, + __rdivmod__implemented, + __divmod__, + __rdivmod__, + generate_pyclass_divmod_slot, + Py_nb_divmod, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__lshift__SlotFragment, + PyClass__rlshift__SlotFragment, + __lshift__implemented, + __rlshift__implemented, + __lshift__, + __rlshift__, + generate_pyclass_lshift_slot, + Py_nb_lshift, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__rshift__SlotFragment, + PyClass__rrshift__SlotFragment, + __rshift__implemented, + __rrshift__implemented, + __rshift__, + __rrshift__, + generate_pyclass_rshift_slot, + Py_nb_rshift, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__and__SlotFragment, + PyClass__rand__SlotFragment, + __and__implemented, + __rand__implemented, + __and__, + __rand__, + generate_pyclass_and_slot, + Py_nb_and, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__or__SlotFragment, + PyClass__ror__SlotFragment, + __or__implemented, + __ror__implemented, + __or__, + __ror__, + generate_pyclass_or_slot, + Py_nb_or, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__xor__SlotFragment, + PyClass__rxor__SlotFragment, + __xor__implemented, + __rxor__implemented, + __xor__, + __rxor__, + generate_pyclass_xor_slot, + Py_nb_xor, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__matmul__SlotFragment, + PyClass__rmatmul__SlotFragment, + __matmul__implemented, + __rmatmul__implemented, + __matmul__, + __rmatmul__, + generate_pyclass_matmul_slot, + Py_nb_matrix_multiply, + binaryfunc, +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index afcbca0d..50f1e4a4 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -551,4 +551,5 @@ assert c.counter.count == 3 // TODO: test __delete__ // TODO: test __anext__, __aiter__ // TODO: test __index__, __int__, __float__, __invert__ +// TODO: __floordiv__, __truediv__ // TODO: better argument casting errors From c2d78ca76e139f919f1f390483d74f22cde3ca82 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 18 Sep 2021 09:49:05 +0100 Subject: [PATCH 09/12] pymethods: faster compilation for protos, tidy ups --- pyo3-macros-backend/src/method.rs | 19 ++-- pyo3-macros-backend/src/pyclass.rs | 44 --------- pyo3-macros-backend/src/pyimpl.rs | 45 +++++++++- pyo3-macros-backend/src/pymethod.rs | 133 +++++++++++++++++++++++----- src/class/impl_.rs | 129 ++++++++------------------- tests/test_arithmetics.rs | 4 +- tests/test_proto_methods.rs | 14 ++- 7 files changed, 212 insertions(+), 176 deletions(-) diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 3646e360..27b9c612 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -93,11 +93,17 @@ pub enum FnType { } impl FnType { - pub fn self_conversion(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream { + pub fn self_conversion( + &self, + cls: Option<&syn::Type>, + error_mode: ExtractErrorMode, + ) -> TokenStream { match self { - FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => { - st.receiver(cls.expect("no class given for Fn with a \"self\" receiver"), error_mode) - } + FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => st + .receiver( + cls.expect("no class given for Fn with a \"self\" receiver"), + error_mode, + ), FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => { quote!() } @@ -128,6 +134,7 @@ pub enum SelfType { TryFromPyCell(Span), } +#[derive(Clone, Copy)] pub enum ExtractErrorMode { NotImplemented, Raise, @@ -138,7 +145,7 @@ impl SelfType { let cell = match error_mode { ExtractErrorMode::Raise => { quote! { _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>()? } - }, + } ExtractErrorMode::NotImplemented => { quote! { match _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>() { @@ -146,7 +153,7 @@ impl SelfType { ::std::result::Result::Err(_) => return ::pyo3::callback::convert(_py, _py.NotImplemented()), } } - }, + } }; match self { SelfType::Receiver { mutable: false } => { diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index adb87ab1..baf0041a 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -592,50 +592,6 @@ fn impl_class( visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); visitor(collector.methods_protocol_slots()); - let mut generated_slots = ::std::vec::Vec::new(); - if let ::std::option::Option::Some(setattr) = ::pyo3::generate_pyclass_setattr_slot!(#cls) { - generated_slots.push(setattr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setdescr_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setitem_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_add_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_sub_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mul_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mod_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_divmod_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_lshift_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_rshift_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_and_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_or_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_xor_slot!(#cls) { - generated_slots.push(setdescr); - } - if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_matmul_slot!(#cls) { - generated_slots.push(setdescr); - } - visitor(&generated_slots); } fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 5d2dd881..4fc1c7da 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -1,5 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use std::collections::HashSet; + use crate::{ konst::{ConstAttributes, ConstSpec}, pyfunction::PyFunctionOptions, @@ -40,6 +42,9 @@ pub fn impl_methods( let mut trait_impls = Vec::new(); let mut proto_impls = Vec::new(); let mut methods = Vec::new(); + + let mut implemented_proto_fragments = HashSet::new(); + for iimpl in impls.iter_mut() { match iimpl { syn::ImplItem::Method(meth) => { @@ -53,6 +58,11 @@ pub fn impl_methods( let attrs = get_cfg_attributes(&meth.attrs); trait_impls.push(quote!(#(#attrs)* #token_stream)); } + GeneratedPyMethod::SlotTraitImpl(method_name, token_stream) => { + implemented_proto_fragments.insert(method_name); + let attrs = get_cfg_attributes(&meth.attrs); + trait_impls.push(quote!(#(#attrs)* #token_stream)); + } GeneratedPyMethod::Proto(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); proto_impls.push(quote!(#(#attrs)* #token_stream)) @@ -81,7 +91,9 @@ pub fn impl_methods( }; let protos_registration = match methods_type { - PyClassMethodsType::Specialization => Some(impl_protos(ty, proto_impls)), + PyClassMethodsType::Specialization => { + Some(impl_protos(ty, proto_impls, implemented_proto_fragments)) + } PyClassMethodsType::Inventory => { if proto_impls.is_empty() { None @@ -135,7 +147,36 @@ fn impl_py_methods(ty: &syn::Type, methods: Vec) -> TokenStream { } } -fn impl_protos(ty: &syn::Type, proto_impls: Vec) -> TokenStream { +fn impl_protos( + ty: &syn::Type, + mut proto_impls: Vec, + mut implemented_proto_fragments: HashSet, +) -> TokenStream { + macro_rules! try_add_shared_slot { + ($first:literal, $second:literal, $slot:ident) => {{ + let first_implemented = implemented_proto_fragments.remove($first); + let second_implemented = implemented_proto_fragments.remove($second); + if first_implemented || second_implemented { + proto_impls.push(quote! { ::pyo3::$slot!(#ty) }) + } + }}; + } + + try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot); + try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot); + try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot); + try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot); + try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot); + try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot); + try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot); + try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot); + try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot); + try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot); + try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot); + try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot); + try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot); + try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot); + quote! { impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> for ::pyo3::class::impl_::PyClassImplCollector<#ty> diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index b6508e58..e0698262 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -19,6 +19,7 @@ pub enum GeneratedPyMethod { Method(TokenStream), Proto(TokenStream), TraitImpl(TokenStream), + SlotTraitImpl(String, TokenStream), } pub fn gen_py_method( @@ -41,7 +42,7 @@ pub fn gen_py_method( if let Some(slot_fragment_def) = pyproto_fragment(&method_name) { let proto = slot_fragment_def.generate_pyproto_fragment(cls, &spec)?; - return Ok(GeneratedPyMethod::TraitImpl(proto)); + return Ok(GeneratedPyMethod::SlotTraitImpl(method_name, proto)); } Ok(match &spec.tp { @@ -216,8 +217,12 @@ pub fn impl_py_setter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise), - PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise), + PropertyType::Descriptor { .. } => { + SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise) + } + PropertyType::Function { self_type, .. } => { + self_type.receiver(cls, ExtractErrorMode::Raise) + } }; Ok(quote! { ::pyo3::class::PyMethodDefType::Setter({ @@ -292,8 +297,12 @@ pub fn impl_py_getter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise), - PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise), + PropertyType::Descriptor { .. } => { + SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise) + } + PropertyType::Function { self_type, .. } => { + self_type.receiver(cls, ExtractErrorMode::Raise) + } }; Ok(quote! { ::pyo3::class::PyMethodDefType::Getter({ @@ -401,8 +410,9 @@ const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc") .return_conversion(TokenGenerator( || quote! { ::pyo3::callback::HashCallbackOutput }, )); -const __RICHCMP__: SlotDef = - SlotDef::new("Py_tp_richcompare", "richcmpfunc").arguments(&[Ty::Object, Ty::CompareOp]); +const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc") + .extract_error_mode(ExtractErrorMode::NotImplemented) + .arguments(&[Ty::ObjectOrNotImplemented, Ty::CompareOp]); const __GET__: SlotDef = SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]); const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); @@ -431,42 +441,55 @@ const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc") .arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); fn pyproto(method_name: &str) -> Option<&'static SlotDef> { @@ -659,6 +682,7 @@ struct SlotDef { arguments: &'static [Ty], ret_ty: Ty, before_call_method: Option, + extract_error_mode: ExtractErrorMode, return_mode: Option, } @@ -670,6 +694,7 @@ impl SlotDef { arguments: &[], ret_ty: Ty::Object, before_call_method: None, + extract_error_mode: ExtractErrorMode::Raise, return_mode: None, } } @@ -694,6 +719,11 @@ impl SlotDef { self } + const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { + self.extract_error_mode = extract_error_mode; + self + } + const fn return_self(mut self) -> Self { self.return_mode = Some(ReturnMode::ReturnSelf); self @@ -705,13 +735,21 @@ impl SlotDef { func_ty, before_call_method, arguments, + extract_error_mode, ret_ty, return_mode, } = self; let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); let ret_ty = ret_ty.ffi_type(); - let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::Raise, return_mode.as_ref())?; + let body = generate_method_body( + cls, + spec, + &py, + arguments, + *extract_error_mode, + return_mode.as_ref(), + )?; Ok(quote!({ unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { let _slf = _raw_slf; @@ -765,6 +803,7 @@ fn generate_method_body( struct SlotFragmentDef { fragment: &'static str, arguments: &'static [Ty], + extract_error_mode: ExtractErrorMode, ret_ty: Ty, } @@ -773,10 +812,16 @@ impl SlotFragmentDef { SlotFragmentDef { fragment, arguments, + extract_error_mode: ExtractErrorMode::Raise, ret_ty: Ty::Void, } } + const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { + self.extract_error_mode = extract_error_mode; + self + } + const fn ret_ty(mut self, ret_ty: Ty) -> Self { self.ret_ty = ret_ty; self @@ -786,19 +831,17 @@ impl SlotFragmentDef { let SlotFragmentDef { fragment, arguments, + extract_error_mode, ret_ty, } = self; let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); - let implemented = format_ident!("{}implemented", fragment); let method = syn::Ident::new(fragment, Span::call_site()); let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); - let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::NotImplemented, None)?; + let body = generate_method_body(cls, spec, &py, arguments, *extract_error_mode, None)?; let ret_ty = ret_ty.ffi_type(); Ok(quote! { impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - #[inline] - fn #implemented(self) -> bool { true } #[inline] unsafe fn #method( @@ -817,21 +860,43 @@ impl SlotFragmentDef { const __SETATTR__: SlotFragmentDef = SlotFragmentDef::new("__setattr__", &[Ty::Object, Ty::NonNullObject]); -const __DELATTR__: SlotFragmentDef = - SlotFragmentDef::new("__delattr__", &[Ty::Object]); -const __SET__: SlotFragmentDef = - SlotFragmentDef::new("__set__", &[Ty::Object, Ty::NonNullObject]); -const __DELETE__: SlotFragmentDef = - SlotFragmentDef::new("__delete__", &[Ty::Object]); +const __DELATTR__: SlotFragmentDef = SlotFragmentDef::new("__delattr__", &[Ty::Object]); +const __SET__: SlotFragmentDef = SlotFragmentDef::new("__set__", &[Ty::Object, Ty::NonNullObject]); +const __DELETE__: SlotFragmentDef = SlotFragmentDef::new("__delete__", &[Ty::Object]); const __SETITEM__: SlotFragmentDef = SlotFragmentDef::new("__setitem__", &[Ty::Object, Ty::NonNullObject]); -const __DELITEM__: SlotFragmentDef = - SlotFragmentDef::new("__delitem__", &[Ty::Object]); +const __DELITEM__: SlotFragmentDef = SlotFragmentDef::new("__delitem__", &[Ty::Object]); -const __ADD__: SlotFragmentDef = - SlotFragmentDef::new("__add__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object); -const __RADD__: SlotFragmentDef = - SlotFragmentDef::new("__radd__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object); +macro_rules! binary_num_slot_fragment_def { + ($ident:ident, $name:literal) => { + const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); + }; +} + +binary_num_slot_fragment_def!(__ADD__, "__add__"); +binary_num_slot_fragment_def!(__RADD__, "__radd__"); +binary_num_slot_fragment_def!(__SUB__, "__sub__"); +binary_num_slot_fragment_def!(__RSUB__, "__rsub__"); +binary_num_slot_fragment_def!(__MUL__, "__mul__"); +binary_num_slot_fragment_def!(__RMUL__, "__rmul__"); +binary_num_slot_fragment_def!(__MATMUL__, "__matmul__"); +binary_num_slot_fragment_def!(__RMATMUL__, "__rmatmul__"); +binary_num_slot_fragment_def!(__DIVMOD__, "__divmod__"); +binary_num_slot_fragment_def!(__RDIVMOD__, "__rdivmod__"); +binary_num_slot_fragment_def!(__MOD__, "__mod__"); +binary_num_slot_fragment_def!(__RMOD__, "__rmod__"); +binary_num_slot_fragment_def!(__LSHIFT__, "__lshift__"); +binary_num_slot_fragment_def!(__RLSHIFT__, "__rlshift__"); +binary_num_slot_fragment_def!(__RSHIFT__, "__rshift__"); +binary_num_slot_fragment_def!(__RRSHIFT__, "__rrshift__"); +binary_num_slot_fragment_def!(__AND__, "__and__"); +binary_num_slot_fragment_def!(__RAND__, "__rand__"); +binary_num_slot_fragment_def!(__XOR__, "__xor__"); +binary_num_slot_fragment_def!(__RXOR__, "__rxor__"); +binary_num_slot_fragment_def!(__OR__, "__or__"); +binary_num_slot_fragment_def!(__ROR__, "__ror__"); fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { match method_name { @@ -843,6 +908,26 @@ fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { "__delitem__" => Some(&__DELITEM__), "__add__" => Some(&__ADD__), "__radd__" => Some(&__RADD__), + "__sub__" => Some(&__SUB__), + "__rsub__" => Some(&__RSUB__), + "__mul__" => Some(&__MUL__), + "__rmul__" => Some(&__RMUL__), + "__matmul__" => Some(&__MATMUL__), + "__rmatmul__" => Some(&__RMATMUL__), + "__divmod__" => Some(&__DIVMOD__), + "__rdivmod__" => Some(&__RDIVMOD__), + "__mod__" => Some(&__MOD__), + "__rmod__" => Some(&__RMOD__), + "__lshift__" => Some(&__LSHIFT__), + "__rlshift__" => Some(&__RLSHIFT__), + "__rshift__" => Some(&__RSHIFT__), + "__rrshift__" => Some(&__RRSHIFT__), + "__and__" => Some(&__AND__), + "__rand__" => Some(&__RAND__), + "__xor__" => Some(&__XOR__), + "__rxor__" => Some(&__RXOR__), + "__or__" => Some(&__OR__), + "__ror__" => Some(&__ROR__), _ => None, } } diff --git a/src/class/impl_.rs b/src/class/impl_.rs index d2fe3814..5a914868 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -109,15 +109,9 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } macro_rules! slot_fragment_trait { - ($trait_name:ident, $implemented_name:ident, $($default_method:tt)*) => { + ($trait_name:ident, $($default_method:tt)*) => { #[allow(non_camel_case_types)] pub trait $trait_name: Sized { - #[inline] - #[allow(non_snake_case)] - fn $implemented_name(self) -> bool { - false - } - $($default_method)* } @@ -133,8 +127,6 @@ macro_rules! define_pyclass_setattr_slot { ( $set_trait:ident, $del_trait:ident, - $set_implemented:ident, - $del_implemented:ident, $set:ident, $del:ident, $set_error:expr, @@ -145,7 +137,6 @@ macro_rules! define_pyclass_setattr_slot { ) => { slot_fragment_trait! { $set_trait, - $set_implemented, /// # Safety: _slf and _attr must be valid non-null Python objects #[inline] @@ -162,7 +153,6 @@ macro_rules! define_pyclass_setattr_slot { slot_fragment_trait! { $del_trait, - $del_implemented, /// # Safety: _slf and _attr must be valid non-null Python objects #[inline] @@ -180,31 +170,26 @@ macro_rules! define_pyclass_setattr_slot { #[macro_export] macro_rules! $generate_macro { ($cls:ty) => {{ - use ::std::option::Option::*; - use $crate::class::impl_::*; - let collector = PyClassImplCollector::<$cls>::new(); - if collector.$set_implemented() || collector.$del_implemented() { - unsafe extern "C" fn __wrap( - _slf: *mut $crate::ffi::PyObject, - attr: *mut $crate::ffi::PyObject, - value: *mut $crate::ffi::PyObject, - ) -> ::std::os::raw::c_int { - use $crate::callback::IntoPyCallbackOutput; - $crate::callback::handle_panic(|py| { - let collector = PyClassImplCollector::<$cls>::new(); - if let Some(value) = ::std::ptr::NonNull::new(value) { - collector.$set(py, _slf, attr, value).convert(py) - } else { - collector.$del(py, _slf, attr).convert(py) - } - }) - } - Some($crate::ffi::PyType_Slot { - slot: $crate::ffi::$slot, - pfunc: __wrap as $crate::ffi::$func_ty as _, + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + use ::std::option::Option::*; + use $crate::callback::IntoPyCallbackOutput; + use $crate::class::impl_::*; + $crate::callback::handle_panic(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.$set(py, _slf, attr, value).convert(py) + } else { + collector.$del(py, _slf, attr).convert(py) + } }) - } else { - None + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, } }}; } @@ -214,8 +199,6 @@ macro_rules! define_pyclass_setattr_slot { define_pyclass_setattr_slot! { PyClass__setattr__SlotFragment, PyClass__delattr__SlotFragment, - __setattr__implemented, - __delattr__implemented, __setattr__, __delattr__, Err(PyAttributeError::new_err("can't set attribute")), @@ -228,8 +211,6 @@ define_pyclass_setattr_slot! { define_pyclass_setattr_slot! { PyClass__set__SlotFragment, PyClass__delete__SlotFragment, - __set__implemented, - __delete__implemented, __set__, __delete__, Err(PyNotImplementedError::new_err("can't set descriptor")), @@ -242,8 +223,6 @@ define_pyclass_setattr_slot! { define_pyclass_setattr_slot! { PyClass__setitem__SlotFragment, PyClass__delitem__SlotFragment, - __setitem__implemented, - __delitem__implemented, __setitem__, __delitem__, Err(PyNotImplementedError::new_err("can't set item")), @@ -261,8 +240,6 @@ macro_rules! define_pyclass_binary_operator_slot { ( $lhs_trait:ident, $rhs_trait:ident, - $lhs_implemented:ident, - $rhs_implemented:ident, $lhs:ident, $rhs:ident, $generate_macro:ident, @@ -271,7 +248,6 @@ macro_rules! define_pyclass_binary_operator_slot { ) => { slot_fragment_trait! { $lhs_trait, - $lhs_implemented, /// # Safety: _slf and _attr must be valid non-null Python objects #[inline] @@ -288,7 +264,6 @@ macro_rules! define_pyclass_binary_operator_slot { slot_fragment_trait! { $rhs_trait, - $rhs_implemented, /// # Safety: _slf and _attr must be valid non-null Python objects #[inline] @@ -307,31 +282,25 @@ macro_rules! define_pyclass_binary_operator_slot { #[macro_export] macro_rules! $generate_macro { ($cls:ty) => {{ - use ::std::option::Option::*; - use $crate::class::impl_::*; - let collector = PyClassImplCollector::<$cls>::new(); - if collector.$lhs_implemented() || collector.$rhs_implemented() { - unsafe extern "C" fn __wrap( - _slf: *mut $crate::ffi::PyObject, - _other: *mut $crate::ffi::PyObject, - ) -> *mut $crate::ffi::PyObject { - $crate::callback::handle_panic(|py| { - let collector = PyClassImplCollector::<$cls>::new(); - let lhs_result = collector.$lhs(py, _slf, _other)?; - if lhs_result == $crate::ffi::Py_NotImplemented() { - $crate::ffi::Py_DECREF(lhs_result); - collector.$rhs(py, _other, _slf) - } else { - ::std::result::Result::Ok(lhs_result) - } - }) - } - Some($crate::ffi::PyType_Slot { - slot: $crate::ffi::$slot, - pfunc: __wrap as $crate::ffi::$func_ty as _, + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.$lhs(py, _slf, _other)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.$rhs(py, _other, _slf) + } else { + ::std::result::Result::Ok(lhs_result) + } }) - } else { - None + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, } }}; } @@ -341,8 +310,6 @@ macro_rules! define_pyclass_binary_operator_slot { define_pyclass_binary_operator_slot! { PyClass__add__SlotFragment, PyClass__radd__SlotFragment, - __add__implemented, - __radd__implemented, __add__, __radd__, generate_pyclass_add_slot, @@ -353,8 +320,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__sub__SlotFragment, PyClass__rsub__SlotFragment, - __sub__implemented, - __rsub__implemented, __sub__, __rsub__, generate_pyclass_sub_slot, @@ -365,8 +330,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__mul__SlotFragment, PyClass__rmul__SlotFragment, - __mul__implemented, - __rmul__implemented, __mul__, __rmul__, generate_pyclass_mul_slot, @@ -377,8 +340,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__mod__SlotFragment, PyClass__rmod__SlotFragment, - __mod__implemented, - __rmod__implemented, __mod__, __rmod__, generate_pyclass_mod_slot, @@ -389,8 +350,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__divmod__SlotFragment, PyClass__rdivmod__SlotFragment, - __divmod__implemented, - __rdivmod__implemented, __divmod__, __rdivmod__, generate_pyclass_divmod_slot, @@ -401,8 +360,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__lshift__SlotFragment, PyClass__rlshift__SlotFragment, - __lshift__implemented, - __rlshift__implemented, __lshift__, __rlshift__, generate_pyclass_lshift_slot, @@ -413,8 +370,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__rshift__SlotFragment, PyClass__rrshift__SlotFragment, - __rshift__implemented, - __rrshift__implemented, __rshift__, __rrshift__, generate_pyclass_rshift_slot, @@ -425,8 +380,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__and__SlotFragment, PyClass__rand__SlotFragment, - __and__implemented, - __rand__implemented, __and__, __rand__, generate_pyclass_and_slot, @@ -437,8 +390,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__or__SlotFragment, PyClass__ror__SlotFragment, - __or__implemented, - __ror__implemented, __or__, __ror__, generate_pyclass_or_slot, @@ -449,8 +400,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__xor__SlotFragment, PyClass__rxor__SlotFragment, - __xor__implemented, - __rxor__implemented, __xor__, __rxor__, generate_pyclass_xor_slot, @@ -461,8 +410,6 @@ define_pyclass_binary_operator_slot! { define_pyclass_binary_operator_slot! { PyClass__matmul__SlotFragment, PyClass__rmatmul__SlotFragment, - __matmul__implemented, - __rmatmul__implemented, __matmul__, __rmatmul__, generate_pyclass_matmul_slot, diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 6f39920d..711cbfb9 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -9,7 +9,9 @@ struct UnaryArithmetic { inner: f64, } +#[pymethods] impl UnaryArithmetic { + #[new] fn new(value: f64) -> Self { UnaryArithmetic { inner: value } } @@ -521,7 +523,7 @@ mod return_not_implemented { fn __mod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { slf } - fn __pow__<'p>(slf: PyRef<'p, Self>, _other: u8, _modulo: Option) -> PyRef<'p, Self> { + fn __pow__(slf: PyRef, _other: u8, _modulo: Option) -> PyRef { slf } fn __lshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 50f1e4a4..8a12d2df 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -158,14 +158,11 @@ fn test_hash() { fn test_richcmp() { Python::with_gil(|py| { let example_py = make_example(py); - assert_eq!( - example_py - .rich_compare(example_py, CompareOp::Eq) - .unwrap() - .is_true() - .unwrap(), - true - ); + assert!(example_py + .rich_compare(example_py, CompareOp::Eq) + .unwrap() + .is_true() + .unwrap()); }) } @@ -552,4 +549,5 @@ assert c.counter.count == 3 // TODO: test __anext__, __aiter__ // TODO: test __index__, __int__, __float__, __invert__ // TODO: __floordiv__, __truediv__ +// TODO: __pow__, __rpow__ // TODO: better argument casting errors From a551b005b424fd69ce8b3285aa66e67c7175fb7c Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 18 Sep 2021 12:59:25 +0100 Subject: [PATCH 10/12] pymethods: finish support for number protocol --- guide/src/class.md | 1 + pyo3-macros-backend/src/pyimpl.rs | 1 + pyo3-macros-backend/src/pymethod.rs | 24 +++++++++++ src/class/impl_.rs | 66 ++++++++++++++++++++++++++++- tests/test_arithmetics.rs | 27 ++++++------ tests/test_proto_methods.rs | 7 --- 6 files changed, 104 insertions(+), 22 deletions(-) diff --git a/guide/src/class.md b/guide/src/class.md index 25eef825..8b1a4241 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -865,6 +865,7 @@ impl pyo3::class::impl_::PyClassImpl for MyClass { visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + visitor(collector.methods_protocol_slots()); } fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 4fc1c7da..18bdcfff 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -176,6 +176,7 @@ fn impl_protos( try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot); try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot); try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot); + try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot); quote! { impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index e0698262..f0fced41 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -439,6 +439,13 @@ const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc"); const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc"); const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); +const __TRUEDIV__: SlotDef = SlotDef::new("Py_nb_true_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); +const __FLOORDIV__: SlotDef = SlotDef::new("Py_nb_floor_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); + const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") .arguments(&[Ty::ObjectOrNotImplemented]) .extract_error_mode(ExtractErrorMode::NotImplemented) @@ -516,6 +523,8 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { "__int__" => Some(&__INT__), "__float__" => Some(&__FLOAT__), "__bool__" => Some(&__BOOL__), + "__truediv__" => Some(&__TRUEDIV__), + "__floordiv__" => Some(&__FLOORDIV__), "__iadd__" => Some(&__IADD__), "__isub__" => Some(&__ISUB__), "__imul__" => Some(&__IMUL__), @@ -898,6 +907,19 @@ binary_num_slot_fragment_def!(__RXOR__, "__rxor__"); binary_num_slot_fragment_def!(__OR__, "__or__"); binary_num_slot_fragment_def!(__ROR__, "__ror__"); +const __POW__: SlotFragmentDef = SlotFragmentDef::new( + "__pow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); +const __RPOW__: SlotFragmentDef = SlotFragmentDef::new( + "__rpow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); + fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { match method_name { "__setattr__" => Some(&__SETATTR__), @@ -928,6 +950,8 @@ fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { "__rxor__" => Some(&__RXOR__), "__or__" => Some(&__OR__), "__ror__" => Some(&__ROR__), + "__pow__" => Some(&__POW__), + "__rpow__" => Some(&__RPOW__), _ => None, } } diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 5a914868..dc38c038 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -249,7 +249,7 @@ macro_rules! define_pyclass_binary_operator_slot { slot_fragment_trait! { $lhs_trait, - /// # Safety: _slf and _attr must be valid non-null Python objects + /// # Safety: _slf and _other must be valid non-null Python objects #[inline] unsafe fn $lhs( self, @@ -265,7 +265,7 @@ macro_rules! define_pyclass_binary_operator_slot { slot_fragment_trait! { $rhs_trait, - /// # Safety: _slf and _attr must be valid non-null Python objects + /// # Safety: _slf and _other must be valid non-null Python objects #[inline] unsafe fn $rhs( self, @@ -417,6 +417,68 @@ define_pyclass_binary_operator_slot! { binaryfunc, } +slot_fragment_trait! { + PyClass__pow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __pow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +slot_fragment_trait! { + PyClass__rpow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __rpow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_pow_slot { + ($cls:ty) => {{ + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + _mod: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.__pow__(py, _slf, _other, _mod)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.__rpow__(py, _other, _slf, _mod) + } else { + ::std::result::Result::Ok(lhs_result) + } + }) + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_nb_power, + pfunc: __wrap as $crate::ffi::ternaryfunc as _, + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 711cbfb9..c045c26b 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -177,24 +177,27 @@ fn binary_arithmetic() { py_run!(py, c, "assert c + c == 'BA + BA'"); py_run!(py, c, "assert c.__add__(c) == 'BA + BA'"); py_run!(py, c, "assert c + 1 == 'BA + 1'"); - py_run!(py, c, "assert 1 + c == '1 + BA'"); py_run!(py, c, "assert c - 1 == 'BA - 1'"); - py_run!(py, c, "assert 1 - c == '1 - BA'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); - py_run!(py, c, "assert 1 * c == '1 * BA'"); - py_run!(py, c, "assert c << 1 == 'BA << 1'"); - py_run!(py, c, "assert 1 << c == '1 << BA'"); py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); - py_run!(py, c, "assert 1 >> c == '1 >> BA'"); py_run!(py, c, "assert c & 1 == 'BA & 1'"); - py_run!(py, c, "assert 1 & c == '1 & BA'"); py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'"); - py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); py_run!(py, c, "assert c | 1 == 'BA | 1'"); - py_run!(py, c, "assert 1 | c == '1 | BA'"); py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'"); - py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'"); + + // Class with __add__ only should not allow the reverse op; + // this is consistent with Python classes. + + py_expect_exception!(py, c, "1 + c", PyTypeError); + py_expect_exception!(py, c, "1 - c", PyTypeError); + py_expect_exception!(py, c, "1 * c", PyTypeError); + py_expect_exception!(py, c, "1 << c", PyTypeError); + py_expect_exception!(py, c, "1 >> c", PyTypeError); + py_expect_exception!(py, c, "1 & c", PyTypeError); + py_expect_exception!(py, c, "1 ^ c", PyTypeError); + py_expect_exception!(py, c, "1 | c", PyTypeError); + py_expect_exception!(py, c, "1 ** c", PyTypeError); py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); } @@ -629,15 +632,13 @@ mod return_not_implemented { } #[test] - #[ignore] fn reverse_arith() { _test_binary_dunder("radd"); _test_binary_dunder("rsub"); _test_binary_dunder("rmul"); _test_binary_dunder("rmatmul"); - _test_binary_dunder("rtruediv"); - _test_binary_dunder("rfloordiv"); _test_binary_dunder("rmod"); + _test_binary_dunder("rdivmod"); _test_binary_dunder("rpow"); } diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 8a12d2df..6d070b47 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -544,10 +544,3 @@ assert c.counter.count == 3 .map_err(|e| e.print(py)) .unwrap(); } - -// TODO: test __delete__ -// TODO: test __anext__, __aiter__ -// TODO: test __index__, __int__, __float__, __invert__ -// TODO: __floordiv__, __truediv__ -// TODO: __pow__, __rpow__ -// TODO: better argument casting errors From 592c98c722f85f2f47287dd8a3bbe0402d37ed24 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 18 Sep 2021 13:08:24 +0100 Subject: [PATCH 11/12] pymethods: disable protocols with multiple-pymethods for now --- pyo3-macros-backend/src/pyclass.rs | 9 ++++++++- tests/test_arithmetics.rs | 2 ++ tests/test_proto_methods.rs | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index baf0041a..c68a67c4 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -464,6 +464,13 @@ fn impl_class( ), }; + let methods_protos = match methods_type { + PyClassMethodsType::Specialization => { + Some(quote! { visitor(collector.methods_protocol_slots()); }) + } + PyClassMethodsType::Inventory => None, + }; + let base = &attr.base; let base_nativetype = if attr.has_extends { quote! { ::BaseNativeType } @@ -591,7 +598,7 @@ fn impl_class( visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); - visitor(collector.methods_protocol_slots()); + #methods_protos } fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index c045c26b..f648ee2e 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -1,3 +1,5 @@ +#![cfg(not(feature = "multiple-pymethods"))] + use pyo3::class::basic::CompareOp; use pyo3::prelude::*; use pyo3::py_run; diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 6d070b47..8e405197 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -1,3 +1,5 @@ +#![cfg(not(feature = "multiple-pymethods"))] + use pyo3::exceptions::PyValueError; use pyo3::types::{PySlice, PyType}; use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; From 179b5d1f47adcc7c596113cdc79180e994e00552 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 18 Sep 2021 13:18:16 +0100 Subject: [PATCH 12/12] pymethods: fix support for MSRV --- pyo3-macros-backend/src/pymethod.rs | 4 +++- src/class/basic.rs | 11 ----------- tests/test_proto_methods.rs | 18 +----------------- 3 files changed, 4 insertions(+), 29 deletions(-) diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index f0fced41..d49898dd 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -695,12 +695,14 @@ struct SlotDef { return_mode: Option, } +const NO_ARGUMENTS: &[Ty] = &[]; + impl SlotDef { const fn new(slot: &'static str, func_ty: &'static str) -> Self { SlotDef { slot: StaticIdent(slot), func_ty: StaticIdent(func_ty), - arguments: &[], + arguments: NO_ARGUMENTS, ret_ty: Ty::Object, before_call_method: None, extract_error_mode: ExtractErrorMode::Raise, diff --git a/src/class/basic.rs b/src/class/basic.rs index 602096c5..1e34d493 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -41,17 +41,6 @@ impl CompareOp { _ => None, } } - - pub fn matches_ordering(self, ordering: std::cmp::Ordering) -> bool { - match self { - CompareOp::Lt => ordering.is_lt(), - CompareOp::Le => ordering.is_le(), - CompareOp::Eq => ordering.is_eq(), - CompareOp::Ne => ordering.is_ne(), - CompareOp::Gt => ordering.is_gt(), - CompareOp::Ge => ordering.is_ge(), - } - } } /// Basic Python class customization diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 8e405197..1cfafefd 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -2,7 +2,7 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PySlice, PyType}; -use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; +use pyo3::{exceptions::PyAttributeError, prelude::*}; use pyo3::{ffi, py_run, AsPyPointer, PyCell}; use std::{isize, iter}; @@ -56,10 +56,6 @@ impl ExampleClass { i64_value as u64 } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { - op.matches_ordering(self.value.cmp(&other.value)) - } - fn __bool__(&self) -> bool { self.value != 0 } @@ -156,18 +152,6 @@ fn test_hash() { }) } -#[test] -fn test_richcmp() { - Python::with_gil(|py| { - let example_py = make_example(py); - assert!(example_py - .rich_compare(example_py, CompareOp::Eq) - .unwrap() - .is_true() - .unwrap()); - }) -} - #[test] fn test_bool() { Python::with_gil(|py| {