allow constructor customization of complex enum variants (#4158)
* allow `#[pyo3(signature = ...)]` on complex enum variants to specify constructor signature * rename keyword to `constructor` * review feedback * add docs in guide * add newsfragment
This commit is contained in:
parent
2d19b7e2a7
commit
7beb64a8ca
|
@ -2,6 +2,7 @@
|
|||
|
||||
| Parameter | Description |
|
||||
| :- | :- |
|
||||
| `constructor` | This is currently only allowed on [variants of complex enums][params-constructor]. It allows customization of the generated class constructor for each variant. It uses the same syntax and supports the same options as the `signature` attribute of functions and methods. |
|
||||
| <span style="white-space: pre">`crate = "some::path"`</span> | Path to import the `pyo3` crate, if it's not accessible at `::pyo3`. |
|
||||
| `dict` | Gives instances of this class an empty `__dict__` to store custom attributes. |
|
||||
| <span style="white-space: pre">`extends = BaseType`</span> | Use a custom baseclass. Defaults to [`PyAny`][params-1] |
|
||||
|
@ -39,5 +40,6 @@ struct MyClass {}
|
|||
[params-4]: https://doc.rust-lang.org/std/rc/struct.Rc.html
|
||||
[params-5]: https://doc.rust-lang.org/std/sync/struct.Arc.html
|
||||
[params-6]: https://docs.python.org/3/library/weakref.html
|
||||
[params-constructor]: https://pyo3.rs/latest/class.html#complex-enums
|
||||
[params-mapping]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
|
||||
[params-sequence]: https://pyo3.rs/latest/class/protocols.html#mapping--sequence-types
|
||||
|
|
|
@ -1243,6 +1243,46 @@ Python::with_gil(|py| {
|
|||
})
|
||||
```
|
||||
|
||||
The constructor of each generated class can be customized using the `#[pyo3(constructor = (...))]` attribute. This uses the same syntax as the [`#[pyo3(signature = (...))]`](function/signature.md)
|
||||
attribute on function and methods and supports the same options. To apply this attribute simply place it on top of a variant in a `#[pyclass]` complex enum as shown below:
|
||||
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
#[pyclass]
|
||||
enum Shape {
|
||||
#[pyo3(constructor = (radius=1.0))]
|
||||
Circle { radius: f64 },
|
||||
#[pyo3(constructor = (*, width, height))]
|
||||
Rectangle { width: f64, height: f64 },
|
||||
#[pyo3(constructor = (side_count, radius=1.0))]
|
||||
RegularPolygon { side_count: u32, radius: f64 },
|
||||
Nothing { },
|
||||
}
|
||||
|
||||
# #[cfg(Py_3_10)]
|
||||
Python::with_gil(|py| {
|
||||
let cls = py.get_type_bound::<Shape>();
|
||||
pyo3::py_run!(py, cls, r#"
|
||||
circle = cls.Circle()
|
||||
assert isinstance(circle, cls)
|
||||
assert isinstance(circle, cls.Circle)
|
||||
assert circle.radius == 1.0
|
||||
|
||||
square = cls.Rectangle(width = 1, height = 1)
|
||||
assert isinstance(square, cls)
|
||||
assert isinstance(square, cls.Rectangle)
|
||||
assert square.width == 1
|
||||
assert square.height == 1
|
||||
|
||||
hexagon = cls.RegularPolygon(6)
|
||||
assert isinstance(hexagon, cls)
|
||||
assert isinstance(hexagon, cls.RegularPolygon)
|
||||
assert hexagon.side_count == 6
|
||||
assert hexagon.radius == 1
|
||||
"#)
|
||||
})
|
||||
```
|
||||
|
||||
## Implementation details
|
||||
|
||||
The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Added `#[pyo3(constructor = (...))]` to customize the generated constructors for complex enum variants
|
|
@ -12,6 +12,7 @@ pub mod kw {
|
|||
syn::custom_keyword!(annotation);
|
||||
syn::custom_keyword!(attribute);
|
||||
syn::custom_keyword!(cancel_handle);
|
||||
syn::custom_keyword!(constructor);
|
||||
syn::custom_keyword!(dict);
|
||||
syn::custom_keyword!(extends);
|
||||
syn::custom_keyword!(freelist);
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::attributes::{
|
|||
use crate::deprecations::Deprecations;
|
||||
use crate::konst::{ConstAttributes, ConstSpec};
|
||||
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
|
||||
use crate::pyfunction::ConstructorAttribute;
|
||||
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
|
||||
use crate::pymethod::{
|
||||
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
|
||||
|
@ -620,12 +621,15 @@ struct PyClassEnumVariantNamedField<'a> {
|
|||
}
|
||||
|
||||
/// `#[pyo3()]` options for pyclass enum variants
|
||||
#[derive(Default)]
|
||||
struct EnumVariantPyO3Options {
|
||||
name: Option<NameAttribute>,
|
||||
constructor: Option<ConstructorAttribute>,
|
||||
}
|
||||
|
||||
enum EnumVariantPyO3Option {
|
||||
Name(NameAttribute),
|
||||
Constructor(ConstructorAttribute),
|
||||
}
|
||||
|
||||
impl Parse for EnumVariantPyO3Option {
|
||||
|
@ -633,6 +637,8 @@ impl Parse for EnumVariantPyO3Option {
|
|||
let lookahead = input.lookahead1();
|
||||
if lookahead.peek(attributes::kw::name) {
|
||||
input.parse().map(EnumVariantPyO3Option::Name)
|
||||
} else if lookahead.peek(attributes::kw::constructor) {
|
||||
input.parse().map(EnumVariantPyO3Option::Constructor)
|
||||
} else {
|
||||
Err(lookahead.error())
|
||||
}
|
||||
|
@ -641,22 +647,34 @@ impl Parse for EnumVariantPyO3Option {
|
|||
|
||||
impl EnumVariantPyO3Options {
|
||||
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
|
||||
let mut options = EnumVariantPyO3Options { name: None };
|
||||
let mut options = EnumVariantPyO3Options::default();
|
||||
|
||||
for option in take_pyo3_options(attrs)? {
|
||||
match option {
|
||||
EnumVariantPyO3Option::Name(name) => {
|
||||
ensure_spanned!(
|
||||
options.name.is_none(),
|
||||
name.span() => "`name` may only be specified once"
|
||||
);
|
||||
options.name = Some(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
take_pyo3_options(attrs)?
|
||||
.into_iter()
|
||||
.try_for_each(|option| options.set_option(option))?;
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> {
|
||||
macro_rules! set_option {
|
||||
($key:ident) => {
|
||||
{
|
||||
ensure_spanned!(
|
||||
self.$key.is_none(),
|
||||
$key.span() => concat!("`", stringify!($key), "` may only be specified once")
|
||||
);
|
||||
self.$key = Some($key);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match option {
|
||||
EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor),
|
||||
EnumVariantPyO3Option::Name(name) => set_option!(name),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_enum(
|
||||
|
@ -689,6 +707,10 @@ fn impl_simple_enum(
|
|||
let variants = simple_enum.variants;
|
||||
let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx);
|
||||
|
||||
for variant in &variants {
|
||||
ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant");
|
||||
}
|
||||
|
||||
let (default_repr, default_repr_slot) = {
|
||||
let variants_repr = variants.iter().map(|variant| {
|
||||
let variant_name = variant.ident;
|
||||
|
@ -889,7 +911,7 @@ fn impl_complex_enum(
|
|||
let mut variant_cls_pytypeinfos = vec![];
|
||||
let mut variant_cls_pyclass_impls = vec![];
|
||||
let mut variant_cls_impls = vec![];
|
||||
for variant in &variants {
|
||||
for variant in variants {
|
||||
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
|
||||
|
||||
let variant_cls_zst = quote! {
|
||||
|
@ -908,11 +930,11 @@ fn impl_complex_enum(
|
|||
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
|
||||
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
|
||||
|
||||
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
|
||||
|
||||
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
|
||||
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
|
||||
variant_cls_impls.push(variant_cls_impl);
|
||||
|
||||
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
|
||||
|
||||
let pyclass_impl = PyClassImplsBuilder::new(
|
||||
&variant_cls,
|
||||
&variant_args,
|
||||
|
@ -1120,7 +1142,7 @@ pub fn gen_complex_enum_variant_attr(
|
|||
|
||||
fn complex_enum_variant_new<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
variant: &'a PyClassEnumVariant<'a>,
|
||||
variant: PyClassEnumVariant<'a>,
|
||||
ctx: &Ctx,
|
||||
) -> Result<MethodAndSlotDef> {
|
||||
match variant {
|
||||
|
@ -1132,7 +1154,7 @@ fn complex_enum_variant_new<'a>(
|
|||
|
||||
fn complex_enum_struct_variant_new<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
variant: &'a PyClassEnumStructVariant<'a>,
|
||||
variant: PyClassEnumStructVariant<'a>,
|
||||
ctx: &Ctx,
|
||||
) -> Result<MethodAndSlotDef> {
|
||||
let Ctx { pyo3_path } = ctx;
|
||||
|
@ -1162,7 +1184,15 @@ fn complex_enum_struct_variant_new<'a>(
|
|||
}
|
||||
args
|
||||
};
|
||||
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
|
||||
|
||||
let signature = if let Some(constructor) = variant.options.constructor {
|
||||
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
|
||||
args,
|
||||
constructor.into_signature(),
|
||||
)?
|
||||
} else {
|
||||
crate::pyfunction::FunctionSignature::from_arguments(args)?
|
||||
};
|
||||
|
||||
let spec = FnSpec {
|
||||
tp: crate::method::FnType::FnNew,
|
||||
|
|
|
@ -18,7 +18,7 @@ use syn::{
|
|||
|
||||
mod signature;
|
||||
|
||||
pub use self::signature::{FunctionSignature, SignatureAttribute};
|
||||
pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PyFunctionArgPyO3Attributes {
|
||||
|
|
|
@ -195,6 +195,16 @@ impl ToTokens for SignatureItemPosargsSep {
|
|||
}
|
||||
|
||||
pub type SignatureAttribute = KeywordAttribute<kw::signature, Signature>;
|
||||
pub type ConstructorAttribute = KeywordAttribute<kw::constructor, Signature>;
|
||||
|
||||
impl ConstructorAttribute {
|
||||
pub fn into_signature(self) -> SignatureAttribute {
|
||||
SignatureAttribute {
|
||||
kw: kw::signature(self.kw.span),
|
||||
value: self.value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct PythonSignature {
|
||||
|
|
|
@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
|
|||
|
||||
#[pyclass]
|
||||
pub enum ComplexEnum {
|
||||
Int { i: i32 },
|
||||
Float { f: f64 },
|
||||
Str { s: String },
|
||||
Int {
|
||||
i: i32,
|
||||
},
|
||||
Float {
|
||||
f: f64,
|
||||
},
|
||||
Str {
|
||||
s: String,
|
||||
},
|
||||
EmptyStruct {},
|
||||
MultiFieldStruct { a: i32, b: f64, c: bool },
|
||||
MultiFieldStruct {
|
||||
a: i32,
|
||||
b: f64,
|
||||
c: bool,
|
||||
},
|
||||
#[pyo3(constructor = (a = 42, b = None))]
|
||||
VariantWithDefault {
|
||||
a: i32,
|
||||
b: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
|
@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
|
|||
b: *b,
|
||||
c: *c,
|
||||
},
|
||||
ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
|
||||
a: 2 * a,
|
||||
b: b.as_ref().map(|s| s.to_uppercase()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors():
|
|||
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
|
||||
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
|
||||
|
||||
variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
|
||||
assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)
|
||||
|
||||
variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
|
||||
assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
|
@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors():
|
|||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
enums.ComplexEnum.VariantWithDefault(),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
|
||||
|
@ -48,6 +55,10 @@ def test_complex_enum_field_getters():
|
|||
assert multi_field_struct_variant.b == 3.14
|
||||
assert multi_field_struct_variant.c is True
|
||||
|
||||
variant_with_default = enums.ComplexEnum.VariantWithDefault()
|
||||
assert variant_with_default.a == 42
|
||||
assert variant_with_default.b is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
|
@ -57,6 +68,7 @@ def test_complex_enum_field_getters():
|
|||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
enums.ComplexEnum.VariantWithDefault(),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
|
||||
|
@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
|
|||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
|
||||
x = variant.a
|
||||
y = variant.b
|
||||
assert x == 42
|
||||
assert y is None
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
|
|||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
enums.ComplexEnum.VariantWithDefault(b="hello"),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
|
||||
|
@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
|
|||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
|
||||
x = variant.a
|
||||
y = variant.b
|
||||
assert x == 84
|
||||
assert y == "HELLO"
|
||||
else:
|
||||
assert False
|
||||
|
|
|
@ -27,4 +27,11 @@ enum NoTupleVariants {
|
|||
TupleVariant(i32),
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
enum SimpleNoSignature {
|
||||
#[pyo3(constructor = (a, b))]
|
||||
A,
|
||||
B,
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
|||
|
|
||||
27 | TupleVariant(i32),
|
||||
| ^^^^^^^^^^^^
|
||||
|
||||
error: `constructor` can't be used on a simple enum variant
|
||||
--> tests/ui/invalid_pyclass_enum.rs:32:12
|
||||
|
|
||||
32 | #[pyo3(constructor = (a, b))]
|
||||
| ^^^^^^^^^^^
|
||||
|
|
Loading…
Reference in New Issue