Refactor #[pyclass] and now it supports enum.

There's no functionality since it does not generate __richcmp__.

Also it only works on enums with only variants, and does not support
C-like enums.
This commit is contained in:
b05902132 2021-11-17 03:31:30 +08:00
parent e9b46f76da
commit b7419b5278
10 changed files with 596 additions and 275 deletions

View File

@ -25,7 +25,7 @@ mod pyproto;
pub use from_pyobject::build_derive_from_pyobject;
pub use module::{process_functions_in_module, py_init, PyModuleOptions};
pub use pyclass::{build_py_class, PyClassArgs};
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionOptions};
pub use pyimpl::{build_py_methods, PyClassMethodsType};
pub use pyproto::build_py_proto;

View File

@ -2,7 +2,8 @@
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
use crate::deprecations::Deprecations;
use crate::pyimpl::PyClassMethodsType;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
@ -10,7 +11,14 @@ use quote::quote;
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{parse_quote, spanned::Spanned, Expr, Result, Token};
use syn::{parse_quote, spanned::Spanned, Expr, Result, Token}; //unraw
/// If the class is derived from a Rust `struct` or `enum`.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum PyClassKind {
Struct,
Enum,
}
/// The parsed arguments of the pyclass macro
pub struct PyClassArgs {
@ -24,22 +32,28 @@ pub struct PyClassArgs {
pub has_extends: bool,
pub has_unsendable: bool,
pub module: Option<syn::LitStr>,
pub class_kind: PyClassKind,
}
impl Parse for PyClassArgs {
fn parse(input: ParseStream) -> Result<Self> {
let mut slf = PyClassArgs::default();
impl PyClassArgs {
fn parse(input: ParseStream, kind: PyClassKind) -> Result<Self> {
let mut slf = PyClassArgs::new(kind);
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
for expr in vars {
slf.add_expr(&expr)?;
}
Ok(slf)
}
}
impl Default for PyClassArgs {
fn default() -> Self {
pub fn parse_stuct_args(input: ParseStream) -> syn::Result<Self> {
Self::parse(input, PyClassKind::Struct)
}
pub fn parse_enum_args(input: ParseStream) -> syn::Result<Self> {
Self::parse(input, PyClassKind::Enum)
}
fn new(class_kind: PyClassKind) -> Self {
PyClassArgs {
freelist: None,
name: None,
@ -51,11 +65,10 @@ impl Default for PyClassArgs {
is_basetype: false,
has_extends: false,
has_unsendable: false,
class_kind,
}
}
}
impl PyClassArgs {
/// Adda single expression from the comma separated list in the attribute, which is
/// either a single word or an assignment expression
fn add_expr(&mut self, expr: &Expr) -> Result<()> {
@ -113,6 +126,9 @@ impl PyClassArgs {
},
"extends" => match unwrap_group(&**right) {
syn::Expr::Path(exp) => {
if self.class_kind == PyClassKind::Enum {
bail_spanned!( assign.span() => "enums cannot extend from other classes" );
}
self.base = syn::TypePath {
path: exp.path.clone(),
qself: None,
@ -147,6 +163,9 @@ impl PyClassArgs {
self.has_weaklist = true;
}
"subclass" => {
if self.class_kind == PyClassKind::Enum {
bail_spanned!(exp.span() => "enums can't be inherited by other classes");
}
self.is_basetype = true;
}
"dict" => {
@ -328,41 +347,6 @@ impl FieldPyO3Options {
}
}
/// To allow multiple #[pymethods] block, we define inventory types.
fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
// Try to build a unique type for better error messages
let name = format!("Pyo3MethodsInventoryFor{}", cls.unraw());
let inventory_cls = syn::Ident::new(&name, Span::call_site());
quote! {
#[doc(hidden)]
pub struct #inventory_cls {
methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>,
slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>,
}
impl ::pyo3::class::impl_::PyMethodsInventory for #inventory_cls {
fn new(
methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>,
slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>,
) -> Self {
Self { methods, slots }
}
fn methods(&'static self) -> &'static [::pyo3::class::PyMethodDefType] {
&self.methods
}
fn slots(&'static self) -> &'static [::pyo3::ffi::PyType_Slot] {
&self.slots
}
}
impl ::pyo3::class::impl_::HasMethodsInventory for #cls {
type Methods = #inventory_cls;
}
::pyo3::inventory::collect!(#inventory_cls);
}
}
fn get_class_python_name<'a>(cls: &'a syn::Ident, attr: &'a PyClassArgs) -> &'a syn::Ident {
attr.name.as_ref().unwrap_or(cls)
}
@ -375,243 +359,121 @@ fn impl_class(
methods_type: PyClassMethodsType,
deprecations: Deprecations,
) -> syn::Result<TokenStream> {
let cls_name = get_class_python_name(cls, attr).to_string();
let pytypeinfo_impl = impl_pytypeinfo(cls, attr, Some(&deprecations));
let alloc = attr.freelist.as_ref().map(|freelist| {
quote! {
impl ::pyo3::class::impl_::PyClassWithFreeList for #cls {
#[inline]
fn get_free_list(_py: ::pyo3::Python<'_>) -> &mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> {
static mut FREELIST: *mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _;
unsafe {
if FREELIST.is_null() {
FREELIST = ::std::boxed::Box::into_raw(::std::boxed::Box::new(
::pyo3::impl_::freelist::FreeList::with_capacity(#freelist)));
}
&mut *FREELIST
}
}
}
impl ::pyo3::class::impl_::PyClassAllocImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn alloc_impl(self) -> ::std::option::Option<::pyo3::ffi::allocfunc> {
::std::option::Option::Some(::pyo3::class::impl_::alloc_with_freelist::<#cls>)
}
}
impl ::pyo3::class::impl_::PyClassFreeImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn free_impl(self) -> ::std::option::Option<::pyo3::ffi::freefunc> {
::std::option::Option::Some(::pyo3::class::impl_::free_with_freelist::<#cls>)
}
}
}
});
let py_class_impl = PyClassImplsBuilder::new(cls, attr, methods_type)
.doc(doc)
.impl_all();
let descriptors = impl_descriptors(cls, field_options)?;
// insert space for weak ref
let weakref = if attr.has_weaklist {
quote! { ::pyo3::pyclass_slots::PyClassWeakRefSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::WeakRef }
} else {
quote! { ::pyo3::pyclass_slots::PyClassDummySlot }
};
let dict = if attr.has_dict {
quote! { ::pyo3::pyclass_slots::PyClassDictSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::Dict }
} else {
quote! { ::pyo3::pyclass_slots::PyClassDummySlot }
};
let module = if let Some(m) = &attr.module {
quote! { ::std::option::Option::Some(#m) }
} else {
quote! { ::std::option::Option::None }
};
Ok(quote! {
#pytypeinfo_impl
// Enforce at compile time that PyGCProtocol is implemented
let gc_impl = if attr.is_gc {
let closure_name = format!("__assertion_closure_{}", cls);
let closure_token = syn::Ident::new(&closure_name, Span::call_site());
quote! {
fn #closure_token() {
use ::pyo3::class;
#py_class_impl
fn _assert_implements_protocol<'p, T: ::pyo3::class::PyGCProtocol<'p>>() {}
_assert_implements_protocol::<#cls>();
}
}
} else {
quote! {}
};
#descriptors
})
}
let (impl_inventory, for_each_py_method) = match methods_type {
PyClassMethodsType::Specialization => (None, quote! { visitor(collector.py_methods()); }),
PyClassMethodsType::Inventory => (
Some(impl_methods_inventory(cls)),
quote! {
for inventory in ::pyo3::inventory::iter::<<Self as ::pyo3::class::impl_::HasMethodsInventory>::Methods>() {
visitor(::pyo3::class::impl_::PyMethodsInventory::methods(inventory));
}
},
),
};
struct PyClassEnumVariant<'a> {
ident: &'a syn::Ident,
/* currently have no more options */
}
let methods_protos = match methods_type {
PyClassMethodsType::Specialization => {
quote! { visitor(collector.methods_protocol_slots()); }
}
PyClassMethodsType::Inventory => {
quote! {
for inventory in ::pyo3::inventory::iter::<<Self as ::pyo3::class::impl_::HasMethodsInventory>::Methods>() {
visitor(::pyo3::class::impl_::PyMethodsInventory::slots(inventory));
}
}
}
};
pub fn build_py_enum(
enum_: &syn::ItemEnum,
args: PyClassArgs,
method_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let variants: Vec<PyClassEnumVariant> = enum_
.variants
.iter()
.map(|v| extract_variant_data(v))
.collect::<syn::Result<_>>()?;
impl_enum(enum_, args, variants, method_type)
}
let base = &attr.base;
let base_nativetype = if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::BaseNativeType }
} else {
quote! { ::pyo3::PyAny }
};
// If #cls is not extended type, we allow Self->PyObject conversion
let into_pyobject = if !attr.has_extends {
quote! {
impl ::pyo3::IntoPy<::pyo3::PyObject> for #cls {
fn into_py(self, py: ::pyo3::Python) -> ::pyo3::PyObject {
::pyo3::IntoPy::into_py(::pyo3::Py::new(py, self).unwrap(), py)
}
}
}
} else {
quote! {}
};
let thread_checker = if attr.has_unsendable {
quote! { ::pyo3::class::impl_::ThreadCheckerImpl<#cls> }
} else if attr.has_extends {
quote! {
::pyo3::class::impl_::ThreadCheckerInherited<#cls, <#cls as ::pyo3::class::impl_::PyClassImpl>::BaseType>
}
} else {
quote! { ::pyo3::class::impl_::ThreadCheckerStub<#cls> }
};
let is_gc = attr.is_gc;
let is_basetype = attr.is_basetype;
let is_subclass = attr.has_extends;
fn impl_enum(
enum_: &syn::ItemEnum,
attrs: PyClassArgs,
variants: Vec<PyClassEnumVariant>,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let enum_name = &enum_.ident;
let doc = utils::get_doc(&enum_.attrs, None);
let enum_cls = impl_enum_class(enum_name, &attrs, variants, doc, methods_type)?;
Ok(quote! {
unsafe impl ::pyo3::type_object::PyTypeInfo for #cls {
type AsRefTarget = ::pyo3::PyCell<Self>;
#enum_cls
})
}
const NAME: &'static str = #cls_name;
const MODULE: ::std::option::Option<&'static str> = #module;
fn impl_enum_class(
cls: &syn::Ident,
attr: &PyClassArgs,
variants: Vec<PyClassEnumVariant>,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let pytypeinfo = impl_pytypeinfo(cls, attr, None);
let pyclass_impls = PyClassImplsBuilder::new(cls, attr, methods_type)
.doc(doc)
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
#[inline]
fn type_object_raw(py: ::pyo3::Python<'_>) -> *mut ::pyo3::ffi::PyTypeObject {
#deprecations
Ok(quote! {
use ::pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>(py)
}
}
#pytypeinfo
impl ::pyo3::PyClass for #cls {
type Dict = #dict;
type WeakRef = #weakref;
type BaseNativeType = #base_nativetype;
}
impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = ::pyo3::PyRef<'a, #cls>;
}
impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = ::pyo3::PyRefMut<'a, #cls>;
}
#into_pyobject
#impl_inventory
impl ::pyo3::class::impl_::PyClassImpl for #cls {
const DOC: &'static str = #doc;
const IS_GC: bool = #is_gc;
const IS_BASETYPE: bool = #is_basetype;
const IS_SUBCLASS: bool = #is_subclass;
type Layout = ::pyo3::PyCell<Self>;
type BaseType = #base;
type ThreadChecker = #thread_checker;
fn for_each_method_def(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::class::PyMethodDefType])) {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
#for_each_py_method;
visitor(collector.py_class_descriptors());
visitor(collector.object_protocol_methods());
visitor(collector.async_protocol_methods());
visitor(collector.descr_protocol_methods());
visitor(collector.mapping_protocol_methods());
visitor(collector.number_protocol_methods());
}
fn get_new() -> ::std::option::Option<::pyo3::ffi::newfunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.new_impl()
}
fn get_alloc() -> ::std::option::Option<::pyo3::ffi::allocfunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.alloc_impl()
}
fn get_free() -> ::std::option::Option<::pyo3::ffi::freefunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.free_impl()
}
fn for_each_proto_slot(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::ffi::PyType_Slot])) {
// Implementation which uses dtolnay specialization to load all slots.
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
visitor(collector.gc_protocol_slots());
visitor(collector.descr_protocol_slots());
visitor(collector.mapping_protocol_slots());
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
#methods_protos
}
fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.buffer_procs()
}
}
#alloc
#pyclass_impls
#descriptors
#gc_impl
})
}
fn unit_variants_as_descriptors<'a>(
cls: &'a syn::Ident,
variant_names: impl IntoIterator<Item = &'a syn::Ident>,
) -> TokenStream {
let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |ident: &syn::Ident| ConstSpec {
rust_ident: ident.clone(),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute(ident.clone())),
deprecations: Default::default(),
},
};
let py_methods = variant_names
.into_iter()
.map(|var| gen_py_const(&cls_type, &variant_to_attribute(var)));
quote! {
impl ::pyo3::class::impl_::PyClassDescriptors<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls>
{
fn py_class_descriptors(self) -> &'static [::pyo3::class::methods::PyMethodDefType] {
static METHODS: &[::pyo3::class::methods::PyMethodDefType] = &[#(#py_methods),*];
METHODS
}
}
}
}
fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVariant> {
use syn::Fields;
let ident = match variant.fields {
Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
};
if let Some(discriminant) = variant.discriminant.as_ref() {
bail_spanned!(discriminant.0.span() => "Currently does not support discriminats.")
};
Ok(PyClassEnumVariant { ident })
}
fn impl_descriptors(
cls: &syn::Ident,
field_options: Vec<(&syn::Field, FieldPyO3Options)>,
@ -662,3 +524,341 @@ fn impl_descriptors(
}
})
}
fn impl_pytypeinfo(
cls: &syn::Ident,
attr: &PyClassArgs,
deprecations: Option<&Deprecations>,
) -> TokenStream {
let cls_name = get_class_python_name(cls, attr).to_string();
let module = if let Some(m) = &attr.module {
quote! { ::core::option::Option::Some(#m) }
} else {
quote! { ::core::option::Option::None }
};
quote! {
unsafe impl ::pyo3::type_object::PyTypeInfo for #cls {
type AsRefTarget = ::pyo3::PyCell<Self>;
const NAME: &'static str = #cls_name;
const MODULE: ::std::option::Option<&'static str> = #module;
#[inline]
fn type_object_raw(py: ::pyo3::Python<'_>) -> *mut ::pyo3::ffi::PyTypeObject {
#deprecations
use ::pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>(py)
}
}
}
}
/// Implements most traits used by `#[pyclass]`.
///
/// Specifically, it implements traits that only depend on class name,
/// and attributes of `#[pyclass]`, and docstrings.
/// Therefore it doesn't implement traits that depends on struct fields and enum variants.
struct PyClassImplsBuilder<'a> {
cls: &'a syn::Ident,
attr: &'a PyClassArgs,
methods_type: PyClassMethodsType,
doc: Option<PythonDoc>,
}
impl<'a> PyClassImplsBuilder<'a> {
fn new(cls: &'a syn::Ident, attr: &'a PyClassArgs, methods_type: PyClassMethodsType) -> Self {
Self {
cls,
attr,
methods_type,
doc: None,
}
}
fn doc(self, doc: PythonDoc) -> Self {
Self {
doc: Some(doc),
..self
}
}
fn impl_all(&self) -> TokenStream {
vec![
self.impl_pyclass(),
self.impl_extractext(),
self.impl_into_py(),
self.impl_methods_inventory(),
self.impl_pyclassimpl(),
self.impl_freelist(),
self.impl_gc(),
]
.into_iter()
.collect()
}
fn impl_pyclass(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
let dict = if attr.has_dict {
quote! { ::pyo3::pyclass_slots::PyClassDictSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::Dict }
} else {
quote! { ::pyo3::pyclass_slots::PyClassDummySlot }
};
// insert space for weak ref
let weakref = if attr.has_weaklist {
quote! { ::pyo3::pyclass_slots::PyClassWeakRefSlot }
} else if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::WeakRef }
} else {
quote! { ::pyo3::pyclass_slots::PyClassDummySlot }
};
let base_nativetype = if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::BaseNativeType }
} else {
quote! { ::pyo3::PyAny }
};
quote! {
impl ::pyo3::PyClass for #cls {
type Dict = #dict;
type WeakRef = #weakref;
type BaseNativeType = #base_nativetype;
}
}
}
fn impl_extractext(&self) -> TokenStream {
let cls = self.cls;
quote! {
impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = ::pyo3::PyRef<'a, #cls>;
}
impl<'a> ::pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = ::pyo3::PyRefMut<'a, #cls>;
}
}
}
fn impl_into_py(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
// If #cls is not extended type, we allow Self->PyObject conversion
if !attr.has_extends {
quote! {
impl ::pyo3::IntoPy<::pyo3::PyObject> for #cls {
fn into_py(self, py: ::pyo3::Python) -> ::pyo3::PyObject {
::pyo3::IntoPy::into_py(::pyo3::Py::new(py, self).unwrap(), py)
}
}
}
} else {
quote! {}
}
}
/// To allow multiple #[pymethods] block, we define inventory types.
fn impl_methods_inventory(&self) -> TokenStream {
let cls = self.cls;
let methods_type = self.methods_type;
match methods_type {
PyClassMethodsType::Specialization => quote! {},
PyClassMethodsType::Inventory => {
// Try to build a unique type for better error messages
let name = format!("Pyo3MethodsInventoryFor{}", cls.unraw());
let inventory_cls = syn::Ident::new(&name, Span::call_site());
quote! {
#[doc(hidden)]
pub struct #inventory_cls {
methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>,
slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>,
}
impl ::pyo3::class::impl_::PyMethodsInventory for #inventory_cls {
fn new(
methods: ::std::vec::Vec<::pyo3::class::PyMethodDefType>,
slots: ::std::vec::Vec<::pyo3::ffi::PyType_Slot>,
) -> Self {
Self { methods, slots }
}
fn methods(&'static self) -> &'static [::pyo3::class::PyMethodDefType] {
&self.methods
}
fn slots(&'static self) -> &'static [::pyo3::ffi::PyType_Slot] {
&self.slots
}
}
impl ::pyo3::class::impl_::HasMethodsInventory for #cls {
type Methods = #inventory_cls;
}
::pyo3::inventory::collect!(#inventory_cls);
}
}
}
}
fn impl_pyclassimpl(&self) -> TokenStream {
let cls = self.cls;
let doc = self.doc.as_ref().map_or(quote! {"\0"}, |doc| quote! {#doc});
let is_gc = self.attr.is_gc;
let is_basetype = self.attr.is_basetype;
let base = &self.attr.base;
let is_subclass = self.attr.has_extends;
let thread_checker = if self.attr.has_unsendable {
quote! { ::pyo3::class::impl_::ThreadCheckerImpl<#cls> }
} else if self.attr.has_extends {
quote! {
::pyo3::class::impl_::ThreadCheckerInherited<#cls, <#cls as ::pyo3::class::impl_::PyClassImpl>::BaseType>
}
} else {
quote! { ::pyo3::class::impl_::ThreadCheckerStub<#cls> }
};
let methods_protos = match self.methods_type {
PyClassMethodsType::Specialization => {
quote! { visitor(collector.methods_protocol_slots()); }
}
PyClassMethodsType::Inventory => {
quote! {
for inventory in ::pyo3::inventory::iter::<<Self as ::pyo3::class::impl_::HasMethodsInventory>::Methods>() {
visitor(::pyo3::class::impl_::PyMethodsInventory::slots(inventory));
}
}
}
};
let for_each_py_method = match self.methods_type {
PyClassMethodsType::Specialization => quote! { visitor(collector.py_methods()); },
PyClassMethodsType::Inventory => quote! {
for inventory in ::pyo3::inventory::iter::<<Self as ::pyo3::class::impl_::HasMethodsInventory>::Methods>() {
visitor(::pyo3::class::impl_::PyMethodsInventory::methods(inventory));
}
},
};
quote! {
impl ::pyo3::class::impl_::PyClassImpl for #cls {
const DOC: &'static str = #doc;
const IS_GC: bool = #is_gc;
const IS_BASETYPE: bool = #is_basetype;
const IS_SUBCLASS: bool = #is_subclass;
type Layout = ::pyo3::PyCell<Self>;
type BaseType = #base;
type ThreadChecker = #thread_checker;
fn for_each_method_def(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::class::PyMethodDefType])) {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
#for_each_py_method;
visitor(collector.py_class_descriptors());
visitor(collector.object_protocol_methods());
visitor(collector.async_protocol_methods());
visitor(collector.descr_protocol_methods());
visitor(collector.mapping_protocol_methods());
visitor(collector.number_protocol_methods());
}
fn get_new() -> ::std::option::Option<::pyo3::ffi::newfunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.new_impl()
}
fn get_alloc() -> ::std::option::Option<::pyo3::ffi::allocfunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.alloc_impl()
}
fn get_free() -> ::std::option::Option<::pyo3::ffi::freefunc> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.free_impl()
}
fn for_each_proto_slot(visitor: &mut dyn ::std::ops::FnMut(&[::pyo3::ffi::PyType_Slot])) {
// Implementation which uses dtolnay specialization to load all slots.
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
visitor(collector.gc_protocol_slots());
visitor(collector.descr_protocol_slots());
visitor(collector.mapping_protocol_slots());
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
#methods_protos
}
fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> {
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
collector.buffer_procs()
}
}
}
}
fn impl_freelist(&self) -> TokenStream {
let cls = self.cls;
self.attr.freelist.as_ref().map_or(quote!{}, |freelist| {
quote! {
impl ::pyo3::class::impl_::PyClassWithFreeList for #cls {
#[inline]
fn get_free_list(_py: ::pyo3::Python<'_>) -> &mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> {
static mut FREELIST: *mut ::pyo3::impl_::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _;
unsafe {
if FREELIST.is_null() {
FREELIST = ::std::boxed::Box::into_raw(::std::boxed::Box::new(
::pyo3::impl_::freelist::FreeList::with_capacity(#freelist)));
}
&mut *FREELIST
}
}
}
impl ::pyo3::class::impl_::PyClassAllocImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn alloc_impl(self) -> ::std::option::Option<::pyo3::ffi::allocfunc> {
::std::option::Option::Some(::pyo3::class::impl_::alloc_with_freelist::<#cls>)
}
}
impl ::pyo3::class::impl_::PyClassFreeImpl<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn free_impl(self) -> ::std::option::Option<::pyo3::ffi::freefunc> {
::std::option::Option::Some(::pyo3::class::impl_::free_with_freelist::<#cls>)
}
}
}
})
}
/// Enforce at compile time that PyGCProtocol is implemented
fn impl_gc(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
if attr.is_gc {
let closure_name = format!("__assertion_closure_{}", cls);
let closure_token = syn::Ident::new(&closure_name, Span::call_site());
quote! {
fn #closure_token() {
use ::pyo3::class;
fn _assert_implements_protocol<'p, T: ::pyo3::class::PyGCProtocol<'p>>() {}
_assert_implements_protocol::<#cls>();
}
}
} else {
quote! {}
}
}
}

