From b7419b5278e18ac9b99680ecb12fc109ddd56320 Mon Sep 17 00:00:00 2001 From: b05902132 Date: Wed, 17 Nov 2021 03:31:30 +0800 Subject: [PATCH] Refactor #[pyclass] and now it supports enum. There's no functionality since it does not generate __richcmp__. Also it only works on enums with only variants, and does not support C-like enums. --- pyo3-macros-backend/src/lib.rs | 2 +- pyo3-macros-backend/src/pyclass.rs | 724 +++++++++++++++++---------- pyo3-macros-backend/src/pyimpl.rs | 3 +- pyo3-macros/src/lib.rs | 50 +- tests/test_compile_error.rs | 2 + tests/test_enum.rs | 53 ++ tests/ui/invalid_pyclass_enum.rs | 15 + tests/ui/invalid_pyclass_enum.stderr | 11 + tests/ui/invalid_pyclass_item.rs | 6 + tests/ui/invalid_pyclass_item.stderr | 5 + 10 files changed, 596 insertions(+), 275 deletions(-) create mode 100644 tests/test_enum.rs create mode 100644 tests/ui/invalid_pyclass_enum.rs create mode 100644 tests/ui/invalid_pyclass_enum.stderr create mode 100644 tests/ui/invalid_pyclass_item.rs create mode 100644 tests/ui/invalid_pyclass_item.stderr diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index f42bd3b9..69fc24d2 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -25,7 +25,7 @@ mod pyproto; pub use from_pyobject::build_derive_from_pyobject; pub use module::{process_functions_in_module, py_init, PyModuleOptions}; -pub use pyclass::{build_py_class, PyClassArgs}; +pub use pyclass::{build_py_class, build_py_enum, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionOptions}; pub use pyimpl::{build_py_methods, PyClassMethodsType}; pub use pyproto::build_py_proto; diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index bafe9910..3e59a960 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -2,7 +2,8 @@ use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute}; use crate::deprecations::Deprecations; -use crate::pyimpl::PyClassMethodsType; +use crate::konst::{ConstAttributes, ConstSpec}; +use crate::pyimpl::{gen_py_const, PyClassMethodsType}; use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType}; use crate::utils::{self, unwrap_group, PythonDoc}; use proc_macro2::{Span, TokenStream}; @@ -10,7 +11,14 @@ use quote::quote; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::{parse_quote, spanned::Spanned, Expr, Result, Token}; +use syn::{parse_quote, spanned::Spanned, Expr, Result, Token}; //unraw + +/// If the class is derived from a Rust `struct` or `enum`. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PyClassKind { + Struct, + Enum, +} /// The parsed arguments of the pyclass macro pub struct PyClassArgs { @@ -24,22 +32,28 @@ pub struct PyClassArgs { pub has_extends: bool, pub has_unsendable: bool, pub module: Option, + pub class_kind: PyClassKind, } -impl Parse for PyClassArgs { - fn parse(input: ParseStream) -> Result { - let mut slf = PyClassArgs::default(); - +impl PyClassArgs { + fn parse(input: ParseStream, kind: PyClassKind) -> Result { + let mut slf = PyClassArgs::new(kind); let vars = Punctuated::::parse_terminated(input)?; for expr in vars { slf.add_expr(&expr)?; } Ok(slf) } -} -impl Default for PyClassArgs { - fn default() -> Self { + pub fn parse_stuct_args(input: ParseStream) -> syn::Result { + Self::parse(input, PyClassKind::Struct) + } + + pub fn parse_enum_args(input: ParseStream) -> syn::Result { + Self::parse(input, PyClassKind::Enum) + } + + fn new(class_kind: PyClassKind) -> Self { PyClassArgs { freelist: None, name: None, @@ -51,11 +65,10 @@ impl Default for PyClassArgs { is_basetype: false, has_extends: false, has_unsendable: false, + class_kind, } } -} -impl PyClassArgs { /// Adda single expression from the comma separated list in the attribute, which is /// either a single word or an assignment expression fn add_expr(&mut self, expr: &Expr) -> Result<()> { @@ -113,6 +126,9 @@ impl PyClassArgs { }, "extends" => match unwrap_group(&**right) { syn::Expr::Path(exp) => { + if self.class_kind == PyClassKind::Enum { + bail_spanned!( assign.span() => "enums cannot extend from other classes" ); + } self.base = syn::TypePath { path: exp.path.clone(), qself: None, @@ -147,6 +163,9 @@ impl PyClassArgs { self.has_weaklist = true; } "subclass" => { + if self.class_kind == PyClassKind::Enum { + bail_spanned!(exp.span() => "enums can't be inherited by other classes"); + } self.is_basetype = true; } "dict" => { @@ -328,41 +347,6 @@ impl FieldPyO3Options { } } -/// To allow multiple #[pymethods] block, we define inventory types. -fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream { - // Try to build a unique type for better error messages - let name = format!("Pyo3MethodsInventoryFor{}", cls.unraw()); - let inventory_cls = syn::Ident::new(&name, Span::call_site()); - - quote! { - #[doc(hidden)] - pub struct #inventory_cls { - methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>, - slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>, - } - impl ::pyo3::class::impl_::PyMethodsInventory for #inventory_cls { - fn new( - methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>, - slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>, - ) -> Self { - Self { methods, slots } - } - fn methods(&'static self) -> &'static [::pyo3::class::PyMethodDefType] { - &self.methods - } - fn slots(&'static self) -> &'static [::pyo3::ffi::PyType_Slot] { - &self.slots - } - } - - impl ::pyo3::class::impl_::HasMethodsInventory for #cls { - type Methods = #inventory_cls; - } - - ::pyo3::inventory::collect!(#inventory_cls); - } -} - fn get_class_python_name<'a>(cls: &'a syn::Ident, attr: &'a PyClassArgs) -> &'a syn::Ident { attr.name.as_ref().unwrap_or(cls) } @@ -375,243 +359,121 @@ fn impl_class( methods_type: PyClassMethodsType, deprecations: Deprecations, ) -> syn::Result { - let cls_name = get_class_python_name(cls, attr).to_string(); + let pytypeinfo_impl = impl_pytypeinfo(cls, attr, Some(&deprecations)); - let alloc = attr.freelist.as_ref().map(|freelist| { - quote! { - impl ::pyo3::class::impl_::PyClassWithFreeList for #cls { - #[inline] - fn get_free_list(_py: ::pyo3::Python<'_>) -> &mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> { - static mut FREELIST: *mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _; - unsafe { - if FREELIST.is_null() { - FREELIST = ::std::boxed::Box::into_raw(::std::boxed::Box::new( - ::pyo3::impl_::freelist::FreeList::with_capacity(#freelist))); - } - &mut *FREELIST - } - } - } - - impl ::pyo3::class::impl_::PyClassAllocImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - #[inline] - fn alloc_impl(self) -> ::std::option::Option<::pyo3::ffi::allocfunc> { - ::std::option::Option::Some(::pyo3::class::impl_::alloc_with_freelist::<#cls>) - } - } - - impl ::pyo3::class::impl_::PyClassFreeImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { - #[inline] - fn free_impl(self) -> ::std::option::Option<::pyo3::ffi::freefunc> { - ::std::option::Option::Some(::pyo3::class::impl_::free_with_freelist::<#cls>) - } - } - } - }); + let py_class_impl = PyClassImplsBuilder::new(cls, attr, methods_type) + .doc(doc) + .impl_all(); let descriptors = impl_descriptors(cls, field_options)?; - // insert space for weak ref - let weakref = if attr.has_weaklist { - quote! { ::pyo3::pyclass_slots::PyClassWeakRefSlot } - } else if attr.has_extends { - quote! { ::WeakRef } - } else { - quote! { ::pyo3::pyclass_slots::PyClassDummySlot } - }; - let dict = if attr.has_dict { - quote! { ::pyo3::pyclass_slots::PyClassDictSlot } - } else if attr.has_extends { - quote! { ::Dict } - } else { - quote! { ::pyo3::pyclass_slots::PyClassDummySlot } - }; - let module = if let Some(m) = &attr.module { - quote! { ::std::option::Option::Some(#m) } - } else { - quote! { ::std::option::Option::None } - }; + Ok(quote! { + #pytypeinfo_impl - // Enforce at compile time that PyGCProtocol is implemented - let gc_impl = if attr.is_gc { - let closure_name = format!("__assertion_closure_{}", cls); - let closure_token = syn::Ident::new(&closure_name, Span::call_site()); - quote! { - fn #closure_token() { - use ::pyo3::class; + #py_class_impl - fn _assert_implements_protocol<'p, T: ::pyo3::class::PyGCProtocol<'p>>() {} - _assert_implements_protocol::<#cls>(); - } - } - } else { - quote! {} - }; + #descriptors + }) +} - let (impl_inventory, for_each_py_method) = match methods_type { - PyClassMethodsType::Specialization => (None, quote! { visitor(collector.py_methods()); }), - PyClassMethodsType::Inventory => ( - Some(impl_methods_inventory(cls)), - quote! { - for inventory in ::pyo3::inventory::iter::<::Methods>() { - visitor(::pyo3::class::impl_::PyMethodsInventory::methods(inventory)); - } - }, - ), - }; +struct PyClassEnumVariant<'a> { + ident: &'a syn::Ident, + /* currently have no more options */ +} - let methods_protos = match methods_type { - PyClassMethodsType::Specialization => { - quote! { visitor(collector.methods_protocol_slots()); } - } - PyClassMethodsType::Inventory => { - quote! { - for inventory in ::pyo3::inventory::iter::<::Methods>() { - visitor(::pyo3::class::impl_::PyMethodsInventory::slots(inventory)); - } - } - } - }; +pub fn build_py_enum( + enum_: &syn::ItemEnum, + args: PyClassArgs, + method_type: PyClassMethodsType, +) -> syn::Result { + let variants: Vec = enum_ + .variants + .iter() + .map(|v| extract_variant_data(v)) + .collect::>()?; + impl_enum(enum_, args, variants, method_type) +} - let base = &attr.base; - let base_nativetype = if attr.has_extends { - quote! { ::BaseNativeType } - } else { - quote! { ::pyo3::PyAny } - }; - - // If #cls is not extended type, we allow Self->PyObject conversion - let into_pyobject = if !attr.has_extends { - quote! { - impl ::pyo3::IntoPy<::pyo3::PyObject> for #cls { - fn into_py(self, py: ::pyo3::Python) -> ::pyo3::PyObject { - ::pyo3::IntoPy::into_py(::pyo3::Py::new(py, self).unwrap(), py) - } - } - } - } else { - quote! {} - }; - - let thread_checker = if attr.has_unsendable { - quote! { ::pyo3::class::impl_::ThreadCheckerImpl<#cls> } - } else if attr.has_extends { - quote! { - ::pyo3::class::impl_::ThreadCheckerInherited<#cls, <#cls as ::pyo3::class::impl_::PyClassImpl>::BaseType> - } - } else { - quote! { ::pyo3::class::impl_::ThreadCheckerStub<#cls> } - }; - - let is_gc = attr.is_gc; - let is_basetype = attr.is_basetype; - let is_subclass = attr.has_extends; +fn impl_enum( + enum_: &syn::ItemEnum, + attrs: PyClassArgs, + variants: Vec, + methods_type: PyClassMethodsType, +) -> syn::Result { + let enum_name = &enum_.ident; + let doc = utils::get_doc(&enum_.attrs, None); + let enum_cls = impl_enum_class(enum_name, &attrs, variants, doc, methods_type)?; Ok(quote! { - unsafe impl ::pyo3::type_object::PyTypeInfo for #cls { - type AsRefTarget = ::pyo3::PyCell; + #enum_cls + }) +} - const NAME: &'static str = #cls_name; - const MODULE: ::std::option::Option<&'static str> = #module; +fn impl_enum_class( + cls: &syn::Ident, + attr: &PyClassArgs, + variants: Vec, + doc: PythonDoc, + methods_type: PyClassMethodsType, +) -> syn::Result { + let pytypeinfo = impl_pytypeinfo(cls, attr, None); + let pyclass_impls = PyClassImplsBuilder::new(cls, attr, methods_type) + .doc(doc) + .impl_all(); + let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident)); - #[inline] - fn type_object_raw(py: ::pyo3::Python<'_>) -> *mut ::pyo3::ffi::PyTypeObject { - #deprecations + Ok(quote! { - use ::pyo3::type_object::LazyStaticType; - static TYPE_OBJECT: LazyStaticType = LazyStaticType::new(); - TYPE_OBJECT.get_or_init::(py) - } - } + #pytypeinfo - impl ::pyo3::PyClass for #cls { - type Dict = #dict; - type WeakRef = #weakref; - type BaseNativeType = #base_nativetype; - } - - impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a #cls - { - type Target = ::pyo3::PyRef<'a, #cls>; - } - - impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls - { - type Target = ::pyo3::PyRefMut<'a, #cls>; - } - - #into_pyobject - - #impl_inventory - - impl ::pyo3::class::impl_::PyClassImpl for #cls { - const DOC: &'static str = #doc; - const IS_GC: bool = #is_gc; - const IS_BASETYPE: bool = #is_basetype; - const IS_SUBCLASS: bool = #is_subclass; - - type Layout = ::pyo3::PyCell; - type BaseType = #base; - type ThreadChecker = #thread_checker; - - fn for_each_method_def(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::class::PyMethodDefType])) { - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - #for_each_py_method; - visitor(collector.py_class_descriptors()); - visitor(collector.object_protocol_methods()); - visitor(collector.async_protocol_methods()); - visitor(collector.descr_protocol_methods()); - visitor(collector.mapping_protocol_methods()); - visitor(collector.number_protocol_methods()); - } - fn get_new() -> ::std::option::Option<::pyo3::ffi::newfunc> { - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.new_impl() - } - fn get_alloc() -> ::std::option::Option<::pyo3::ffi::allocfunc> { - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.alloc_impl() - } - fn get_free() -> ::std::option::Option<::pyo3::ffi::freefunc> { - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.free_impl() - } - - fn for_each_proto_slot(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::ffi::PyType_Slot])) { - // Implementation which uses dtolnay specialization to load all slots. - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - visitor(collector.object_protocol_slots()); - visitor(collector.number_protocol_slots()); - visitor(collector.iter_protocol_slots()); - visitor(collector.gc_protocol_slots()); - visitor(collector.descr_protocol_slots()); - visitor(collector.mapping_protocol_slots()); - visitor(collector.sequence_protocol_slots()); - visitor(collector.async_protocol_slots()); - visitor(collector.buffer_protocol_slots()); - #methods_protos - } - - fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { - use ::pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.buffer_procs() - } - } - - #alloc + #pyclass_impls #descriptors - #gc_impl }) } +fn unit_variants_as_descriptors<'a>( + cls: &'a syn::Ident, + variant_names: impl IntoIterator, +) -> TokenStream { + let cls_type = syn::parse_quote!(#cls); + let variant_to_attribute = |ident: &syn::Ident| ConstSpec { + rust_ident: ident.clone(), + attributes: ConstAttributes { + is_class_attr: true, + name: Some(NameAttribute(ident.clone())), + deprecations: Default::default(), + }, + }; + let py_methods = variant_names + .into_iter() + .map(|var| gen_py_const(&cls_type, &variant_to_attribute(var))); + + quote! { + impl ::pyo3::class::impl_::PyClassDescriptors<#cls> + for ::pyo3::class::impl_::PyClassImplCollector<#cls> + { + fn py_class_descriptors(self) -> &'static [::pyo3::class::methods::PyMethodDefType] { + static METHODS: &[::pyo3::class::methods::PyMethodDefType] = &[#(#py_methods),*]; + METHODS + } + } + } +} + +fn extract_variant_data(variant: &syn::Variant) -> syn::Result { + use syn::Fields; + let ident = match variant.fields { + Fields::Unit => &variant.ident, + _ => bail_spanned!(variant.span() => "Currently only support unit variants."), + }; + if let Some(discriminant) = variant.discriminant.as_ref() { + bail_spanned!(discriminant.0.span() => "Currently does not support discriminats.") + }; + Ok(PyClassEnumVariant { ident }) +} + fn impl_descriptors( cls: &syn::Ident, field_options: Vec<(&syn::Field, FieldPyO3Options)>, @@ -662,3 +524,341 @@ fn impl_descriptors( } }) } + +fn impl_pytypeinfo( + cls: &syn::Ident, + attr: &PyClassArgs, + deprecations: Option<&Deprecations>, +) -> TokenStream { + let cls_name = get_class_python_name(cls, attr).to_string(); + + let module = if let Some(m) = &attr.module { + quote! { ::core::option::Option::Some(#m) } + } else { + quote! { ::core::option::Option::None } + }; + + quote! { + unsafe impl ::pyo3::type_object::PyTypeInfo for #cls { + type AsRefTarget = ::pyo3::PyCell; + + const NAME: &'static str = #cls_name; + const MODULE: ::std::option::Option<&'static str> = #module; + + #[inline] + fn type_object_raw(py: ::pyo3::Python<'_>) -> *mut ::pyo3::ffi::PyTypeObject { + #deprecations + + use ::pyo3::type_object::LazyStaticType; + static TYPE_OBJECT: LazyStaticType = LazyStaticType::new(); + TYPE_OBJECT.get_or_init::(py) + } + } + } +} + +/// Implements most traits used by `#[pyclass]`. +/// +/// Specifically, it implements traits that only depend on class name, +/// and attributes of `#[pyclass]`, and docstrings. +/// Therefore it doesn't implement traits that depends on struct fields and enum variants. +struct PyClassImplsBuilder<'a> { + cls: &'a syn::Ident, + attr: &'a PyClassArgs, + methods_type: PyClassMethodsType, + doc: Option, +} + +impl<'a> PyClassImplsBuilder<'a> { + fn new(cls: &'a syn::Ident, attr: &'a PyClassArgs, methods_type: PyClassMethodsType) -> Self { + Self { + cls, + attr, + methods_type, + doc: None, + } + } + + fn doc(self, doc: PythonDoc) -> Self { + Self { + doc: Some(doc), + ..self + } + } + + fn impl_all(&self) -> TokenStream { + vec![ + self.impl_pyclass(), + self.impl_extractext(), + self.impl_into_py(), + self.impl_methods_inventory(), + self.impl_pyclassimpl(), + self.impl_freelist(), + self.impl_gc(), + ] + .into_iter() + .collect() + } + + fn impl_pyclass(&self) -> TokenStream { + let cls = self.cls; + let attr = self.attr; + let dict = if attr.has_dict { + quote! { ::pyo3::pyclass_slots::PyClassDictSlot } + } else if attr.has_extends { + quote! { ::Dict } + } else { + quote! { ::pyo3::pyclass_slots::PyClassDummySlot } + }; + + // insert space for weak ref + let weakref = if attr.has_weaklist { + quote! { ::pyo3::pyclass_slots::PyClassWeakRefSlot } + } else if attr.has_extends { + quote! { ::WeakRef } + } else { + quote! { ::pyo3::pyclass_slots::PyClassDummySlot } + }; + + let base_nativetype = if attr.has_extends { + quote! { ::BaseNativeType } + } else { + quote! { ::pyo3::PyAny } + }; + quote! { + impl ::pyo3::PyClass for #cls { + type Dict = #dict; + type WeakRef = #weakref; + type BaseNativeType = #base_nativetype; + } + } + } + fn impl_extractext(&self) -> TokenStream { + let cls = self.cls; + quote! { + impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a #cls + { + type Target = ::pyo3::PyRef<'a, #cls>; + } + + impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls + { + type Target = ::pyo3::PyRefMut<'a, #cls>; + } + } + } + + fn impl_into_py(&self) -> TokenStream { + let cls = self.cls; + let attr = self.attr; + // If #cls is not extended type, we allow Self->PyObject conversion + if !attr.has_extends { + quote! { + impl ::pyo3::IntoPy<::pyo3::PyObject> for #cls { + fn into_py(self, py: ::pyo3::Python) -> ::pyo3::PyObject { + ::pyo3::IntoPy::into_py(::pyo3::Py::new(py, self).unwrap(), py) + } + } + } + } else { + quote! {} + } + } + + /// To allow multiple #[pymethods] block, we define inventory types. + fn impl_methods_inventory(&self) -> TokenStream { + let cls = self.cls; + let methods_type = self.methods_type; + match methods_type { + PyClassMethodsType::Specialization => quote! {}, + PyClassMethodsType::Inventory => { + // Try to build a unique type for better error messages + let name = format!("Pyo3MethodsInventoryFor{}", cls.unraw()); + let inventory_cls = syn::Ident::new(&name, Span::call_site()); + + quote! { + #[doc(hidden)] + pub struct #inventory_cls { + methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>, + slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>, + } + impl ::pyo3::class::impl_::PyMethodsInventory for #inventory_cls { + fn new( + methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>, + slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>, + ) -> Self { + Self { methods, slots } + } + fn methods(&'static self) -> &'static [::pyo3::class::PyMethodDefType] { + &self.methods + } + fn slots(&'static self) -> &'static [::pyo3::ffi::PyType_Slot] { + &self.slots + } + } + + impl ::pyo3::class::impl_::HasMethodsInventory for #cls { + type Methods = #inventory_cls; + } + + ::pyo3::inventory::collect!(#inventory_cls); + } + } + } + } + fn impl_pyclassimpl(&self) -> TokenStream { + let cls = self.cls; + let doc = self.doc.as_ref().map_or(quote! {"\0"}, |doc| quote! {#doc}); + let is_gc = self.attr.is_gc; + let is_basetype = self.attr.is_basetype; + let base = &self.attr.base; + let is_subclass = self.attr.has_extends; + + let thread_checker = if self.attr.has_unsendable { + quote! { ::pyo3::class::impl_::ThreadCheckerImpl<#cls> } + } else if self.attr.has_extends { + quote! { + ::pyo3::class::impl_::ThreadCheckerInherited<#cls, <#cls as ::pyo3::class::impl_::PyClassImpl>::BaseType> + } + } else { + quote! { ::pyo3::class::impl_::ThreadCheckerStub<#cls> } + }; + + let methods_protos = match self.methods_type { + PyClassMethodsType::Specialization => { + quote! { visitor(collector.methods_protocol_slots()); } + } + PyClassMethodsType::Inventory => { + quote! { + for inventory in ::pyo3::inventory::iter::<::Methods>() { + visitor(::pyo3::class::impl_::PyMethodsInventory::slots(inventory)); + } + } + } + }; + let for_each_py_method = match self.methods_type { + PyClassMethodsType::Specialization => quote! { visitor(collector.py_methods()); }, + PyClassMethodsType::Inventory => quote! { + for inventory in ::pyo3::inventory::iter::<::Methods>() { + visitor(::pyo3::class::impl_::PyMethodsInventory::methods(inventory)); + } + }, + }; + quote! { + impl ::pyo3::class::impl_::PyClassImpl for #cls { + const DOC: &'static str = #doc; + const IS_GC: bool = #is_gc; + const IS_BASETYPE: bool = #is_basetype; + const IS_SUBCLASS: bool = #is_subclass; + + type Layout = ::pyo3::PyCell; + type BaseType = #base; + type ThreadChecker = #thread_checker; + + fn for_each_method_def(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::class::PyMethodDefType])) { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + #for_each_py_method; + visitor(collector.py_class_descriptors()); + visitor(collector.object_protocol_methods()); + visitor(collector.async_protocol_methods()); + visitor(collector.descr_protocol_methods()); + visitor(collector.mapping_protocol_methods()); + visitor(collector.number_protocol_methods()); + } + fn get_new() -> ::std::option::Option<::pyo3::ffi::newfunc> { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + collector.new_impl() + } + fn get_alloc() -> ::std::option::Option<::pyo3::ffi::allocfunc> { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + collector.alloc_impl() + } + fn get_free() -> ::std::option::Option<::pyo3::ffi::freefunc> { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + collector.free_impl() + } + + fn for_each_proto_slot(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::ffi::PyType_Slot])) { + // Implementation which uses dtolnay specialization to load all slots. + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + visitor(collector.object_protocol_slots()); + visitor(collector.number_protocol_slots()); + visitor(collector.iter_protocol_slots()); + visitor(collector.gc_protocol_slots()); + visitor(collector.descr_protocol_slots()); + visitor(collector.mapping_protocol_slots()); + visitor(collector.sequence_protocol_slots()); + visitor(collector.async_protocol_slots()); + visitor(collector.buffer_protocol_slots()); + #methods_protos + } + + fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> { + use ::pyo3::class::impl_::*; + let collector = PyClassImplCollector::::new(); + collector.buffer_procs() + } + } + } + } + + fn impl_freelist(&self) -> TokenStream { + let cls = self.cls; + + self.attr.freelist.as_ref().map_or(quote!{}, |freelist| { + quote! { + impl ::pyo3::class::impl_::PyClassWithFreeList for #cls { + #[inline] + fn get_free_list(_py: ::pyo3::Python<'_>) -> &mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> { + static mut FREELIST: *mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _; + unsafe { + if FREELIST.is_null() { + FREELIST = ::std::boxed::Box::into_raw(::std::boxed::Box::new( + ::pyo3::impl_::freelist::FreeList::with_capacity(#freelist))); + } + &mut *FREELIST + } + } + } + + impl ::pyo3::class::impl_::PyClassAllocImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn alloc_impl(self) -> ::std::option::Option<::pyo3::ffi::allocfunc> { + ::std::option::Option::Some(::pyo3::class::impl_::alloc_with_freelist::<#cls>) + } + } + + impl ::pyo3::class::impl_::PyClassFreeImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> { + #[inline] + fn free_impl(self) -> ::std::option::Option<::pyo3::ffi::freefunc> { + ::std::option::Option::Some(::pyo3::class::impl_::free_with_freelist::<#cls>) + } + } + } + }) + } + /// Enforce at compile time that PyGCProtocol is implemented + fn impl_gc(&self) -> TokenStream { + let cls = self.cls; + let attr = self.attr; + if attr.is_gc { + let closure_name = format!("__assertion_closure_{}", cls); + let closure_token = syn::Ident::new(&closure_name, Span::call_site()); + quote! { + fn #closure_token() { + use ::pyo3::class; + + fn _assert_implements_protocol<'p, T: ::pyo3::class::PyGCProtocol<'p>>() {} + _assert_implements_protocol::<#cls>(); + } + } + } else { + quote! {} + } + } +} diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 54e248c9..e2fbf8d8 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -13,6 +13,7 @@ use quote::quote; use syn::spanned::Spanned; /// The mechanism used to collect `#[pymethods]` into the type object +#[derive(Copy, Clone)] pub enum PyClassMethodsType { Specialization, Inventory, @@ -118,7 +119,7 @@ pub fn impl_methods( }) } -fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream { +pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream { let member = &spec.rust_ident; let deprecations = &spec.attributes.deprecations; let python_name = &spec.null_terminated_python_name(); diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index a6612d19..f08317d1 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -7,7 +7,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use pyo3_macros_backend::{ - build_derive_from_pyobject, build_py_class, build_py_function, build_py_methods, + build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyClassMethodsType, PyFunctionOptions, PyModuleOptions, }; @@ -107,12 +107,17 @@ pub fn pyproto(_: TokenStream, input: TokenStream) -> TokenStream { /// [10]: https://en.wikipedia.org/wiki/Free_list #[proc_macro_attribute] pub fn pyclass(attr: TokenStream, input: TokenStream) -> TokenStream { - let methods_type = if cfg!(feature = "multiple-pymethods") { - PyClassMethodsType::Inventory - } else { - PyClassMethodsType::Specialization - }; - pyclass_impl(attr, input, methods_type) + use syn::Item; + let item = parse_macro_input!(input as Item); + match item { + Item::Struct(struct_) => pyclass_impl(attr, struct_, methods_type()), + Item::Enum(enum_) => pyclass_enum_impl(attr, enum_, methods_type()), + unsupported => { + syn::Error::new_spanned(unsupported, "#[pyclass] only supports structs and enums.") + .to_compile_error() + .into() + } + } } /// A proc macro used to expose methods to Python. @@ -195,12 +200,11 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream { } fn pyclass_impl( - attr: TokenStream, - input: TokenStream, + attrs: TokenStream, + mut ast: syn::ItemStruct, methods_type: PyClassMethodsType, ) -> TokenStream { - let mut ast = parse_macro_input!(input as syn::ItemStruct); - let args = parse_macro_input!(attr as PyClassArgs); + let args = parse_macro_input!(attrs with PyClassArgs::parse_stuct_args); let expanded = build_py_class(&mut ast, &args, methods_type).unwrap_or_else(|e| e.to_compile_error()); @@ -211,6 +215,22 @@ fn pyclass_impl( .into() } +fn pyclass_enum_impl( + attr: TokenStream, + enum_: syn::ItemEnum, + methods_type: PyClassMethodsType, +) -> TokenStream { + let args = parse_macro_input!(attr with PyClassArgs::parse_enum_args); + let expanded = + build_py_enum(&enum_, args, methods_type).unwrap_or_else(|e| e.into_compile_error()); + + quote!( + #enum_ + #expanded + ) + .into() +} + fn pymethods_impl(input: TokenStream, methods_type: PyClassMethodsType) -> TokenStream { let mut ast = parse_macro_input!(input as syn::ItemImpl); let expanded = @@ -222,3 +242,11 @@ fn pymethods_impl(input: TokenStream, methods_type: PyClassMethodsType) -> Token ) .into() } + +fn methods_type() -> PyClassMethodsType { + if cfg!(feature = "multiple-pymethods") { + PyClassMethodsType::Inventory + } else { + PyClassMethodsType::Specialization + } +} diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 99decdc4..9631b2a1 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -19,6 +19,8 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_need_module_arg_position.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); + t.compile_fail("tests/ui/invalid_pyclass_enum.rs"); + t.compile_fail("tests/ui/invalid_pyclass_item.rs"); t.compile_fail("tests/ui/invalid_pyfunctions.rs"); t.compile_fail("tests/ui/invalid_pymethods.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); diff --git a/tests/test_enum.rs b/tests/test_enum.rs new file mode 100644 index 00000000..67ed9278 --- /dev/null +++ b/tests/test_enum.rs @@ -0,0 +1,53 @@ +use pyo3::prelude::*; +use pyo3::{py_run, wrap_pyfunction}; + +mod common; + +#[pyclass] +#[derive(Debug, PartialEq, Clone)] +pub enum MyEnum { + Variant, + OtherVariant, +} + +#[test] +fn test_enum_class_attr() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let my_enum = py.get_type::(); + py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None"); + py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None"); + py_run!(py, my_enum, "my_enum.Variant = None"); +} + +#[pyfunction] +fn return_enum() -> MyEnum { + MyEnum::Variant +} + +#[test] +#[ignore] // need to implement __eq__ +fn test_return_enum() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let f = wrap_pyfunction!(return_enum)(py).unwrap(); + let mynum = py.get_type::(); + + py_run!(py, f mynum, "assert f() == mynum.Variant") +} + +#[pyfunction] +fn enum_arg(e: MyEnum) { + assert_eq!(MyEnum::OtherVariant, e) +} + +#[test] +#[ignore] // need to implement __eq__ +fn test_enum_arg() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let f = wrap_pyfunction!(enum_arg)(py).unwrap(); + let mynum = py.get_type::(); + + py_run!(py, f mynum, "f(mynum.Variant)") +} diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs new file mode 100644 index 00000000..62f2a3d6 --- /dev/null +++ b/tests/ui/invalid_pyclass_enum.rs @@ -0,0 +1,15 @@ +use pyo3::prelude::*; + +#[pyclass(subclass)] +enum NotBaseClass { + x, + y, +} + +#[pyclass(extends = PyList)] +enum NotDrivedClass { + x, + y, +} + +fn main() {} diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr new file mode 100644 index 00000000..2dd0e737 --- /dev/null +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -0,0 +1,11 @@ +error: enums can't be inherited by other classes + --> tests/ui/invalid_pyclass_enum.rs:3:11 + | +3 | #[pyclass(subclass)] + | ^^^^^^^^ + +error: enums cannot extend from other classes + --> tests/ui/invalid_pyclass_enum.rs:9:11 + | +9 | #[pyclass(extends = PyList)] + | ^^^^^^^ diff --git a/tests/ui/invalid_pyclass_item.rs b/tests/ui/invalid_pyclass_item.rs new file mode 100644 index 00000000..4b348540 --- /dev/null +++ b/tests/ui/invalid_pyclass_item.rs @@ -0,0 +1,6 @@ +use pyo3::prelude::*; + +#[pyclass] +fn foo() {} + +fn main() {} diff --git a/tests/ui/invalid_pyclass_item.stderr b/tests/ui/invalid_pyclass_item.stderr new file mode 100644 index 00000000..f29756df --- /dev/null +++ b/tests/ui/invalid_pyclass_item.stderr @@ -0,0 +1,5 @@ +error: #[pyclass] only supports structs and enums. + --> tests/ui/invalid_pyclass_item.rs:4:1 + | +4 | fn foo() {} + | ^^^^^^^^^^^