From 7beb64a8cac4873b151d87c8099c56c69c8f602e Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Thu, 9 May 2024 23:08:23 +0200 Subject: [PATCH] allow constructor customization of complex enum variants (#4158) * allow `#[pyo3(signature = ...)]` on complex enum variants to specify constructor signature * rename keyword to `constructor` * review feedback * add docs in guide * add newsfragment --- guide/pyclass-parameters.md | 2 + guide/src/class.md | 40 +++++++++++ newsfragments/4158.added.md | 1 + pyo3-macros-backend/src/attributes.rs | 1 + pyo3-macros-backend/src/pyclass.rs | 68 +++++++++++++------ pyo3-macros-backend/src/pyfunction.rs | 2 +- .../src/pyfunction/signature.rs | 10 +++ pytests/src/enums.rs | 27 ++++++-- pytests/tests/test_enums.py | 23 +++++++ tests/ui/invalid_pyclass_enum.rs | 7 ++ tests/ui/invalid_pyclass_enum.stderr | 6 ++ 11 files changed, 163 insertions(+), 24 deletions(-) create mode 100644 newsfragments/4158.added.md diff --git a/guide/pyclass-parameters.md b/guide/pyclass-parameters.md index 6951a5b5..9bd0534e 100644 --- a/guide/pyclass-parameters.md +++ b/guide/pyclass-parameters.md @@ -2,6 +2,7 @@ | Parameter | Description | | :- | :- | +| `constructor` | This is currently only allowed on [variants of complex enums][params-constructor]. It allows customization of the generated class constructor for each variant. It uses the same syntax and supports the same options as the `signature` attribute of functions and methods. | | `crate = "some::path"` | Path to import the `pyo3` crate, if it's not accessible at `::pyo3`. | | `dict` | Gives instances of this class an empty `__dict__` to store custom attributes. | | `extends = BaseType` | Use a custom baseclass. Defaults to [`PyAny`][params-1] | @@ -39,5 +40,6 @@ struct MyClass {} [params-4]: https://doc.rust-lang.org/std/rc/struct.Rc.html [params-5]: https://doc.rust-lang.org/std/sync/struct.Arc.html [params-6]: https://docs.python.org/3/library/weakref.html +[params-constructor]: https://pyo3.rs/latest/class.html#complex-enums [params-mapping]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types [params-sequence]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types diff --git a/guide/src/class.md b/guide/src/class.md index b5ef95cb..3fcfaca4 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -1243,6 +1243,46 @@ Python::with_gil(|py| { }) ``` +The constructor of each generated class can be customized using the `#[pyo3(constructor = (...))]` attribute. This uses the same syntax as the [`#[pyo3(signature = (...))]`](function/signature.md) +attribute on function and methods and supports the same options. To apply this attribute simply place it on top of a variant in a `#[pyclass]` complex enum as shown below: + +```rust +# use pyo3::prelude::*; +#[pyclass] +enum Shape { + #[pyo3(constructor = (radius=1.0))] + Circle { radius: f64 }, + #[pyo3(constructor = (*, width, height))] + Rectangle { width: f64, height: f64 }, + #[pyo3(constructor = (side_count, radius=1.0))] + RegularPolygon { side_count: u32, radius: f64 }, + Nothing { }, +} + +# #[cfg(Py_3_10)] +Python::with_gil(|py| { + let cls = py.get_type_bound::(); + pyo3::py_run!(py, cls, r#" + circle = cls.Circle() + assert isinstance(circle, cls) + assert isinstance(circle, cls.Circle) + assert circle.radius == 1.0 + + square = cls.Rectangle(width = 1, height = 1) + assert isinstance(square, cls) + assert isinstance(square, cls.Rectangle) + assert square.width == 1 + assert square.height == 1 + + hexagon = cls.RegularPolygon(6) + assert isinstance(hexagon, cls) + assert isinstance(hexagon, cls.RegularPolygon) + assert hexagon.side_count == 6 + assert hexagon.radius == 1 + "#) +}) +``` + ## Implementation details The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block. diff --git a/newsfragments/4158.added.md b/newsfragments/4158.added.md new file mode 100644 index 00000000..42e6d3ff --- /dev/null +++ b/newsfragments/4158.added.md @@ -0,0 +1 @@ +Added `#[pyo3(constructor = (...))]` to customize the generated constructors for complex enum variants diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index e91b3b8d..d9c805aa 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -12,6 +12,7 @@ pub mod kw { syn::custom_keyword!(annotation); syn::custom_keyword!(attribute); syn::custom_keyword!(cancel_handle); + syn::custom_keyword!(constructor); syn::custom_keyword!(dict); syn::custom_keyword!(extends); syn::custom_keyword!(freelist); diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index f8bfa164..3023f897 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -8,6 +8,7 @@ use crate::attributes::{ use crate::deprecations::Deprecations; use crate::konst::{ConstAttributes, ConstSpec}; use crate::method::{FnArg, FnSpec, PyArg, RegularArg}; +use crate::pyfunction::ConstructorAttribute; use crate::pyimpl::{gen_py_const, PyClassMethodsType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType, @@ -620,12 +621,15 @@ struct PyClassEnumVariantNamedField<'a> { } /// `#[pyo3()]` options for pyclass enum variants +#[derive(Default)] struct EnumVariantPyO3Options { name: Option, + constructor: Option, } enum EnumVariantPyO3Option { Name(NameAttribute), + Constructor(ConstructorAttribute), } impl Parse for EnumVariantPyO3Option { @@ -633,6 +637,8 @@ impl Parse for EnumVariantPyO3Option { let lookahead = input.lookahead1(); if lookahead.peek(attributes::kw::name) { input.parse().map(EnumVariantPyO3Option::Name) + } else if lookahead.peek(attributes::kw::constructor) { + input.parse().map(EnumVariantPyO3Option::Constructor) } else { Err(lookahead.error()) } @@ -641,22 +647,34 @@ impl Parse for EnumVariantPyO3Option { impl EnumVariantPyO3Options { fn take_pyo3_options(attrs: &mut Vec) -> Result { - let mut options = EnumVariantPyO3Options { name: None }; + let mut options = EnumVariantPyO3Options::default(); - for option in take_pyo3_options(attrs)? { - match option { - EnumVariantPyO3Option::Name(name) => { - ensure_spanned!( - options.name.is_none(), - name.span() => "`name` may only be specified once" - ); - options.name = Some(name); - } - } - } + take_pyo3_options(attrs)? + .into_iter() + .try_for_each(|option| options.set_option(option))?; Ok(options) } + + fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> { + macro_rules! set_option { + ($key:ident) => { + { + ensure_spanned!( + self.$key.is_none(), + $key.span() => concat!("`", stringify!($key), "` may only be specified once") + ); + self.$key = Some($key); + } + }; + } + + match option { + EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor), + EnumVariantPyO3Option::Name(name) => set_option!(name), + } + Ok(()) + } } fn impl_enum( @@ -689,6 +707,10 @@ fn impl_simple_enum( let variants = simple_enum.variants; let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx); + for variant in &variants { + ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant"); + } + let (default_repr, default_repr_slot) = { let variants_repr = variants.iter().map(|variant| { let variant_name = variant.ident; @@ -889,7 +911,7 @@ fn impl_complex_enum( let mut variant_cls_pytypeinfos = vec![]; let mut variant_cls_pyclass_impls = vec![]; let mut variant_cls_impls = vec![]; - for variant in &variants { + for variant in variants { let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident()); let variant_cls_zst = quote! { @@ -908,11 +930,11 @@ fn impl_complex_enum( let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx); variant_cls_pytypeinfos.push(variant_cls_pytypeinfo); - let variant_new = complex_enum_variant_new(cls, variant, ctx)?; - - let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?; + let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?; variant_cls_impls.push(variant_cls_impl); + let variant_new = complex_enum_variant_new(cls, variant, ctx)?; + let pyclass_impl = PyClassImplsBuilder::new( &variant_cls, &variant_args, @@ -1120,7 +1142,7 @@ pub fn gen_complex_enum_variant_attr( fn complex_enum_variant_new<'a>( cls: &'a syn::Ident, - variant: &'a PyClassEnumVariant<'a>, + variant: PyClassEnumVariant<'a>, ctx: &Ctx, ) -> Result { match variant { @@ -1132,7 +1154,7 @@ fn complex_enum_variant_new<'a>( fn complex_enum_struct_variant_new<'a>( cls: &'a syn::Ident, - variant: &'a PyClassEnumStructVariant<'a>, + variant: PyClassEnumStructVariant<'a>, ctx: &Ctx, ) -> Result { let Ctx { pyo3_path } = ctx; @@ -1162,7 +1184,15 @@ fn complex_enum_struct_variant_new<'a>( } args }; - let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?; + + let signature = if let Some(constructor) = variant.options.constructor { + crate::pyfunction::FunctionSignature::from_arguments_and_attribute( + args, + constructor.into_signature(), + )? + } else { + crate::pyfunction::FunctionSignature::from_arguments(args)? + }; let spec = FnSpec { tp: crate::method::FnType::FnNew, diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 7c355533..e259f0e2 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -18,7 +18,7 @@ use syn::{ mod signature; -pub use self::signature::{FunctionSignature, SignatureAttribute}; +pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute}; #[derive(Clone, Debug)] pub struct PyFunctionArgPyO3Attributes { diff --git a/pyo3-macros-backend/src/pyfunction/signature.rs b/pyo3-macros-backend/src/pyfunction/signature.rs index 3daa79c8..b73b96a3 100644 --- a/pyo3-macros-backend/src/pyfunction/signature.rs +++ b/pyo3-macros-backend/src/pyfunction/signature.rs @@ -195,6 +195,16 @@ impl ToTokens for SignatureItemPosargsSep { } pub type SignatureAttribute = KeywordAttribute; +pub type ConstructorAttribute = KeywordAttribute; + +impl ConstructorAttribute { + pub fn into_signature(self) -> SignatureAttribute { + SignatureAttribute { + kw: kw::signature(self.kw.span), + value: self.value, + } + } +} #[derive(Default)] pub struct PythonSignature { diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs index 0a1bc49b..68a5fc93 100644 --- a/pytests/src/enums.rs +++ b/pytests/src/enums.rs @@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum { #[pyclass] pub enum ComplexEnum { - Int { i: i32 }, - Float { f: f64 }, - Str { s: String }, + Int { + i: i32, + }, + Float { + f: f64, + }, + Str { + s: String, + }, EmptyStruct {}, - MultiFieldStruct { a: i32, b: f64, c: bool }, + MultiFieldStruct { + a: i32, + b: f64, + c: bool, + }, + #[pyo3(constructor = (a = 42, b = None))] + VariantWithDefault { + a: i32, + b: Option, + }, } #[pyfunction] @@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum { b: *b, c: *c, }, + ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault { + a: 2 * a, + b: b.as_ref().map(|s| s.to_uppercase()), + }, } } diff --git a/pytests/tests/test_enums.py b/pytests/tests/test_enums.py index 04b0cdca..cd1d7aed 100644 --- a/pytests/tests/test_enums.py +++ b/pytests/tests/test_enums.py @@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors(): multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True) assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct) + variant_with_default_1 = enums.ComplexEnum.VariantWithDefault() + assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault) + + variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello") + assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault) + @pytest.mark.parametrize( "variant", @@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors(): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(), ], ) def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum): @@ -48,6 +55,10 @@ def test_complex_enum_field_getters(): assert multi_field_struct_variant.b == 3.14 assert multi_field_struct_variant.c is True + variant_with_default = enums.ComplexEnum.VariantWithDefault() + assert variant_with_default.a == 42 + assert variant_with_default.b is None + @pytest.mark.parametrize( "variant", @@ -57,6 +68,7 @@ def test_complex_enum_field_getters(): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(), ], ) def test_complex_enum_desugared_match(variant: enums.ComplexEnum): @@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum): assert x == 42 assert y == 3.14 assert z is True + elif isinstance(variant, enums.ComplexEnum.VariantWithDefault): + x = variant.a + y = variant.b + assert x == 42 + assert y is None else: assert False @@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum): enums.ComplexEnum.Str("hello"), enums.ComplexEnum.EmptyStruct(), enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + enums.ComplexEnum.VariantWithDefault(b="hello"), ], ) def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum): @@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn assert x == 42 assert y == 3.14 assert z is True + elif isinstance(variant, enums.ComplexEnum.VariantWithDefault): + x = variant.a + y = variant.b + assert x == 84 + assert y == "HELLO" else: assert False diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs index 95879c2f..116b8968 100644 --- a/tests/ui/invalid_pyclass_enum.rs +++ b/tests/ui/invalid_pyclass_enum.rs @@ -27,4 +27,11 @@ enum NoTupleVariants { TupleVariant(i32), } +#[pyclass] +enum SimpleNoSignature { + #[pyo3(constructor = (a, b))] + A, + B, +} + fn main() {} diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr index a03a0ae2..e9ba9806 100644 --- a/tests/ui/invalid_pyclass_enum.stderr +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum | 27 | TupleVariant(i32), | ^^^^^^^^^^^^ + +error: `constructor` can't be used on a simple enum variant + --> tests/ui/invalid_pyclass_enum.rs:32:12 + | +32 | #[pyo3(constructor = (a, b))] + | ^^^^^^^^^^^