From 7beb64a8cac4873b151d87c8099c56c69c8f602e Mon Sep 17 00:00:00 2001
From: Icxolu <10486322+Icxolu@users.noreply.github.com>
Date: Thu, 9 May 2024 23:08:23 +0200
Subject: [PATCH] allow constructor customization of complex enum variants
(#4158)
* allow `#[pyo3(signature = ...)]` on complex enum variants to specify constructor signature
* rename keyword to `constructor`
* review feedback
* add docs in guide
* add newsfragment
---
guide/pyclass-parameters.md | 2 +
guide/src/class.md | 40 +++++++++++
newsfragments/4158.added.md | 1 +
pyo3-macros-backend/src/attributes.rs | 1 +
pyo3-macros-backend/src/pyclass.rs | 68 +++++++++++++------
pyo3-macros-backend/src/pyfunction.rs | 2 +-
.../src/pyfunction/signature.rs | 10 +++
pytests/src/enums.rs | 27 ++++++--
pytests/tests/test_enums.py | 23 +++++++
tests/ui/invalid_pyclass_enum.rs | 7 ++
tests/ui/invalid_pyclass_enum.stderr | 6 ++
11 files changed, 163 insertions(+), 24 deletions(-)
create mode 100644 newsfragments/4158.added.md
diff --git a/guide/pyclass-parameters.md b/guide/pyclass-parameters.md
index 6951a5b5..9bd0534e 100644
--- a/guide/pyclass-parameters.md
+++ b/guide/pyclass-parameters.md
@@ -2,6 +2,7 @@
| Parameter | Description |
| :- | :- |
+| `constructor` | This is currently only allowed on [variants of complex enums][params-constructor]. It allows customization of the generated class constructor for each variant. It uses the same syntax and supports the same options as the `signature` attribute of functions and methods. |
| `crate = "some::path"` | Path to import the `pyo3` crate, if it's not accessible at `::pyo3`. |
| `dict` | Gives instances of this class an empty `__dict__` to store custom attributes. |
| `extends = BaseType` | Use a custom baseclass. Defaults to [`PyAny`][params-1] |
@@ -39,5 +40,6 @@ struct MyClass {}
[params-4]: https://doc.rust-lang.org/std/rc/struct.Rc.html
[params-5]: https://doc.rust-lang.org/std/sync/struct.Arc.html
[params-6]: https://docs.python.org/3/library/weakref.html
+[params-constructor]: https://pyo3.rs/latest/class.html#complex-enums
[params-mapping]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
[params-sequence]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
diff --git a/guide/src/class.md b/guide/src/class.md
index b5ef95cb..3fcfaca4 100644
--- a/guide/src/class.md
+++ b/guide/src/class.md
@@ -1243,6 +1243,46 @@ Python::with_gil(|py| {
})
```
+The constructor of each generated class can be customized using the `#[pyo3(constructor = (...))]` attribute. This uses the same syntax as the [`#[pyo3(signature = (...))]`](function/signature.md)
+attribute on function and methods and supports the same options. To apply this attribute simply place it on top of a variant in a `#[pyclass]` complex enum as shown below:
+
+```rust
+# use pyo3::prelude::*;
+#[pyclass]
+enum Shape {
+ #[pyo3(constructor = (radius=1.0))]
+ Circle { radius: f64 },
+ #[pyo3(constructor = (*, width, height))]
+ Rectangle { width: f64, height: f64 },
+ #[pyo3(constructor = (side_count, radius=1.0))]
+ RegularPolygon { side_count: u32, radius: f64 },
+ Nothing { },
+}
+
+# #[cfg(Py_3_10)]
+Python::with_gil(|py| {
+ let cls = py.get_type_bound::();
+ pyo3::py_run!(py, cls, r#"
+ circle = cls.Circle()
+ assert isinstance(circle, cls)
+ assert isinstance(circle, cls.Circle)
+ assert circle.radius == 1.0
+
+ square = cls.Rectangle(width = 1, height = 1)
+ assert isinstance(square, cls)
+ assert isinstance(square, cls.Rectangle)
+ assert square.width == 1
+ assert square.height == 1
+
+ hexagon = cls.RegularPolygon(6)
+ assert isinstance(hexagon, cls)
+ assert isinstance(hexagon, cls.RegularPolygon)
+ assert hexagon.side_count == 6
+ assert hexagon.radius == 1
+ "#)
+})
+```
+
## Implementation details
The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.
diff --git a/newsfragments/4158.added.md b/newsfragments/4158.added.md
new file mode 100644
index 00000000..42e6d3ff
--- /dev/null
+++ b/newsfragments/4158.added.md
@@ -0,0 +1 @@
+Added `#[pyo3(constructor = (...))]` to customize the generated constructors for complex enum variants
diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs
index e91b3b8d..d9c805aa 100644
--- a/pyo3-macros-backend/src/attributes.rs
+++ b/pyo3-macros-backend/src/attributes.rs
@@ -12,6 +12,7 @@ pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
+ syn::custom_keyword!(constructor);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);
diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs
index f8bfa164..3023f897 100644
--- a/pyo3-macros-backend/src/pyclass.rs
+++ b/pyo3-macros-backend/src/pyclass.rs
@@ -8,6 +8,7 @@ use crate::attributes::{
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
+use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
@@ -620,12 +621,15 @@ struct PyClassEnumVariantNamedField<'a> {
}
/// `#[pyo3()]` options for pyclass enum variants
+#[derive(Default)]
struct EnumVariantPyO3Options {
name: Option,
+ constructor: Option,
}
enum EnumVariantPyO3Option {
Name(NameAttribute),
+ Constructor(ConstructorAttribute),
}
impl Parse for EnumVariantPyO3Option {
@@ -633,6 +637,8 @@ impl Parse for EnumVariantPyO3Option {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
+ } else if lookahead.peek(attributes::kw::constructor) {
+ input.parse().map(EnumVariantPyO3Option::Constructor)
} else {
Err(lookahead.error())
}
@@ -641,22 +647,34 @@ impl Parse for EnumVariantPyO3Option {
impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec) -> Result {
- let mut options = EnumVariantPyO3Options { name: None };
+ let mut options = EnumVariantPyO3Options::default();
- for option in take_pyo3_options(attrs)? {
- match option {
- EnumVariantPyO3Option::Name(name) => {
- ensure_spanned!(
- options.name.is_none(),
- name.span() => "`name` may only be specified once"
- );
- options.name = Some(name);
- }
- }
- }
+ take_pyo3_options(attrs)?
+ .into_iter()
+ .try_for_each(|option| options.set_option(option))?;
Ok(options)
}
+
+ fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> {
+ macro_rules! set_option {
+ ($key:ident) => {
+ {
+ ensure_spanned!(
+ self.$key.is_none(),
+ $key.span() => concat!("`", stringify!($key), "` may only be specified once")
+ );
+ self.$key = Some($key);
+ }
+ };
+ }
+
+ match option {
+ EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor),
+ EnumVariantPyO3Option::Name(name) => set_option!(name),
+ }
+ Ok(())
+ }
}
fn impl_enum(
@@ -689,6 +707,10 @@ fn impl_simple_enum(
let variants = simple_enum.variants;
let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx);
+ for variant in &variants {
+ ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant");
+ }
+
let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
@@ -889,7 +911,7 @@ fn impl_complex_enum(
let mut variant_cls_pytypeinfos = vec![];
let mut variant_cls_pyclass_impls = vec![];
let mut variant_cls_impls = vec![];
- for variant in &variants {
+ for variant in variants {
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
let variant_cls_zst = quote! {
@@ -908,11 +930,11 @@ fn impl_complex_enum(
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
- let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
-
- let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
+ let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
variant_cls_impls.push(variant_cls_impl);
+ let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
+
let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
@@ -1120,7 +1142,7 @@ pub fn gen_complex_enum_variant_attr(
fn complex_enum_variant_new<'a>(
cls: &'a syn::Ident,
- variant: &'a PyClassEnumVariant<'a>,
+ variant: PyClassEnumVariant<'a>,
ctx: &Ctx,
) -> Result {
match variant {
@@ -1132,7 +1154,7 @@ fn complex_enum_variant_new<'a>(
fn complex_enum_struct_variant_new<'a>(
cls: &'a syn::Ident,
- variant: &'a PyClassEnumStructVariant<'a>,
+ variant: PyClassEnumStructVariant<'a>,
ctx: &Ctx,
) -> Result {
let Ctx { pyo3_path } = ctx;
@@ -1162,7 +1184,15 @@ fn complex_enum_struct_variant_new<'a>(
}
args
};
- let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
+
+ let signature = if let Some(constructor) = variant.options.constructor {
+ crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
+ args,
+ constructor.into_signature(),
+ )?
+ } else {
+ crate::pyfunction::FunctionSignature::from_arguments(args)?
+ };
let spec = FnSpec {
tp: crate::method::FnType::FnNew,
diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs
index 7c355533..e259f0e2 100644
--- a/pyo3-macros-backend/src/pyfunction.rs
+++ b/pyo3-macros-backend/src/pyfunction.rs
@@ -18,7 +18,7 @@ use syn::{
mod signature;
-pub use self::signature::{FunctionSignature, SignatureAttribute};
+pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
diff --git a/pyo3-macros-backend/src/pyfunction/signature.rs b/pyo3-macros-backend/src/pyfunction/signature.rs
index 3daa79c8..b73b96a3 100644
--- a/pyo3-macros-backend/src/pyfunction/signature.rs
+++ b/pyo3-macros-backend/src/pyfunction/signature.rs
@@ -195,6 +195,16 @@ impl ToTokens for SignatureItemPosargsSep {
}
pub type SignatureAttribute = KeywordAttribute;
+pub type ConstructorAttribute = KeywordAttribute;
+
+impl ConstructorAttribute {
+ pub fn into_signature(self) -> SignatureAttribute {
+ SignatureAttribute {
+ kw: kw::signature(self.kw.span),
+ value: self.value,
+ }
+ }
+}
#[derive(Default)]
pub struct PythonSignature {
diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs
index 0a1bc49b..68a5fc93 100644
--- a/pytests/src/enums.rs
+++ b/pytests/src/enums.rs
@@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
#[pyclass]
pub enum ComplexEnum {
- Int { i: i32 },
- Float { f: f64 },
- Str { s: String },
+ Int {
+ i: i32,
+ },
+ Float {
+ f: f64,
+ },
+ Str {
+ s: String,
+ },
EmptyStruct {},
- MultiFieldStruct { a: i32, b: f64, c: bool },
+ MultiFieldStruct {
+ a: i32,
+ b: f64,
+ c: bool,
+ },
+ #[pyo3(constructor = (a = 42, b = None))]
+ VariantWithDefault {
+ a: i32,
+ b: Option,
+ },
}
#[pyfunction]
@@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
b: *b,
c: *c,
},
+ ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
+ a: 2 * a,
+ b: b.as_ref().map(|s| s.to_uppercase()),
+ },
}
}
diff --git a/pytests/tests/test_enums.py b/pytests/tests/test_enums.py
index 04b0cdca..cd1d7aed 100644
--- a/pytests/tests/test_enums.py
+++ b/pytests/tests/test_enums.py
@@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors():
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
+ variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
+ assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)
+
+ variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
+ assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)
+
@pytest.mark.parametrize(
"variant",
@@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
+ enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
@@ -48,6 +55,10 @@ def test_complex_enum_field_getters():
assert multi_field_struct_variant.b == 3.14
assert multi_field_struct_variant.c is True
+ variant_with_default = enums.ComplexEnum.VariantWithDefault()
+ assert variant_with_default.a == 42
+ assert variant_with_default.b is None
+
@pytest.mark.parametrize(
"variant",
@@ -57,6 +68,7 @@ def test_complex_enum_field_getters():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
+ enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
@@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
assert x == 42
assert y == 3.14
assert z is True
+ elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
+ x = variant.a
+ y = variant.b
+ assert x == 42
+ assert y is None
else:
assert False
@@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
+ enums.ComplexEnum.VariantWithDefault(b="hello"),
],
)
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
@@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
assert x == 42
assert y == 3.14
assert z is True
+ elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
+ x = variant.a
+ y = variant.b
+ assert x == 84
+ assert y == "HELLO"
else:
assert False
diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs
index 95879c2f..116b8968 100644
--- a/tests/ui/invalid_pyclass_enum.rs
+++ b/tests/ui/invalid_pyclass_enum.rs
@@ -27,4 +27,11 @@ enum NoTupleVariants {
TupleVariant(i32),
}
+#[pyclass]
+enum SimpleNoSignature {
+ #[pyo3(constructor = (a, b))]
+ A,
+ B,
+}
+
fn main() {}
diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr
index a03a0ae2..e9ba9806 100644
--- a/tests/ui/invalid_pyclass_enum.stderr
+++ b/tests/ui/invalid_pyclass_enum.stderr
@@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^
+
+error: `constructor` can't be used on a simple enum variant
+ --> tests/ui/invalid_pyclass_enum.rs:32:12
+ |
+32 | #[pyo3(constructor = (a, b))]
+ | ^^^^^^^^^^^