diff --git a/guide/src/class.md b/guide/src/class.md index 25df536c..356e7c71 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -2,7 +2,7 @@ PyO3 exposes a group of attributes powered by Rust's proc macro system for defining Python classes as Rust structs. -The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or a fieldless `enum` (a.k.a. C-like enum) to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`. +The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or `enum` to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`. This chapter will discuss the functionality and configuration these attributes offer. Below is a list of links to the relevant section of this chapter for each: @@ -21,13 +21,13 @@ This chapter will discuss the functionality and configuration these attributes o ## Defining a new class -To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or a fieldless enum. +To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or enum. ```rust # #![allow(dead_code)] use pyo3::prelude::*; #[pyclass] -struct Integer { +struct MyClass { inner: i32, } @@ -35,7 +35,15 @@ struct Integer { #[pyclass] struct Number(i32); -// PyO3 supports custom discriminants in enums +// PyO3 supports unit-only enums (which contain only unit variants) +// These simple enums behave similarly to Python's enumerations (enum.Enum) +#[pyclass] +enum MyEnum { + Variant, + OtherVariant = 30, // PyO3 supports custom discriminants. +} + +// PyO3 supports custom discriminants in unit-only enums #[pyclass] enum HttpResponse { Ok = 200, @@ -44,14 +52,19 @@ enum HttpResponse { // ... } +// PyO3 also supports enums with non-unit variants +// These complex enums have sligtly different behavior from the simple enums above +// They are meant to work with instance checks and match statement patterns #[pyclass] -enum MyEnum { - Variant, - OtherVariant = 30, // PyO3 supports custom discriminants. +enum Shape { + Circle { radius: f64 }, + Rectangle { width: f64, height: f64 }, + RegularPolygon { side_count: u32, radius: f64 }, + Nothing { }, } ``` -The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass` and `MyEnum`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter. +The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass`, `Number`, `MyEnum`, `HttpResponse`, and `Shape`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter. ### Restrictions @@ -964,7 +977,13 @@ Note that `text_signature` on `#[new]` is not compatible with compilation in ## #[pyclass] enums -Currently PyO3 only supports fieldless enums. PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`: +Enum support in PyO3 comes in two flavors, depending on what kind of variants the enum has: simple and complex. + +### Simple enums + +A simple enum (a.k.a. C-like enum) has only unit variants. + +PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`: ```rust # use pyo3::prelude::*; @@ -986,7 +1005,7 @@ Python::with_gil(|py| { }) ``` -You can also convert your enums into `int`: +You can also convert your simple enums into `int`: ```rust # use pyo3::prelude::*; @@ -1094,6 +1113,90 @@ enum BadSubclass { `#[pyclass]` enums are currently not interoperable with `IntEnum` in Python. +### Complex enums + +An enum is complex if it has any non-unit (struct or tuple) variants. + +Currently PyO3 supports only struct variants in a complex enum. Support for unit and tuple variants is planned. + +PyO3 adds a class attribute for each variant, which may be used to construct values and in match patterns. PyO3 also provides getter methods for all fields of each variant. + +```rust +# use pyo3::prelude::*; +#[pyclass] +enum Shape { + Circle { radius: f64 }, + Rectangle { width: f64, height: f64 }, + RegularPolygon { side_count: u32, radius: f64 }, + Nothing { }, +} + +Python::with_gil(|py| { + let def_count_vertices = if py.version_info() >= (3, 10) { r#" + def count_vertices(cls, shape): + match shape: + case cls.Circle(): + return 0 + case cls.Rectangle(): + return 4 + case cls.RegularPolygon(side_count=n): + return n + case cls.Nothing(): + return 0 + "# } else { r#" + def count_vertices(cls, shape): + if isinstance(shape, cls.Circle): + return 0 + elif isinstance(shape, cls.Rectangle): + return 4 + elif isinstance(shape, cls.RegularPolygon): + n = shape.side_count + return n + elif isinstance(shape, cls.Nothing): + return 0 + "# }; + + let circle = Shape::Circle { radius: 10.0 }.into_py(py); + let square = Shape::RegularPolygon { side_count: 4, radius: 10.0 }.into_py(py); + let cls = py.get_type::(); + + pyo3::py_run!(py, circle square cls, &format!(r#" + assert isinstance(circle, cls) + assert isinstance(circle, cls.Circle) + assert circle.radius == 10.0 + + assert isinstance(square, cls) + assert isinstance(square, cls.RegularPolygon) + assert square.side_count == 4 + assert square.radius == 10.0 + + {} + + assert count_vertices(cls, circle) == 0 + assert count_vertices(cls, square) == 4 + "#, def_count_vertices)) +}) +``` + +WARNING: `Py::new` and `.into_py` are currently inconsistent. Note how the constructed value is _not_ an instance of the specific variant. For this reason, constructing values is only recommended using `.into_py`. + +```rust +# use pyo3::prelude::*; +#[pyclass] +enum MyEnum { + Variant { i: i32 }, +} + +Python::with_gil(|py| { + let x = Py::new(py, MyEnum::Variant { i: 42 }).unwrap(); + let cls = py.get_type::(); + pyo3::py_run!(py, x cls, r#" + assert isinstance(x, cls) + assert not isinstance(x, cls.Variant) + "#) +}) +``` + ## Implementation details The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block. diff --git a/newsfragments/3582.added.md b/newsfragments/3582.added.md new file mode 100644 index 00000000..59659a88 --- /dev/null +++ b/newsfragments/3582.added.md @@ -0,0 +1 @@ +Support `#[pyclass]` on enums that have non-unit variants. diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 66f41504..e2f84949 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -7,7 +7,7 @@ use crate::attributes::{ }; use crate::deprecations::Deprecations; use crate::konst::{ConstAttributes, ConstSpec}; -use crate::method::FnSpec; +use crate::method::{FnArg, FnSpec}; use crate::pyimpl::{gen_py_const, PyClassMethodsType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType, @@ -16,7 +16,7 @@ use crate::pymethod::{ use crate::utils::{self, apply_renaming_rule, get_pyo3_crate, PythonDoc}; use crate::PyFunctionOptions; use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote}; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; @@ -30,6 +30,7 @@ pub enum PyClassKind { } /// The parsed arguments of the pyclass macro +#[derive(Clone)] pub struct PyClassArgs { pub class_kind: PyClassKind, pub options: PyClassPyO3Options, @@ -52,7 +53,7 @@ impl PyClassArgs { } } -#[derive(Default)] +#[derive(Clone, Default)] pub struct PyClassPyO3Options { pub krate: Option, pub dict: Option, @@ -128,7 +129,7 @@ impl Parse for PyClassPyO3Option { } } -impl PyClassPyO3Options { +impl Parse for PyClassPyO3Options { fn parse(input: ParseStream<'_>) -> syn::Result { let mut options: PyClassPyO3Options = Default::default(); @@ -138,7 +139,9 @@ impl PyClassPyO3Options { Ok(options) } +} +impl PyClassPyO3Options { pub fn take_pyo3_options(&mut self, attrs: &mut Vec) -> syn::Result<()> { take_pyo3_options(attrs)? .into_iter() @@ -369,73 +372,24 @@ fn impl_class( }) } -struct PyClassEnumVariant<'a> { - ident: &'a syn::Ident, - options: EnumVariantPyO3Options, -} - -impl<'a> PyClassEnumVariant<'a> { - fn python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> { - self.options - .name - .as_ref() - .map(|name_attr| Cow::Borrowed(&name_attr.value.0)) - .unwrap_or_else(|| { - let name = self.ident.unraw(); - if let Some(attr) = &args.options.rename_all { - let new_name = apply_renaming_rule(attr.value.rule, &name.to_string()); - Cow::Owned(Ident::new(&new_name, Span::call_site())) - } else { - Cow::Owned(name) - } - }) - } -} - -struct PyClassEnum<'a> { - ident: &'a syn::Ident, - // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__. - // This matters when the underlying representation may not fit in `isize`. - repr_type: syn::Ident, - variants: Vec>, +enum PyClassEnum<'a> { + Simple(PyClassSimpleEnum<'a>), + Complex(PyClassComplexEnum<'a>), } impl<'a> PyClassEnum<'a> { fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result { - fn is_numeric_type(t: &syn::Ident) -> bool { - [ - "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize", - "isize", - ] - .iter() - .any(|&s| t == s) - } - let ident = &enum_.ident; - // According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html), - // "Under the default representation, the specified discriminant is interpreted as an isize - // value", so `isize` should be enough by default. - let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site()); - if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) { - let args = - attr.parse_args_with(Punctuated::::parse_terminated)?; - if let Some(ident) = args - .into_iter() - .filter_map(|ts| syn::parse2::(ts).ok()) - .find(is_numeric_type) - { - repr_type = ident; - } - } - - let variants = enum_ + let has_only_unit_variants = enum_ .variants - .iter_mut() - .map(extract_variant_data) - .collect::>()?; - Ok(Self { - ident, - repr_type, - variants, + .iter() + .all(|variant| matches!(variant.fields, syn::Fields::Unit)); + + Ok(if has_only_unit_variants { + let simple_enum = PyClassSimpleEnum::new(enum_)?; + Self::Simple(simple_enum) + } else { + let complex_enum = PyClassComplexEnum::new(enum_)?; + Self::Complex(complex_enum) }) } } @@ -460,6 +414,208 @@ pub fn build_py_enum( impl_enum(enum_, &args, doc, method_type) } +struct PyClassSimpleEnum<'a> { + ident: &'a syn::Ident, + // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__. + // This matters when the underlying representation may not fit in `isize`. + repr_type: syn::Ident, + variants: Vec>, +} + +impl<'a> PyClassSimpleEnum<'a> { + fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result { + fn is_numeric_type(t: &syn::Ident) -> bool { + [ + "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize", + "isize", + ] + .iter() + .any(|&s| t == s) + } + + fn extract_unit_variant_data( + variant: &mut syn::Variant, + ) -> syn::Result> { + use syn::Fields; + let ident = match &variant.fields { + Fields::Unit => &variant.ident, + _ => bail_spanned!(variant.span() => "Must be a unit variant."), + }; + let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?; + Ok(PyClassEnumUnitVariant { ident, options }) + } + + let ident = &enum_.ident; + + // According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html), + // "Under the default representation, the specified discriminant is interpreted as an isize + // value", so `isize` should be enough by default. + let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site()); + if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) { + let args = + attr.parse_args_with(Punctuated::::parse_terminated)?; + if let Some(ident) = args + .into_iter() + .filter_map(|ts| syn::parse2::(ts).ok()) + .find(is_numeric_type) + { + repr_type = ident; + } + } + + let variants: Vec<_> = enum_ + .variants + .iter_mut() + .map(extract_unit_variant_data) + .collect::>()?; + Ok(Self { + ident, + repr_type, + variants, + }) + } +} + +struct PyClassComplexEnum<'a> { + ident: &'a syn::Ident, + variants: Vec>, +} + +impl<'a> PyClassComplexEnum<'a> { + fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result { + let witness = enum_ + .variants + .iter() + .find(|variant| !matches!(variant.fields, syn::Fields::Unit)) + .expect("complex enum has a non-unit variant") + .ident + .to_owned(); + + let extract_variant_data = + |variant: &'a mut syn::Variant| -> syn::Result> { + use syn::Fields; + let ident = &variant.ident; + let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?; + + let variant = match &variant.fields { + Fields::Unit => { + bail_spanned!(variant.span() => format!( + "Unit variant `{ident}` is not yet supported in a complex enum\n\ + = help: change to a struct variant with no fields: `{ident} {{ }}`\n\ + = note: the enum is complex because of non-unit variant `{witness}`", + ident=ident, witness=witness)) + } + Fields::Named(fields) => { + let fields = fields + .named + .iter() + .map(|field| PyClassEnumVariantNamedField { + ident: field.ident.as_ref().expect("named field has an identifier"), + ty: &field.ty, + span: field.span(), + }) + .collect(); + + PyClassEnumVariant::Struct(PyClassEnumStructVariant { + ident, + fields, + options, + }) + } + Fields::Unnamed(_) => { + bail_spanned!(variant.span() => format!( + "Tuple variant `{ident}` is not yet supported in a complex enum\n\ + = help: change to a struct variant with named fields: `{ident} {{ /* fields */ }}`\n\ + = note: the enum is complex because of non-unit variant `{witness}`", + ident=ident, witness=witness)) + } + }; + + Ok(variant) + }; + + let ident = &enum_.ident; + + let variants: Vec<_> = enum_ + .variants + .iter_mut() + .map(extract_variant_data) + .collect::>()?; + + Ok(Self { ident, variants }) + } +} + +enum PyClassEnumVariant<'a> { + // TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>), + Struct(PyClassEnumStructVariant<'a>), + // TODO(mkovaxx): Tuple(PyClassEnumTupleVariant<'a>), +} + +trait EnumVariant { + fn get_ident(&self) -> &syn::Ident; + fn get_options(&self) -> &EnumVariantPyO3Options; + + fn get_python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> { + self.get_options() + .name + .as_ref() + .map(|name_attr| Cow::Borrowed(&name_attr.value.0)) + .unwrap_or_else(|| { + let name = self.get_ident().unraw(); + if let Some(attr) = &args.options.rename_all { + let new_name = apply_renaming_rule(attr.value.rule, &name.to_string()); + Cow::Owned(Ident::new(&new_name, Span::call_site())) + } else { + Cow::Owned(name) + } + }) + } +} + +impl<'a> EnumVariant for PyClassEnumVariant<'a> { + fn get_ident(&self) -> &syn::Ident { + match self { + PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident, + } + } + + fn get_options(&self) -> &EnumVariantPyO3Options { + match self { + PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options, + } + } +} + +/// A unit variant has no fields +struct PyClassEnumUnitVariant<'a> { + ident: &'a syn::Ident, + options: EnumVariantPyO3Options, +} + +impl<'a> EnumVariant for PyClassEnumUnitVariant<'a> { + fn get_ident(&self) -> &syn::Ident { + self.ident + } + + fn get_options(&self) -> &EnumVariantPyO3Options { + &self.options + } +} + +/// A struct variant has named fields +struct PyClassEnumStructVariant<'a> { + ident: &'a syn::Ident, + fields: Vec>, + options: EnumVariantPyO3Options, +} + +struct PyClassEnumVariantNamedField<'a> { + ident: &'a syn::Ident, + ty: &'a syn::Type, + span: Span, +} + /// `#[pyo3()]` options for pyclass enum variants struct EnumVariantPyO3Options { name: Option, @@ -505,11 +661,25 @@ fn impl_enum( args: &PyClassArgs, doc: PythonDoc, methods_type: PyClassMethodsType, +) -> Result { + match enum_ { + PyClassEnum::Simple(simple_enum) => impl_simple_enum(simple_enum, args, doc, methods_type), + PyClassEnum::Complex(complex_enum) => { + impl_complex_enum(complex_enum, args, doc, methods_type) + } + } +} + +fn impl_simple_enum( + simple_enum: PyClassSimpleEnum<'_>, + args: &PyClassArgs, + doc: PythonDoc, + methods_type: PyClassMethodsType, ) -> Result { let krate = get_pyo3_crate(&args.options.krate); - let cls = enum_.ident; + let cls = simple_enum.ident; let ty: syn::Type = syn::parse_quote!(#cls); - let variants = enum_.variants; + let variants = simple_enum.variants; let pytypeinfo = impl_pytypeinfo(cls, args, None); let (default_repr, default_repr_slot) = { @@ -519,7 +689,7 @@ fn impl_enum( let repr = format!( "{}.{}", get_class_python_name(cls, args), - variant.python_name(args), + variant.get_python_name(args), ); quote! { #cls::#variant_name => #repr, } }); @@ -534,7 +704,7 @@ fn impl_enum( (repr_impl, repr_slot) }; - let repr_type = &enum_.repr_type; + let repr_type = &simple_enum.repr_type; let (default_int, default_int_slot) = { // This implementation allows us to convert &T to #repr_type without implementing `Copy` @@ -601,7 +771,10 @@ fn impl_enum( cls, args, methods_type, - enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name(args)))), + simple_enum_default_methods( + cls, + variants.iter().map(|v| (v.ident, v.get_python_name(args))), + ), default_slots, ) .doc(doc) @@ -626,6 +799,214 @@ fn impl_enum( }) } +fn impl_complex_enum( + complex_enum: PyClassComplexEnum<'_>, + args: &PyClassArgs, + doc: PythonDoc, + methods_type: PyClassMethodsType, +) -> Result { + // Need to rig the enum PyClass options + let args = { + let mut rigged_args = args.clone(); + // Needs to be frozen to disallow `&mut self` methods, which could break a runtime invariant + rigged_args.options.frozen = parse_quote!(frozen); + // Needs to be subclassable by the variant PyClasses + rigged_args.options.subclass = parse_quote!(subclass); + rigged_args + }; + + let krate = get_pyo3_crate(&args.options.krate); + let cls = complex_enum.ident; + let variants = complex_enum.variants; + let pytypeinfo = impl_pytypeinfo(cls, &args, None); + + let default_slots = vec![]; + + let impl_builder = PyClassImplsBuilder::new( + cls, + &args, + methods_type, + complex_enum_default_methods( + cls, + variants + .iter() + .map(|v| (v.get_ident(), v.get_python_name(&args))), + ), + default_slots, + ) + .doc(doc); + + // Need to customize the into_py impl so that it returns the variant PyClass + let enum_into_py_impl = { + let match_arms: Vec = variants + .iter() + .map(|variant| { + let variant_ident = variant.get_ident(); + let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident()); + quote! { + #cls::#variant_ident { .. } => { + let pyclass_init = _pyo3::PyClassInitializer::from(self).add_subclass(#variant_cls); + let variant_value = _pyo3::Py::new(py, pyclass_init).unwrap(); + _pyo3::IntoPy::into_py(variant_value, py) + } + } + }) + .collect(); + + quote! { + impl _pyo3::IntoPy<_pyo3::PyObject> for #cls { + fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject { + match self { + #(#match_arms)* + } + } + } + } + }; + + let pyclass_impls: TokenStream = vec![ + impl_builder.impl_pyclass(), + impl_builder.impl_extractext(), + enum_into_py_impl, + impl_builder.impl_pyclassimpl()?, + impl_builder.impl_freelist(), + ] + .into_iter() + .collect(); + + let mut variant_cls_zsts = vec![]; + let mut variant_cls_pytypeinfos = vec![]; + let mut variant_cls_pyclass_impls = vec![]; + let mut variant_cls_impls = vec![]; + for variant in &variants { + let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident()); + + let variant_cls_zst = quote! { + #[doc(hidden)] + #[allow(non_camel_case_types)] + struct #variant_cls; + }; + variant_cls_zsts.push(variant_cls_zst); + + let variant_args = PyClassArgs { + class_kind: PyClassKind::Struct, + // TODO(mkovaxx): propagate variant.options + options: parse_quote!(extends = #cls, frozen), + }; + + let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None); + variant_cls_pytypeinfos.push(variant_cls_pytypeinfo); + + let variant_new = complex_enum_variant_new(cls, variant)?; + + let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant)?; + variant_cls_impls.push(variant_cls_impl); + + let pyclass_impl = PyClassImplsBuilder::new( + &variant_cls, + &variant_args, + methods_type, + field_getters, + vec![variant_new], + ) + .impl_all()?; + + variant_cls_pyclass_impls.push(pyclass_impl); + } + + Ok(quote! { + const _: () = { + use #krate as _pyo3; + + #pytypeinfo + + #pyclass_impls + + #[doc(hidden)] + #[allow(non_snake_case)] + impl #cls {} + + #(#variant_cls_zsts)* + + #(#variant_cls_pytypeinfos)* + + #(#variant_cls_pyclass_impls)* + + #(#variant_cls_impls)* + }; + }) +} + +fn impl_complex_enum_variant_cls( + enum_name: &syn::Ident, + variant: &PyClassEnumVariant<'_>, +) -> Result<(TokenStream, Vec)> { + match variant { + PyClassEnumVariant::Struct(struct_variant) => { + impl_complex_enum_struct_variant_cls(enum_name, struct_variant) + } + } +} + +fn impl_complex_enum_struct_variant_cls( + enum_name: &syn::Ident, + variant: &PyClassEnumStructVariant<'_>, +) -> Result<(TokenStream, Vec)> { + let variant_ident = &variant.ident; + let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident); + let variant_cls_type = parse_quote!(#variant_cls); + + let mut field_names: Vec = vec![]; + let mut fields_with_types: Vec = vec![]; + let mut field_getters = vec![]; + let mut field_getter_impls: Vec = vec![]; + for field in &variant.fields { + let field_name = field.ident; + let field_type = field.ty; + let field_with_type = quote! { #field_name: #field_type }; + + let field_getter = complex_enum_variant_field_getter( + &variant_cls_type, + field_name, + field_type, + field.span, + )?; + + let field_getter_impl = quote! { + fn #field_name(slf: _pyo3::PyRef) -> _pyo3::PyResult<#field_type> { + match &*slf.into_super() { + #enum_name::#variant_ident { #field_name, .. } => Ok(#field_name.clone()), + _ => unreachable!("Wrong complex enum variant found in variant wrapper PyClass"), + } + } + }; + + field_names.push(field_name.clone()); + fields_with_types.push(field_with_type); + field_getters.push(field_getter); + field_getter_impls.push(field_getter_impl); + } + + let cls_impl = quote! { + #[doc(hidden)] + #[allow(non_snake_case)] + impl #variant_cls { + fn __pymethod_constructor__(py: _pyo3::Python<'_>, #(#fields_with_types,)*) -> _pyo3::PyClassInitializer<#variant_cls> { + let base_value = #enum_name::#variant_ident { #(#field_names,)* }; + _pyo3::PyClassInitializer::from(base_value).add_subclass(#variant_cls) + } + + #(#field_getter_impls)* + } + }; + + Ok((cls_impl, field_getters)) +} + +fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident { + format_ident!("{}_{}", enum_, variant) +} + fn generate_default_protocol_slot( cls: &syn::Type, method: &mut syn::ImplItemFn, @@ -645,7 +1026,7 @@ fn generate_default_protocol_slot( ) } -fn enum_default_methods<'a>( +fn simple_enum_default_methods<'a>( cls: &'a syn::Ident, unit_variant_names: impl IntoIterator)>, ) -> Vec { @@ -667,14 +1048,167 @@ fn enum_default_methods<'a>( .collect() } -fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result> { - use syn::Fields; - let ident = match variant.fields { - Fields::Unit => &variant.ident, - _ => bail_spanned!(variant.span() => "Currently only support unit variants."), +fn complex_enum_default_methods<'a>( + cls: &'a syn::Ident, + variant_names: impl IntoIterator)>, +) -> Vec { + let cls_type = syn::parse_quote!(#cls); + let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec { + rust_ident: var_ident.clone(), + attributes: ConstAttributes { + is_class_attr: true, + name: Some(NameAttribute { + kw: syn::parse_quote! { name }, + value: NameLitStr(py_ident.clone()), + }), + deprecations: Default::default(), + }, }; - let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?; - Ok(PyClassEnumVariant { ident, options }) + variant_names + .into_iter() + .map(|(var, py_name)| { + gen_complex_enum_variant_attr(cls, &cls_type, &variant_to_attribute(var, &py_name)) + }) + .collect() +} + +pub fn gen_complex_enum_variant_attr( + cls: &syn::Ident, + cls_type: &syn::Type, + spec: &ConstSpec, +) -> MethodAndMethodDef { + let member = &spec.rust_ident; + let wrapper_ident = format_ident!("__pymethod_variant_cls_{}__", member); + let deprecations = &spec.attributes.deprecations; + let python_name = &spec.null_terminated_python_name(); + + let variant_cls = format_ident!("{}_{}", cls, member); + let associated_method = quote! { + fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> { + #deprecations + ::std::result::Result::Ok(py.get_type::<#variant_cls>().into()) + } + }; + + let method_def = quote! { + _pyo3::class::PyMethodDefType::ClassAttribute({ + _pyo3::class::PyClassAttributeDef::new( + #python_name, + _pyo3::impl_::pymethods::PyClassAttributeFactory(#cls_type::#wrapper_ident) + ) + }) + }; + + MethodAndMethodDef { + associated_method, + method_def, + } +} + +fn complex_enum_variant_new<'a>( + cls: &'a syn::Ident, + variant: &'a PyClassEnumVariant<'a>, +) -> Result { + match variant { + PyClassEnumVariant::Struct(struct_variant) => { + complex_enum_struct_variant_new(cls, struct_variant) + } + } +} + +fn complex_enum_struct_variant_new<'a>( + cls: &'a syn::Ident, + variant: &'a PyClassEnumStructVariant<'a>, +) -> Result { + let variant_cls = format_ident!("{}_{}", cls, variant.ident); + let variant_cls_type: syn::Type = parse_quote!(#variant_cls); + + let arg_py_ident: syn::Ident = parse_quote!(py); + let arg_py_type: syn::Type = parse_quote!(_pyo3::Python<'_>); + + let args = { + let mut no_pyo3_attrs = vec![]; + let attrs = crate::pyfunction::PyFunctionArgPyO3Attributes::from_attrs(&mut no_pyo3_attrs)?; + + let mut args = vec![ + // py: Python<'_> + FnArg { + name: &arg_py_ident, + ty: &arg_py_type, + optional: None, + default: None, + py: true, + attrs: attrs.clone(), + is_varargs: false, + is_kwargs: false, + is_cancel_handle: false, + }, + ]; + + for field in &variant.fields { + args.push(FnArg { + name: field.ident, + ty: field.ty, + optional: None, + default: None, + py: false, + attrs: attrs.clone(), + is_varargs: false, + is_kwargs: false, + is_cancel_handle: false, + }); + } + args + }; + let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?; + + let spec = FnSpec { + tp: crate::method::FnType::FnNew, + name: &format_ident!("__pymethod_constructor__"), + python_name: format_ident!("__new__"), + signature, + output: variant_cls_type.clone(), + convention: crate::method::CallingConvention::TpNew, + text_signature: None, + asyncness: None, + unsafety: None, + deprecations: Deprecations::default(), + }; + + crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec) +} + +fn complex_enum_variant_field_getter<'a>( + variant_cls_type: &'a syn::Type, + field_name: &'a syn::Ident, + field_type: &'a syn::Type, + field_span: Span, +) -> Result { + let signature = crate::pyfunction::FunctionSignature::from_arguments(vec![])?; + + let self_type = crate::method::SelfType::TryFromPyCell(field_span); + + let spec = FnSpec { + tp: crate::method::FnType::Getter(self_type.clone()), + name: field_name, + python_name: field_name.clone(), + signature, + output: field_type.clone(), + convention: crate::method::CallingConvention::Noargs, + text_signature: None, + asyncness: None, + unsafety: None, + deprecations: Deprecations::default(), + }; + + let property_type = crate::pymethod::PropertyType::Function { + self_type: &self_type, + spec: &spec, + doc: crate::get_doc(&[], None), + }; + + let getter = crate::pymethod::impl_py_getter_def(variant_cls_type, property_type)?; + Ok(getter) } fn descriptors_to_items( diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 8a97c712..d45d2e12 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -324,7 +324,8 @@ pub fn impl_py_method_def( }) } -fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result { +/// Also used by pyclass. +pub fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result { let wrapper_ident = syn::Ident::new("__pymethod___new____", Span::call_site()); let associated_method = spec.get_wrapper_function(&wrapper_ident, Some(cls))?; // Use just the text_signature_call_signature() because the class' Python name diff --git a/pytests/noxfile.py b/pytests/noxfile.py index 57d9d63a..7c681ab1 100644 --- a/pytests/noxfile.py +++ b/pytests/noxfile.py @@ -1,4 +1,5 @@ import nox +import sys from nox.command import CommandFailed nox.options.sessions = ["test"] @@ -13,7 +14,12 @@ def test(session: nox.Session): except CommandFailed: # No binary wheel for numpy available on this platform pass - session.run("pytest", *session.posargs) + ignored_paths = [] + if sys.version_info < (3, 10): + # Match syntax is only available in Python >= 3.10 + ignored_paths.append("tests/test_enums_match.py") + ignore_args = [f"--ignore={path}" for path in ignored_paths] + session.run("pytest", *ignore_args, *session.posargs) @nox.session diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs new file mode 100644 index 00000000..11b592d3 --- /dev/null +++ b/pytests/src/enums.rs @@ -0,0 +1,58 @@ +use pyo3::{pyclass, pyfunction, pymodule, types::PyModule, wrap_pyfunction, PyResult, Python}; + +#[pymodule] +pub fn enums(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_wrapped(wrap_pyfunction!(do_simple_stuff))?; + m.add_wrapped(wrap_pyfunction!(do_complex_stuff))?; + Ok(()) +} + +#[pyclass] +pub enum SimpleEnum { + Sunday, + Monday, + Tuesday, + Wednesday, + Thursday, + Friday, + Saturday, +} + +#[pyfunction] +pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum { + match thing { + SimpleEnum::Sunday => SimpleEnum::Monday, + SimpleEnum::Monday => SimpleEnum::Tuesday, + SimpleEnum::Tuesday => SimpleEnum::Wednesday, + SimpleEnum::Wednesday => SimpleEnum::Thursday, + SimpleEnum::Thursday => SimpleEnum::Friday, + SimpleEnum::Friday => SimpleEnum::Saturday, + SimpleEnum::Saturday => SimpleEnum::Sunday, + } +} + +#[pyclass] +pub enum ComplexEnum { + Int { i: i32 }, + Float { f: f64 }, + Str { s: String }, + EmptyStruct {}, + MultiFieldStruct { a: i32, b: f64, c: bool }, +} + +#[pyfunction] +pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum { + match thing { + ComplexEnum::Int { i } => ComplexEnum::Str { s: i.to_string() }, + ComplexEnum::Float { f } => ComplexEnum::Float { f: f * f }, + ComplexEnum::Str { s } => ComplexEnum::Int { i: s.len() as i32 }, + ComplexEnum::EmptyStruct {} => ComplexEnum::EmptyStruct {}, + ComplexEnum::MultiFieldStruct { a, b, c } => ComplexEnum::MultiFieldStruct { + a: *a, + b: *b, + c: *c, + }, + } +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index dbcd3ca4..e65385bf 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -7,6 +7,7 @@ pub mod buf_and_str; pub mod comparisons; pub mod datetime; pub mod dict_iter; +pub mod enums; pub mod misc; pub mod objstore; pub mod othermod; @@ -25,6 +26,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> { #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(datetime::datetime))?; m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?; + m.add_wrapped(wrap_pymodule!(enums::enums))?; m.add_wrapped(wrap_pymodule!(misc::misc))?; m.add_wrapped(wrap_pymodule!(objstore::objstore))?; m.add_wrapped(wrap_pymodule!(othermod::othermod))?; @@ -44,6 +46,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> { sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?; sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?; sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?; + sys_modules.set_item("pyo3_pytests.enums", m.getattr("enums")?)?; sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?; sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?; sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?; diff --git a/pytests/tests/test_enums.py b/pytests/tests/test_enums.py new file mode 100644 index 00000000..04b0cdca --- /dev/null +++ b/pytests/tests/test_enums.py @@ -0,0 +1,116 @@ +import pytest +from pyo3_pytests import enums + + +def test_complex_enum_variant_constructors(): + int_variant = enums.ComplexEnum.Int(42) + assert isinstance(int_variant, enums.ComplexEnum.Int) + + float_variant = enums.ComplexEnum.Float(3.14) + assert isinstance(float_variant, enums.ComplexEnum.Float) + + str_variant = enums.ComplexEnum.Str("hello") + assert isinstance(str_variant, enums.ComplexEnum.Str) + + empty_struct_variant = enums.ComplexEnum.EmptyStruct() + assert isinstance(empty_struct_variant, enums.ComplexEnum.EmptyStruct) + + multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True) + assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct) + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.Int(42), + enums.ComplexEnum.Float(3.14), + enums.ComplexEnum.Str("hello"), + enums.ComplexEnum.EmptyStruct(), + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum): + assert isinstance(variant, enums.ComplexEnum) + + +def test_complex_enum_field_getters(): + int_variant = enums.ComplexEnum.Int(42) + assert int_variant.i == 42 + + float_variant = enums.ComplexEnum.Float(3.14) + assert float_variant.f == 3.14 + + str_variant = enums.ComplexEnum.Str("hello") + assert str_variant.s == "hello" + + multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True) + assert multi_field_struct_variant.a == 42 + assert multi_field_struct_variant.b == 3.14 + assert multi_field_struct_variant.c is True + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.Int(42), + enums.ComplexEnum.Float(3.14), + enums.ComplexEnum.Str("hello"), + enums.ComplexEnum.EmptyStruct(), + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_desugared_match(variant: enums.ComplexEnum): + if isinstance(variant, enums.ComplexEnum.Int): + x = variant.i + assert x == 42 + elif isinstance(variant, enums.ComplexEnum.Float): + x = variant.f + assert x == 3.14 + elif isinstance(variant, enums.ComplexEnum.Str): + x = variant.s + assert x == "hello" + elif isinstance(variant, enums.ComplexEnum.EmptyStruct): + assert True + elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct): + x = variant.a + y = variant.b + z = variant.c + assert x == 42 + assert y == 3.14 + assert z is True + else: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.Int(42), + enums.ComplexEnum.Float(3.14), + enums.ComplexEnum.Str("hello"), + enums.ComplexEnum.EmptyStruct(), + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum): + variant = enums.do_complex_stuff(variant) + if isinstance(variant, enums.ComplexEnum.Int): + x = variant.i + assert x == 5 + elif isinstance(variant, enums.ComplexEnum.Float): + x = variant.f + assert x == 9.8596 + elif isinstance(variant, enums.ComplexEnum.Str): + x = variant.s + assert x == "42" + elif isinstance(variant, enums.ComplexEnum.EmptyStruct): + assert True + elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct): + x = variant.a + y = variant.b + z = variant.c + assert x == 42 + assert y == 3.14 + assert z is True + else: + assert False diff --git a/pytests/tests/test_enums_match.py b/pytests/tests/test_enums_match.py new file mode 100644 index 00000000..4d55bbbe --- /dev/null +++ b/pytests/tests/test_enums_match.py @@ -0,0 +1,59 @@ +# This file is only collected when Python >= 3.10, because it tests match syntax. +import pytest +from pyo3_pytests import enums + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.Int(42), + enums.ComplexEnum.Float(3.14), + enums.ComplexEnum.Str("hello"), + enums.ComplexEnum.EmptyStruct(), + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_match_statement(variant: enums.ComplexEnum): + match variant: + case enums.ComplexEnum.Int(i=x): + assert x == 42 + case enums.ComplexEnum.Float(f=x): + assert x == 3.14 + case enums.ComplexEnum.Str(s=x): + assert x == "hello" + case enums.ComplexEnum.EmptyStruct(): + assert True + case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z): + assert x == 42 + assert y == 3.14 + assert z is True + case _: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.Int(42), + enums.ComplexEnum.Float(3.14), + enums.ComplexEnum.Str("hello"), + enums.ComplexEnum.EmptyStruct(), + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum): + match enums.do_complex_stuff(variant): + case enums.ComplexEnum.Int(i=x): + assert x == 5 + case enums.ComplexEnum.Float(f=x): + assert x == 9.8596 + case enums.ComplexEnum.Str(s=x): + assert x == "42" + case enums.ComplexEnum.EmptyStruct(): + assert True + case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z): + assert x == 42 + assert y == 3.14 + assert z is True + case _: + assert False diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 0154f3f1..adcef887 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -14,6 +14,7 @@ fn test_compile_errors() { #[cfg(any(not(Py_LIMITED_API), Py_3_11))] t.compile_fail("tests/ui/invalid_pymethods_buffer.rs"); t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs"); + t.compile_fail("tests/ui/invalid_pymethod_enum.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/reject_generics.rs"); diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs index 4bc53238..95879c2f 100644 --- a/tests/ui/invalid_pyclass_enum.rs +++ b/tests/ui/invalid_pyclass_enum.rs @@ -15,4 +15,16 @@ enum NotDrivedClass { #[pyclass] enum NoEmptyEnum {} +#[pyclass] +enum NoUnitVariants { + StructVariant { field: i32 }, + UnitVariant, +} + +#[pyclass] +enum NoTupleVariants { + StructVariant { field: i32 }, + TupleVariant(i32), +} + fn main() {} diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr index 8f340a76..a03a0ae2 100644 --- a/tests/ui/invalid_pyclass_enum.stderr +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -15,3 +15,19 @@ error: #[pyclass] can't be used on enums without any variants | 16 | enum NoEmptyEnum {} | ^^ + +error: Unit variant `UnitVariant` is not yet supported in a complex enum + = help: change to a struct variant with no fields: `UnitVariant { }` + = note: the enum is complex because of non-unit variant `StructVariant` + --> tests/ui/invalid_pyclass_enum.rs:21:5 + | +21 | UnitVariant, + | ^^^^^^^^^^^ + +error: Tuple variant `TupleVariant` is not yet supported in a complex enum + = help: change to a struct variant with named fields: `TupleVariant { /* fields */ }` + = note: the enum is complex because of non-unit variant `StructVariant` + --> tests/ui/invalid_pyclass_enum.rs:27:5 + | +27 | TupleVariant(i32), + | ^^^^^^^^^^^^ diff --git a/tests/ui/invalid_pymethod_enum.rs b/tests/ui/invalid_pymethod_enum.rs new file mode 100644 index 00000000..9b596e08 --- /dev/null +++ b/tests/ui/invalid_pymethod_enum.rs @@ -0,0 +1,19 @@ +use pyo3::prelude::*; + +#[pyclass] +enum ComplexEnum { + Int { int: i32 }, + Str { string: String }, +} + +#[pymethods] +impl ComplexEnum { + fn mutate_in_place(&mut self) { + *self = match self { + ComplexEnum::Int { int } => ComplexEnum::Str { string: int.to_string() }, + ComplexEnum::Str { string } => ComplexEnum::Int { int: string.len() as i32 }, + } + } +} + +fn main() {} diff --git a/tests/ui/invalid_pymethod_enum.stderr b/tests/ui/invalid_pymethod_enum.stderr new file mode 100644 index 00000000..bb327dcc --- /dev/null +++ b/tests/ui/invalid_pymethod_enum.stderr @@ -0,0 +1,11 @@ +error[E0271]: type mismatch resolving `::Frozen == False` + --> tests/ui/invalid_pymethod_enum.rs:11:24 + | +11 | fn mutate_in_place(&mut self) { + | ^ expected `False`, found `True` + | +note: required by a bound in `extract_pyclass_ref_mut` + --> src/impl_/extract_argument.rs + | + | pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass>( + | ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut`