From 7967874177631ff8f5adac6f0617c1d04cc79a21 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Mon, 1 Jun 2020 22:35:18 +0900 Subject: [PATCH] Remove specialization from basic/buffer/descr/iter protocols --- Cargo.toml | 3 +- pyo3-derive-backend/src/defs.rs | 91 ++++++++- pyo3-derive-backend/src/pyclass.rs | 16 ++ pyo3-derive-backend/src/pyproto.rs | 77 ++++++- src/class/basic.rs | 317 +++++++---------------------- src/class/buffer.rs | 84 ++------ src/class/descr.rs | 90 ++------ src/class/iter.rs | 77 ++----- src/class/mod.rs | 1 + src/class/number.rs | 46 +++-- src/class/proto_methods.rs | 73 +++++++ src/ffi/object.rs | 7 +- src/instance.rs | 5 +- src/lib.rs | 1 + src/pycell.rs | 2 +- src/pyclass.rs | 17 +- tests/test_dunder.rs | 4 + 17 files changed, 429 insertions(+), 482 deletions(-) create mode 100644 src/class/proto_methods.rs diff --git a/Cargo.toml b/Cargo.toml index d7365a99..51b13c1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ travis-ci = { repository = "PyO3/pyo3", branch = "master" } appveyor = { repository = "fafhrd91/pyo3" } [dependencies] +ctor = { version = "0.1", optional = true } indoc = { version = "0.3.4", optional = true } inventory = { version = "0.1.4", optional = true } libc = "0.2.62" @@ -38,7 +39,7 @@ version_check = "0.9.1" [features] default = ["macros"] -macros = ["indoc", "inventory", "paste", "pyo3cls", "unindent"] +macros = ["ctor", "indoc", "inventory", "paste", "pyo3cls", "unindent"] # this is no longer needed internally, but setuptools-rust assumes this feature python3 = [] diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 01240d44..f5e80521 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -3,8 +3,11 @@ use crate::func::MethodProto; pub struct Proto { pub name: &'static str, + pub slot_table: &'static str, + pub set_slot_table: &'static str, pub methods: &'static [MethodProto], pub py_methods: &'static [PyMethod], + pub slot_setters: &'static [SlotSetter], } impl Proto { @@ -47,8 +50,27 @@ impl PyMethod { } } +pub struct SlotSetter { + pub proto_names: &'static [&'static str], + pub set_function: &'static str, + pub exclude_indices: &'static [usize], +} + +impl SlotSetter { + const EMPTY_INDICES: &'static [usize] = &[]; + const fn new(names: &'static [&'static str], set_function: &'static str) -> Self { + SlotSetter { + proto_names: names, + set_function, + exclude_indices: Self::EMPTY_INDICES, + } + } +} + pub const OBJECT: Proto = Proto { name: "Object", + slot_table: "pyo3::class::basic::PyObjectMethods", + set_slot_table: "set_basic_methods", methods: &[ MethodProto::Binary { name: "__getattr__", @@ -95,11 +117,6 @@ pub const OBJECT: Proto = Proto { pyres: true, proto: "pyo3::class::basic::PyObjectBytesProtocol", }, - MethodProto::Unary { - name: "__bool__", - pyres: false, - proto: "pyo3::class::basic::PyObjectBoolProtocol", - }, MethodProto::Binary { name: "__richcmp__", arg: "Other", @@ -112,10 +129,27 @@ pub const OBJECT: Proto = Proto { PyMethod::new("__bytes__", "pyo3::class::basic::BytesProtocolImpl"), PyMethod::new("__unicode__", "pyo3::class::basic::UnicodeProtocolImpl"), ], + slot_setters: &[ + SlotSetter::new(&["__str__"], "set_str"), + SlotSetter::new(&["__repr__"], "set_repr"), + SlotSetter::new(&["__hash__"], "set_hash"), + SlotSetter::new(&["__getattr__"], "set_getattr"), + SlotSetter::new(&["__richcmp__"], "set_richcompare"), + SlotSetter { + proto_names: &["__setattr__", "__delattr__"], + set_function: "set_setdelattr", + // exclude set and del + exclude_indices: &[6, 7], + }, + SlotSetter::new(&["__setattr__"], "set_setattr"), + SlotSetter::new(&["__delattr__"], "set_setattr"), + ], }; pub const ASYNC: Proto = Proto { name: "Async", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::Unary { name: "__await__", @@ -155,10 +189,13 @@ pub const ASYNC: Proto = Proto { "pyo3::class::pyasync::PyAsyncAexitProtocolImpl", ), ], + slot_setters: &[], }; pub const BUFFER: Proto = Proto { name: "Buffer", + slot_table: "pyo3::ffi::PyBufferProcs", + set_slot_table: "set_buffer_methods", methods: &[ MethodProto::Unary { name: "bf_getbuffer", @@ -172,10 +209,16 @@ pub const BUFFER: Proto = Proto { }, ], py_methods: &[], + slot_setters: &[ + SlotSetter::new(&["bf_getbuffer"], "set_getbuffer"), + SlotSetter::new(&["bf_releasebuffer"], "set_releasebuffer"), + ], }; pub const CONTEXT: Proto = Proto { name: "Context", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::Unary { name: "__enter__", @@ -200,10 +243,13 @@ pub const CONTEXT: Proto = Proto { "pyo3::class::context::PyContextExitProtocolImpl", ), ], + slot_setters: &[], }; pub const GC: Proto = Proto { name: "GC", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::Free { name: "__traverse__", @@ -215,10 +261,13 @@ pub const GC: Proto = Proto { }, ], py_methods: &[], + slot_setters: &[], }; pub const DESCR: Proto = Proto { name: "Descriptor", + slot_table: "pyo3::class::descr::PyDescrMethods", + set_slot_table: "set_descr_methods", methods: &[ MethodProto::Ternary { name: "__get__", @@ -254,10 +303,16 @@ pub const DESCR: Proto = Proto { "pyo3::class::context::PyDescrNameProtocolImpl", ), ], + slot_setters: &[ + SlotSetter::new(&["__get__"], "set_descr_get"), + SlotSetter::new(&["__set__"], "set_descr_set"), + ], }; pub const ITER: Proto = Proto { name: "Iter", + slot_table: "pyo3::class::iter::PyIterMethods", + set_slot_table: "set_iter_methods", py_methods: &[], methods: &[ MethodProto::UnaryS { @@ -273,10 +328,24 @@ pub const ITER: Proto = Proto { proto: "pyo3::class::iter::PyIterNextProtocol", }, ], + slot_setters: &[ + SlotSetter { + proto_names: &["__iter__"], + set_function: "set_iter", + exclude_indices: &[], + }, + SlotSetter { + proto_names: &["__next__"], + set_function: "set_iternext", + exclude_indices: &[], + }, + ], }; pub const MAPPING: Proto = Proto { name: "Mapping", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::Unary { name: "__len__", @@ -312,10 +381,13 @@ pub const MAPPING: Proto = Proto { "__reversed__", "pyo3::class::mapping::PyMappingReversedProtocolImpl", )], + slot_setters: &[], }; pub const SEQ: Proto = Proto { name: "Sequence", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::Unary { name: "__len__", @@ -373,10 +445,13 @@ pub const SEQ: Proto = Proto { }, ], py_methods: &[], + slot_setters: &[], }; pub const NUM: Proto = Proto { name: "Number", + slot_table: "", + set_slot_table: "", methods: &[ MethodProto::BinaryS { name: "__add__", @@ -686,6 +761,11 @@ pub const NUM: Proto = Proto { pyres: true, proto: "pyo3::class::number::PyNumberRoundProtocol", }, + MethodProto::Unary { + name: "__bool__", + pyres: false, + proto: "pyo3::class::number::PyNumberBoolProtocol", + }, ], py_methods: &[ PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"), @@ -729,4 +809,5 @@ pub const NUM: Proto = Proto { "pyo3::class::number::PyNumberRoundProtocolImpl", ), ], + slot_setters: &[], }; diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index 33a45ba6..166d03ef 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -236,6 +236,19 @@ fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { } } +/// TODO(kngwyu): doc +fn impl_proto_registory(cls: &syn::Ident) -> TokenStream { + quote! { + impl pyo3::class::proto_methods::HasPyProtoRegistry for #cls { + fn registory() -> &'static pyo3::class::proto_methods::PyProtoRegistry { + static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry + = pyo3::class::proto_methods::PyProtoRegistry::new(); + ®ISTRY + } + } + } +} + fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream { match &attr.name { Some(name) => quote! { #name }, @@ -340,6 +353,7 @@ fn impl_class( }; let impl_inventory = impl_methods_inventory(&cls); + let impl_proto_registory = impl_proto_registory(&cls); let base = &attr.base; let flags = &attr.flags; @@ -414,6 +428,8 @@ fn impl_class( #impl_inventory + #impl_proto_registory + #extra #gc_impl diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index 09459556..bea1d123 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -4,9 +4,10 @@ use crate::defs; use crate::func::impl_method_proto; use crate::method::FnSpec; use crate::pymethod; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use quote::ToTokens; +use std::collections::HashSet; pub fn build_py_proto(ast: &mut syn::ItemImpl) -> syn::Result { if let Some((_, ref mut path, _)) = ast.trait_ { @@ -60,12 +61,17 @@ fn impl_proto_impl( ) -> syn::Result { let mut trait_impls = TokenStream::new(); let mut py_methods = Vec::new(); + let mut method_names = HashSet::new(); for iimpl in impls.iter_mut() { if let syn::ImplItem::Method(ref mut met) = iimpl { + // impl Py~Protocol<'p> { type = ... } if let Some(m) = proto.get_proto(&met.sig.ident) { impl_method_proto(ty, &mut met.sig, m).to_tokens(&mut trait_impls); + // Insert the method to the HashSet + method_names.insert(met.sig.ident.to_string()); } + // Add non-slot methods to inventory slots if let Some(m) = proto.get_method(&met.sig.ident) { let name = &met.sig.ident; let fn_spec = FnSpec::parse(&met.sig, &mut met.attrs, false)?; @@ -76,7 +82,7 @@ fn impl_proto_impl( } else { quote!(0) }; - // TODO(kngwyu): doc + // TODO(kngwyu): ml_doc py_methods.push(quote! { pyo3::class::PyMethodDefType::Method({ #method @@ -91,20 +97,77 @@ fn impl_proto_impl( } } } + let inventory_submission = inventory_submission(py_methods, ty); + let slot_initialization = slot_initialization(method_names, ty, proto)?; + Ok(quote! { + #trait_impls + #inventory_submission + #slot_initialization + }) +} +fn inventory_submission(py_methods: Vec, ty: &syn::Type) -> TokenStream { if py_methods.is_empty() { - return Ok(quote! { #trait_impls }); + return quote! {}; } - let inventory_submission = quote! { + quote! { pyo3::inventory::submit! { #![crate = pyo3] { type Inventory = <#ty as pyo3::class::methods::HasMethodsInventory>::Methods; ::new(&[#(#py_methods),*]) } } - }; + } +} + +fn slot_initialization( + method_names: HashSet, + ty: &syn::Type, + proto: &defs::Proto, +) -> syn::Result { + let mut initializers: Vec = vec![]; + // This is for setters. + // If we can use set_setdelattr, skip set_setattr and set_setdelattr. + let mut skip_indices = vec![]; + 'setter_loop: for (i, m) in proto.slot_setters.iter().enumerate() { + if skip_indices.contains(&i) { + continue; + } + for name in m.proto_names { + if !method_names.contains(*name) { + // This `#[pyproto] impl` doesn't have all required methods, + // let's skip implementation. + continue 'setter_loop; + } + } + skip_indices.extend_from_slice(m.exclude_indices); + // Add slot methods to PyProtoRegistry + let set = syn::Ident::new(m.set_function, Span::call_site()); + initializers.push(quote! { table.#set::<#ty>(); }); + } + if initializers.is_empty() { + return Ok(quote! {}); + } + let table: syn::Path = syn::parse_str(proto.slot_table)?; + let set = syn::Ident::new(proto.set_slot_table, Span::call_site()); + let ty_hash = typename_hash(ty); + let init = syn::Ident::new( + &format!("__init_{}_{}", proto.name, ty_hash), + Span::call_site(), + ); Ok(quote! { - #trait_impls - #inventory_submission + #[pyo3::ctor::ctor] + fn #init() { + let mut table = #table::default(); + #(#initializers)* + <#ty as pyo3::class::proto_methods::HasPyProtoRegistry>::registory().#set(table); + } }) } + +fn typename_hash(ty: &syn::Type) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + ty.hash(&mut hasher); + hasher.finish() +} diff --git a/src/class/basic.rs b/src/class/basic.rs index 30183cb8..145ce69c 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -77,13 +77,6 @@ pub trait PyObjectProtocol<'p>: PyClass { unimplemented!() } - fn __bool__(&'p self) -> Self::Result - where - Self: PyObjectBoolProtocol<'p>, - { - unimplemented!() - } - fn __bytes__(&'p self) -> Self::Result where Self: PyObjectBytesProtocol<'p>, @@ -142,55 +135,54 @@ pub trait PyObjectRichcmpProtocol<'p>: PyObjectProtocol<'p> { type Result: Into>; } -#[doc(hidden)] -pub trait PyObjectProtocolImpl { - fn tp_as_object(_type_object: &mut ffi::PyTypeObject); - fn nb_bool_fn() -> Option; +/// All functions necessary for basic protocols. +#[derive(Default)] +pub struct PyObjectMethods { + pub tp_str: Option, + pub tp_repr: Option, + pub tp_hash: Option, + pub tp_getattro: Option, + pub tp_richcompare: Option, + pub tp_setattro: Option, } -impl PyObjectProtocolImpl for T { - default fn tp_as_object(_type_object: &mut ffi::PyTypeObject) {} - default fn nb_bool_fn() -> Option { - None +impl PyObjectMethods { + pub(crate) fn prepare_type_obj(&self, type_object: &mut ffi::PyTypeObject) { + type_object.tp_str = self.tp_str; + type_object.tp_repr = self.tp_repr; + type_object.tp_hash = self.tp_hash; + type_object.tp_getattro = self.tp_getattro; + type_object.tp_richcompare = self.tp_richcompare; + type_object.tp_setattro = self.tp_setattro; } -} -impl<'p, T> PyObjectProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - fn tp_as_object(type_object: &mut ffi::PyTypeObject) { - type_object.tp_str = Self::tp_str(); - type_object.tp_repr = Self::tp_repr(); - type_object.tp_hash = Self::tp_hash(); - type_object.tp_getattro = Self::tp_getattro(); - type_object.tp_richcompare = Self::tp_richcompare(); - type_object.tp_setattro = tp_setattro_impl::tp_setattro::(); + pub fn set_str(&mut self) + where + T: for<'p> PyObjectStrProtocol<'p>, + { + self.tp_str = py_unary_func!(PyObjectStrProtocol, T::__str__); } - fn nb_bool_fn() -> Option { - Self::nb_bool() + pub fn set_repr(&mut self) + where + T: for<'p> PyObjectReprProtocol<'p>, + { + self.tp_repr = py_unary_func!(PyObjectReprProtocol, T::__repr__); } -} - -trait GetAttrProtocolImpl { - fn tp_getattro() -> Option; -} - -impl<'p, T> GetAttrProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn tp_getattro() -> Option { - None + pub fn set_hash(&mut self) + where + T: for<'p> PyObjectHashProtocol<'p>, + { + self.tp_hash = py_unary_func!( + PyObjectHashProtocol, + T::__hash__, + ffi::Py_hash_t, + HashCallbackOutput + ); } -} - -impl GetAttrProtocolImpl for T -where - T: for<'p> PyObjectGetAttrProtocol<'p>, -{ - fn tp_getattro() -> Option { - #[allow(unused_mut)] + pub fn set_getattr(&mut self) + where + T: for<'p> PyObjectGetAttrProtocol<'p>, + { unsafe extern "C" fn wrap( slf: *mut ffi::PyObject, arg: *mut ffi::PyObject, @@ -214,207 +206,12 @@ where call_ref!(slf, __getattr__, arg) }) } - Some(wrap::) + self.tp_getattro = Some(wrap::); } -} - -/// An object may support setting attributes (by implementing PyObjectSetAttrProtocol) -/// and may support deleting attributes (by implementing PyObjectDelAttrProtocol). -/// We need to generate a single "extern C" function that supports only setting, only deleting -/// or both, and return None in case none of the two is supported. -mod tp_setattro_impl { - use super::*; - - /// setattrofunc PyTypeObject.tp_setattro - /// - /// An optional pointer to the function for setting and deleting attributes. - /// - /// The signature is the same as for PyObject_SetAttr(), but setting v to NULL to delete an - /// attribute must be supported. It is usually convenient to set this field to - /// PyObject_GenericSetAttr(), which implements the normal way of setting object attributes. - pub(super) fn tp_setattro<'p, T: PyObjectProtocol<'p>>() -> Option { - if let Some(set_del) = T::set_del_attr() { - Some(set_del) - } else if let Some(set) = T::set_attr() { - Some(set) - } else if let Some(del) = T::del_attr() { - Some(del) - } else { - None - } - } - - trait SetAttr { - fn set_attr() -> Option; - } - - impl<'p, T: PyObjectProtocol<'p>> SetAttr for T { - default fn set_attr() -> Option { - None - } - } - - impl SetAttr for T + pub fn set_richcompare(&mut self) where - T: for<'p> PyObjectSetAttrProtocol<'p>, + T: for<'p> PyObjectRichcmpProtocol<'p>, { - fn set_attr() -> Option { - py_func_set!(PyObjectSetAttrProtocol, T, __setattr__) - } - } - - trait DelAttr { - fn del_attr() -> Option; - } - - impl<'p, T> DelAttr for T - where - T: PyObjectProtocol<'p>, - { - default fn del_attr() -> Option { - None - } - } - - impl DelAttr for T - where - T: for<'p> PyObjectDelAttrProtocol<'p>, - { - fn del_attr() -> Option { - py_func_del!(PyObjectDelAttrProtocol, T, __delattr__) - } - } - - trait SetDelAttr { - fn set_del_attr() -> Option; - } - - impl<'p, T> SetDelAttr for T - where - T: PyObjectProtocol<'p>, - { - default fn set_del_attr() -> Option { - None - } - } - - impl SetDelAttr for T - where - T: for<'p> PyObjectSetAttrProtocol<'p> + for<'p> PyObjectDelAttrProtocol<'p>, - { - fn set_del_attr() -> Option { - py_func_set_del!( - PyObjectSetAttrProtocol, - PyObjectDelAttrProtocol, - T, - __setattr__, - __delattr__ - ) - } - } -} - -trait StrProtocolImpl { - fn tp_str() -> Option; -} -impl<'p, T> StrProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn tp_str() -> Option { - None - } -} -impl StrProtocolImpl for T -where - T: for<'p> PyObjectStrProtocol<'p>, -{ - fn tp_str() -> Option { - py_unary_func!(PyObjectStrProtocol, T::__str__) - } -} - -trait ReprProtocolImpl { - fn tp_repr() -> Option; -} -impl<'p, T> ReprProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn tp_repr() -> Option { - None - } -} -impl ReprProtocolImpl for T -where - T: for<'p> PyObjectReprProtocol<'p>, -{ - fn tp_repr() -> Option { - py_unary_func!(PyObjectReprProtocol, T::__repr__) - } -} - -trait HashProtocolImpl { - fn tp_hash() -> Option; -} -impl<'p, T> HashProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn tp_hash() -> Option { - None - } -} -impl HashProtocolImpl for T -where - T: for<'p> PyObjectHashProtocol<'p>, -{ - fn tp_hash() -> Option { - py_unary_func!( - PyObjectHashProtocol, - T::__hash__, - ffi::Py_hash_t, - HashCallbackOutput - ) - } -} - -trait BoolProtocolImpl { - fn nb_bool() -> Option; -} -impl<'p, T> BoolProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn nb_bool() -> Option { - None - } -} -impl BoolProtocolImpl for T -where - T: for<'p> PyObjectBoolProtocol<'p>, -{ - fn nb_bool() -> Option { - py_unary_func!(PyObjectBoolProtocol, T::__bool__, c_int) - } -} - -trait RichcmpProtocolImpl { - fn tp_richcompare() -> Option; -} -impl<'p, T> RichcmpProtocolImpl for T -where - T: PyObjectProtocol<'p>, -{ - default fn tp_richcompare() -> Option { - None - } -} -impl RichcmpProtocolImpl for T -where - T: for<'p> PyObjectRichcmpProtocol<'p>, -{ - fn tp_richcompare() -> Option { unsafe extern "C" fn wrap( slf: *mut ffi::PyObject, arg: *mut ffi::PyObject, @@ -433,7 +230,31 @@ where slf.try_borrow()?.__richcmp__(arg, op).into() }) } - Some(wrap::) + self.tp_richcompare = Some(wrap::); + } + pub fn set_setattr(&mut self) + where + T: for<'p> PyObjectSetAttrProtocol<'p>, + { + self.tp_setattro = py_func_set!(PyObjectSetAttrProtocol, T, __setattr__); + } + pub fn set_delattr(&mut self) + where + T: for<'p> PyObjectDelAttrProtocol<'p>, + { + self.tp_setattro = py_func_del!(PyObjectDelAttrProtocol, T, __delattr__); + } + pub fn set_setdelattr(&mut self) + where + T: for<'p> PyObjectSetAttrProtocol<'p> + for<'p> PyObjectDelAttrProtocol<'p>, + { + self.tp_setattro = py_func_set_del!( + PyObjectSetAttrProtocol, + PyObjectDelAttrProtocol, + T, + __setattr__, + __delattr__ + ) } } diff --git a/src/class/buffer.rs b/src/class/buffer.rs index 7b27ef15..7fd366ee 100644 --- a/src/class/buffer.rs +++ b/src/class/buffer.rs @@ -5,7 +5,10 @@ //! For more information check [buffer protocol](https://docs.python.org/3/c-api/buffer.html) //! c-api use crate::err::PyResult; -use crate::{ffi, PyCell, PyClass, PyRefMut}; +use crate::{ + ffi::{self, PyBufferProcs}, + PyCell, PyClass, PyRefMut, +}; use std::os::raw::c_int; /// Buffer protocol interface @@ -37,51 +40,11 @@ pub trait PyBufferReleaseBufferProtocol<'p>: PyBufferProtocol<'p> { type Result: Into>; } -#[doc(hidden)] -pub trait PyBufferProtocolImpl { - fn tp_as_buffer() -> Option; -} - -impl PyBufferProtocolImpl for T { - default fn tp_as_buffer() -> Option { - None - } -} - -impl<'p, T> PyBufferProtocolImpl for T -where - T: PyBufferProtocol<'p>, -{ - #[inline] - #[allow(clippy::needless_update)] // For python 2 it's not useless - fn tp_as_buffer() -> Option { - Some(ffi::PyBufferProcs { - bf_getbuffer: Self::cb_bf_getbuffer(), - bf_releasebuffer: Self::cb_bf_releasebuffer(), - ..ffi::PyBufferProcs_INIT - }) - } -} - -trait PyBufferGetBufferProtocolImpl { - fn cb_bf_getbuffer() -> Option; -} - -impl<'p, T> PyBufferGetBufferProtocolImpl for T -where - T: PyBufferProtocol<'p>, -{ - default fn cb_bf_getbuffer() -> Option { - None - } -} - -impl PyBufferGetBufferProtocolImpl for T -where - T: for<'p> PyBufferGetBufferProtocol<'p>, -{ - #[inline] - fn cb_bf_getbuffer() -> Option { +impl PyBufferProcs { + pub fn set_getbuffer(&mut self) + where + T: for<'p> PyBufferGetBufferProtocol<'p>, + { unsafe extern "C" fn wrap( slf: *mut ffi::PyObject, arg1: *mut ffi::Py_buffer, @@ -95,29 +58,12 @@ where T::bf_getbuffer(slf.try_borrow_mut()?, arg1, arg2).into() }) } - Some(wrap::) + self.bf_getbuffer = Some(wrap::); } -} - -trait PyBufferReleaseBufferProtocolImpl { - fn cb_bf_releasebuffer() -> Option; -} - -impl<'p, T> PyBufferReleaseBufferProtocolImpl for T -where - T: PyBufferProtocol<'p>, -{ - default fn cb_bf_releasebuffer() -> Option { - None - } -} - -impl PyBufferReleaseBufferProtocolImpl for T -where - T: for<'p> PyBufferReleaseBufferProtocol<'p>, -{ - #[inline] - fn cb_bf_releasebuffer() -> Option { + pub fn set_releasebuffer(&mut self) + where + T: for<'p> PyBufferReleaseBufferProtocol<'p>, + { unsafe extern "C" fn wrap(slf: *mut ffi::PyObject, arg1: *mut ffi::Py_buffer) where T: for<'p> PyBufferReleaseBufferProtocol<'p>, @@ -127,6 +73,6 @@ where T::bf_releasebuffer(slf.try_borrow_mut()?, arg1).into() }) } - Some(wrap::) + self.bf_releasebuffer = Some(wrap::); } } diff --git a/src/class/descr.rs b/src/class/descr.rs index 1349f7a7..7ddb2712 100644 --- a/src/class/descr.rs +++ b/src/class/descr.rs @@ -5,7 +5,6 @@ //! [Python information]( //! https://docs.python.org/3/reference/datamodel.html#implementing-descriptors) -use crate::class::methods::PyMethodDef; use crate::err::PyResult; use crate::types::{PyAny, PyType}; use crate::{ffi, FromPyObject, IntoPy, PyClass, PyObject}; @@ -66,76 +65,27 @@ pub trait PyDescrSetNameProtocol<'p>: PyDescrProtocol<'p> { type Result: Into>; } -trait PyDescrGetProtocolImpl { - fn tp_descr_get() -> Option; -} -impl<'p, T> PyDescrGetProtocolImpl for T -where - T: PyDescrProtocol<'p>, -{ - default fn tp_descr_get() -> Option { - None - } +#[derive(Default)] +pub struct PyDescrMethods { + pub tp_descr_get: Option, + pub tp_descr_set: Option, } -impl PyDescrGetProtocolImpl for T -where - T: for<'p> PyDescrGetProtocol<'p>, -{ - fn tp_descr_get() -> Option { - py_ternary_func!(PyDescrGetProtocol, T::__get__) - } -} - -trait PyDescrSetProtocolImpl { - fn tp_descr_set() -> Option; -} -impl<'p, T> PyDescrSetProtocolImpl for T -where - T: PyDescrProtocol<'p>, -{ - default fn tp_descr_set() -> Option { - None - } -} -impl PyDescrSetProtocolImpl for T -where - T: for<'p> PyDescrSetProtocol<'p>, -{ - fn tp_descr_set() -> Option { - py_ternary_func!(PyDescrSetProtocol, T::__set__, c_int) - } -} - -trait PyDescrDelProtocolImpl { - fn __del__() -> Option { - None - } -} -impl<'p, T> PyDescrDelProtocolImpl for T where T: PyDescrProtocol<'p> {} - -trait PyDescrSetNameProtocolImpl { - fn __set_name__() -> Option { - None - } -} -impl<'p, T> PyDescrSetNameProtocolImpl for T where T: PyDescrProtocol<'p> {} - -#[doc(hidden)] -pub trait PyDescrProtocolImpl { - fn tp_as_descr(_type_object: &mut ffi::PyTypeObject); -} - -impl PyDescrProtocolImpl for T { - default fn tp_as_descr(_type_object: &mut ffi::PyTypeObject) {} -} - -impl<'p, T> PyDescrProtocolImpl for T -where - T: PyDescrProtocol<'p>, -{ - fn tp_as_descr(type_object: &mut ffi::PyTypeObject) { - type_object.tp_descr_get = Self::tp_descr_get(); - type_object.tp_descr_set = Self::tp_descr_set(); +impl PyDescrMethods { + pub(crate) fn prepare_type_obj(&self, type_object: &mut ffi::PyTypeObject) { + type_object.tp_descr_get = self.tp_descr_get; + type_object.tp_descr_set = self.tp_descr_set; + } + pub fn set_descr_get(&mut self) + where + T: for<'p> PyDescrGetProtocol<'p>, + { + self.tp_descr_get = py_ternary_func!(PyDescrGetProtocol, T::__get__); + } + pub fn set_descr_set(&mut self) + where + T: for<'p> PyDescrSetProtocol<'p>, + { + self.tp_descr_set = py_ternary_func!(PyDescrSetProtocol, T::__set__, c_int); } } diff --git a/src/class/iter.rs b/src/class/iter.rs index 08c12963..6329e528 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -40,69 +40,28 @@ pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> { type Result: Into>>; } -#[doc(hidden)] -pub trait PyIterProtocolImpl { - fn tp_as_iter(_typeob: &mut ffi::PyTypeObject); +#[derive(Default)] +pub struct PyIterMethods { + pub tp_iter: Option, + pub tp_iternext: Option, } -impl PyIterProtocolImpl for T { - default fn tp_as_iter(_typeob: &mut ffi::PyTypeObject) {} -} - -impl<'p, T> PyIterProtocolImpl for T -where - T: PyIterProtocol<'p>, -{ - #[inline] - fn tp_as_iter(typeob: &mut ffi::PyTypeObject) { - typeob.tp_iter = Self::tp_iter(); - typeob.tp_iternext = Self::tp_iternext(); +impl PyIterMethods { + pub(crate) fn prepare_type_obj(&self, type_object: &mut ffi::PyTypeObject) { + type_object.tp_iter = self.tp_iter; + type_object.tp_iternext = self.tp_iternext; } -} - -trait PyIterIterProtocolImpl { - fn tp_iter() -> Option; -} - -impl<'p, T> PyIterIterProtocolImpl for T -where - T: PyIterProtocol<'p>, -{ - default fn tp_iter() -> Option { - None + pub fn set_iter(&mut self) + where + T: for<'p> PyIterIterProtocol<'p>, + { + self.tp_iter = py_unarys_func!(PyIterIterProtocol, T::__iter__); } -} - -impl PyIterIterProtocolImpl for T -where - T: for<'p> PyIterIterProtocol<'p>, -{ - #[inline] - fn tp_iter() -> Option { - py_unarys_func!(PyIterIterProtocol, T::__iter__) - } -} - -trait PyIterNextProtocolImpl { - fn tp_iternext() -> Option; -} - -impl<'p, T> PyIterNextProtocolImpl for T -where - T: PyIterProtocol<'p>, -{ - default fn tp_iternext() -> Option { - None - } -} - -impl PyIterNextProtocolImpl for T -where - T: for<'p> PyIterNextProtocol<'p>, -{ - #[inline] - fn tp_iternext() -> Option { - py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter) + pub fn set_iternext(&mut self) + where + T: for<'p> PyIterNextProtocol<'p>, + { + self.tp_iternext = py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter); } } diff --git a/src/class/mod.rs b/src/class/mod.rs index df828ecb..fd37a51f 100644 --- a/src/class/mod.rs +++ b/src/class/mod.rs @@ -14,6 +14,7 @@ pub mod iter; pub mod mapping; pub mod methods; pub mod number; +pub mod proto_methods; pub mod pyasync; pub mod sequence; diff --git a/src/class/number.rs b/src/class/number.rs index ab5270b4..9e73b52d 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -3,9 +3,9 @@ //! Python Number Interface //! Trait and support implementation for implementing number protocol -use crate::class::basic::PyObjectProtocolImpl; use crate::err::PyResult; use crate::{ffi, FromPyObject, IntoPy, PyClass, PyObject}; +use std::os::raw::c_int; /// Number interface #[allow(unused_variables)] @@ -314,6 +314,12 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } + fn __bool__(&'p self) -> Self::Result + where + Self: PyNumberBoolProtocol<'p>, + { + unimplemented!() + } } pub trait PyNumberAddProtocol<'p>: PyNumberProtocol<'p> { @@ -616,22 +622,18 @@ pub trait PyNumberIndexProtocol<'p>: PyNumberProtocol<'p> { type Result: Into>; } +pub trait PyNumberBoolProtocol<'p>: PyNumberProtocol<'p> { + type Result: Into>; +} + #[doc(hidden)] -pub trait PyNumberProtocolImpl: PyObjectProtocolImpl { +pub trait PyNumberProtocolImpl { fn tp_as_number() -> Option; } impl<'p, T> PyNumberProtocolImpl for T { default fn tp_as_number() -> Option { - if let Some(nb_bool) = ::nb_bool_fn() { - let meth = ffi::PyNumberMethods { - nb_bool: Some(nb_bool), - ..ffi::PyNumberMethods_INIT - }; - Some(meth) - } else { - None - } + None } } @@ -650,7 +652,7 @@ where nb_negative: Self::nb_negative(), nb_positive: Self::nb_positive(), nb_absolute: Self::nb_absolute(), - nb_bool: ::nb_bool_fn(), + nb_bool: Self::nb_bool(), nb_invert: Self::nb_invert(), nb_lshift: Self::nb_lshift().or_else(Self::nb_lshift_fallback), nb_rshift: Self::nb_rshift().or_else(Self::nb_rshift_fallback), @@ -1738,3 +1740,23 @@ where py_unary_func!(PyNumberIndexProtocol, T::__index__) } } + +trait PyNumberBoolProtocolImpl { + fn nb_bool() -> Option; +} +impl<'p, T> PyNumberBoolProtocolImpl for T +where + T: PyNumberProtocol<'p>, +{ + default fn nb_bool() -> Option { + None + } +} +impl PyNumberBoolProtocolImpl for T +where + T: for<'p> PyNumberBoolProtocol<'p>, +{ + fn nb_bool() -> Option { + py_unary_func!(PyNumberBoolProtocol, T::__bool__, c_int) + } +} diff --git a/src/class/proto_methods.rs b/src/class/proto_methods.rs new file mode 100644 index 00000000..f847dd06 --- /dev/null +++ b/src/class/proto_methods.rs @@ -0,0 +1,73 @@ +use crate::class::{basic::PyObjectMethods, descr::PyDescrMethods, iter::PyIterMethods}; +use crate::ffi::PyBufferProcs; +use std::{ + ptr::{self, NonNull}, + sync::atomic::{AtomicPtr, Ordering}, +}; + +/// For rust-numpy, we need a stub implementation. +pub trait PyProtoMethods { + fn basic_methods() -> Option>; + fn buffer_methods() -> Option>; + fn descr_methods() -> Option>; + fn iter_methods() -> Option>; +} + +#[doc(hidden)] +pub trait HasPyProtoRegistry: Sized + 'static { + fn registory() -> &'static PyProtoRegistry; +} + +impl PyProtoMethods for T { + fn basic_methods() -> Option> { + NonNull::new(Self::registory().basic_methods.load(Ordering::SeqCst)) + } + fn buffer_methods() -> Option> { + NonNull::new(Self::registory().buffer_methods.load(Ordering::SeqCst)) + } + fn descr_methods() -> Option> { + NonNull::new(Self::registory().descr_methods.load(Ordering::SeqCst)) + } + fn iter_methods() -> Option> { + NonNull::new(Self::registory().iter_methods.load(Ordering::SeqCst)) + } +} + +#[doc(hidden)] +pub struct PyProtoRegistry { + // Basic Protocols + basic_methods: AtomicPtr, + // Buffer Protocols + buffer_methods: AtomicPtr, + // Descr Protocols + descr_methods: AtomicPtr, + // Iterator Protocols + iter_methods: AtomicPtr, +} + +impl PyProtoRegistry { + pub const fn new() -> Self { + PyProtoRegistry { + basic_methods: AtomicPtr::new(ptr::null_mut()), + buffer_methods: AtomicPtr::new(ptr::null_mut()), + descr_methods: AtomicPtr::new(ptr::null_mut()), + iter_methods: AtomicPtr::new(ptr::null_mut()), + } + } + pub fn set_basic_methods(&self, methods: PyObjectMethods) { + self.basic_methods + .store(Box::into_raw(Box::new(methods)), Ordering::SeqCst) + } + pub fn set_buffer_methods(&self, methods: PyBufferProcs) { + self.buffer_methods + .store(Box::into_raw(Box::new(methods)), Ordering::SeqCst) + } + pub fn set_descr_methods(&self, methods: PyDescrMethods) { + self.descr_methods + .store(Box::into_raw(Box::new(methods)), Ordering::SeqCst) + } + pub fn set_iter_methods(&self, methods: PyIterMethods) { + self.iter_methods + .store(Box::into_raw(Box::new(methods)), Ordering::SeqCst) + } +} diff --git a/src/ffi/object.rs b/src/ffi/object.rs index 175105b8..ef43db5b 100644 --- a/src/ffi/object.rs +++ b/src/ffi/object.rs @@ -438,14 +438,16 @@ mod typeobject { impl Default for PyAsyncMethods { #[inline] fn default() -> Self { - unsafe { mem::zeroed() } + PyAsyncMethods_INIT } } + pub const PyAsyncMethods_INIT: PyAsyncMethods = PyAsyncMethods { am_await: None, am_aiter: None, am_anext: None, }; + #[repr(C)] #[derive(Copy, Clone, Debug)] pub struct PyBufferProcs { @@ -456,9 +458,10 @@ mod typeobject { impl Default for PyBufferProcs { #[inline] fn default() -> Self { - unsafe { mem::zeroed() } + PyBufferProcs_INIT } } + pub const PyBufferProcs_INIT: PyBufferProcs = PyBufferProcs { bf_getbuffer: None, bf_releasebuffer: None, diff --git a/src/instance.rs b/src/instance.rs index 8cc6ddfa..749d1042 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -45,9 +45,8 @@ pub unsafe trait PyNativeType: Sized { #[repr(transparent)] pub struct Py(NonNull, PhantomData); -unsafe impl Send for Py {} - -unsafe impl Sync for Py {} +unsafe impl Send for Py {} +unsafe impl Sync for Py {} impl Py { /// Create a new instance `Py`. diff --git a/src/lib.rs b/src/lib.rs index 329ba0ea..f9580377 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -153,6 +153,7 @@ pub use crate::types::PyAny; #[cfg(feature = "macros")] #[doc(hidden)] pub use { + ctor, // Re-exported for pyproto indoc, // Re-exported for py_run inventory, // Re-exported for pymethods paste, // Re-exported for wrap_function diff --git a/src/pycell.rs b/src/pycell.rs index e6249824..094b0bf1 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -372,7 +372,7 @@ impl AsPyPointer for PyCell { } } -impl ToPyObject for &PyCell { +impl ToPyObject for &PyCell { fn to_object(&self, py: Python<'_>) -> PyObject { unsafe { PyObject::from_borrowed_ptr(py, self.as_ptr()) } } diff --git a/src/pyclass.rs b/src/pyclass.rs index 9d7bc8de..a1e8dab4 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -1,5 +1,6 @@ //! `PyClass` trait use crate::class::methods::{PyClassAttributeDef, PyMethodDefType, PyMethods}; +use crate::class::proto_methods::PyProtoMethods; use crate::conversion::{IntoPyPointer, ToPyObject}; use crate::pyclass_slots::{PyClassDict, PyClassWeakRef}; use crate::type_object::{type_flags, PyLayout}; @@ -73,7 +74,7 @@ pub unsafe fn tp_free_fallback(obj: *mut ffi::PyObject) { /// The `#[pyclass]` attribute automatically implements this trait for your Rust struct, /// so you don't have to use this trait directly. pub trait PyClass: - PyTypeInfo> + Sized + PyClassAlloc + PyMethods + Send + PyTypeInfo> + Sized + PyClassAlloc + PyMethods + PyProtoMethods + Send { /// Specify this class has `#[pyclass(dict)]` or not. type Dict: PyClassDict; @@ -140,13 +141,19 @@ where ::update_type_object(type_object); // descriptor protocol - ::tp_as_descr(type_object); + if let Some(descr) = T::descr_methods() { + unsafe { descr.as_ref() }.prepare_type_obj(type_object); + } // iterator methods - ::tp_as_iter(type_object); + if let Some(iter) = T::iter_methods() { + unsafe { iter.as_ref() }.prepare_type_obj(type_object); + } // basic methods - ::tp_as_object(type_object); + if let Some(basic) = T::basic_methods() { + unsafe { basic.as_ref() }.prepare_type_obj(type_object); + } fn to_ptr(value: Option) -> *mut T { value @@ -165,7 +172,7 @@ where // async methods type_object.tp_as_async = to_ptr(::tp_as_async()); // buffer protocol - type_object.tp_as_buffer = to_ptr(::tp_as_buffer()); + type_object.tp_as_buffer = T::buffer_methods().map_or_else(ptr::null_mut, |p| p.as_ptr()); let (new, call, mut methods, attrs) = py_class_method_defs::(); diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 421873f0..8ecc0101 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -121,6 +121,10 @@ impl PyObjectProtocol for Comparisons { fn __hash__(&self) -> PyResult { Ok(self.val as isize) } +} + +#[pyproto] +impl pyo3::class::PyNumberProtocol for Comparisons { fn __bool__(&self) -> PyResult { Ok(self.val != 0) }