allow constructor customization of complex enum variants (#4158)

* allow `#[pyo3(signature = ...)]` on complex enum variants to specify constructor signature

* rename keyword to `constructor`

* review feedback

* add docs in guide

* add newsfragment
This commit is contained in:
Icxolu 2024-05-09 23:08:23 +02:00 committed by GitHub
parent 2d19b7e2a7
commit 7beb64a8ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 163 additions and 24 deletions

View File

@ -2,6 +2,7 @@
| Parameter | Description |
| :- | :- |
| `constructor` | This is currently only allowed on [variants of complex enums][params-constructor]. It allows customization of the generated class constructor for each variant. It uses the same syntax and supports the same options as the `signature` attribute of functions and methods. |
| <span style="white-space: pre">`crate = "some::path"`</span> | Path to import the `pyo3` crate, if it's not accessible at `::pyo3`. |
| `dict` | Gives instances of this class an empty `__dict__` to store custom attributes. |
| <span style="white-space: pre">`extends = BaseType`</span> | Use a custom baseclass. Defaults to [`PyAny`][params-1] |
@ -39,5 +40,6 @@ struct MyClass {}
[params-4]: https://doc.rust-lang.org/std/rc/struct.Rc.html
[params-5]: https://doc.rust-lang.org/std/sync/struct.Arc.html
[params-6]: https://docs.python.org/3/library/weakref.html
[params-constructor]: https://pyo3.rs/latest/class.html#complex-enums
[params-mapping]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
[params-sequence]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types

View File

@ -1243,6 +1243,46 @@ Python::with_gil(|py| {
})
```
The constructor of each generated class can be customized using the `#[pyo3(constructor = (...))]` attribute. This uses the same syntax as the [`#[pyo3(signature = (...))]`](function/signature.md)
attribute on function and methods and supports the same options. To apply this attribute simply place it on top of a variant in a `#[pyclass]` complex enum as shown below:
```rust
# use pyo3::prelude::*;
#[pyclass]
enum Shape {
#[pyo3(constructor = (radius=1.0))]
Circle { radius: f64 },
#[pyo3(constructor = (*, width, height))]
Rectangle { width: f64, height: f64 },
#[pyo3(constructor = (side_count, radius=1.0))]
RegularPolygon { side_count: u32, radius: f64 },
Nothing { },
}
# #[cfg(Py_3_10)]
Python::with_gil(|py| {
let cls = py.get_type_bound::<Shape>();
pyo3::py_run!(py, cls, r#"
circle = cls.Circle()
assert isinstance(circle, cls)
assert isinstance(circle, cls.Circle)
assert circle.radius == 1.0
square = cls.Rectangle(width = 1, height = 1)
assert isinstance(square, cls)
assert isinstance(square, cls.Rectangle)
assert square.width == 1
assert square.height == 1
hexagon = cls.RegularPolygon(6)
assert isinstance(hexagon, cls)
assert isinstance(hexagon, cls.RegularPolygon)
assert hexagon.side_count == 6
assert hexagon.radius == 1
"#)
})
```
## Implementation details
The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.

View File

@ -0,0 +1 @@
Added `#[pyo3(constructor = (...))]` to customize the generated constructors for complex enum variants

View File

@ -12,6 +12,7 @@ pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(constructor);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);

View File

@ -8,6 +8,7 @@ use crate::attributes::{
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
@ -620,12 +621,15 @@ struct PyClassEnumVariantNamedField<'a> {
}
/// `#[pyo3()]` options for pyclass enum variants
#[derive(Default)]
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
constructor: Option<ConstructorAttribute>,
}
enum EnumVariantPyO3Option {
Name(NameAttribute),
Constructor(ConstructorAttribute),
}
impl Parse for EnumVariantPyO3Option {
@ -633,6 +637,8 @@ impl Parse for EnumVariantPyO3Option {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
} else if lookahead.peek(attributes::kw::constructor) {
input.parse().map(EnumVariantPyO3Option::Constructor)
} else {
Err(lookahead.error())
}
@ -641,22 +647,34 @@ impl Parse for EnumVariantPyO3Option {
impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = EnumVariantPyO3Options { name: None };
let mut options = EnumVariantPyO3Options::default();
for option in take_pyo3_options(attrs)? {
match option {
EnumVariantPyO3Option::Name(name) => {
ensure_spanned!(
options.name.is_none(),
name.span() => "`name` may only be specified once"
);
options.name = Some(name);
}
}
}
take_pyo3_options(attrs)?
.into_iter()
.try_for_each(|option| options.set_option(option))?;
Ok(options)
}
fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> {
macro_rules! set_option {
($key:ident) => {
{
ensure_spanned!(
self.$key.is_none(),
$key.span() => concat!("`", stringify!($key), "` may only be specified once")
);
self.$key = Some($key);
}
};
}
match option {
EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor),
EnumVariantPyO3Option::Name(name) => set_option!(name),
}
Ok(())
}
}
fn impl_enum(
@ -689,6 +707,10 @@ fn impl_simple_enum(
let variants = simple_enum.variants;
let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx);
for variant in &variants {
ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant");
}
let (default_repr, default_repr_slot) = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
@ -889,7 +911,7 @@ fn impl_complex_enum(
let mut variant_cls_pytypeinfos = vec![];
let mut variant_cls_pyclass_impls = vec![];
let mut variant_cls_impls = vec![];
for variant in &variants {
for variant in variants {
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
let variant_cls_zst = quote! {
@ -908,11 +930,11 @@ fn impl_complex_enum(
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
variant_cls_impls.push(variant_cls_impl);
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
@ -1120,7 +1142,7 @@ pub fn gen_complex_enum_variant_attr(
fn complex_enum_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumVariant<'a>,
variant: PyClassEnumVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
match variant {
@ -1132,7 +1154,7 @@ fn complex_enum_variant_new<'a>(
fn complex_enum_struct_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumStructVariant<'a>,
variant: PyClassEnumStructVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
let Ctx { pyo3_path } = ctx;
@ -1162,7 +1184,15 @@ fn complex_enum_struct_variant_new<'a>(
}
args
};
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
let signature = if let Some(constructor) = variant.options.constructor {
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
args,
constructor.into_signature(),
)?
} else {
crate::pyfunction::FunctionSignature::from_arguments(args)?
};
let spec = FnSpec {
tp: crate::method::FnType::FnNew,

View File

@ -18,7 +18,7 @@ use syn::{
mod signature;
pub use self::signature::{FunctionSignature, SignatureAttribute};
pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {

View File

@ -195,6 +195,16 @@ impl ToTokens for SignatureItemPosargsSep {
}
pub type SignatureAttribute = KeywordAttribute<kw::signature, Signature>;
pub type ConstructorAttribute = KeywordAttribute<kw::constructor, Signature>;
impl ConstructorAttribute {
pub fn into_signature(self) -> SignatureAttribute {
SignatureAttribute {
kw: kw::signature(self.kw.span),
value: self.value,
}
}
}
#[derive(Default)]
pub struct PythonSignature {

View File

@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
#[pyclass]
pub enum ComplexEnum {
Int { i: i32 },
Float { f: f64 },
Str { s: String },
Int {
i: i32,
},
Float {
f: f64,
},
Str {
s: String,
},
EmptyStruct {},
MultiFieldStruct { a: i32, b: f64, c: bool },
MultiFieldStruct {
a: i32,
b: f64,
c: bool,
},
#[pyo3(constructor = (a = 42, b = None))]
VariantWithDefault {
a: i32,
b: Option<String>,
},
}
#[pyfunction]
@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
b: *b,
c: *c,
},
ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
a: 2 * a,
b: b.as_ref().map(|s| s.to_uppercase()),
},
}
}

View File

@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors():
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)
variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)
@pytest.mark.parametrize(
"variant",
@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
@ -48,6 +55,10 @@ def test_complex_enum_field_getters():
assert multi_field_struct_variant.b == 3.14
assert multi_field_struct_variant.c is True
variant_with_default = enums.ComplexEnum.VariantWithDefault()
assert variant_with_default.a == 42
assert variant_with_default.b is None
@pytest.mark.parametrize(
"variant",
@ -57,6 +68,7 @@ def test_complex_enum_field_getters():
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(),
],
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 42
assert y is None
else:
assert False
@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
enums.ComplexEnum.VariantWithDefault(b="hello"),
],
)
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
assert x == 42
assert y == 3.14
assert z is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 84
assert y == "HELLO"
else:
assert False

View File

@ -27,4 +27,11 @@ enum NoTupleVariants {
TupleVariant(i32),
}
#[pyclass]
enum SimpleNoSignature {
#[pyo3(constructor = (a, b))]
A,
B,
}
fn main() {}

View File

@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^
error: `constructor` can't be used on a simple enum variant
--> tests/ui/invalid_pyclass_enum.rs:32:12
|
32 | #[pyo3(constructor = (a, b))]
| ^^^^^^^^^^^