From 8408328cb3905a132bddbd26c3ea7c69d8af8bd9 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 9 Sep 2021 09:11:41 +0100 Subject: [PATCH] pymethods: add support for protocol methods --- pyo3-macros-backend/src/pyclass.rs | 9 + pyo3-macros-backend/src/pyimpl.rs | 43 +- pyo3-macros-backend/src/pymethod.rs | 366 ++++++++++++- src/class/basic.rs | 27 +- src/class/impl_.rs | 183 ++++++- tests/test_proto_methods.rs | 765 ++++++++++++++++++++++++++++ 6 files changed, 1376 insertions(+), 17 deletions(-) create mode 100644 tests/test_proto_methods.rs diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 034d1784..cf8b43ba 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -591,6 +591,15 @@ fn impl_class( visitor(collector.sequence_protocol_slots()); visitor(collector.async_protocol_slots()); visitor(collector.buffer_protocol_slots()); + visitor(collector.methods_protocol_slots()); + let mut generated_slots = Vec::new(); + if let ::std::option::Option::Some(setattr) = ::pyo3::generate_pyclass_setattr_slot!(#cls) { + generated_slots.push(setattr); + } + if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setdescr_slot!(#cls) { + generated_slots.push(setdescr); + } + visitor(&generated_slots); } fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 61d20267..5d2dd881 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -37,8 +37,8 @@ pub fn impl_methods( impls: &mut Vec, methods_type: PyClassMethodsType, ) -> syn::Result { - let mut new_impls = Vec::new(); - let mut call_impls = Vec::new(); + let mut trait_impls = Vec::new(); + let mut proto_impls = Vec::new(); let mut methods = Vec::new(); for iimpl in impls.iter_mut() { match iimpl { @@ -49,13 +49,13 @@ pub fn impl_methods( let attrs = get_cfg_attributes(&meth.attrs); methods.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::New(token_stream) => { + GeneratedPyMethod::TraitImpl(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); - new_impls.push(quote!(#(#attrs)* #token_stream)); + trait_impls.push(quote!(#(#attrs)* #token_stream)); } - GeneratedPyMethod::Call(token_stream) => { + GeneratedPyMethod::Proto(token_stream) => { let attrs = get_cfg_attributes(&meth.attrs); - call_impls.push(quote!(#(#attrs)* #token_stream)); + proto_impls.push(quote!(#(#attrs)* #token_stream)) } } } @@ -80,10 +80,23 @@ pub fn impl_methods( PyClassMethodsType::Inventory => submit_methods_inventory(ty, methods), }; - Ok(quote! { - #(#new_impls)* + let protos_registration = match methods_type { + PyClassMethodsType::Specialization => Some(impl_protos(ty, proto_impls)), + PyClassMethodsType::Inventory => { + if proto_impls.is_empty() { + None + } else { + panic!( + "cannot implement protos in #[pymethods] using `multiple-pymethods` feature" + ); + } + } + }; - #(#call_impls)* + Ok(quote! { + #(#trait_impls)* + + #protos_registration #methods_registration }) @@ -122,6 +135,18 @@ fn impl_py_methods(ty: &syn::Type, methods: Vec) -> TokenStream { } } +fn impl_protos(ty: &syn::Type, proto_impls: Vec) -> TokenStream { + quote! { + impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty> + for ::pyo3::class::impl_::PyClassImplCollector<#ty> + { + fn methods_protocol_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] { + &[#(#proto_impls),*] + } + } + } +} + fn submit_methods_inventory(ty: &syn::Type, methods: Vec) -> TokenStream { if methods.is_empty() { return TokenStream::default(); diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 93f771ab..28de756a 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use crate::attributes::NameAttribute; -use crate::utils::{ensure_not_async_fn, PythonDoc}; +use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc}; use crate::{deprecations::Deprecations, utils}; use crate::{ method::{FnArg, FnSpec, FnType, SelfType}, @@ -11,12 +11,13 @@ use crate::{ }; use proc_macro2::{Span, TokenStream}; use quote::quote; +use syn::Ident; use syn::{ext::IdentExt, spanned::Spanned, Result}; pub enum GeneratedPyMethod { Method(TokenStream), - New(TokenStream), - Call(TokenStream), + Proto(TokenStream), + TraitImpl(TokenStream), } pub fn gen_py_method( @@ -30,6 +31,14 @@ pub fn gen_py_method( ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; + if let Some(proto) = pyproto(cls, &spec) { + return Ok(GeneratedPyMethod::Proto(proto)); + } + + if let Some(proto) = pyproto_fragment(cls, &spec)? { + return Ok(GeneratedPyMethod::TraitImpl(proto)); + } + Ok(match &spec.tp { // ordinary functions (with some specialties) FnType::Fn(_) => GeneratedPyMethod::Method(impl_py_method_def(cls, &spec, None)?), @@ -44,8 +53,8 @@ pub fn gen_py_method( Some(quote!(::pyo3::ffi::METH_STATIC)), )?), // special prototypes - FnType::FnNew => GeneratedPyMethod::New(impl_py_method_def_new(cls, &spec)?), - FnType::FnCall(_) => GeneratedPyMethod::Call(impl_py_method_def_call(cls, &spec)?), + FnType::FnNew => GeneratedPyMethod::TraitImpl(impl_py_method_def_new(cls, &spec)?), + FnType::FnCall(_) => GeneratedPyMethod::TraitImpl(impl_py_method_def_call(cls, &spec)?), FnType::ClassAttribute => GeneratedPyMethod::Method(impl_py_class_attribute(cls, &spec)), FnType::Getter(self_type) => GeneratedPyMethod::Method(impl_py_getter_def( cls, @@ -364,3 +373,350 @@ impl PropertyType<'_> { } } } + +fn pyproto(cls: &syn::Type, spec: &FnSpec) -> Option { + match spec.python_name.to_string().as_str() { + "__getattr__" => Some( + SlotDef::new("Py_tp_getattro", "getattrofunc") + .arguments(&[Ty::Object]) + .before_call_method(quote! { + // Behave like python's __getattr__ (as opposed to __getattribute__) and check + // for existing fields and methods first + let existing = ::pyo3::ffi::PyObject_GenericGetAttr(_slf, arg0); + if existing.is_null() { + // PyObject_HasAttr also tries to get an object and clears the error if it fails + ::pyo3::ffi::PyErr_Clear(); + } else { + return existing; + } + }) + .generate_type_slot(cls, spec), + ), + "__str__" => Some(SlotDef::new("Py_tp_str", "reprfunc").generate_type_slot(cls, spec)), + "__repr__" => Some(SlotDef::new("Py_tp_repr", "reprfunc").generate_type_slot(cls, spec)), + "__hash__" => Some( + SlotDef::new("Py_tp_hash", "hashfunc") + .ret_ty(Ty::PyHashT) + .return_conversion(quote! { ::pyo3::callback::HashCallbackOutput }) + .generate_type_slot(cls, spec), + ), + "__richcmp__" => Some( + SlotDef::new("Py_tp_richcompare", "richcmpfunc") + .arguments(&[Ty::Object, Ty::CompareOp]) + .generate_type_slot(cls, spec), + ), + "__bool__" => Some( + SlotDef::new("Py_nb_bool", "inquiry") + .ret_ty(Ty::Int) + .generate_type_slot(cls, spec), + ), + "__get__" => Some( + SlotDef::new("Py_tp_descr_get", "descrgetfunc") + .arguments(&[Ty::Object, Ty::Object]) + .generate_type_slot(cls, spec), + ), + _ => None, + } +} + +#[derive(Clone, Copy)] +enum Ty { + Object, + NonNullObject, + CompareOp, + Int, + PyHashT, +} + +impl Ty { + fn ffi_type(self) -> TokenStream { + match self { + Ty::Object => quote! { *mut ::pyo3::ffi::PyObject }, + Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> }, + Ty::Int => quote! { ::std::os::raw::c_int }, + Ty::CompareOp => quote! { ::std::os::raw::c_int }, + Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t }, + } + } + + fn extract( + self, + cls: &syn::Type, + py: &syn::Ident, + ident: &syn::Ident, + target: &syn::Type, + ) -> TokenStream { + match self { + Ty::Object => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident); + #extract + } + } + Ty::NonNullObject => { + let extract = extract_from_any(cls, target, ident); + quote! { + let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident.as_ptr()); + #extract + } + } + Ty::Int => todo!(), + Ty::PyHashT => todo!(), + Ty::CompareOp => quote! { + let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident) + .ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?; + }, + } + } +} + +fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -> TokenStream { + return if let syn::Type::Reference(tref) = unwrap_ty_group(target) { + let (tref, mut_) = preprocess_tref(tref, self_); + quote! { + let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #ident.extract()?; + let #ident = &#mut_ *#ident; + } + } else { + quote! { + let #ident = #ident.extract()?; + } + }; + + /// Replace `Self`, remove lifetime and get mutability from the type + fn preprocess_tref( + tref: &syn::TypeReference, + self_: &syn::Type, + ) -> (syn::TypeReference, Option) { + let mut tref = tref.to_owned(); + if let syn::Type::Path(tpath) = self_ { + replace_self(&mut tref, &tpath.path); + } + tref.lifetime = None; + let mut_ = tref.mutability; + (tref, mut_) + } + + /// Replace `Self` with the exact type name since it is used out of the impl block + fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { + match &mut *tref.elem { + syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), + syn::Type::Path(tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + tpath.path = self_path.to_owned(); + } + } + } + _ => {} + } + } +} + +struct SlotDef { + slot: syn::Ident, + func_ty: syn::Ident, + arguments: &'static [Ty], + ret_ty: Ty, + before_call_method: Option, + return_conversion: Option, +} + +impl SlotDef { + fn new(slot: &str, func_ty: &str) -> Self { + SlotDef { + slot: syn::Ident::new(slot, Span::call_site()), + func_ty: syn::Ident::new(func_ty, Span::call_site()), + arguments: &[], + ret_ty: Ty::Object, + before_call_method: None, + return_conversion: None, + } + } + + fn arguments(mut self, arguments: &'static [Ty]) -> Self { + self.arguments = arguments; + self + } + + fn ret_ty(mut self, ret_ty: Ty) -> Self { + self.ret_ty = ret_ty; + self + } + + fn before_call_method(mut self, before_call_method: TokenStream) -> Self { + self.before_call_method = Some(before_call_method); + self + } + + fn return_conversion(mut self, return_conversion: TokenStream) -> Self { + self.return_conversion = Some(return_conversion); + self + } + + fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> TokenStream { + let SlotDef { + slot, + func_ty, + before_call_method, + arguments, + ret_ty, + return_conversion, + } = self; + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let arguments = arguments.into_iter().enumerate().map(|(i, arg)| { + let ident = syn::Ident::new(&format!("arg{}", i), Span::call_site()); + let ffi_type = arg.ffi_type(); + quote! { + #ident: #ffi_type + } + }); + let ret_ty = ret_ty.ffi_type(); + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &self.arguments); + let call = + quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) }; + let body = if let Some(return_conversion) = return_conversion { + quote! { + let _result: PyResult<#return_conversion> = #call; + ::pyo3::callback::convert(#py, _result) + } + } else { + call + }; + quote!({ + unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#arguments),*) -> #ret_ty { + #before_call_method + ::pyo3::callback::handle_panic(|#py| { + #self_conversion + #conversions + #body + }) + } + ::pyo3::ffi::PyType_Slot { + slot: ::pyo3::ffi::#slot, + pfunc: __wrap as ::pyo3::ffi::#func_ty as _ + } + }) + } +} + +fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result> { + Ok(match spec.python_name.to_string().as_str() { + "__setattr__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassSetattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn setattr_implemented(self) -> bool { true } + + #[inline] + unsafe fn setattr( + self, + _slf: *mut ::pyo3::ffi::PyObject, + arg0: *mut ::pyo3::ffi::PyObject, + arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject> + ) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + } + }) + } + "__delattr__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassDelattrSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn delattr_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + "__set__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object, Ty::NonNullObject]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassSetSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn set_impl(self) -> ::std::option::Option) -> ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject, arg1: ::std::ptr::NonNull<::pyo3::ffi::PyObject>) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + "__delete__" => { + let py = syn::Ident::new("_py", Span::call_site()); + let self_conversion = spec.tp.self_conversion(Some(cls)); + let rust_name = spec.name; + let (arg_idents, conversions) = + extract_proto_arguments(cls, &py, &spec.args, &[Ty::Object]); + Some(quote! { + impl ::pyo3::class::impl_::PyClassDeleteSlotFragment<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + fn delete_impl(self) -> ::std::option::Option ::pyo3::PyResult<()>> { + unsafe fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, arg0: *mut ::pyo3::ffi::PyObject) -> ::pyo3::PyResult<()> { + let #py = ::pyo3::Python::assume_gil_acquired(); + #self_conversion + #conversions + ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) + } + Some(__wrap) + } + } + }) + } + _ => None, + }) +} + +fn extract_proto_arguments( + cls: &syn::Type, + py: &syn::Ident, + method_args: &[FnArg], + proto_args: &[Ty], +) -> (Vec, TokenStream) { + let mut arg_idents = Vec::with_capacity(method_args.len()); + let mut non_python_args = 0; + + let args_conversion = method_args.into_iter().filter_map(|arg| { + if arg.py { + arg_idents.push(py.clone()); + None + } else { + let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site()); + let conversions = proto_args[non_python_args].extract(cls, py, &ident, arg.ty); + non_python_args += 1; + arg_idents.push(ident); + Some(conversions) + } + }); + let conversions = quote!(#(#args_conversion)*); + (arg_idents, conversions) +} diff --git a/src/class/basic.rs b/src/class/basic.rs index affa7ee6..602096c5 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -13,7 +13,7 @@ use crate::{exceptions, ffi, FromPyObject, PyAny, PyCell, PyClass, PyObject}; use std::os::raw::c_int; /// Operators for the `__richcmp__` method -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum CompareOp { /// The *less than* operator. Lt = ffi::Py_LT as isize, @@ -29,6 +29,31 @@ pub enum CompareOp { Ge = ffi::Py_GE as isize, } +impl CompareOp { + pub fn from_raw(op: c_int) -> Option { + match op { + ffi::Py_LT => Some(CompareOp::Lt), + ffi::Py_LE => Some(CompareOp::Le), + ffi::Py_EQ => Some(CompareOp::Eq), + ffi::Py_NE => Some(CompareOp::Ne), + ffi::Py_GT => Some(CompareOp::Gt), + ffi::Py_GE => Some(CompareOp::Ge), + _ => None, + } + } + + pub fn matches_ordering(self, ordering: std::cmp::Ordering) -> bool { + match self { + CompareOp::Lt => ordering.is_lt(), + CompareOp::Le => ordering.is_le(), + CompareOp::Eq => ordering.is_eq(), + CompareOp::Ne => ordering.is_ne(), + CompareOp::Gt => ordering.is_gt(), + CompareOp::Ge => ordering.is_ge(), + } + } +} + /// Basic Python class customization #[allow(unused_variables)] pub trait PyObjectProtocol<'p>: PyClass { diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 86bf8496..9015ad18 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -3,12 +3,13 @@ use crate::{ ffi, impl_::freelist::FreeList, + exceptions::PyAttributeError, pycell::PyCellLayout, pyclass_init::PyObjectInit, type_object::{PyLayout, PyTypeObject}, - PyClass, PyMethodDefType, PyNativeType, PyTypeInfo, Python, + PyClass, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python, }; -use std::{marker::PhantomData, os::raw::c_void, thread}; +use std::{marker::PhantomData, os::raw::c_void, ptr::NonNull, thread}; /// This type is used as a "dummy" type on which dtolnay specializations are /// applied to apply implementations from `#[pymethods]` & `#[pyproto]` @@ -107,6 +108,181 @@ impl PyClassCallImpl for &'_ PyClassImplCollector { } } +pub trait PyClassSetattrSlotFragment: Sized { + #[inline] + fn setattr_implemented(self) -> bool { + false + } + + unsafe fn setattr( + self, + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>; +} + +impl PyClassSetattrSlotFragment for &'_ PyClassImplCollector { + #[inline] + unsafe fn setattr( + self, + _slf: *mut ffi::PyObject, + _attr: *mut ffi::PyObject, + _value: NonNull, + ) -> PyResult<()> { + Err(PyAttributeError::new_err("can't set attribute")) + } +} + +pub trait PyClassDelattrSlotFragment { + fn delattr_impl( + self, + ) -> Option PyResult<()>>; +} + +impl PyClassDelattrSlotFragment for &'_ PyClassImplCollector { + fn delattr_impl( + self, + ) -> Option PyResult<()>> { + None + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_setattr_slot { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let delattr = collector.delattr_impl(); + if collector.setattr_implemented() || delattr.is_some() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + collector.setattr(_slf, attr, value) + } else { + if let Some(del) = collector.delattr_impl() { + del(_slf, attr) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyAttributeError::new_err( + "can't delete attribute", + ), + ) + } + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_tp_setattro, + pfunc: __wrap as $crate::ffi::setattrofunc as _, + }) + } else { + None + } + }}; +} + +pub trait PyClassSetSlotFragment { + fn set_impl( + self, + ) -> Option< + unsafe fn( + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>, + >; +} + +impl PyClassSetSlotFragment for &'_ PyClassImplCollector { + fn set_impl( + self, + ) -> Option< + unsafe fn( + _slf: *mut ffi::PyObject, + attr: *mut ffi::PyObject, + value: NonNull, + ) -> PyResult<()>, + > { + None + } +} + +pub trait PyClassDeleteSlotFragment { + fn delete_impl( + self, + ) -> Option PyResult<()>>; +} + +impl PyClassDeleteSlotFragment for &'_ PyClassImplCollector { + fn delete_impl( + self, + ) -> Option PyResult<()>> { + None + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_setdescr_slot { + ($cls:ty) => {{ + use ::std::option::Option::*; + use $crate::class::impl_::*; + let collector = PyClassImplCollector::<$cls>::new(); + let set = collector.set_impl(); + let delete = collector.delete_impl(); + if set.is_some() || delete.is_some() { + unsafe extern "C" fn __wrap( + _slf: *mut $crate::ffi::PyObject, + attr: *mut $crate::ffi::PyObject, + value: *mut $crate::ffi::PyObject, + ) -> ::std::os::raw::c_int { + $crate::callback::handle_panic::<_, ::std::os::raw::c_int>(|py| { + let collector = PyClassImplCollector::<$cls>::new(); + $crate::callback::convert(py, { + if let Some(value) = ::std::ptr::NonNull::new(value) { + if let Some(set) = collector.set_impl() { + set(_slf, attr, value) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyTypeError::new_err( + "can't set descriptor", + ), + ) + } + } else { + if let Some(del) = collector.delete_impl() { + del(_slf, attr) + } else { + ::std::result::Result::Err( + $crate::exceptions::PyTypeError::new_err( + "can't delete descriptor", + ), + ) + } + } + }) + }) + } + Some($crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_tp_descr_set, + pfunc: __wrap as $crate::ffi::descrsetfunc as _, + }) + } else { + None + } + }}; +} + pub trait PyClassAllocImpl { fn alloc_impl(self) -> Option; } @@ -288,6 +464,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots); slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots); slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots); +#[cfg(not(feature = "multiple-pymethods"))] +slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots); + methods_trait!(PyObjectProtocolMethods, object_protocol_methods); methods_trait!(PyAsyncProtocolMethods, async_protocol_methods); methods_trait!(PyContextProtocolMethods, context_protocol_methods); diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs new file mode 100644 index 00000000..5de12f0e --- /dev/null +++ b/tests/test_proto_methods.rs @@ -0,0 +1,765 @@ +use pyo3::{basic::CompareOp, exceptions::PyAttributeError, prelude::*}; +use pyo3::exceptions::{PyIndexError, PyValueError}; +use pyo3::types::{PySlice, PyType}; +use pyo3::{ffi, py_run, AsPyPointer, PyCell}; +use std::convert::TryFrom; +use std::{isize, iter}; + +mod common; + +#[pyclass] +struct ExampleClass { + #[pyo3(get, set)] + value: i32, + _custom_attr: Option, +} + +#[pymethods] +impl ExampleClass { + fn __getattr__(&self, py: Python, attr: &str) -> PyResult { + if attr == "special_custom_attr" { + Ok(self._custom_attr.into_py(py)) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __setattr__(&mut self, attr: &str, value: &PyAny) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = Some(value.extract()?); + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __delattr__(&mut self, attr: &str) -> PyResult<()> { + if attr == "special_custom_attr" { + self._custom_attr = None; + Ok(()) + } else { + Err(PyAttributeError::new_err(attr.to_string())) + } + } + + fn __str__(&self) -> String { + self.value.to_string() + } + + fn __repr__(&self) -> String { + format!("ExampleClass(value={})", self.value) + } + + fn __hash__(&self) -> u64 { + let i64_value: i64 = self.value.into(); + i64_value as u64 + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches_ordering(self.value.cmp(&other.value)) + } + + fn __bool__(&self) -> bool { + self.value != 0 + } +} + +fn make_example(py: Python) -> &PyCell { + Py::new( + py, + ExampleClass { + value: 5, + _custom_attr: Some(20), + }, + ) + .unwrap() + .into_ref(py) +} + +#[test] +fn test_getattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py + .getattr("value") + .unwrap() + .extract::() + .unwrap(), + 5, + ); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 20, + ); + assert!(example_py + .getattr("other_attr") + .unwrap_err() + .is_instance::(py)); + }) +} + +#[test] +fn test_setattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.setattr("special_custom_attr", 15).unwrap(); + assert_eq!( + example_py + .getattr("special_custom_attr") + .unwrap() + .extract::() + .unwrap(), + 15, + ); + }) +} + +#[test] +fn test_delattr() { + Python::with_gil(|py| { + let example_py = make_example(py); + example_py.delattr("special_custom_attr").unwrap(); + assert!(example_py.getattr("special_custom_attr").unwrap().is_none()); + }) +} + +#[test] +fn test_str() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.str().unwrap().to_str().unwrap(), "5"); + }) +} + +#[test] +fn test_repr() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py.repr().unwrap().to_str().unwrap(), + "ExampleClass(value=5)" + ); + }) +} + +#[test] +fn test_hash() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!(example_py.hash().unwrap(), 5); + }) +} + +#[test] +fn test_richcmp() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert_eq!( + example_py + .rich_compare(example_py, CompareOp::Eq) + .unwrap() + .is_true() + .unwrap(), + true + ); + }) +} + +#[test] +fn test_bool() { + Python::with_gil(|py| { + let example_py = make_example(py); + assert!(example_py.is_true().unwrap()); + example_py.borrow_mut().value = 0; + assert!(!example_py.is_true().unwrap()); + }) +} + +#[pyclass] +pub struct Len { + l: usize, +} + +#[pymethods] +impl Len { + fn __len__(&self) -> usize { + self.l + } +} + +#[test] +fn len() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new(py, Len { l: 10 }).unwrap(); + py_assert!(py, inst, "len(inst) == 10"); + unsafe { + assert_eq!(ffi::PyObject_Size(inst.as_ptr()), 10); + assert_eq!(ffi::PyMapping_Size(inst.as_ptr()), 10); + } + + let inst = Py::new( + py, + Len { + l: (isize::MAX as usize) + 1, + }, + ) + .unwrap(); + py_expect_exception!(py, inst, "len(inst)", PyOverflowError); +} + +#[pyclass] +struct Iterator { + iter: Box + Send>, +} + +#[pymethods] +impl Iterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> Option { + slf.iter.next() + } +} + +#[test] +fn iterator() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let inst = Py::new( + py, + Iterator { + iter: Box::new(5..8), + }, + ) + .unwrap(); + py_assert!(py, inst, "iter(inst) is inst"); + py_assert!(py, inst, "list(inst) == [5, 6, 7]"); +} + +#[pyclass] +struct StringMethods {} + +#[pymethods] +impl StringMethods { + fn __str__(&self) -> &'static str { + "str" + } + + fn __repr__(&self) -> &'static str { + "repr" + } + + fn __format__(&self, format_spec: String) -> String { + format!("format({})", format_spec) + } + + fn __bytes__(&self) -> &'static [u8] { + b"bytes" + } +} + +#[test] +fn string_methods() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let obj = Py::new(py, StringMethods {}).unwrap(); + py_assert!(py, obj, "str(obj) == 'str'"); + py_assert!(py, obj, "repr(obj) == 'repr'"); + py_assert!(py, obj, "'{0:x}'.format(obj) == 'format(x)'"); + py_assert!(py, obj, "bytes(obj) == b'bytes'"); + + // Test that `__bytes__` takes no arguments (should be METH_NOARGS) + py_assert!(py, obj, "obj.__bytes__() == b'bytes'"); + py_expect_exception!(py, obj, "obj.__bytes__('unexpected argument')", PyTypeError); +} + +#[pyclass] +struct Comparisons { + val: i32, +} + +#[pymethods] +impl Comparisons { + fn __hash__(&self) -> isize { + self.val as isize + } + fn __bool__(&self) -> bool { + self.val != 0 + } +} + +#[test] +fn comparisons() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let zero = Py::new(py, Comparisons { val: 0 }).unwrap(); + let one = Py::new(py, Comparisons { val: 1 }).unwrap(); + let ten = Py::new(py, Comparisons { val: 10 }).unwrap(); + let minus_one = Py::new(py, Comparisons { val: -1 }).unwrap(); + py_assert!(py, one, "hash(one) == 1"); + py_assert!(py, ten, "hash(ten) == 10"); + py_assert!(py, minus_one, "hash(minus_one) == -2"); + + py_assert!(py, one, "bool(one) is True"); + py_assert!(py, zero, "not zero"); +} + +#[pyclass] +#[derive(Debug)] +struct Sequence { + fields: Vec, +} + +impl Default for Sequence { + fn default() -> Sequence { + let mut fields = vec![]; + for &s in &["A", "B", "C", "D", "E", "F", "G"] { + fields.push(s.to_string()); + } + Sequence { fields } + } +} + +#[pymethods] +impl Sequence { + fn __len__(&self) -> usize { + self.fields.len() + } + + fn __getitem__(&self, key: isize) -> PyResult { + let idx = usize::try_from(key)?; + if let Some(s) = self.fields.get(idx) { + Ok(s.clone()) + } else { + Err(PyIndexError::new_err(())) + } + } + + fn __setitem__(&mut self, idx: isize, value: String) -> PyResult<()> { + let idx = usize::try_from(idx)?; + if let Some(elem) = self.fields.get_mut(idx) { + *elem = value; + Ok(()) + } else { + Err(PyIndexError::new_err(())) + } + } +} + +#[test] +fn sequence() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Sequence::default()).unwrap(); + py_assert!(py, c, "list(c) == ['A', 'B', 'C', 'D', 'E', 'F', 'G']"); + py_assert!(py, c, "c[-1] == 'G'"); + py_run!( + py, + c, + r#" + c[0] = 'H' + assert c[0] == 'H' +"# + ); + py_expect_exception!(py, c, "c['abc']", PyTypeError); +} + +#[pyclass] +struct Callable {} + +#[pymethods] +impl Callable { + #[__call__] + fn __call__(&self, arg: i32) -> i32 { + arg * 6 + } +} + +#[test] +fn callable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Callable {}).unwrap(); + py_assert!(py, c, "callable(c)"); + py_assert!(py, c, "c(7) == 42"); + + let nc = Py::new(py, Comparisons { val: 0 }).unwrap(); + py_assert!(py, nc, "not callable(nc)"); +} + +#[pyclass] +#[derive(Debug)] +struct SetItem { + key: i32, + val: i32, +} + +#[pymethods] +impl SetItem { + fn __setitem__(&mut self, key: i32, val: i32) { + self.key = key; + self.val = val; + } +} + +#[test] +fn setitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetItem { key: 0, val: 0 }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + assert_eq!(c.val, 2); + } + py_expect_exception!(py, c, "del c[1]", PyNotImplementedError); +} + +#[pyclass] +struct DelItem { + key: i32, +} + +#[pymethods] +impl DelItem { + fn __delitem__(&mut self, key: i32) { + self.key = key; + } +} + +#[test] +fn delitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, DelItem { key: 0 }).unwrap(); + py_run!(py, c, "del c[1]"); + { + let c = c.borrow(); + assert_eq!(c.key, 1); + } + py_expect_exception!(py, c, "c[1] = 2", PyNotImplementedError); +} + +#[pyclass] +struct SetDelItem { + val: Option, +} + +#[pymethods] +impl SetDelItem { + fn __setitem__(&mut self, _key: i32, val: i32) { + self.val = Some(val); + } + + fn __delitem__(&mut self, _key: i32) { + self.val = None; + } +} + +#[test] +fn setdelitem() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, SetDelItem { val: None }).unwrap(); + py_run!(py, c, "c[1] = 2"); + { + let c = c.borrow(); + assert_eq!(c.val, Some(2)); + } + py_run!(py, c, "del c[1]"); + let c = c.borrow(); + assert_eq!(c.val, None); +} + +#[pyclass] +struct Reversed {} + +#[pymethods] +impl Reversed { + fn __reversed__(&self) -> &'static str { + "I am reversed" + } +} + +#[test] +fn reversed() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Reversed {}).unwrap(); + py_run!(py, c, "assert reversed(c) == 'I am reversed'"); +} + +#[pyclass] +struct Contains {} + +#[pymethods] +impl Contains { + fn __contains__(&self, item: i32) -> bool { + item >= 0 + } +} + +#[test] +fn contains() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = Py::new(py, Contains {}).unwrap(); + py_run!(py, c, "assert 1 in c"); + py_run!(py, c, "assert -1 not in c"); + py_expect_exception!(py, c, "assert 'wrong type' not in c", PyTypeError); +} + +#[pyclass] +struct ContextManager { + exit_called: bool, +} + +#[pymethods] +impl ContextManager { + fn __enter__(&mut self) -> i32 { + 42 + } + + fn __exit__( + &mut self, + ty: Option<&PyType>, + _value: Option<&PyAny>, + _traceback: Option<&PyAny>, + ) -> bool { + let gil = Python::acquire_gil(); + self.exit_called = true; + ty == Some(gil.python().get_type::()) + } +} + +#[test] +fn context_manager() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = PyCell::new(py, ContextManager { exit_called: false }).unwrap(); + py_run!(py, c, "with c as x: assert x == 42"); + { + let mut c = c.borrow_mut(); + assert!(c.exit_called); + c.exit_called = false; + } + py_run!(py, c, "with c as x: raise ValueError"); + { + let mut c = c.borrow_mut(); + assert!(c.exit_called); + c.exit_called = false; + } + py_expect_exception!( + py, + c, + "with c as x: raise NotImplementedError", + PyNotImplementedError + ); + let c = c.borrow(); + assert!(c.exit_called); +} + +#[test] +fn test_basics() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let v = PySlice::new(py, 1, 10, 2); + let indices = v.indices(100).unwrap(); + assert_eq!(1, indices.start); + assert_eq!(10, indices.stop); + assert_eq!(2, indices.step); + assert_eq!(5, indices.slicelength); +} + +#[pyclass] +struct Test {} + +#[pymethods] +impl Test { + fn __getitem__(&self, idx: &PyAny) -> PyResult<&'static str> { + if let Ok(slice) = idx.cast_as::() { + let indices = slice.indices(1000)?; + if indices.start == 100 && indices.stop == 200 && indices.step == 1 { + return Ok("slice"); + } + } else if let Ok(idx) = idx.extract::() { + if idx == 1 { + return Ok("int"); + } + } + Err(PyValueError::new_err("error")) + } +} + +#[test] +fn test_cls_impl() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let ob = Py::new(py, Test {}).unwrap(); + + py_assert!(py, ob, "ob[1] == 'int'"); + py_assert!(py, ob, "ob[100:200:1] == 'slice'"); +} + +#[pyclass] +struct ClassWithGetAttr { + #[pyo3(get, set)] + data: u32, +} + +#[pymethods] +impl ClassWithGetAttr { + fn __getattr__(&self, _name: &str) -> u32 { + self.data * 2 + } +} + +#[test] +fn getattr_doesnt_override_member() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let inst = PyCell::new(py, ClassWithGetAttr { data: 4 }).unwrap(); + py_assert!(py, inst, "inst.data == 4"); + py_assert!(py, inst, "inst.a == 8"); +} + +/// Wraps a Python future and yield it once. +#[pyclass] +struct OnceFuture { + future: PyObject, + polled: bool, +} + +#[pymethods] +impl OnceFuture { + #[new] + fn new(future: PyObject) -> Self { + OnceFuture { + future, + polled: false, + } + } + + fn __await__(slf: PyRef) -> PyRef { + slf + } + + fn __iter__(slf: PyRef) -> PyRef { + slf + } + fn __next__(mut slf: PyRefMut) -> Option { + if !slf.polled { + slf.polled = true; + Some(slf.future.clone()) + } else { + None + } + } +} + +#[test] +fn test_await() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let once = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +import asyncio +import sys + +async def main(): + res = await Once(await asyncio.sleep(0.1)) + return res +# For an odd error similar to https://bugs.python.org/issue38563 +if sys.platform == "win32" and sys.version_info >= (3, 8, 0): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +# get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579 +loop = asyncio.new_event_loop() +asyncio.set_event_loop(loop) +assert loop.run_until_complete(main()) is None +loop.close() +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Once", once).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + +/// Increment the count when `__get__` is called. +#[pyclass] +struct DescrCounter { + #[pyo3(get)] + count: usize, +} + +#[pymethods] +impl DescrCounter { + #[new] + fn new() -> Self { + DescrCounter { count: 0 } + } + + fn __get__<'a>( + mut slf: PyRefMut<'a, Self>, + _instance: &PyAny, + _owner: Option<&PyType>, + ) -> PyRefMut<'a, Self> { + slf.count += 1; + slf + } + fn __set__(_slf: PyRef, _instance: &PyAny, mut new_value: PyRefMut) { + new_value.count = _slf.count; + } +} + +#[test] +fn descr_getset() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let counter = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +class Class: + counter = Counter() +c = Class() +c.counter # count += 1 +assert c.counter.count == 2 +c.counter = Counter() +assert c.counter.count == 3 +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("Counter", counter).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +} + + +// TODO: test __delete__ +// TODO: better argument casting errors