diff --git a/guide/src/class.md b/guide/src/class.md index 25eef825..8b1a4241 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -865,6 +865,7 @@ impl pyo3::class::impl_::PyClassImpl for MyClass { visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + visitor(collector.methods_protocol_slots()); } fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 71b4c5bc..27b9c612 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -93,11 +93,17 @@ pub enum FnType { } impl FnType { - pub fn self_conversion(&self, cls: Option<&syn::Type>) -> TokenStream { + pub fn self_conversion( + &self, + cls: Option<&syn::Type>, + error_mode: ExtractErrorMode, + ) -> TokenStream { match self { - FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => { - st.receiver(cls.expect("no class given for Fn with a \"self\" receiver")) - } + FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => st + .receiver( + cls.expect("no class given for Fn with a \"self\" receiver"), + error_mode, + ), FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => { quote!() } @@ -128,26 +134,45 @@ pub enum SelfType { TryFromPyCell(Span), } +#[derive(Clone, Copy)] +pub enum ExtractErrorMode { + NotImplemented, + Raise, +} + impl SelfType { - pub fn receiver(&self, cls: &syn::Type) -> TokenStream { + pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream { + let cell = match error_mode { + ExtractErrorMode::Raise => { + quote! { _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>()? } + } + ExtractErrorMode::NotImplemented => { + quote! { + match _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>() { + ::std::result::Result::Ok(cell) => cell, + ::std::result::Result::Err(_) => return ::pyo3::callback::convert(_py, _py.NotImplemented()), + } + } + } + }; match self { SelfType::Receiver { mutable: false } => { quote! { - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; let _ref = _cell.try_borrow()?; let _slf = &_ref; } } SelfType::Receiver { mutable: true } => { quote! { - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; let mut _ref = _cell.try_borrow_mut()?; let _slf = &mut _ref; } } SelfType::TryFromPyCell(span) => { quote_spanned! { *span => - let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf); + let _cell = #cell; #[allow(clippy::useless_conversion)] // In case _slf is PyCell let _slf = std::convert::TryFrom::try_from(_cell)?; } @@ -442,7 +467,7 @@ impl<'a> FnSpec<'a> { cls: Option<&syn::Type>, ) -> Result { let deprecations = &self.deprecations; - let self_conversion = self.tp.self_conversion(cls); + let self_conversion = self.tp.self_conversion(cls, ExtractErrorMode::Raise); let self_arg = self.tp.self_arg(); let arg_names = (0..self.args.len()) .map(|pos| syn::Ident::new(&format!("arg{}", pos), Span::call_site())) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 034d1784..c68a67c4 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -464,6 +464,13 @@ fn impl_class( ), }; + let methods_protos = match methods_type { + PyClassMethodsType::Specialization => { + Some(quote! { visitor(collector.methods_protocol_slots()); }) + } + PyClassMethodsType::Inventory => None, + }; + let base = &attr.base; let base_nativetype = if attr.has_extends { quote! { ::BaseNativeType } @@ -591,6 +598,7 @@ fn impl_class( visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + #methods_protos } fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 61d20267..18bdcfff 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -1,5 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use std::collections::HashSet; + use crate::{ konst::{ConstAttributes, ConstSpec}, pyfunction::PyFunctionOptions, @@ -37,9 +39,12 @@ pub fn impl_methods( impls: &mut Vec, methods_type: PyClassMethodsType, ) -> syn::Result { - let mut new_impls = Vec::new(); - let mut call_impls = Vec::new(); + let mut trait_impls = Vec::new(); + let mut proto_impls = Vec::new(); let mut methods = Vec::new(); + + let mut implemented_proto_fragments = HashSet::new(); + for iimpl in impls.iter_mut() { match iimpl { syn::ImplItem::Method(meth) => { @@ -49,13 +54,18 @@ pub fn impl_methods( let attrs = get_cfg_attributes(&meth.attrs); methods.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::New(token_stream) => { + GeneratedPyMethod::TraitImpl(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); - new_impls.push(quote!(#(#attrs)* #token_stream)); + trait_impls.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::Call(token_stream) => { + GeneratedPyMethod::SlotTraitImpl(method_name, token_stream) => { + implemented_proto_fragments.insert(method_name); let attrs = get_cfg_attributes(&meth.attrs); - call_impls.push(quote!(#(#attrs)* #token_stream)); + trait_impls.push(quote!(#(#attrs)* #token_stream)); + } + GeneratedPyMethod::Proto(token_stream) => { + let attrs = get_cfg_attributes(&meth.attrs); + proto_impls.push(quote!(#(#attrs)* #token_stream)) } } } @@ -80,10 +90,25 @@ pub fn impl_methods( PyClassMethodsType::Inventory => submit_methods_inventory(ty, methods), }; - Ok(quote! { - #(#new_impls)* + let protos_registration = match methods_type { + PyClassMethodsType::Specialization => { + Some(impl_protos(ty, proto_impls, implemented_proto_fragments)) + } + PyClassMethodsType::Inventory => { + if proto_impls.is_empty() { + None + } else { + panic!( + "cannot implement protos in #[pymethods] using `multiple-pymethods` feature" + ); + } + } + }; - #(#call_impls)* + Ok(quote! { + #(#trait_impls)* + + #protos_registration #methods_registration }) @@ -122,6 +147,48 @@ fn impl_py_methods(ty: &syn::Type, methods: Vec) -> TokenStream { } } +fn impl_protos( + ty: &syn::Type, + mut proto_impls: Vec, + mut implemented_proto_fragments: HashSet, +) -> TokenStream { + macro_rules! try_add_shared_slot { + ($first:literal, $second:literal, $slot:ident) => {{ + let first_implemented = implemented_proto_fragments.remove($first); + let second_implemented = implemented_proto_fragments.remove($second); + if first_implemented || second_implemented { + proto_impls.push(quote! { ::pyo3::$slot!(#ty) }) + } + }}; + } + + try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot); + try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot); + try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot); + try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot); + try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot); + try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot); + try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot); + try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot); + try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot); + try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot); + try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot); + try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot); + try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot); + try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot); + try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot); + + quote! { + impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> + for ::pyo3::class::impl_::PyClassImplCollector<#ty> + { + fn methods_protocol_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { + &[#(#proto_impls),*] + } + } + } +} + fn submit_methods_inventory(ty: &syn::Type, methods: Vec) -> TokenStream { if methods.is_empty() { return TokenStream::default(); diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 93f771ab..d49898dd 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -3,20 +3,23 @@ use std::borrow::Cow; use crate::attributes::NameAttribute; -use crate::utils::{ensure_not_async_fn, PythonDoc}; +use crate::method::ExtractErrorMode; +use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc}; use crate::{deprecations::Deprecations, utils}; use crate::{ method::{FnArg, FnSpec, FnType, SelfType}, pyfunction::PyFunctionOptions, }; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote, ToTokens}; +use syn::Ident; use syn::{ext::IdentExt, spanned::Spanned, Result}; pub enum GeneratedPyMethod { Method(TokenStream), - New(TokenStream), - Call(TokenStream), + Proto(TokenStream), + TraitImpl(TokenStream), + SlotTraitImpl(String, TokenStream), } pub fn gen_py_method( @@ -30,6 +33,18 @@ pub fn gen_py_method( ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; + let method_name = spec.python_name.to_string(); + + if let Some(slot_def) = pyproto(&method_name) { + let slot = slot_def.generate_type_slot(cls, &spec)?; + return Ok(GeneratedPyMethod::Proto(slot)); + } + + if let Some(slot_fragment_def) = pyproto_fragment(&method_name) { + let proto = slot_fragment_def.generate_pyproto_fragment(cls, &spec)?; + return Ok(GeneratedPyMethod::SlotTraitImpl(method_name, proto)); + } + Ok(match &spec.tp { // ordinary functions (with some specialties) FnType::Fn(_) => GeneratedPyMethod::Method(impl_py_method_def(cls, &spec, None)?), @@ -44,8 +59,8 @@ pub fn gen_py_method( Some(quote!(::pyo3::ffi::METH_STATIC)), )?), // special prototypes - FnType::FnNew => GeneratedPyMethod::New(impl_py_method_def_new(cls, &spec)?), - FnType::FnCall(_) => GeneratedPyMethod::Call(impl_py_method_def_call(cls, &spec)?), + FnType::FnNew => GeneratedPyMethod::TraitImpl(impl_py_method_def_new(cls, &spec)?), + FnType::FnCall(_) => GeneratedPyMethod::TraitImpl(impl_py_method_def_call(cls, &spec)?), FnType::ClassAttribute => GeneratedPyMethod::Method(impl_py_class_attribute(cls, &spec)), FnType::Getter(self_type) => GeneratedPyMethod::Method(impl_py_getter_def( cls, @@ -202,8 +217,12 @@ pub fn impl_py_setter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls), - PropertyType::Function { self_type, .. } => self_type.receiver(cls), + PropertyType::Descriptor { .. } => { + SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise) + } + PropertyType::Function { self_type, .. } => { + self_type.receiver(cls, ExtractErrorMode::Raise) + } }; Ok(quote! { ::pyo3::class::PyMethodDefType::Setter({ @@ -278,8 +297,12 @@ pub fn impl_py_getter_def(cls: &syn::Type, property_type: PropertyType) -> Resul }; let slf = match property_type { - PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls), - PropertyType::Function { self_type, .. } => self_type.receiver(cls), + PropertyType::Descriptor { .. } => { + SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise) + } + PropertyType::Function { self_type, .. } => { + self_type.receiver(cls, ExtractErrorMode::Raise) + } }; Ok(quote! { ::pyo3::class::PyMethodDefType::Getter({ @@ -364,3 +387,618 @@ impl PropertyType<'_> { } } } + +const __GETATTR__: SlotDef = SlotDef::new("Py_tp_getattro", "getattrofunc") + .arguments(&[Ty::Object]) + .before_call_method(TokenGenerator(|| { + quote! { + // Behave like python's __getattr__ (as opposed to __getattribute__) and check + // for existing fields and methods first + let existing = ::pyo3::ffi::PyObject_GenericGetAttr(_slf, arg0); + if existing.is_null() { + // PyObject_HasAttr also tries to get an object and clears the error if it fails + ::pyo3::ffi::PyErr_Clear(); + } else { + return existing; + } + } + })); +const __STR__: SlotDef = SlotDef::new("Py_tp_str", "reprfunc"); +const __REPR__: SlotDef = SlotDef::new("Py_tp_repr", "reprfunc"); +const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc") + .ret_ty(Ty::PyHashT) + .return_conversion(TokenGenerator( + || quote! { ::pyo3::callback::HashCallbackOutput }, + )); +const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc") + .extract_error_mode(ExtractErrorMode::NotImplemented) + .arguments(&[Ty::ObjectOrNotImplemented, Ty::CompareOp]); +const __GET__: SlotDef = + SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]); +const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc"); +const __NEXT__: SlotDef = SlotDef::new("Py_tp_iternext", "iternextfunc").return_conversion( + TokenGenerator(|| quote! { ::pyo3::class::iter::IterNextOutput::<_, _> }), +); +const __AWAIT__: SlotDef = SlotDef::new("Py_am_await", "unaryfunc"); +const __AITER__: SlotDef = SlotDef::new("Py_am_aiter", "unaryfunc"); +const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_conversion( + TokenGenerator(|| quote! { ::pyo3::class::pyasync::IterANextOutput::<_, _> }), +); +const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT); +const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc") + .arguments(&[Ty::Object]) + .ret_ty(Ty::Int); +const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); + +const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc"); +const __NEG__: SlotDef = SlotDef::new("Py_nb_negative", "unaryfunc"); +const __ABS__: SlotDef = SlotDef::new("Py_nb_absolute", "unaryfunc"); +const __INVERT__: SlotDef = SlotDef::new("Py_nb_invert", "unaryfunc"); +const __INDEX__: SlotDef = SlotDef::new("Py_nb_index", "unaryfunc"); +const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc"); +const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc"); +const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int); + +const __TRUEDIV__: SlotDef = SlotDef::new("Py_nb_true_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); +const __FLOORDIV__: SlotDef = SlotDef::new("Py_nb_floor_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented); + +const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); +const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc") + .arguments(&[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .return_self(); + +fn pyproto(method_name: &str) -> Option<&'static SlotDef> { + match method_name { + "__getattr__" => Some(&__GETATTR__), + "__str__" => Some(&__STR__), + "__repr__" => Some(&__REPR__), + "__hash__" => Some(&__HASH__), + "__richcmp__" => Some(&__RICHCMP__), + "__get__" => Some(&__GET__), + "__iter__" => Some(&__ITER__), + "__next__" => Some(&__NEXT__), + "__await__" => Some(&__AWAIT__), + "__aiter__" => Some(&__AITER__), + "__anext__" => Some(&__ANEXT__), + "__len__" => Some(&__LEN__), + "__contains__" => Some(&__CONTAINS__), + "__getitem__" => Some(&__GETITEM__), + "__pos__" => Some(&__POS__), + "__neg__" => Some(&__NEG__), + "__abs__" => Some(&__ABS__), + "__invert__" => Some(&__INVERT__), + "__index__" => Some(&__INDEX__), + "__int__" => Some(&__INT__), + "__float__" => Some(&__FLOAT__), + "__bool__" => Some(&__BOOL__), + "__truediv__" => Some(&__TRUEDIV__), + "__floordiv__" => Some(&__FLOORDIV__), + "__iadd__" => Some(&__IADD__), + "__isub__" => Some(&__ISUB__), + "__imul__" => Some(&__IMUL__), + "__imatmul__" => Some(&__IMATMUL__), + "__itruediv__" => Some(&__ITRUEDIV__), + "__ifloordiv__" => Some(&__IFLOORDIV__), + "__imod__" => Some(&__IMOD__), + "__ipow__" => Some(&__IPOW__), + "__ilshift__" => Some(&__ILSHIFT__), + "__irshift__" => Some(&__IRSHIFT__), + "__iand__" => Some(&__IAND__), + "__ixor__" => Some(&__IXOR__), + "__ior__" => Some(&__IOR__), + _ => None, + } +} + +#[derive(Clone, Copy)] +enum Ty { + Object, + ObjectOrNotImplemented, + NonNullObject, + CompareOp, + Int, + PyHashT, + PySsizeT, + Void, +} + +impl Ty { + fn ffi_type(self) -> TokenStream { + match self { + Ty::Object | Ty::ObjectOrNotImplemented => quote! { *mut ::pyo3::ffi::PyObject }, + Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, + Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, + Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, + Ty::PySsizeT => quote! { ::pyo3::ffi::Py_ssize_t }, + Ty::Void => quote! { () }, + } + } + + fn extract( + self, + cls: &syn::Type, + py: &syn::Ident, + ident: &syn::Ident, + target: &syn::Type, + ) -> TokenStream { + match self { + Ty::Object => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); + #extract + } + } + Ty::ObjectOrNotImplemented => { + let extract = if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let (tref, mut_) = preprocess_tref(tref, cls); + quote! { + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = match #ident.extract() { + Ok(#ident) => #ident, + Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), + }; + let #ident = &#mut_ *#ident; + } + } else { + quote! { + let #ident = match #ident.extract() { + Ok(#ident) => #ident, + Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()), + }; + } + }; + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); + #extract + } + } + Ty::NonNullObject => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident.as_ptr()); + #extract + } + } + Ty::CompareOp => quote! { + let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) + .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; + }, + Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(), + } + } +} + +fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -> TokenStream { + return if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let (tref, mut_) = preprocess_tref(tref, self_); + quote! { + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #ident.extract()?; + let #ident = &#mut_ *#ident; + } + } else { + quote! { + let #ident = #ident.extract()?; + } + }; +} + +/// Replace `Self`, remove lifetime and get mutability from the type +fn preprocess_tref( + tref: &syn::TypeReference, + self_: &syn::Type, +) -> (syn::TypeReference, Option) { + let mut tref = tref.to_owned(); + if let syn::Type::Path(tpath) = self_ { + replace_self(&mut tref, &tpath.path); + } + tref.lifetime = None; + let mut_ = tref.mutability; + (tref, mut_) +} + +/// Replace `Self` with the exact type name since it is used out of the impl block +fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { + match &mut *tref.elem { + syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), + syn::Type::Path(tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + tpath.path = self_path.to_owned(); + } + } + } + _ => {} + } +} + +enum ReturnMode { + ReturnSelf, + Conversion(TokenGenerator), +} + +impl ReturnMode { + fn return_call_output(&self, py: &syn::Ident, call: TokenStream) -> TokenStream { + match self { + ReturnMode::Conversion(conversion) => quote! { + let _result: PyResult<#conversion> = #call; + ::pyo3::callback::convert(#py, _result) + }, + ReturnMode::ReturnSelf => quote! { + let _result: PyResult<()> = #call; + _result?; + ::pyo3::ffi::Py_XINCREF(_raw_slf); + Ok(_raw_slf) + }, + } + } +} + +struct SlotDef { + slot: StaticIdent, + func_ty: StaticIdent, + arguments: &'static [Ty], + ret_ty: Ty, + before_call_method: Option, + extract_error_mode: ExtractErrorMode, + return_mode: Option, +} + +const NO_ARGUMENTS: &[Ty] = &[]; + +impl SlotDef { + const fn new(slot: &'static str, func_ty: &'static str) -> Self { + SlotDef { + slot: StaticIdent(slot), + func_ty: StaticIdent(func_ty), + arguments: NO_ARGUMENTS, + ret_ty: Ty::Object, + before_call_method: None, + extract_error_mode: ExtractErrorMode::Raise, + return_mode: None, + } + } + + const fn arguments(mut self, arguments: &'static [Ty]) -> Self { + self.arguments = arguments; + self + } + + const fn ret_ty(mut self, ret_ty: Ty) -> Self { + self.ret_ty = ret_ty; + self + } + + const fn before_call_method(mut self, before_call_method: TokenGenerator) -> Self { + self.before_call_method = Some(before_call_method); + self + } + + const fn return_conversion(mut self, return_conversion: TokenGenerator) -> Self { + self.return_mode = Some(ReturnMode::Conversion(return_conversion)); + self + } + + const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { + self.extract_error_mode = extract_error_mode; + self + } + + const fn return_self(mut self) -> Self { + self.return_mode = Some(ReturnMode::ReturnSelf); + self + } + + fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result { + let SlotDef { + slot, + func_ty, + before_call_method, + arguments, + extract_error_mode, + ret_ty, + return_mode, + } = self; + let py = syn::Ident::new("_py", Span::call_site()); + let method_arguments = generate_method_arguments(arguments); + let ret_ty = ret_ty.ffi_type(); + let body = generate_method_body( + cls, + spec, + &py, + arguments, + *extract_error_mode, + return_mode.as_ref(), + )?; + Ok(quote!({ + unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty { + let _slf = _raw_slf; + #before_call_method + ::pyo3::callback::handle_panic(|#py| { + #body + }) + } + ::pyo3::ffi::PyType_Slot { + slot: ::pyo3::ffi::#slot, + pfunc: __wrap as ::pyo3::ffi::#func_ty as _ + } + })) + } +} + +fn generate_method_arguments(arguments: &[Ty]) -> impl Iterator + '_ { + arguments.iter().enumerate().map(|(i, arg)| { + let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); + let ffi_type = arg.ffi_type(); + quote! { + #ident: #ffi_type + } + }) +} + +fn generate_method_body( + cls: &syn::Type, + spec: &FnSpec, + py: &syn::Ident, + arguments: &[Ty], + extract_error_mode: ExtractErrorMode, + return_mode: Option<&ReturnMode>, +) -> Result { + let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode); + let rust_name = spec.name; + let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?; + let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; + let body = if let Some(return_mode) = return_mode { + return_mode.return_call_output(py, call) + } else { + call + }; + Ok(quote! { + #self_conversion + #conversions + #body + }) +} + +struct SlotFragmentDef { + fragment: &'static str, + arguments: &'static [Ty], + extract_error_mode: ExtractErrorMode, + ret_ty: Ty, +} + +impl SlotFragmentDef { + const fn new(fragment: &'static str, arguments: &'static [Ty]) -> Self { + SlotFragmentDef { + fragment, + arguments, + extract_error_mode: ExtractErrorMode::Raise, + ret_ty: Ty::Void, + } + } + + const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { + self.extract_error_mode = extract_error_mode; + self + } + + const fn ret_ty(mut self, ret_ty: Ty) -> Self { + self.ret_ty = ret_ty; + self + } + + fn generate_pyproto_fragment(&self, cls: &syn::Type, spec: &FnSpec) -> Result { + let SlotFragmentDef { + fragment, + arguments, + extract_error_mode, + ret_ty, + } = self; + let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment); + let method = syn::Ident::new(fragment, Span::call_site()); + let py = syn::Ident::new("_py", Span::call_site()); + let method_arguments = generate_method_arguments(arguments); + let body = generate_method_body(cls, spec, &py, arguments, *extract_error_mode, None)?; + let ret_ty = ret_ty.ffi_type(); + Ok(quote! { + impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + + #[inline] + unsafe fn #method( + self, + #py: ::pyo3::Python, + _raw_slf: *mut ::pyo3::ffi::PyObject, + #(#method_arguments),* + ) -> ::pyo3::PyResult<#ret_ty> { + let _slf = _raw_slf; + #body + } + } + }) + } +} + +const __SETATTR__: SlotFragmentDef = + SlotFragmentDef::new("__setattr__", &[Ty::Object, Ty::NonNullObject]); +const __DELATTR__: SlotFragmentDef = SlotFragmentDef::new("__delattr__", &[Ty::Object]); +const __SET__: SlotFragmentDef = SlotFragmentDef::new("__set__", &[Ty::Object, Ty::NonNullObject]); +const __DELETE__: SlotFragmentDef = SlotFragmentDef::new("__delete__", &[Ty::Object]); +const __SETITEM__: SlotFragmentDef = + SlotFragmentDef::new("__setitem__", &[Ty::Object, Ty::NonNullObject]); +const __DELITEM__: SlotFragmentDef = SlotFragmentDef::new("__delitem__", &[Ty::Object]); + +macro_rules! binary_num_slot_fragment_def { + ($ident:ident, $name:literal) => { + const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::ObjectOrNotImplemented]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); + }; +} + +binary_num_slot_fragment_def!(__ADD__, "__add__"); +binary_num_slot_fragment_def!(__RADD__, "__radd__"); +binary_num_slot_fragment_def!(__SUB__, "__sub__"); +binary_num_slot_fragment_def!(__RSUB__, "__rsub__"); +binary_num_slot_fragment_def!(__MUL__, "__mul__"); +binary_num_slot_fragment_def!(__RMUL__, "__rmul__"); +binary_num_slot_fragment_def!(__MATMUL__, "__matmul__"); +binary_num_slot_fragment_def!(__RMATMUL__, "__rmatmul__"); +binary_num_slot_fragment_def!(__DIVMOD__, "__divmod__"); +binary_num_slot_fragment_def!(__RDIVMOD__, "__rdivmod__"); +binary_num_slot_fragment_def!(__MOD__, "__mod__"); +binary_num_slot_fragment_def!(__RMOD__, "__rmod__"); +binary_num_slot_fragment_def!(__LSHIFT__, "__lshift__"); +binary_num_slot_fragment_def!(__RLSHIFT__, "__rlshift__"); +binary_num_slot_fragment_def!(__RSHIFT__, "__rshift__"); +binary_num_slot_fragment_def!(__RRSHIFT__, "__rrshift__"); +binary_num_slot_fragment_def!(__AND__, "__and__"); +binary_num_slot_fragment_def!(__RAND__, "__rand__"); +binary_num_slot_fragment_def!(__XOR__, "__xor__"); +binary_num_slot_fragment_def!(__RXOR__, "__rxor__"); +binary_num_slot_fragment_def!(__OR__, "__or__"); +binary_num_slot_fragment_def!(__ROR__, "__ror__"); + +const __POW__: SlotFragmentDef = SlotFragmentDef::new( + "__pow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); +const __RPOW__: SlotFragmentDef = SlotFragmentDef::new( + "__rpow__", + &[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented], +) +.extract_error_mode(ExtractErrorMode::NotImplemented) +.ret_ty(Ty::Object); + +fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { + match method_name { + "__setattr__" => Some(&__SETATTR__), + "__delattr__" => Some(&__DELATTR__), + "__set__" => Some(&__SET__), + "__delete__" => Some(&__DELETE__), + "__setitem__" => Some(&__SETITEM__), + "__delitem__" => Some(&__DELITEM__), + "__add__" => Some(&__ADD__), + "__radd__" => Some(&__RADD__), + "__sub__" => Some(&__SUB__), + "__rsub__" => Some(&__RSUB__), + "__mul__" => Some(&__MUL__), + "__rmul__" => Some(&__RMUL__), + "__matmul__" => Some(&__MATMUL__), + "__rmatmul__" => Some(&__RMATMUL__), + "__divmod__" => Some(&__DIVMOD__), + "__rdivmod__" => Some(&__RDIVMOD__), + "__mod__" => Some(&__MOD__), + "__rmod__" => Some(&__RMOD__), + "__lshift__" => Some(&__LSHIFT__), + "__rlshift__" => Some(&__RLSHIFT__), + "__rshift__" => Some(&__RSHIFT__), + "__rrshift__" => Some(&__RRSHIFT__), + "__and__" => Some(&__AND__), + "__rand__" => Some(&__RAND__), + "__xor__" => Some(&__XOR__), + "__rxor__" => Some(&__RXOR__), + "__or__" => Some(&__OR__), + "__ror__" => Some(&__ROR__), + "__pow__" => Some(&__POW__), + "__rpow__" => Some(&__RPOW__), + _ => None, + } +} + +fn extract_proto_arguments( + cls: &syn::Type, + py: &syn::Ident, + method_args: &[FnArg], + proto_args: &[Ty], +) -> Result<(Vec, TokenStream)> { + let mut arg_idents = Vec::with_capacity(method_args.len()); + let mut non_python_args = 0; + + let mut args_conversions = Vec::with_capacity(proto_args.len()); + + for arg in method_args { + if arg.py { + arg_idents.push(py.clone()); + } else { + let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); + let conversions = proto_args.get(non_python_args) + .ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))? + .extract(cls, py, &ident, arg.ty); + non_python_args += 1; + args_conversions.push(conversions); + arg_idents.push(ident); + } + } + + let conversions = quote!(#(#args_conversions)*); + Ok((arg_idents, conversions)) +} + +struct StaticIdent(&'static str); + +impl ToTokens for StaticIdent { + fn to_tokens(&self, tokens: &mut TokenStream) { + syn::Ident::new(self.0, Span::call_site()).to_tokens(tokens) + } +} + +struct TokenGenerator(fn() -> TokenStream); + +impl ToTokens for TokenGenerator { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.0().to_tokens(tokens) + } +} diff --git a/src/class/basic.rs b/src/class/basic.rs index affa7ee6..1e34d493 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -13,7 +13,7 @@ use crate::{exceptions, ffi, FromPyObject, PyAny, PyCell, PyClass, PyObject}; use std::os::raw::c_int; /// Operators for the `__richcmp__` method -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum CompareOp { /// The *less than* operator. Lt = ffi::Py_LT as isize, @@ -29,6 +29,20 @@ pub enum CompareOp { Ge = ffi::Py_GE as isize, } +impl CompareOp { + pub fn from_raw(op: c_int) -> Option { + match op { + ffi::Py_LT => Some(CompareOp::Lt), + ffi::Py_LE => Some(CompareOp::Le), + ffi::Py_EQ => Some(CompareOp::Eq), + ffi::Py_NE => Some(CompareOp::Ne), + ffi::Py_GT => Some(CompareOp::Gt), + ffi::Py_GE => Some(CompareOp::Ge), + _ => None, + } + } +} + /// Basic Python class customization #[allow(unused_variables)] pub trait PyObjectProtocol<'p>: PyClass { diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 86bf8496..dc38c038 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -1,14 +1,15 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::{ + exceptions::{PyAttributeError, PyNotImplementedError}, ffi, impl_::freelist::FreeList, pycell::PyCellLayout, pyclass_init::PyObjectInit, type_object::{PyLayout, PyTypeObject}, - PyClass, PyMethodDefType, PyNativeType, PyTypeInfo, Python, + PyClass, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python, }; -use std::{marker::PhantomData, os::raw::c_void, thread}; +use std::{marker::PhantomData, os::raw::c_void, ptr::NonNull, thread}; /// This type is used as a "dummy" type on which dtolnay specializations are /// applied to apply implementations from `#[pymethods]` & `#[pyproto]` @@ -107,6 +108,377 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } } +macro_rules! slot_fragment_trait { + ($trait_name:ident, $($default_method:tt)*) => { + #[allow(non_camel_case_types)] + pub trait $trait_name: Sized { + $($default_method)* + } + + impl $trait_name for &'_ PyClassImplCollector {} + } +} + +/// Macro which expands to three items +/// - Trait for a __setitem__ dunder +/// - Trait for the corresponding __delitem__ dunder +/// - A macro which will use dtolnay specialisation to generate the shared slot for the two dunders +macro_rules! define_pyclass_setattr_slot { + ( + $set_trait:ident, + $del_trait:ident, + $set:ident, + $del:ident, + $set_error:expr, + $del_error:expr, + $generate_macro:ident, + $slot:ident, + $func_ty:ident, + ) => { + slot_fragment_trait! { + $set_trait, + + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $set( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + $set_error + } + } + + slot_fragment_trait! { + $del_trait, + + /// # Safety: _slf and _attr must be valid non-null Python objects + #[inline] + unsafe fn $del( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + ) -> PyResult<()> { + $del_error + } + } + + #[doc(hidden)] + #[macro_export] + macro_rules! $generate_macro { + ($cls:ty) => {{ + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + use ::std::option::Option::*; + use $crate::callback::IntoPyCallbackOutput; + use $crate::class::impl_::*; + $crate::callback::handle_panic(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.$set(py, _slf, attr, value).convert(py) + } else { + collector.$del(py, _slf, attr).convert(py) + } + }) + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, + } + }}; + } + }; +} + +define_pyclass_setattr_slot! { + PyClass__setattr__SlotFragment, + PyClass__delattr__SlotFragment, + __setattr__, + __delattr__, + Err(PyAttributeError::new_err("can't set attribute")), + Err(PyAttributeError::new_err("can't delete attribute")), + generate_pyclass_setattr_slot, + Py_tp_setattro, + setattrofunc, +} + +define_pyclass_setattr_slot! { + PyClass__set__SlotFragment, + PyClass__delete__SlotFragment, + __set__, + __delete__, + Err(PyNotImplementedError::new_err("can't set descriptor")), + Err(PyNotImplementedError::new_err("can't delete descriptor")), + generate_pyclass_setdescr_slot, + Py_tp_descr_set, + descrsetfunc, +} + +define_pyclass_setattr_slot! { + PyClass__setitem__SlotFragment, + PyClass__delitem__SlotFragment, + __setitem__, + __delitem__, + Err(PyNotImplementedError::new_err("can't set item")), + Err(PyNotImplementedError::new_err("can't delete item")), + generate_pyclass_setitem_slot, + Py_mp_ass_subscript, + objobjargproc, +} + +/// Macro which expands to three items +/// - Trait for a lhs dunder e.g. __add__ +/// - Trait for the corresponding rhs e.g. __radd__ +/// - A macro which will use dtolnay specialisation to generate the shared slot for the two dunders +macro_rules! define_pyclass_binary_operator_slot { + ( + $lhs_trait:ident, + $rhs_trait:ident, + $lhs:ident, + $rhs:ident, + $generate_macro:ident, + $slot:ident, + $func_ty:ident, + ) => { + slot_fragment_trait! { + $lhs_trait, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn $lhs( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } + } + + slot_fragment_trait! { + $rhs_trait, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn $rhs( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } + } + + #[doc(hidden)] + #[macro_export] + macro_rules! $generate_macro { + ($cls:ty) => {{ + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.$lhs(py, _slf, _other)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.$rhs(py, _other, _slf) + } else { + ::std::result::Result::Ok(lhs_result) + } + }) + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::$slot, + pfunc: __wrap as $crate::ffi::$func_ty as _, + } + }}; + } + }; +} + +define_pyclass_binary_operator_slot! { + PyClass__add__SlotFragment, + PyClass__radd__SlotFragment, + __add__, + __radd__, + generate_pyclass_add_slot, + Py_nb_add, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__sub__SlotFragment, + PyClass__rsub__SlotFragment, + __sub__, + __rsub__, + generate_pyclass_sub_slot, + Py_nb_subtract, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__mul__SlotFragment, + PyClass__rmul__SlotFragment, + __mul__, + __rmul__, + generate_pyclass_mul_slot, + Py_nb_multiply, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__mod__SlotFragment, + PyClass__rmod__SlotFragment, + __mod__, + __rmod__, + generate_pyclass_mod_slot, + Py_nb_remainder, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__divmod__SlotFragment, + PyClass__rdivmod__SlotFragment, + __divmod__, + __rdivmod__, + generate_pyclass_divmod_slot, + Py_nb_divmod, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__lshift__SlotFragment, + PyClass__rlshift__SlotFragment, + __lshift__, + __rlshift__, + generate_pyclass_lshift_slot, + Py_nb_lshift, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__rshift__SlotFragment, + PyClass__rrshift__SlotFragment, + __rshift__, + __rrshift__, + generate_pyclass_rshift_slot, + Py_nb_rshift, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__and__SlotFragment, + PyClass__rand__SlotFragment, + __and__, + __rand__, + generate_pyclass_and_slot, + Py_nb_and, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__or__SlotFragment, + PyClass__ror__SlotFragment, + __or__, + __ror__, + generate_pyclass_or_slot, + Py_nb_or, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__xor__SlotFragment, + PyClass__rxor__SlotFragment, + __xor__, + __rxor__, + generate_pyclass_xor_slot, + Py_nb_xor, + binaryfunc, +} + +define_pyclass_binary_operator_slot! { + PyClass__matmul__SlotFragment, + PyClass__rmatmul__SlotFragment, + __matmul__, + __rmatmul__, + generate_pyclass_matmul_slot, + Py_nb_matrix_multiply, + binaryfunc, +} + +slot_fragment_trait! { + PyClass__pow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __pow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +slot_fragment_trait! { + PyClass__rpow__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __rpow__( + self, + _py: Python, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + _mod: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + ffi::Py_INCREF(ffi::Py_NotImplemented()); + Ok(ffi::Py_NotImplemented()) + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_pow_slot { + ($cls:ty) => {{ + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + _other: *mut $crate::ffi::PyObject, + _mod: *mut $crate::ffi::PyObject, + ) -> *mut $crate::ffi::PyObject { + $crate::callback::handle_panic(|py| { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let lhs_result = collector.__pow__(py, _slf, _other, _mod)?; + if lhs_result == $crate::ffi::Py_NotImplemented() { + $crate::ffi::Py_DECREF(lhs_result); + collector.__rpow__(py, _other, _slf, _mod) + } else { + ::std::result::Result::Ok(lhs_result) + } + }) + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_nb_power, + pfunc: __wrap as $crate::ffi::ternaryfunc as _, + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } @@ -288,6 +660,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots); slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots); slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots); +#[cfg(not(feature = "multiple-pymethods"))] +slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots); + methods_trait!(PyObjectProtocolMethods, object_protocol_methods); methods_trait!(PyAsyncProtocolMethods, async_protocol_methods); methods_trait!(PyContextProtocolMethods, context_protocol_methods); diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index dc64155e..f648ee2e 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -1,7 +1,6 @@ -#![allow(deprecated)] // for deprecated protocol methods +#![cfg(not(feature = "multiple-pymethods"))] use pyo3::class::basic::CompareOp; -use pyo3::class::*; use pyo3::prelude::*; use pyo3::py_run; @@ -12,21 +11,17 @@ struct UnaryArithmetic { inner: f64, } +#[pymethods] impl UnaryArithmetic { + #[new] fn new(value: f64) -> Self { UnaryArithmetic { inner: value } } -} -#[pyproto] -impl PyObjectProtocol for UnaryArithmetic { fn __repr__(&self) -> String { format!("UA({})", self.inner) } -} -#[pyproto] -impl PyNumberProtocol for UnaryArithmetic { fn __neg__(&self) -> Self { Self::new(-self.inner) } @@ -57,30 +52,17 @@ fn unary_arithmetic() { py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); } -#[pyclass] -struct BinaryArithmetic {} - -#[pyproto] -impl PyObjectProtocol for BinaryArithmetic { - fn __repr__(&self) -> &'static str { - "BA" - } -} - #[pyclass] struct InPlaceOperations { value: u32, } -#[pyproto] -impl PyObjectProtocol for InPlaceOperations { +#[pymethods] +impl InPlaceOperations { fn __repr__(&self) -> String { format!("IPO({:?})", self.value) } -} -#[pyproto] -impl PyNumberProtocol for InPlaceOperations { fn __iadd__(&mut self, other: u32) { self.value += other; } @@ -142,42 +124,49 @@ fn inplace_operations() { ); } -#[pyproto] -impl PyNumberProtocol for BinaryArithmetic { - fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} + {:?}", lhs, rhs) +#[pyclass] +struct BinaryArithmetic {} + +#[pymethods] +impl BinaryArithmetic { + fn __repr__(&self) -> &'static str { + "BA" } - fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} - {:?}", lhs, rhs) + fn __add__(&self, rhs: &PyAny) -> String { + format!("BA + {:?}", rhs) } - fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} * {:?}", lhs, rhs) + fn __sub__(&self, rhs: &PyAny) -> String { + format!("BA - {:?}", rhs) } - fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} << {:?}", lhs, rhs) + fn __mul__(&self, rhs: &PyAny) -> String { + format!("BA * {:?}", rhs) } - fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} >> {:?}", lhs, rhs) + fn __lshift__(&self, rhs: &PyAny) -> String { + format!("BA << {:?}", rhs) } - fn __and__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} & {:?}", lhs, rhs) + fn __rshift__(&self, rhs: &PyAny) -> String { + format!("BA >> {:?}", rhs) } - fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} ^ {:?}", lhs, rhs) + fn __and__(&self, rhs: &PyAny) -> String { + format!("BA & {:?}", rhs) } - fn __or__(lhs: &PyAny, rhs: &PyAny) -> String { - format!("{:?} | {:?}", lhs, rhs) + fn __xor__(&self, rhs: &PyAny) -> String { + format!("BA ^ {:?}", rhs) } - fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option) -> String { - format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_) + fn __or__(&self, rhs: &PyAny) -> String { + format!("BA | {:?}", rhs) + } + + fn __pow__(&self, rhs: &PyAny, mod_: Option) -> String { + format!("BA ** {:?} (mod: {:?})", rhs, mod_) } } @@ -190,24 +179,27 @@ fn binary_arithmetic() { py_run!(py, c, "assert c + c == 'BA + BA'"); py_run!(py, c, "assert c.__add__(c) == 'BA + BA'"); py_run!(py, c, "assert c + 1 == 'BA + 1'"); - py_run!(py, c, "assert 1 + c == '1 + BA'"); py_run!(py, c, "assert c - 1 == 'BA - 1'"); - py_run!(py, c, "assert 1 - c == '1 - BA'"); py_run!(py, c, "assert c * 1 == 'BA * 1'"); - py_run!(py, c, "assert 1 * c == '1 * BA'"); - py_run!(py, c, "assert c << 1 == 'BA << 1'"); - py_run!(py, c, "assert 1 << c == '1 << BA'"); py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); - py_run!(py, c, "assert 1 >> c == '1 >> BA'"); py_run!(py, c, "assert c & 1 == 'BA & 1'"); - py_run!(py, c, "assert 1 & c == '1 & BA'"); py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'"); - py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); py_run!(py, c, "assert c | 1 == 'BA | 1'"); - py_run!(py, c, "assert 1 | c == '1 | BA'"); py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'"); - py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'"); + + // Class with __add__ only should not allow the reverse op; + // this is consistent with Python classes. + + py_expect_exception!(py, c, "1 + c", PyTypeError); + py_expect_exception!(py, c, "1 - c", PyTypeError); + py_expect_exception!(py, c, "1 * c", PyTypeError); + py_expect_exception!(py, c, "1 << c", PyTypeError); + py_expect_exception!(py, c, "1 >> c", PyTypeError); + py_expect_exception!(py, c, "1 & c", PyTypeError); + py_expect_exception!(py, c, "1 ^ c", PyTypeError); + py_expect_exception!(py, c, "1 | c", PyTypeError); + py_expect_exception!(py, c, "1 ** c", PyTypeError); py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); } @@ -215,8 +207,8 @@ fn binary_arithmetic() { #[pyclass] struct RhsArithmetic {} -#[pyproto] -impl PyNumberProtocol for RhsArithmetic { +#[pymethods] +impl RhsArithmetic { fn __radd__(&self, other: &PyAny) -> String { format!("{:?} + RA", other) } @@ -249,7 +241,7 @@ impl PyNumberProtocol for RhsArithmetic { format!("{:?} | RA", other) } - fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String { format!("{:?} ** RA", other) } } @@ -289,8 +281,12 @@ impl std::fmt::Debug for LhsAndRhs { } } -#[pyproto] -impl PyNumberProtocol for LhsAndRhs { +#[pymethods] +impl LhsAndRhs { + // fn __repr__(&self) -> &'static str { + // "BA" + // } + fn __add__(lhs: PyRef, rhs: &PyAny) -> String { format!("{:?} + {:?}", lhs, rhs) } @@ -363,7 +359,7 @@ impl PyNumberProtocol for LhsAndRhs { format!("{:?} | RA", other) } - fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String { format!("{:?} ** RA", other) } @@ -372,13 +368,6 @@ impl PyNumberProtocol for LhsAndRhs { } } -#[pyproto] -impl PyObjectProtocol for LhsAndRhs { - fn __repr__(&self) -> &'static str { - "BA" - } -} - #[test] fn lhs_fellback_to_rhs() { let gil = Python::acquire_gil(); @@ -412,8 +401,8 @@ fn lhs_fellback_to_rhs() { #[pyclass] struct RichComparisons {} -#[pyproto] -impl PyObjectProtocol for RichComparisons { +#[pymethods] +impl RichComparisons { fn __repr__(&self) -> &'static str { "RC" } @@ -433,8 +422,8 @@ impl PyObjectProtocol for RichComparisons { #[pyclass] struct RichComparisons2 {} -#[pyproto] -impl PyObjectProtocol for RichComparisons2 { +#[pymethods] +impl RichComparisons2 { fn __repr__(&self) -> &'static str { "RC2" } @@ -508,76 +497,73 @@ mod return_not_implemented { #[pyclass] struct RichComparisonToSelf {} - #[pyproto] - impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf { + #[pymethods] + impl RichComparisonToSelf { fn __repr__(&self) -> &'static str { "RC_Self" } - fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject { + fn __richcmp__(&self, other: PyRef, _op: CompareOp) -> PyObject { other.py().None() } - } - #[pyproto] - impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf { - fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __add__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __sub__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __mul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __matmul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __truediv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __floordiv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __mod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option) -> &'p PyAny { - lhs + fn __pow__(slf: PyRef, _other: u8, _modulo: Option) -> PyRef { + slf } - fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __lshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __rshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __divmod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __and__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __or__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } - fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { - lhs + fn __xor__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> { + slf } // Inplace assignments - fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {} - fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __iadd__(&mut self, _other: PyRef) {} + fn __isub__(&mut self, _other: PyRef) {} + fn __imul__(&mut self, _other: PyRef) {} + fn __imatmul__(&mut self, _other: PyRef) {} + fn __itruediv__(&mut self, _other: PyRef) {} + fn __ifloordiv__(&mut self, _other: PyRef) {} + fn __imod__(&mut self, _other: PyRef) {} + fn __ipow__(&mut self, _other: PyRef) {} + fn __ilshift__(&mut self, _other: PyRef) {} + fn __irshift__(&mut self, _other: PyRef) {} + fn __iand__(&mut self, _other: PyRef) {} + fn __ior__(&mut self, _other: PyRef) {} + fn __ixor__(&mut self, _other: PyRef) {} } fn _test_binary_dunder(dunder: &str) { @@ -648,15 +634,13 @@ mod return_not_implemented { } #[test] - #[ignore] fn reverse_arith() { _test_binary_dunder("radd"); _test_binary_dunder("rsub"); _test_binary_dunder("rmul"); _test_binary_dunder("rmatmul"); - _test_binary_dunder("rtruediv"); - _test_binary_dunder("rfloordiv"); _test_binary_dunder("rmod"); + _test_binary_dunder("rdivmod"); _test_binary_dunder("rpow"); } diff --git a/tests/test_arithmetics_protos.rs b/tests/test_arithmetics_protos.rs new file mode 100644 index 00000000..dc64155e --- /dev/null +++ b/tests/test_arithmetics_protos.rs @@ -0,0 +1,683 @@ +#![allow(deprecated)] // for deprecated protocol methods + +use pyo3::class::basic::CompareOp; +use pyo3::class::*; +use pyo3::prelude::*; +use pyo3::py_run; + +mod common; + +#[pyclass] +struct UnaryArithmetic { + inner: f64, +} + +impl UnaryArithmetic { + fn new(value: f64) -> Self { + UnaryArithmetic { inner: value } + } +} + +#[pyproto] +impl PyObjectProtocol for UnaryArithmetic { + fn __repr__(&self) -> String { + format!("UA({})", self.inner) + } +} + +#[pyproto] +impl PyNumberProtocol for UnaryArithmetic { + fn __neg__(&self) -> Self { + Self::new(-self.inner) + } + + fn __pos__(&self) -> Self { + Self::new(self.inner) + } + + fn __abs__(&self) -> Self { + Self::new(self.inner.abs()) + } + + fn __round__(&self, _ndigits: Option) -> Self { + Self::new(self.inner.round()) + } +} + +#[test] +fn unary_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, UnaryArithmetic::new(2.7)).unwrap(); + py_run!(py, c, "assert repr(-c) == 'UA(-2.7)'"); + py_run!(py, c, "assert repr(+c) == 'UA(2.7)'"); + py_run!(py, c, "assert repr(abs(c)) == 'UA(2.7)'"); + py_run!(py, c, "assert repr(round(c)) == 'UA(3)'"); + py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'"); +} + +#[pyclass] +struct BinaryArithmetic {} + +#[pyproto] +impl PyObjectProtocol for BinaryArithmetic { + fn __repr__(&self) -> &'static str { + "BA" + } +} + +#[pyclass] +struct InPlaceOperations { + value: u32, +} + +#[pyproto] +impl PyObjectProtocol for InPlaceOperations { + fn __repr__(&self) -> String { + format!("IPO({:?})", self.value) + } +} + +#[pyproto] +impl PyNumberProtocol for InPlaceOperations { + fn __iadd__(&mut self, other: u32) { + self.value += other; + } + + fn __isub__(&mut self, other: u32) { + self.value -= other; + } + + fn __imul__(&mut self, other: u32) { + self.value *= other; + } + + fn __ilshift__(&mut self, other: u32) { + self.value <<= other; + } + + fn __irshift__(&mut self, other: u32) { + self.value >>= other; + } + + fn __iand__(&mut self, other: u32) { + self.value &= other; + } + + fn __ixor__(&mut self, other: u32) { + self.value ^= other; + } + + fn __ior__(&mut self, other: u32) { + self.value |= other; + } + + fn __ipow__(&mut self, other: u32) { + self.value = self.value.pow(other); + } +} + +#[test] +fn inplace_operations() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let init = |value, code| { + let c = PyCell::new(py, InPlaceOperations { value }).unwrap(); + py_run!(py, c, code); + }; + + init(0, "d = c; c += 1; assert repr(c) == repr(d) == 'IPO(1)'"); + init(10, "d = c; c -= 1; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c *= 3; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c <<= 2; assert repr(c) == repr(d) == 'IPO(12)'"); + init(12, "d = c; c >>= 2; assert repr(c) == repr(d) == 'IPO(3)'"); + init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'"); + init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'"); + init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'"); + init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'"); + init( + 3, + "d = c; c.__ipow__(4); assert repr(c) == repr(d) == 'IPO(81)'", + ); +} + +#[pyproto] +impl PyNumberProtocol for BinaryArithmetic { + fn __add__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} + {:?}", lhs, rhs) + } + + fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} - {:?}", lhs, rhs) + } + + fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} * {:?}", lhs, rhs) + } + + fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} << {:?}", lhs, rhs) + } + + fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} >> {:?}", lhs, rhs) + } + + fn __and__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} & {:?}", lhs, rhs) + } + + fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} ^ {:?}", lhs, rhs) + } + + fn __or__(lhs: &PyAny, rhs: &PyAny) -> String { + format!("{:?} | {:?}", lhs, rhs) + } + + fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option) -> String { + format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_) + } +} + +#[test] +fn binary_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, BinaryArithmetic {}).unwrap(); + py_run!(py, c, "assert c + c == 'BA + BA'"); + py_run!(py, c, "assert c.__add__(c) == 'BA + BA'"); + py_run!(py, c, "assert c + 1 == 'BA + 1'"); + py_run!(py, c, "assert 1 + c == '1 + BA'"); + py_run!(py, c, "assert c - 1 == 'BA - 1'"); + py_run!(py, c, "assert 1 - c == '1 - BA'"); + py_run!(py, c, "assert c * 1 == 'BA * 1'"); + py_run!(py, c, "assert 1 * c == '1 * BA'"); + + py_run!(py, c, "assert c << 1 == 'BA << 1'"); + py_run!(py, c, "assert 1 << c == '1 << BA'"); + py_run!(py, c, "assert c >> 1 == 'BA >> 1'"); + py_run!(py, c, "assert 1 >> c == '1 >> BA'"); + py_run!(py, c, "assert c & 1 == 'BA & 1'"); + py_run!(py, c, "assert 1 & c == '1 & BA'"); + py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ BA'"); + py_run!(py, c, "assert c | 1 == 'BA | 1'"); + py_run!(py, c, "assert 1 | c == '1 | BA'"); + py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'"); + py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'"); + + py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); +} + +#[pyclass] +struct RhsArithmetic {} + +#[pyproto] +impl PyNumberProtocol for RhsArithmetic { + fn __radd__(&self, other: &PyAny) -> String { + format!("{:?} + RA", other) + } + + fn __rsub__(&self, other: &PyAny) -> String { + format!("{:?} - RA", other) + } + + fn __rmul__(&self, other: &PyAny) -> String { + format!("{:?} * RA", other) + } + + fn __rlshift__(&self, other: &PyAny) -> String { + format!("{:?} << RA", other) + } + + fn __rrshift__(&self, other: &PyAny) -> String { + format!("{:?} >> RA", other) + } + + fn __rand__(&self, other: &PyAny) -> String { + format!("{:?} & RA", other) + } + + fn __rxor__(&self, other: &PyAny) -> String { + format!("{:?} ^ RA", other) + } + + fn __ror__(&self, other: &PyAny) -> String { + format!("{:?} | RA", other) + } + + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + format!("{:?} ** RA", other) + } +} + +#[test] +fn rhs_arithmetic() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, RhsArithmetic {}).unwrap(); + py_run!(py, c, "assert c.__radd__(1) == '1 + RA'"); + py_run!(py, c, "assert 1 + c == '1 + RA'"); + py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'"); + py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'"); + py_run!(py, c, "assert 1 * c == '1 * RA'"); + py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'"); + py_run!(py, c, "assert 1 << c == '1 << RA'"); + py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'"); + py_run!(py, c, "assert 1 >> c == '1 >> RA'"); + py_run!(py, c, "assert c.__rand__(1) == '1 & RA'"); + py_run!(py, c, "assert 1 & c == '1 & RA'"); + py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ RA'"); + py_run!(py, c, "assert c.__ror__(1) == '1 | RA'"); + py_run!(py, c, "assert 1 | c == '1 | RA'"); + py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'"); + py_run!(py, c, "assert 1 ** c == '1 ** RA'"); +} + +#[pyclass] +struct LhsAndRhs {} + +impl std::fmt::Debug for LhsAndRhs { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "LR") + } +} + +#[pyproto] +impl PyNumberProtocol for LhsAndRhs { + fn __add__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} + {:?}", lhs, rhs) + } + + fn __sub__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} - {:?}", lhs, rhs) + } + + fn __mul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} * {:?}", lhs, rhs) + } + + fn __lshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} << {:?}", lhs, rhs) + } + + fn __rshift__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} >> {:?}", lhs, rhs) + } + + fn __and__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} & {:?}", lhs, rhs) + } + + fn __xor__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} ^ {:?}", lhs, rhs) + } + + fn __or__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} | {:?}", lhs, rhs) + } + + fn __pow__(lhs: PyRef, rhs: &PyAny, _mod: Option) -> String { + format!("{:?} ** {:?}", lhs, rhs) + } + + fn __matmul__(lhs: PyRef, rhs: &PyAny) -> String { + format!("{:?} @ {:?}", lhs, rhs) + } + + fn __radd__(&self, other: &PyAny) -> String { + format!("{:?} + RA", other) + } + + fn __rsub__(&self, other: &PyAny) -> String { + format!("{:?} - RA", other) + } + + fn __rmul__(&self, other: &PyAny) -> String { + format!("{:?} * RA", other) + } + + fn __rlshift__(&self, other: &PyAny) -> String { + format!("{:?} << RA", other) + } + + fn __rrshift__(&self, other: &PyAny) -> String { + format!("{:?} >> RA", other) + } + + fn __rand__(&self, other: &PyAny) -> String { + format!("{:?} & RA", other) + } + + fn __rxor__(&self, other: &PyAny) -> String { + format!("{:?} ^ RA", other) + } + + fn __ror__(&self, other: &PyAny) -> String { + format!("{:?} | RA", other) + } + + fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String { + format!("{:?} ** RA", other) + } + + fn __rmatmul__(&self, other: &PyAny) -> String { + format!("{:?} @ RA", other) + } +} + +#[pyproto] +impl PyObjectProtocol for LhsAndRhs { + fn __repr__(&self) -> &'static str { + "BA" + } +} + +#[test] +fn lhs_fellback_to_rhs() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, LhsAndRhs {}).unwrap(); + // If the light hand value is `LhsAndRhs`, LHS is used. + py_run!(py, c, "assert c + 1 == 'LR + 1'"); + py_run!(py, c, "assert c - 1 == 'LR - 1'"); + py_run!(py, c, "assert c * 1 == 'LR * 1'"); + py_run!(py, c, "assert c << 1 == 'LR << 1'"); + py_run!(py, c, "assert c >> 1 == 'LR >> 1'"); + py_run!(py, c, "assert c & 1 == 'LR & 1'"); + py_run!(py, c, "assert c ^ 1 == 'LR ^ 1'"); + py_run!(py, c, "assert c | 1 == 'LR | 1'"); + py_run!(py, c, "assert c ** 1 == 'LR ** 1'"); + py_run!(py, c, "assert c @ 1 == 'LR @ 1'"); + // Fellback to RHS because of type mismatching + py_run!(py, c, "assert 1 + c == '1 + RA'"); + py_run!(py, c, "assert 1 - c == '1 - RA'"); + py_run!(py, c, "assert 1 * c == '1 * RA'"); + py_run!(py, c, "assert 1 << c == '1 << RA'"); + py_run!(py, c, "assert 1 >> c == '1 >> RA'"); + py_run!(py, c, "assert 1 & c == '1 & RA'"); + py_run!(py, c, "assert 1 ^ c == '1 ^ RA'"); + py_run!(py, c, "assert 1 | c == '1 | RA'"); + py_run!(py, c, "assert 1 ** c == '1 ** RA'"); + py_run!(py, c, "assert 1 @ c == '1 @ RA'"); +} + +#[pyclass] +struct RichComparisons {} + +#[pyproto] +impl PyObjectProtocol for RichComparisons { + fn __repr__(&self) -> &'static str { + "RC" + } + + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> String { + match op { + CompareOp::Lt => format!("{} < {:?}", self.__repr__(), other), + CompareOp::Le => format!("{} <= {:?}", self.__repr__(), other), + CompareOp::Eq => format!("{} == {:?}", self.__repr__(), other), + CompareOp::Ne => format!("{} != {:?}", self.__repr__(), other), + CompareOp::Gt => format!("{} > {:?}", self.__repr__(), other), + CompareOp::Ge => format!("{} >= {:?}", self.__repr__(), other), + } + } +} + +#[pyclass] +struct RichComparisons2 {} + +#[pyproto] +impl PyObjectProtocol for RichComparisons2 { + fn __repr__(&self) -> &'static str { + "RC2" + } + + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyObject { + match op { + CompareOp::Eq => true.into_py(other.py()), + CompareOp::Ne => false.into_py(other.py()), + _ => other.py().NotImplemented(), + } + } +} + +#[test] +fn rich_comparisons() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, RichComparisons {}).unwrap(); + py_run!(py, c, "assert (c < c) == 'RC < RC'"); + py_run!(py, c, "assert (c < 1) == 'RC < 1'"); + py_run!(py, c, "assert (1 < c) == 'RC > 1'"); + py_run!(py, c, "assert (c <= c) == 'RC <= RC'"); + py_run!(py, c, "assert (c <= 1) == 'RC <= 1'"); + py_run!(py, c, "assert (1 <= c) == 'RC >= 1'"); + py_run!(py, c, "assert (c == c) == 'RC == RC'"); + py_run!(py, c, "assert (c == 1) == 'RC == 1'"); + py_run!(py, c, "assert (1 == c) == 'RC == 1'"); + py_run!(py, c, "assert (c != c) == 'RC != RC'"); + py_run!(py, c, "assert (c != 1) == 'RC != 1'"); + py_run!(py, c, "assert (1 != c) == 'RC != 1'"); + py_run!(py, c, "assert (c > c) == 'RC > RC'"); + py_run!(py, c, "assert (c > 1) == 'RC > 1'"); + py_run!(py, c, "assert (1 > c) == 'RC < 1'"); + py_run!(py, c, "assert (c >= c) == 'RC >= RC'"); + py_run!(py, c, "assert (c >= 1) == 'RC >= 1'"); + py_run!(py, c, "assert (1 >= c) == 'RC <= 1'"); +} + +#[test] +fn rich_comparisons_python_3_type_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c2 = PyCell::new(py, RichComparisons2 {}).unwrap(); + py_expect_exception!(py, c2, "c2 < c2", PyTypeError); + py_expect_exception!(py, c2, "c2 < 1", PyTypeError); + py_expect_exception!(py, c2, "1 < c2", PyTypeError); + py_expect_exception!(py, c2, "c2 <= c2", PyTypeError); + py_expect_exception!(py, c2, "c2 <= 1", PyTypeError); + py_expect_exception!(py, c2, "1 <= c2", PyTypeError); + py_run!(py, c2, "assert (c2 == c2) == True"); + py_run!(py, c2, "assert (c2 == 1) == True"); + py_run!(py, c2, "assert (1 == c2) == True"); + py_run!(py, c2, "assert (c2 != c2) == False"); + py_run!(py, c2, "assert (c2 != 1) == False"); + py_run!(py, c2, "assert (1 != c2) == False"); + py_expect_exception!(py, c2, "c2 > c2", PyTypeError); + py_expect_exception!(py, c2, "c2 > 1", PyTypeError); + py_expect_exception!(py, c2, "1 > c2", PyTypeError); + py_expect_exception!(py, c2, "c2 >= c2", PyTypeError); + py_expect_exception!(py, c2, "c2 >= 1", PyTypeError); + py_expect_exception!(py, c2, "1 >= c2", PyTypeError); +} + +// Checks that binary operations for which the arguments don't match the +// required type, return NotImplemented. +mod return_not_implemented { + use super::*; + + #[pyclass] + struct RichComparisonToSelf {} + + #[pyproto] + impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf { + fn __repr__(&self) -> &'static str { + "RC_Self" + } + + fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject { + other.py().None() + } + } + + #[pyproto] + impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf { + fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option) -> &'p PyAny { + lhs + } + fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny { + lhs + } + + // Inplace assignments + fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {} + fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {} + } + + fn _test_binary_dunder(dunder: &str) { + let gil = Python::acquire_gil(); + let py = gil.python(); + let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap(); + py_run!( + py, + c2, + &format!( + "class Other: pass\nassert c2.__{}__(Other()) is NotImplemented", + dunder + ) + ); + } + + fn _test_binary_operator(operator: &str, dunder: &str) { + _test_binary_dunder(dunder); + + let gil = Python::acquire_gil(); + let py = gil.python(); + let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap(); + py_expect_exception!( + py, + c2, + &format!("class Other: pass\nc2 {} Other()", operator), + PyTypeError + ); + } + + fn _test_inplace_binary_operator(operator: &str, dunder: &str) { + _test_binary_operator(operator, dunder); + } + + #[test] + fn equality() { + _test_binary_dunder("eq"); + _test_binary_dunder("ne"); + } + + #[test] + fn ordering() { + _test_binary_operator("<", "lt"); + _test_binary_operator("<=", "le"); + _test_binary_operator(">", "gt"); + _test_binary_operator(">=", "ge"); + } + + #[test] + fn bitwise() { + _test_binary_operator("&", "and"); + _test_binary_operator("|", "or"); + _test_binary_operator("^", "xor"); + _test_binary_operator("<<", "lshift"); + _test_binary_operator(">>", "rshift"); + } + + #[test] + fn arith() { + _test_binary_operator("+", "add"); + _test_binary_operator("-", "sub"); + _test_binary_operator("*", "mul"); + _test_binary_operator("@", "matmul"); + _test_binary_operator("/", "truediv"); + _test_binary_operator("//", "floordiv"); + _test_binary_operator("%", "mod"); + _test_binary_operator("**", "pow"); + } + + #[test] + #[ignore] + fn reverse_arith() { + _test_binary_dunder("radd"); + _test_binary_dunder("rsub"); + _test_binary_dunder("rmul"); + _test_binary_dunder("rmatmul"); + _test_binary_dunder("rtruediv"); + _test_binary_dunder("rfloordiv"); + _test_binary_dunder("rmod"); + _test_binary_dunder("rpow"); + } + + #[test] + fn inplace_bitwise() { + _test_inplace_binary_operator("&=", "iand"); + _test_inplace_binary_operator("|=", "ior"); + _test_inplace_binary_operator("^=", "ixor"); + _test_inplace_binary_operator("<<=", "ilshift"); + _test_inplace_binary_operator(">>=", "irshift"); + } + + #[test] + fn inplace_arith() { + _test_inplace_binary_operator("+=", "iadd"); + _test_inplace_binary_operator("-=", "isub"); + _test_inplace_binary_operator("*=", "imul"); + _test_inplace_binary_operator("@=", "imatmul"); + _test_inplace_binary_operator("/=", "itruediv"); + _test_inplace_binary_operator("//=", "ifloordiv"); + _test_inplace_binary_operator("%=", "imod"); + _test_inplace_binary_operator("**=", "ipow"); + } +} diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs new file mode 100644 index 00000000..1cfafefd --- /dev/null +++ b/tests/test_proto_methods.rs @@ -0,0 +1,532 @@ +#![cfg(not(feature = "multiple-pymethods"))] + +use pyo3::exceptions::PyValueError; +use pyo3::types::{PySlice, PyType}; +use pyo3::{exceptions::PyAttributeError, prelude::*}; +use pyo3::{ffi, py_run, AsPyPointer, PyCell}; +use std::{isize, iter}; + +mod common; + +#[pyclass] +struct ExampleClass { + #[pyo3(get, set)] + value: i32, + _custom_attr: Option, +} + +#[pymethods] +impl ExampleClass { + fn __getattr__(&self, py: Python, attr: &str) -> PyResult { + if attr == "special_custom_attr" { + Ok(self._custom_attr.into_py(py)) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __setattr__(&mut self, attr: &str, value: &PyAny) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = Some(value.extract()?); + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __delattr__(&mut self, attr: &str) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = None; + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __str__(&self) -> String { + self.value.to_string() + } + + fn __repr__(&self) -> String { + format!("ExampleClass(value={})", self.value) + } + + fn __hash__(&self) -> u64 { + let i64_value: i64 = self.value.into(); + i64_value as u64 + } + + fn __bool__(&self) -> bool { + self.value != 0 + } +} + +fn make_example(py: Python) -> &PyCell { + Py::new( + py, + ExampleClass { + value: 5, + _custom_attr: Some(20), + }, + ) + .unwrap() + .into_ref(py) +} + +#[test] +fn test_getattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py + .getattr("value") + .unwrap() + .extract::() + .unwrap(), + 5, + ); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 20, + ); + assert!(example_py + .getattr("other_attr") + .unwrap_err() + .is_instance::(py)); + }) +} + +#[test] +fn test_setattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.setattr("special_custom_attr", 15).unwrap(); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 15, + ); + }) +} + +#[test] +fn test_delattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.delattr("special_custom_attr").unwrap(); + assert!(example_py.getattr("special_custom_attr").unwrap().is_none()); + }) +} + +#[test] +fn test_str() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.str().unwrap().to_str().unwrap(), "5"); + }) +} + +#[test] +fn test_repr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py.repr().unwrap().to_str().unwrap(), + "ExampleClass(value=5)" + ); + }) +} + +#[test] +fn test_hash() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.hash().unwrap(), 5); + }) +} + +#[test] +fn test_bool() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert!(example_py.is_true().unwrap()); + example_py.borrow_mut().value = 0; + assert!(!example_py.is_true().unwrap()); + }) +} + +#[pyclass] +pub struct Len { + l: usize, +} + +#[pymethods] +impl Len { + fn __len__(&self) -> usize { + self.l + } +} + +#[test] +fn len() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new(py, Len { l: 10 }).unwrap(); + py_assert!(py, inst, "len(inst) == 10"); + unsafe { + assert_eq!(ffi::PyObject_Size(inst.as_ptr()), 10); + assert_eq!(ffi::PyMapping_Size(inst.as_ptr()), 10); + } + + let inst = Py::new( + py, + Len { + l: (isize::MAX as usize) + 1, + }, + ) + .unwrap(); + py_expect_exception!(py, inst, "len(inst)", PyOverflowError); +} + +#[pyclass] +struct Iterator { + iter: Box + Send>, +} + +#[pymethods] +impl Iterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> Option { + slf.iter.next() + } +} + +#[test] +fn iterator() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new( + py, + Iterator { + iter: Box::new(5..8), + }, + ) + .unwrap(); + py_assert!(py, inst, "iter(inst) is inst"); + py_assert!(py, inst, "list(inst) == [5, 6, 7]"); +} + +#[pyclass] +struct Callable {} + +#[pymethods] +impl Callable { + #[__call__] + fn __call__(&self, arg: i32) -> i32 { + arg * 6 + } +} + +#[pyclass] +struct EmptyClass; + +#[test] +fn callable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Callable {}).unwrap(); + py_assert!(py, c, "callable(c)"); + py_assert!(py, c, "c(7) == 42"); + + let nc = Py::new(py, EmptyClass).unwrap(); + py_assert!(py, nc, "not callable(nc)"); +} + +#[pyclass] +#[derive(Debug)] +struct SetItem { + key: i32, + val: i32, +} + +#[pymethods] +impl SetItem { + fn __setitem__(&mut self, key: i32, val: i32) { + self.key = key; + self.val = val; + } +} + +#[test] +fn setitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetItem { key: 0, val: 0 }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + assert_eq!(c.val, 2); + } + py_expect_exception!(py, c, "del c[1]", PyNotImplementedError); +} + +#[pyclass] +struct DelItem { + key: i32, +} + +#[pymethods] +impl DelItem { + fn __delitem__(&mut self, key: i32) { + self.key = key; + } +} + +#[test] +fn delitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, DelItem { key: 0 }).unwrap(); + py_run!(py, c, "del c[1]"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + } + py_expect_exception!(py, c, "c[1] = 2", PyNotImplementedError); +} + +#[pyclass] +struct SetDelItem { + val: Option, +} + +#[pymethods] +impl SetDelItem { + fn __setitem__(&mut self, _key: i32, val: i32) { + self.val = Some(val); + } + + fn __delitem__(&mut self, _key: i32) { + self.val = None; + } +} + +#[test] +fn setdelitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetDelItem { val: None }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.val, Some(2)); + } + py_run!(py, c, "del c[1]"); + let c = c.borrow(); + assert_eq!(c.val, None); +} + +#[pyclass] +struct Contains {} + +#[pymethods] +impl Contains { + fn __contains__(&self, item: i32) -> bool { + item >= 0 + } +} + +#[test] +fn contains() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Contains {}).unwrap(); + py_run!(py, c, "assert 1 in c"); + py_run!(py, c, "assert -1 not in c"); + py_expect_exception!(py, c, "assert 'wrong type' not in c", PyTypeError); +} + +#[pyclass] +struct GetItem {} + +#[pymethods] +impl GetItem { + fn __getitem__(&self, idx: &PyAny) -> PyResult<&'static str> { + if let Ok(slice) = idx.cast_as::() { + let indices = slice.indices(1000)?; + if indices.start == 100 && indices.stop == 200 && indices.step == 1 { + return Ok("slice"); + } + } else if let Ok(idx) = idx.extract::() { + if idx == 1 { + return Ok("int"); + } + } + Err(PyValueError::new_err("error")) + } +} + +#[test] +fn test_getitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let ob = Py::new(py, GetItem {}).unwrap(); + + py_assert!(py, ob, "ob[1] == 'int'"); + py_assert!(py, ob, "ob[100:200:1] == 'slice'"); +} + +#[pyclass] +struct ClassWithGetAttr { + #[pyo3(get, set)] + data: u32, +} + +#[pymethods] +impl ClassWithGetAttr { + fn __getattr__(&self, _name: &str) -> u32 { + self.data * 2 + } +} + +#[test] +fn getattr_doesnt_override_member() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let inst = PyCell::new(py, ClassWithGetAttr { data: 4 }).unwrap(); + py_assert!(py, inst, "inst.data == 4"); + py_assert!(py, inst, "inst.a == 8"); +} + +/// Wraps a Python future and yield it once. +#[pyclass] +struct OnceFuture { + future: PyObject, + polled: bool, +} + +#[pymethods] +impl OnceFuture { + #[new] + fn new(future: PyObject) -> Self { + OnceFuture { + future, + polled: false, + } + } + + fn __await__(slf: PyRef) -> PyRef { + slf + } + + fn __iter__(slf: PyRef) -> PyRef { + slf + } + fn __next__(mut slf: PyRefMut) -> Option { + if !slf.polled { + slf.polled = true; + Some(slf.future.clone()) + } else { + None + } + } +} + +#[test] +fn test_await() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let once = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +import asyncio +import sys + +async def main(): + res = await Once(await asyncio.sleep(0.1)) + return res +# For an odd error similar to https://bugs.python.org/issue38563 +if sys.platform == "win32" and sys.version_info >= (3, 8, 0): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +# get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579 +loop = asyncio.new_event_loop() +asyncio.set_event_loop(loop) +assert loop.run_until_complete(main()) is None +loop.close() +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Once", once).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + +/// Increment the count when `__get__` is called. +#[pyclass] +struct DescrCounter { + #[pyo3(get)] + count: usize, +} + +#[pymethods] +impl DescrCounter { + #[new] + fn new() -> Self { + DescrCounter { count: 0 } + } + + fn __get__<'a>( + mut slf: PyRefMut<'a, Self>, + _instance: &PyAny, + _owner: Option<&PyType>, + ) -> PyRefMut<'a, Self> { + slf.count += 1; + slf + } + fn __set__(_slf: PyRef, _instance: &PyAny, mut new_value: PyRefMut) { + new_value.count = _slf.count; + } +} + +#[test] +fn descr_getset() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let counter = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +class Class: + counter = Counter() +c = Class() +c.counter # count += 1 +assert c.counter.count == 2 +c.counter = Counter() +assert c.counter.count == 3 +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Counter", counter).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +}