feat: support pyclass on tuple enums (#4072)

* feat: support pyclass on tuple enums

* cargo fmt

* changelog

* ruff format

* rebase with adaptation for FnArg refactor

* fix class.md from pr comments

* add enum tuple variant getitem implementation

* fmt

* progress toward getitem and len impl on derive pyclass for complex enum tuple

* working getitem and len slots for complex tuple enum pyclass derivation

* refactor code generation

* address PR concerns
- take py from function argument on get_item
- make more general slot def implementation
- remove unnecessary function arguments
- add testcases for uncovered cases including future feature match_args

* add tracking issue

* fmt

* ruff

* remove me

* support match_args for tuple enum

* integrate FnArg now takes Cow

* fix empty and single element tuples

* use impl_py_slot_def for cimplex tuple enum slots

* reverse erroneous doc change

* Address latest comments

* formatting suggestion

* fix :
- clippy beta
- better compile error (+related doc and test)

---------

Co-authored-by: Chris Arderne <chris@translucent.app>
This commit is contained in:
newcomertv 2024-05-17 04:59:00 +02:00 committed by GitHub
parent 8de1787580
commit 88f2f6f4d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 581 additions and 45 deletions

View File

@ -52,15 +52,18 @@ enum HttpResponse {
// ...
}
// PyO3 also supports enums with non-unit variants
// PyO3 also supports enums with Struct and Tuple variants
// These complex enums have sligtly different behavior from the simple enums above
// They are meant to work with instance checks and match statement patterns
// The variants can be mixed and matched
// Struct variants have named fields while tuple enums generate generic names for fields in order _0, _1, _2, ...
// Apart from this both types are functionally identical
#[pyclass]
enum Shape {
Circle { radius: f64 },
Rectangle { width: f64, height: f64 },
RegularPolygon { side_count: u32, radius: f64 },
Nothing {},
RegularPolygon(u32, f64),
Nothing(),
}
```
@ -1180,7 +1183,7 @@ enum BadSubclass {
An enum is complex if it has any non-unit (struct or tuple) variants.
Currently PyO3 supports only struct variants in a complex enum. Support for unit and tuple variants is planned.
PyO3 supports only struct and tuple variants in a complex enum. Unit variants aren't supported at present (the recommendation is to use an empty tuple enum instead).
PyO3 adds a class attribute for each variant, which may be used to construct values and in match patterns. PyO3 also provides getter methods for all fields of each variant.
@ -1190,14 +1193,14 @@ PyO3 adds a class attribute for each variant, which may be used to construct val
enum Shape {
Circle { radius: f64 },
Rectangle { width: f64, height: f64 },
RegularPolygon { side_count: u32, radius: f64 },
RegularPolygon(u32, f64),
Nothing { },
}
# #[cfg(Py_3_10)]
Python::with_gil(|py| {
let circle = Shape::Circle { radius: 10.0 }.into_py(py);
let square = Shape::RegularPolygon { side_count: 4, radius: 10.0 }.into_py(py);
let square = Shape::RegularPolygon(4, 10.0).into_py(py);
let cls = py.get_type_bound::<Shape>();
pyo3::py_run!(py, circle square cls, r#"
assert isinstance(circle, cls)
@ -1206,8 +1209,8 @@ Python::with_gil(|py| {
assert isinstance(square, cls)
assert isinstance(square, cls.RegularPolygon)
assert square.side_count == 4
assert square.radius == 10.0
assert square[0] == 4 # Gets _0 field
assert square[1] == 10.0 # Gets _1 field
def count_vertices(cls, shape):
match shape:
@ -1215,7 +1218,7 @@ Python::with_gil(|py| {
return 0
case cls.Rectangle():
return 4
case cls.RegularPolygon(side_count=n):
case cls.RegularPolygon(n):
return n
case cls.Nothing():
return 0

View File

@ -0,0 +1 @@
Support `#[pyclass]` on enums that have tuple variants.

View File

@ -12,7 +12,7 @@ use crate::pyfunction::ConstructorAttribute;
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
SlotDef, __INT__, __REPR__, __RICHCMP__,
SlotDef, __GETITEM__, __INT__, __LEN__, __REPR__, __RICHCMP__,
};
use crate::utils::Ctx;
use crate::utils::{self, apply_renaming_rule, PythonDoc};
@ -505,7 +505,7 @@ impl<'a> PyClassComplexEnum<'a> {
Fields::Unit => {
bail_spanned!(variant.span() => format!(
"Unit variant `{ident}` is not yet supported in a complex enum\n\
= help: change to a struct variant with no fields: `{ident} {{ }}`\n\
= help: change to an empty tuple variant instead: `{ident}()`\n\
= note: the enum is complex because of non-unit variant `{witness}`",
ident=ident, witness=witness))
}
@ -526,12 +526,21 @@ impl<'a> PyClassComplexEnum<'a> {
options,
})
}
Fields::Unnamed(_) => {
bail_spanned!(variant.span() => format!(
"Tuple variant `{ident}` is not yet supported in a complex enum\n\
= help: change to a struct variant with named fields: `{ident} {{ /* fields */ }}`\n\
= note: the enum is complex because of non-unit variant `{witness}`",
ident=ident, witness=witness))
Fields::Unnamed(types) => {
let fields = types
.unnamed
.iter()
.map(|field| PyClassEnumVariantUnnamedField {
ty: &field.ty,
span: field.span(),
})
.collect();
PyClassEnumVariant::Tuple(PyClassEnumTupleVariant {
ident,
fields,
options,
})
}
};
@ -553,7 +562,7 @@ impl<'a> PyClassComplexEnum<'a> {
enum PyClassEnumVariant<'a> {
// TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>),
Struct(PyClassEnumStructVariant<'a>),
// TODO(mkovaxx): Tuple(PyClassEnumTupleVariant<'a>),
Tuple(PyClassEnumTupleVariant<'a>),
}
trait EnumVariant {
@ -581,12 +590,14 @@ impl<'a> EnumVariant for PyClassEnumVariant<'a> {
fn get_ident(&self) -> &syn::Ident {
match self {
PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident,
PyClassEnumVariant::Tuple(tuple_variant) => tuple_variant.ident,
}
}
fn get_options(&self) -> &EnumVariantPyO3Options {
match self {
PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options,
PyClassEnumVariant::Tuple(tuple_variant) => &tuple_variant.options,
}
}
}
@ -614,12 +625,23 @@ struct PyClassEnumStructVariant<'a> {
options: EnumVariantPyO3Options,
}
struct PyClassEnumTupleVariant<'a> {
ident: &'a syn::Ident,
fields: Vec<PyClassEnumVariantUnnamedField<'a>>,
options: EnumVariantPyO3Options,
}
struct PyClassEnumVariantNamedField<'a> {
ident: &'a syn::Ident,
ty: &'a syn::Type,
span: Span,
}
struct PyClassEnumVariantUnnamedField<'a> {
ty: &'a syn::Type,
span: Span,
}
/// `#[pyo3()]` options for pyclass enum variants
#[derive(Default)]
struct EnumVariantPyO3Options {
@ -930,17 +952,19 @@ fn impl_complex_enum(
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
let (variant_cls_impl, field_getters, mut slots) =
impl_complex_enum_variant_cls(cls, &variant, ctx)?;
variant_cls_impls.push(variant_cls_impl);
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
slots.push(variant_new);
let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
methods_type,
field_getters,
vec![variant_new],
slots,
)
.impl_all(ctx)?;
@ -970,19 +994,52 @@ fn impl_complex_enum_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumVariant<'_>,
ctx: &Ctx,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
match variant {
PyClassEnumVariant::Struct(struct_variant) => {
impl_complex_enum_struct_variant_cls(enum_name, struct_variant, ctx)
}
PyClassEnumVariant::Tuple(tuple_variant) => {
impl_complex_enum_tuple_variant_cls(enum_name, tuple_variant, ctx)
}
}
}
fn impl_complex_enum_variant_match_args(
ctx: &Ctx,
variant_cls_type: &syn::Type,
field_names: &mut Vec<Ident>,
) -> (MethodAndMethodDef, syn::ImplItemConst) {
let match_args_const_impl: syn::ImplItemConst = {
let args_tp = field_names.iter().map(|_| {
quote! { &'static str }
});
parse_quote! {
const __match_args__: ( #(#args_tp,)* ) = (
#(stringify!(#field_names),)*
);
}
};
let spec = ConstSpec {
rust_ident: format_ident!("__match_args__"),
attributes: ConstAttributes {
is_class_attr: true,
name: None,
deprecations: Deprecations::new(ctx),
},
};
let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);
(variant_match_args, match_args_const_impl)
}
fn impl_complex_enum_struct_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumStructVariant<'_>,
ctx: &Ctx,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
let Ctx { pyo3_path } = ctx;
let variant_ident = &variant.ident;
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
@ -1015,6 +1072,11 @@ fn impl_complex_enum_struct_variant_cls(
field_getter_impls.push(field_getter_impl);
}
let (variant_match_args, match_args_const_impl) =
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);
field_getters.push(variant_match_args);
let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
@ -1024,11 +1086,190 @@ fn impl_complex_enum_struct_variant_cls(
#pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
}
#match_args_const_impl
#(#field_getter_impls)*
}
};
Ok((cls_impl, field_getters))
Ok((cls_impl, field_getters, Vec::new()))
}
fn impl_complex_enum_tuple_variant_field_getters(
ctx: &Ctx,
variant: &PyClassEnumTupleVariant<'_>,
enum_name: &syn::Ident,
variant_cls_type: &syn::Type,
variant_ident: &&Ident,
field_names: &mut Vec<Ident>,
fields_types: &mut Vec<syn::Type>,
) -> Result<(Vec<MethodAndMethodDef>, Vec<syn::ImplItemFn>)> {
let Ctx { pyo3_path } = ctx;
let mut field_getters = vec![];
let mut field_getter_impls = vec![];
for (index, field) in variant.fields.iter().enumerate() {
let field_name = format_ident!("_{}", index);
let field_type = field.ty;
let field_getter =
complex_enum_variant_field_getter(variant_cls_type, &field_name, field.span, ctx)?;
// Generate the match arms needed to destructure the tuple and access the specific field
let field_access_tokens: Vec<_> = (0..variant.fields.len())
.map(|i| {
if i == index {
quote! { val }
} else {
quote! { _ }
}
})
.collect();
let field_getter_impl: syn::ImplItemFn = parse_quote! {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
match &*slf.into_super() {
#enum_name::#variant_ident ( #(#field_access_tokens), *) => Ok(val.clone()),
_ => unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
}
}
};
field_names.push(field_name);
fields_types.push(field_type.clone());
field_getters.push(field_getter);
field_getter_impls.push(field_getter_impl);
}
Ok((field_getters, field_getter_impls))
}
fn impl_complex_enum_tuple_variant_len(
ctx: &Ctx,
variant_cls_type: &syn::Type,
num_fields: usize,
) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> {
let Ctx { pyo3_path } = ctx;
let mut len_method_impl: syn::ImplItemFn = parse_quote! {
fn __len__(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<usize> {
Ok(#num_fields)
}
};
let variant_len =
generate_default_protocol_slot(variant_cls_type, &mut len_method_impl, &__LEN__, ctx)?;
Ok((variant_len, len_method_impl))
}
fn impl_complex_enum_tuple_variant_getitem(
ctx: &Ctx,
variant_cls: &syn::Ident,
variant_cls_type: &syn::Type,
num_fields: usize,
) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> {
let Ctx { pyo3_path } = ctx;
let match_arms: Vec<_> = (0..num_fields)
.map(|i| {
let field_access = format_ident!("_{}", i);
quote! {
#i => Ok(
#pyo3_path::IntoPy::into_py(
#variant_cls::#field_access(slf)?
, py)
)
}
})
.collect();
let mut get_item_method_impl: syn::ImplItemFn = parse_quote! {
fn __getitem__(slf: #pyo3_path::PyRef<Self>, idx: usize) -> #pyo3_path::PyResult< #pyo3_path::PyObject> {
let py = slf.py();
match idx {
#( #match_arms, )*
_ => Err(pyo3::exceptions::PyIndexError::new_err("tuple index out of range")),
}
}
};
let variant_getitem = generate_default_protocol_slot(
variant_cls_type,
&mut get_item_method_impl,
&__GETITEM__,
ctx,
)?;
Ok((variant_getitem, get_item_method_impl))
}
fn impl_complex_enum_tuple_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumTupleVariant<'_>,
ctx: &Ctx,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
let Ctx { pyo3_path } = ctx;
let variant_ident = &variant.ident;
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
let variant_cls_type = parse_quote!(#variant_cls);
let mut slots = vec![];
// represents the index of the field
let mut field_names: Vec<Ident> = vec![];
let mut field_types: Vec<syn::Type> = vec![];
let (mut field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters(
ctx,
variant,
enum_name,
&variant_cls_type,
variant_ident,
&mut field_names,
&mut field_types,
)?;
let num_fields = variant.fields.len();
let (variant_len, len_method_impl) =
impl_complex_enum_tuple_variant_len(ctx, &variant_cls_type, num_fields)?;
slots.push(variant_len);
let (variant_getitem, getitem_method_impl) =
impl_complex_enum_tuple_variant_getitem(ctx, &variant_cls, &variant_cls_type, num_fields)?;
slots.push(variant_getitem);
let (variant_match_args, match_args_method_impl) =
impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names);
field_getters.push(variant_match_args);
let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
impl #variant_cls {
fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#field_names : #field_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> {
let base_value = #enum_name::#variant_ident ( #(#field_names,)* );
#pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
}
#len_method_impl
#getitem_method_impl
#match_args_method_impl
#(#field_getter_impls)*
}
};
Ok((cls_impl, field_getters, slots))
}
fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident {
@ -1149,6 +1390,9 @@ fn complex_enum_variant_new<'a>(
PyClassEnumVariant::Struct(struct_variant) => {
complex_enum_struct_variant_new(cls, struct_variant, ctx)
}
PyClassEnumVariant::Tuple(tuple_variant) => {
complex_enum_tuple_variant_new(cls, tuple_variant, ctx)
}
}
}
@ -1209,6 +1453,61 @@ fn complex_enum_struct_variant_new<'a>(
crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
}
fn complex_enum_tuple_variant_new<'a>(
cls: &'a syn::Ident,
variant: PyClassEnumTupleVariant<'a>,
ctx: &Ctx,
) -> Result<MethodAndSlotDef> {
let Ctx { pyo3_path } = ctx;
let variant_cls: Ident = format_ident!("{}_{}", cls, variant.ident);
let variant_cls_type: syn::Type = parse_quote!(#variant_cls);
let arg_py_ident: syn::Ident = parse_quote!(py);
let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>);
let args = {
let mut args = vec![FnArg::Py(PyArg {
name: &arg_py_ident,
ty: &arg_py_type,
})];
for (i, field) in variant.fields.iter().enumerate() {
args.push(FnArg::Regular(RegularArg {
name: std::borrow::Cow::Owned(format_ident!("_{}", i)),
ty: field.ty,
from_py_with: None,
default_value: None,
option_wrapped_type: None,
}));
}
args
};
let signature = if let Some(constructor) = variant.options.constructor {
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(
args,
constructor.into_signature(),
)?
} else {
crate::pyfunction::FunctionSignature::from_arguments(args)?
};
let spec = FnSpec {
tp: crate::method::FnType::FnNew,
name: &format_ident!("__pymethod_constructor__"),
python_name: format_ident!("__new__"),
signature,
convention: crate::method::CallingConvention::TpNew,
text_signature: None,
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
};
crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
}
fn complex_enum_variant_field_getter<'a>(
variant_cls_type: &'a syn::Type,
field_name: &'a syn::Ident,

View File

@ -934,7 +934,7 @@ const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_speci
),
TokenGenerator(|_| quote! { async_iter_tag }),
);
const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT);
pub const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT);
const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc")
.arguments(&[Ty::Object])
.ret_ty(Ty::Int);
@ -944,7 +944,8 @@ const __INPLACE_CONCAT__: SlotDef =
SlotDef::new("Py_sq_concat", "binaryfunc").arguments(&[Ty::Object]);
const __INPLACE_REPEAT__: SlotDef =
SlotDef::new("Py_sq_repeat", "ssizeargfunc").arguments(&[Ty::PySsizeT]);
const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]);
pub const __GETITEM__: SlotDef =
SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]);
const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc");
const __NEG__: SlotDef = SlotDef::new("Py_nb_negative", "unaryfunc");

View File

@ -8,8 +8,13 @@ use pyo3::{
pub fn enums(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<SimpleEnum>()?;
m.add_class::<ComplexEnum>()?;
m.add_class::<SimpleTupleEnum>()?;
m.add_class::<TupleEnum>()?;
m.add_class::<MixedComplexEnum>()?;
m.add_wrapped(wrap_pyfunction_bound!(do_simple_stuff))?;
m.add_wrapped(wrap_pyfunction_bound!(do_complex_stuff))?;
m.add_wrapped(wrap_pyfunction_bound!(do_tuple_stuff))?;
m.add_wrapped(wrap_pyfunction_bound!(do_mixed_complex_stuff))?;
Ok(())
}
@ -79,3 +84,40 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
},
}
}
#[pyclass]
enum SimpleTupleEnum {
Int(i32),
Str(String),
}
#[pyclass]
pub enum TupleEnum {
#[pyo3(constructor = (_0 = 1, _1 = 1.0, _2 = true))]
FullWithDefault(i32, f64, bool),
Full(i32, f64, bool),
EmptyTuple(),
}
#[pyfunction]
pub fn do_tuple_stuff(thing: &TupleEnum) -> TupleEnum {
match thing {
TupleEnum::FullWithDefault(a, b, c) => TupleEnum::FullWithDefault(*a, *b, *c),
TupleEnum::Full(a, b, c) => TupleEnum::Full(*a, *b, *c),
TupleEnum::EmptyTuple() => TupleEnum::EmptyTuple(),
}
}
#[pyclass]
pub enum MixedComplexEnum {
Nothing {},
Empty(),
}
#[pyfunction]
pub fn do_mixed_complex_stuff(thing: &MixedComplexEnum) -> MixedComplexEnum {
match thing {
MixedComplexEnum::Nothing {} => MixedComplexEnum::Empty(),
MixedComplexEnum::Empty() => MixedComplexEnum::Nothing {},
}
}

View File

@ -137,3 +137,67 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
assert y == "HELLO"
else:
assert False
def test_tuple_enum_variant_constructors():
tuple_variant = enums.TupleEnum.Full(42, 3.14, False)
assert isinstance(tuple_variant, enums.TupleEnum.Full)
empty_tuple_variant = enums.TupleEnum.EmptyTuple()
assert isinstance(empty_tuple_variant, enums.TupleEnum.EmptyTuple)
@pytest.mark.parametrize(
"variant",
[
enums.TupleEnum.FullWithDefault(),
enums.TupleEnum.Full(42, 3.14, False),
enums.TupleEnum.EmptyTuple(),
],
)
def test_tuple_enum_variant_subclasses(variant: enums.TupleEnum):
assert isinstance(variant, enums.TupleEnum)
def test_tuple_enum_defaults():
variant = enums.TupleEnum.FullWithDefault()
assert variant._0 == 1
assert variant._1 == 1.0
assert variant._2 is True
def test_tuple_enum_field_getters():
tuple_variant = enums.TupleEnum.Full(42, 3.14, False)
assert tuple_variant._0 == 42
assert tuple_variant._1 == 3.14
assert tuple_variant._2 is False
def test_tuple_enum_index_getter():
tuple_variant = enums.TupleEnum.Full(42, 3.14, False)
assert len(tuple_variant) == 3
assert tuple_variant[0] == 42
@pytest.mark.parametrize(
"variant",
[enums.MixedComplexEnum.Nothing()],
)
def test_mixed_complex_enum_pyfunction_instance_nothing(
variant: enums.MixedComplexEnum,
):
assert isinstance(variant, enums.MixedComplexEnum.Nothing)
assert isinstance(
enums.do_mixed_complex_stuff(variant), enums.MixedComplexEnum.Empty
)
@pytest.mark.parametrize(
"variant",
[enums.MixedComplexEnum.Empty()],
)
def test_mixed_complex_enum_pyfunction_instance_empty(variant: enums.MixedComplexEnum):
assert isinstance(variant, enums.MixedComplexEnum.Empty)
assert isinstance(
enums.do_mixed_complex_stuff(variant), enums.MixedComplexEnum.Nothing
)

View File

@ -57,3 +57,102 @@ def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum):
assert z is True
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_partial_match(variant: enums.ComplexEnum):
match variant:
case enums.ComplexEnum.MultiFieldStruct(a):
assert a == 42
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.TupleEnum.Full(42, 3.14, True),
enums.TupleEnum.EmptyTuple(),
],
)
def test_tuple_enum_match_statement(variant: enums.TupleEnum):
match variant:
case enums.TupleEnum.Full(_0=x, _1=y, _2=z):
assert x == 42
assert y == 3.14
assert z is True
case enums.TupleEnum.EmptyTuple():
assert True
case _:
print(variant)
assert False
@pytest.mark.parametrize(
"variant",
[
enums.SimpleTupleEnum.Int(42),
enums.SimpleTupleEnum.Str("hello"),
],
)
def test_simple_tuple_enum_match_statement(variant: enums.SimpleTupleEnum):
match variant:
case enums.SimpleTupleEnum.Int(x):
assert x == 42
case enums.SimpleTupleEnum.Str(x):
assert x == "hello"
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.TupleEnum.Full(42, 3.14, True),
],
)
def test_tuple_enum_match_match_args(variant: enums.TupleEnum):
match variant:
case enums.TupleEnum.Full(x, y, z):
assert x == 42
assert y == 3.14
assert z is True
assert True
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.TupleEnum.Full(42, 3.14, True),
],
)
def test_tuple_enum_partial_match(variant: enums.TupleEnum):
match variant:
case enums.TupleEnum.Full(a):
assert a == 42
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.MixedComplexEnum.Nothing(),
enums.MixedComplexEnum.Empty(),
],
)
def test_mixed_complex_enum_match_statement(variant: enums.MixedComplexEnum):
match variant:
case enums.MixedComplexEnum.Nothing():
assert True
case enums.MixedComplexEnum.Empty():
assert True
case _:
assert False

View File

@ -21,12 +21,6 @@ enum NoUnitVariants {
UnitVariant,
}
#[pyclass]
enum NoTupleVariants {
StructVariant { field: i32 },
TupleVariant(i32),
}
#[pyclass]
enum SimpleNoSignature {
#[pyo3(constructor = (a, b))]

View File

@ -17,23 +17,15 @@ error: #[pyclass] can't be used on enums without any variants
| ^^
error: Unit variant `UnitVariant` is not yet supported in a complex enum
= help: change to a struct variant with no fields: `UnitVariant { }`
= help: change to an empty tuple variant instead: `UnitVariant()`
= note: the enum is complex because of non-unit variant `StructVariant`
--> tests/ui/invalid_pyclass_enum.rs:21:5
|
21 | UnitVariant,
| ^^^^^^^^^^^
error: Tuple variant `TupleVariant` is not yet supported in a complex enum
= help: change to a struct variant with named fields: `TupleVariant { /* fields */ }`
= note: the enum is complex because of non-unit variant `StructVariant`
--> tests/ui/invalid_pyclass_enum.rs:27:5
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^
error: `constructor` can't be used on a simple enum variant
--> tests/ui/invalid_pyclass_enum.rs:32:12
--> tests/ui/invalid_pyclass_enum.rs:26:12
|
32 | #[pyo3(constructor = (a, b))]
26 | #[pyo3(constructor = (a, b))]
| ^^^^^^^^^^^

View File

@ -16,4 +16,20 @@ impl ComplexEnum {
}
}
#[pyclass]
enum TupleEnum {
Int(i32),
Str(String),
}
#[pymethods]
impl TupleEnum {
fn mutate_in_place(&mut self) {
*self = match self {
TupleEnum::Int(int) => TupleEnum::Str(int.to_string()),
TupleEnum::Str(string) => TupleEnum::Int(string.len() as i32),
}
}
}
fn main() {}

View File

@ -22,3 +22,28 @@ note: required by a bound in `PyRefMut`
| pub struct PyRefMut<'p, T: PyClass<Frozen = False>> {
| ^^^^^^^^^^^^^^ required by this bound in `PyRefMut`
= note: this error originates in the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0271]: type mismatch resolving `<TupleEnum as PyClass>::Frozen == False`
--> tests/ui/invalid_pymethod_enum.rs:27:24
|
27 | fn mutate_in_place(&mut self) {
| ^ expected `False`, found `True`
|
note: required by a bound in `extract_pyclass_ref_mut`
--> src/impl_/extract_argument.rs
|
| pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass<Frozen = False>>(
| ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut`
error[E0271]: type mismatch resolving `<TupleEnum as PyClass>::Frozen == False`
--> tests/ui/invalid_pymethod_enum.rs:25:1
|
25 | #[pymethods]
| ^^^^^^^^^^^^ expected `False`, found `True`
|
note: required by a bound in `PyRefMut`
--> src/pycell.rs
|
| pub struct PyRefMut<'p, T: PyClass<Frozen = False>> {
| ^^^^^^^^^^^^^^ required by this bound in `PyRefMut`
= note: this error originates in the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info)