View File

@ -13,6 +13,7 @@ use quote::quote;
use syn::spanned::Spanned;
/// The mechanism used to collect `#[pymethods]` into the type object
#[derive(Copy, Clone)]
pub enum PyClassMethodsType {
Specialization,
Inventory,
@ -118,7 +119,7 @@ pub fn impl_methods(
})
}
fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
let member = &spec.rust_ident;
let deprecations = &spec.attributes.deprecations;
let python_name = &spec.null_terminated_python_name();

View File

@ -7,7 +7,7 @@ extern crate proc_macro;
use proc_macro::TokenStream;
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_function, build_py_methods,
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyClassMethodsType,
PyFunctionOptions, PyModuleOptions,
};
@ -107,12 +107,17 @@ pub fn pyproto(_: TokenStream, input: TokenStream) -> TokenStream {
/// [10]: https://en.wikipedia.org/wiki/Free_list
#[proc_macro_attribute]
pub fn pyclass(attr: TokenStream, input: TokenStream) -> TokenStream {
let methods_type = if cfg!(feature = "multiple-pymethods") {
PyClassMethodsType::Inventory
} else {
PyClassMethodsType::Specialization
};
pyclass_impl(attr, input, methods_type)
use syn::Item;
let item = parse_macro_input!(input as Item);
match item {
Item::Struct(struct_) => pyclass_impl(attr, struct_, methods_type()),
Item::Enum(enum_) => pyclass_enum_impl(attr, enum_, methods_type()),
unsupported => {
syn::Error::new_spanned(unsupported, "#[pyclass] only supports structs and enums.")
.to_compile_error()
.into()
}
}
}
/// A proc macro used to expose methods to Python.
@ -195,12 +200,11 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
}
fn pyclass_impl(
attr: TokenStream,
input: TokenStream,
attrs: TokenStream,
mut ast: syn::ItemStruct,
methods_type: PyClassMethodsType,
) -> TokenStream {
let mut ast = parse_macro_input!(input as syn::ItemStruct);
let args = parse_macro_input!(attr as PyClassArgs);
let args = parse_macro_input!(attrs with PyClassArgs::parse_stuct_args);
let expanded =
build_py_class(&mut ast, &args, methods_type).unwrap_or_else(|e| e.to_compile_error());
@ -211,6 +215,22 @@ fn pyclass_impl(
.into()
}
fn pyclass_enum_impl(
attr: TokenStream,
enum_: syn::ItemEnum,
methods_type: PyClassMethodsType,
) -> TokenStream {
let args = parse_macro_input!(attr with PyClassArgs::parse_enum_args);
let expanded =
build_py_enum(&enum_, args, methods_type).unwrap_or_else(|e| e.into_compile_error());
quote!(
#enum_
#expanded
)
.into()
}
fn pymethods_impl(input: TokenStream, methods_type: PyClassMethodsType) -> TokenStream {
let mut ast = parse_macro_input!(input as syn::ItemImpl);
let expanded =
@ -222,3 +242,11 @@ fn pymethods_impl(input: TokenStream, methods_type: PyClassMethodsType) -> Token
)
.into()
}
fn methods_type() -> PyClassMethodsType {
if cfg!(feature = "multiple-pymethods") {
PyClassMethodsType::Inventory
} else {
PyClassMethodsType::Specialization
}
}

View File

@ -19,6 +19,8 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_enum.rs");
t.compile_fail("tests/ui/invalid_pyclass_item.rs");
t.compile_fail("tests/ui/invalid_pyfunctions.rs");
t.compile_fail("tests/ui/invalid_pymethods.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");

53
tests/test_enum.rs Normal file
View File

@ -0,0 +1,53 @@
use pyo3::prelude::*;
use pyo3::{py_run, wrap_pyfunction};
mod common;
#[pyclass]
#[derive(Debug, PartialEq, Clone)]
pub enum MyEnum {
Variant,
OtherVariant,
}
#[test]
fn test_enum_class_attr() {
let gil = Python::acquire_gil();
let py = gil.python();
let my_enum = py.get_type::<MyEnum>();
py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None");
py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None");
py_run!(py, my_enum, "my_enum.Variant = None");
}
#[pyfunction]
fn return_enum() -> MyEnum {
MyEnum::Variant
}
#[test]
#[ignore] // need to implement __eq__
fn test_return_enum() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(return_enum)(py).unwrap();
let mynum = py.get_type::<MyEnum>();
py_run!(py, f mynum, "assert f() == mynum.Variant")
}
#[pyfunction]
fn enum_arg(e: MyEnum) {
assert_eq!(MyEnum::OtherVariant, e)
}
#[test]
#[ignore] // need to implement __eq__
fn test_enum_arg() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
let mynum = py.get_type::<MyEnum>();
py_run!(py, f mynum, "f(mynum.Variant)")
}

View File

@ -0,0 +1,15 @@
use pyo3::prelude::*;
#[pyclass(subclass)]
enum NotBaseClass {
x,
y,
}
#[pyclass(extends = PyList)]
enum NotDrivedClass {
x,
y,
}
fn main() {}

View File

@ -0,0 +1,11 @@
error: enums can't be inherited by other classes
--> tests/ui/invalid_pyclass_enum.rs:3:11
|
3 | #[pyclass(subclass)]
| ^^^^^^^^
error: enums cannot extend from other classes
--> tests/ui/invalid_pyclass_enum.rs:9:11
|
9 | #[pyclass(extends = PyList)]
| ^^^^^^^

View File

@ -0,0 +1,6 @@
use pyo3::prelude::*;
#[pyclass]
fn foo() {}
fn main() {}

View File

@ -0,0 +1,5 @@
error: #[pyclass] only supports structs and enums.
--> tests/ui/invalid_pyclass_item.rs:4:1
|
4 | fn foo() {}
| ^^^^^^^^^^^