From 88f2f6f4d56f2bac1220d1f0d0ac912b8c160c4a Mon Sep 17 00:00:00 2001 From: newcomertv Date: Fri, 17 May 2024 04:59:00 +0200 Subject: [PATCH] feat: support pyclass on tuple enums (#4072) * feat: support pyclass on tuple enums * cargo fmt * changelog * ruff format * rebase with adaptation for FnArg refactor * fix class.md from pr comments * add enum tuple variant getitem implementation * fmt * progress toward getitem and len impl on derive pyclass for complex enum tuple * working getitem and len slots for complex tuple enum pyclass derivation * refactor code generation * address PR concerns - take py from function argument on get_item - make more general slot def implementation - remove unnecessary function arguments - add testcases for uncovered cases including future feature match_args * add tracking issue * fmt * ruff * remove me * support match_args for tuple enum * integrate FnArg now takes Cow * fix empty and single element tuples * use impl_py_slot_def for cimplex tuple enum slots * reverse erroneous doc change * Address latest comments * formatting suggestion * fix : - clippy beta - better compile error (+related doc and test) --------- Co-authored-by: Chris Arderne --- guide/src/class.md | 21 +- newsfragments/4072.added.md | 1 + pyo3-macros-backend/src/pyclass.rs | 333 ++++++++++++++++++++++++-- pyo3-macros-backend/src/pymethod.rs | 5 +- pytests/src/enums.rs | 42 ++++ pytests/tests/test_enums.py | 64 +++++ pytests/tests/test_enums_match.py | 99 ++++++++ tests/ui/invalid_pyclass_enum.rs | 6 - tests/ui/invalid_pyclass_enum.stderr | 14 +- tests/ui/invalid_pymethod_enum.rs | 16 ++ tests/ui/invalid_pymethod_enum.stderr | 25 ++ 11 files changed, 581 insertions(+), 45 deletions(-) create mode 100644 newsfragments/4072.added.md diff --git a/guide/src/class.md b/guide/src/class.md index ce86ec40..57a5cf6d 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -52,15 +52,18 @@ enum HttpResponse { // ... } -// PyO3 also supports enums with non-unit variants +// PyO3 also supports enums with Struct and Tuple variants // These complex enums have sligtly different behavior from the simple enums above // They are meant to work with instance checks and match statement patterns +// The variants can be mixed and matched +// Struct variants have named fields while tuple enums generate generic names for fields in order _0, _1, _2, ... +// Apart from this both types are functionally identical #[pyclass] enum Shape { Circle { radius: f64 }, Rectangle { width: f64, height: f64 }, - RegularPolygon { side_count: u32, radius: f64 }, - Nothing {}, + RegularPolygon(u32, f64), + Nothing(), } ``` @@ -1180,7 +1183,7 @@ enum BadSubclass { An enum is complex if it has any non-unit (struct or tuple) variants. -Currently PyO3 supports only struct variants in a complex enum. Support for unit and tuple variants is planned. +PyO3 supports only struct and tuple variants in a complex enum. Unit variants aren't supported at present (the recommendation is to use an empty tuple enum instead). PyO3 adds a class attribute for each variant, which may be used to construct values and in match patterns. PyO3 also provides getter methods for all fields of each variant. @@ -1190,14 +1193,14 @@ PyO3 adds a class attribute for each variant, which may be used to construct val enum Shape { Circle { radius: f64 }, Rectangle { width: f64, height: f64 }, - RegularPolygon { side_count: u32, radius: f64 }, + RegularPolygon(u32, f64), Nothing { }, } # #[cfg(Py_3_10)] Python::with_gil(|py| { let circle = Shape::Circle { radius: 10.0 }.into_py(py); - let square = Shape::RegularPolygon { side_count: 4, radius: 10.0 }.into_py(py); + let square = Shape::RegularPolygon(4, 10.0).into_py(py); let cls = py.get_type_bound::(); pyo3::py_run!(py, circle square cls, r#" assert isinstance(circle, cls) @@ -1206,8 +1209,8 @@ Python::with_gil(|py| { assert isinstance(square, cls) assert isinstance(square, cls.RegularPolygon) - assert square.side_count == 4 - assert square.radius == 10.0 + assert square[0] == 4 # Gets _0 field + assert square[1] == 10.0 # Gets _1 field def count_vertices(cls, shape): match shape: @@ -1215,7 +1218,7 @@ Python::with_gil(|py| { return 0 case cls.Rectangle(): return 4 - case cls.RegularPolygon(side_count=n): + case cls.RegularPolygon(n): return n case cls.Nothing(): return 0 diff --git a/newsfragments/4072.added.md b/newsfragments/4072.added.md new file mode 100644 index 00000000..23207c84 --- /dev/null +++ b/newsfragments/4072.added.md @@ -0,0 +1 @@ +Support `#[pyclass]` on enums that have tuple variants. \ No newline at end of file diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 4e71a711..47c52c84 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -12,7 +12,7 @@ use crate::pyfunction::ConstructorAttribute; use crate::pyimpl::{gen_py_const, PyClassMethodsType}; use crate::pymethod::{ impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType, - SlotDef, __INT__, __REPR__, __RICHCMP__, + SlotDef, __GETITEM__, __INT__, __LEN__, __REPR__, __RICHCMP__, }; use crate::utils::Ctx; use crate::utils::{self, apply_renaming_rule, PythonDoc}; @@ -504,10 +504,10 @@ impl<'a> PyClassComplexEnum<'a> { let variant = match &variant.fields { Fields::Unit => { bail_spanned!(variant.span() => format!( - "Unit variant `{ident}` is not yet supported in a complex enum\n\ - = help: change to a struct variant with no fields: `{ident} {{ }}`\n\ - = note: the enum is complex because of non-unit variant `{witness}`", - ident=ident, witness=witness)) + "Unit variant `{ident}` is not yet supported in a complex enum\n\ + = help: change to an empty tuple variant instead: `{ident}()`\n\ + = note: the enum is complex because of non-unit variant `{witness}`", + ident=ident, witness=witness)) } Fields::Named(fields) => { let fields = fields @@ -526,12 +526,21 @@ impl<'a> PyClassComplexEnum<'a> { options, }) } - Fields::Unnamed(_) => { - bail_spanned!(variant.span() => format!( - "Tuple variant `{ident}` is not yet supported in a complex enum\n\ - = help: change to a struct variant with named fields: `{ident} {{ /* fields */ }}`\n\ - = note: the enum is complex because of non-unit variant `{witness}`", - ident=ident, witness=witness)) + Fields::Unnamed(types) => { + let fields = types + .unnamed + .iter() + .map(|field| PyClassEnumVariantUnnamedField { + ty: &field.ty, + span: field.span(), + }) + .collect(); + + PyClassEnumVariant::Tuple(PyClassEnumTupleVariant { + ident, + fields, + options, + }) } }; @@ -553,7 +562,7 @@ impl<'a> PyClassComplexEnum<'a> { enum PyClassEnumVariant<'a> { // TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>), Struct(PyClassEnumStructVariant<'a>), - // TODO(mkovaxx): Tuple(PyClassEnumTupleVariant<'a>), + Tuple(PyClassEnumTupleVariant<'a>), } trait EnumVariant { @@ -581,12 +590,14 @@ impl<'a> EnumVariant for PyClassEnumVariant<'a> { fn get_ident(&self) -> &syn::Ident { match self { PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident, + PyClassEnumVariant::Tuple(tuple_variant) => tuple_variant.ident, } } fn get_options(&self) -> &EnumVariantPyO3Options { match self { PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options, + PyClassEnumVariant::Tuple(tuple_variant) => &tuple_variant.options, } } } @@ -614,12 +625,23 @@ struct PyClassEnumStructVariant<'a> { options: EnumVariantPyO3Options, } +struct PyClassEnumTupleVariant<'a> { + ident: &'a syn::Ident, + fields: Vec>, + options: EnumVariantPyO3Options, +} + struct PyClassEnumVariantNamedField<'a> { ident: &'a syn::Ident, ty: &'a syn::Type, span: Span, } +struct PyClassEnumVariantUnnamedField<'a> { + ty: &'a syn::Type, + span: Span, +} + /// `#[pyo3()]` options for pyclass enum variants #[derive(Default)] struct EnumVariantPyO3Options { @@ -930,17 +952,19 @@ fn impl_complex_enum( let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx); variant_cls_pytypeinfos.push(variant_cls_pytypeinfo); - let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?; + let (variant_cls_impl, field_getters, mut slots) = + impl_complex_enum_variant_cls(cls, &variant, ctx)?; variant_cls_impls.push(variant_cls_impl); let variant_new = complex_enum_variant_new(cls, variant, ctx)?; + slots.push(variant_new); let pyclass_impl = PyClassImplsBuilder::new( &variant_cls, &variant_args, methods_type, field_getters, - vec![variant_new], + slots, ) .impl_all(ctx)?; @@ -970,19 +994,52 @@ fn impl_complex_enum_variant_cls( enum_name: &syn::Ident, variant: &PyClassEnumVariant<'_>, ctx: &Ctx, -) -> Result<(TokenStream, Vec)> { +) -> Result<(TokenStream, Vec, Vec)> { match variant { PyClassEnumVariant::Struct(struct_variant) => { impl_complex_enum_struct_variant_cls(enum_name, struct_variant, ctx) } + PyClassEnumVariant::Tuple(tuple_variant) => { + impl_complex_enum_tuple_variant_cls(enum_name, tuple_variant, ctx) + } } } +fn impl_complex_enum_variant_match_args( + ctx: &Ctx, + variant_cls_type: &syn::Type, + field_names: &mut Vec, +) -> (MethodAndMethodDef, syn::ImplItemConst) { + let match_args_const_impl: syn::ImplItemConst = { + let args_tp = field_names.iter().map(|_| { + quote! { &'static str } + }); + parse_quote! { + const __match_args__: ( #(#args_tp,)* ) = ( + #(stringify!(#field_names),)* + ); + } + }; + + let spec = ConstSpec { + rust_ident: format_ident!("__match_args__"), + attributes: ConstAttributes { + is_class_attr: true, + name: None, + deprecations: Deprecations::new(ctx), + }, + }; + + let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx); + + (variant_match_args, match_args_const_impl) +} + fn impl_complex_enum_struct_variant_cls( enum_name: &syn::Ident, variant: &PyClassEnumStructVariant<'_>, ctx: &Ctx, -) -> Result<(TokenStream, Vec)> { +) -> Result<(TokenStream, Vec, Vec)> { let Ctx { pyo3_path } = ctx; let variant_ident = &variant.ident; let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident); @@ -1015,6 +1072,11 @@ fn impl_complex_enum_struct_variant_cls( field_getter_impls.push(field_getter_impl); } + let (variant_match_args, match_args_const_impl) = + impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names); + + field_getters.push(variant_match_args); + let cls_impl = quote! { #[doc(hidden)] #[allow(non_snake_case)] @@ -1024,11 +1086,190 @@ fn impl_complex_enum_struct_variant_cls( #pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls) } + #match_args_const_impl + #(#field_getter_impls)* } }; - Ok((cls_impl, field_getters)) + Ok((cls_impl, field_getters, Vec::new())) +} + +fn impl_complex_enum_tuple_variant_field_getters( + ctx: &Ctx, + variant: &PyClassEnumTupleVariant<'_>, + enum_name: &syn::Ident, + variant_cls_type: &syn::Type, + variant_ident: &&Ident, + field_names: &mut Vec, + fields_types: &mut Vec, +) -> Result<(Vec, Vec)> { + let Ctx { pyo3_path } = ctx; + + let mut field_getters = vec![]; + let mut field_getter_impls = vec![]; + + for (index, field) in variant.fields.iter().enumerate() { + let field_name = format_ident!("_{}", index); + let field_type = field.ty; + + let field_getter = + complex_enum_variant_field_getter(variant_cls_type, &field_name, field.span, ctx)?; + + // Generate the match arms needed to destructure the tuple and access the specific field + let field_access_tokens: Vec<_> = (0..variant.fields.len()) + .map(|i| { + if i == index { + quote! { val } + } else { + quote! { _ } + } + }) + .collect(); + + let field_getter_impl: syn::ImplItemFn = parse_quote! { + fn #field_name(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult<#field_type> { + match &*slf.into_super() { + #enum_name::#variant_ident ( #(#field_access_tokens), *) => Ok(val.clone()), + _ => unreachable!("Wrong complex enum variant found in variant wrapper PyClass"), + } + } + }; + + field_names.push(field_name); + fields_types.push(field_type.clone()); + field_getters.push(field_getter); + field_getter_impls.push(field_getter_impl); + } + + Ok((field_getters, field_getter_impls)) +} + +fn impl_complex_enum_tuple_variant_len( + ctx: &Ctx, + + variant_cls_type: &syn::Type, + num_fields: usize, +) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> { + let Ctx { pyo3_path } = ctx; + + let mut len_method_impl: syn::ImplItemFn = parse_quote! { + fn __len__(slf: #pyo3_path::PyRef) -> #pyo3_path::PyResult { + Ok(#num_fields) + } + }; + + let variant_len = + generate_default_protocol_slot(variant_cls_type, &mut len_method_impl, &__LEN__, ctx)?; + + Ok((variant_len, len_method_impl)) +} + +fn impl_complex_enum_tuple_variant_getitem( + ctx: &Ctx, + variant_cls: &syn::Ident, + variant_cls_type: &syn::Type, + num_fields: usize, +) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> { + let Ctx { pyo3_path } = ctx; + + let match_arms: Vec<_> = (0..num_fields) + .map(|i| { + let field_access = format_ident!("_{}", i); + quote! { + #i => Ok( + #pyo3_path::IntoPy::into_py( + #variant_cls::#field_access(slf)? + , py) + ) + + } + }) + .collect(); + + let mut get_item_method_impl: syn::ImplItemFn = parse_quote! { + fn __getitem__(slf: #pyo3_path::PyRef, idx: usize) -> #pyo3_path::PyResult< #pyo3_path::PyObject> { + let py = slf.py(); + match idx { + #( #match_arms, )* + _ => Err(pyo3::exceptions::PyIndexError::new_err("tuple index out of range")), + } + } + }; + + let variant_getitem = generate_default_protocol_slot( + variant_cls_type, + &mut get_item_method_impl, + &__GETITEM__, + ctx, + )?; + + Ok((variant_getitem, get_item_method_impl)) +} + +fn impl_complex_enum_tuple_variant_cls( + enum_name: &syn::Ident, + variant: &PyClassEnumTupleVariant<'_>, + ctx: &Ctx, +) -> Result<(TokenStream, Vec, Vec)> { + let Ctx { pyo3_path } = ctx; + let variant_ident = &variant.ident; + let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident); + let variant_cls_type = parse_quote!(#variant_cls); + + let mut slots = vec![]; + + // represents the index of the field + let mut field_names: Vec = vec![]; + let mut field_types: Vec = vec![]; + + let (mut field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters( + ctx, + variant, + enum_name, + &variant_cls_type, + variant_ident, + &mut field_names, + &mut field_types, + )?; + + let num_fields = variant.fields.len(); + + let (variant_len, len_method_impl) = + impl_complex_enum_tuple_variant_len(ctx, &variant_cls_type, num_fields)?; + + slots.push(variant_len); + + let (variant_getitem, getitem_method_impl) = + impl_complex_enum_tuple_variant_getitem(ctx, &variant_cls, &variant_cls_type, num_fields)?; + + slots.push(variant_getitem); + + let (variant_match_args, match_args_method_impl) = + impl_complex_enum_variant_match_args(ctx, &variant_cls_type, &mut field_names); + + field_getters.push(variant_match_args); + + let cls_impl = quote! { + #[doc(hidden)] + #[allow(non_snake_case)] + impl #variant_cls { + fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#field_names : #field_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> { + let base_value = #enum_name::#variant_ident ( #(#field_names,)* ); + #pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls) + } + + #len_method_impl + + #getitem_method_impl + + #match_args_method_impl + + #(#field_getter_impls)* + } + }; + + Ok((cls_impl, field_getters, slots)) } fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident { @@ -1149,6 +1390,9 @@ fn complex_enum_variant_new<'a>( PyClassEnumVariant::Struct(struct_variant) => { complex_enum_struct_variant_new(cls, struct_variant, ctx) } + PyClassEnumVariant::Tuple(tuple_variant) => { + complex_enum_tuple_variant_new(cls, tuple_variant, ctx) + } } } @@ -1209,6 +1453,61 @@ fn complex_enum_struct_variant_new<'a>( crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx) } +fn complex_enum_tuple_variant_new<'a>( + cls: &'a syn::Ident, + variant: PyClassEnumTupleVariant<'a>, + ctx: &Ctx, +) -> Result { + let Ctx { pyo3_path } = ctx; + + let variant_cls: Ident = format_ident!("{}_{}", cls, variant.ident); + let variant_cls_type: syn::Type = parse_quote!(#variant_cls); + + let arg_py_ident: syn::Ident = parse_quote!(py); + let arg_py_type: syn::Type = parse_quote!(#pyo3_path::Python<'_>); + + let args = { + let mut args = vec![FnArg::Py(PyArg { + name: &arg_py_ident, + ty: &arg_py_type, + })]; + + for (i, field) in variant.fields.iter().enumerate() { + args.push(FnArg::Regular(RegularArg { + name: std::borrow::Cow::Owned(format_ident!("_{}", i)), + ty: field.ty, + from_py_with: None, + default_value: None, + option_wrapped_type: None, + })); + } + args + }; + + let signature = if let Some(constructor) = variant.options.constructor { + crate::pyfunction::FunctionSignature::from_arguments_and_attribute( + args, + constructor.into_signature(), + )? + } else { + crate::pyfunction::FunctionSignature::from_arguments(args)? + }; + + let spec = FnSpec { + tp: crate::method::FnType::FnNew, + name: &format_ident!("__pymethod_constructor__"), + python_name: format_ident!("__new__"), + signature, + convention: crate::method::CallingConvention::TpNew, + text_signature: None, + asyncness: None, + unsafety: None, + deprecations: Deprecations::new(ctx), + }; + + crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx) +} + fn complex_enum_variant_field_getter<'a>( variant_cls_type: &'a syn::Type, field_name: &'a syn::Ident, diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 208735f2..f5b11af3 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -934,7 +934,7 @@ const __ANEXT__: SlotDef = SlotDef::new("Py_am_anext", "unaryfunc").return_speci ), TokenGenerator(|_| quote! { async_iter_tag }), ); -const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT); +pub const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySsizeT); const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc") .arguments(&[Ty::Object]) .ret_ty(Ty::Int); @@ -944,7 +944,8 @@ const __INPLACE_CONCAT__: SlotDef = SlotDef::new("Py_sq_concat", "binaryfunc").arguments(&[Ty::Object]); const __INPLACE_REPEAT__: SlotDef = SlotDef::new("Py_sq_repeat", "ssizeargfunc").arguments(&[Ty::PySsizeT]); -const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); +pub const __GETITEM__: SlotDef = + SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]); const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc"); const __NEG__: SlotDef = SlotDef::new("Py_nb_negative", "unaryfunc"); diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs index 68a5fc93..964f0d43 100644 --- a/pytests/src/enums.rs +++ b/pytests/src/enums.rs @@ -8,8 +8,13 @@ use pyo3::{ pub fn enums(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction_bound!(do_simple_stuff))?; m.add_wrapped(wrap_pyfunction_bound!(do_complex_stuff))?; + m.add_wrapped(wrap_pyfunction_bound!(do_tuple_stuff))?; + m.add_wrapped(wrap_pyfunction_bound!(do_mixed_complex_stuff))?; Ok(()) } @@ -79,3 +84,40 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum { }, } } + +#[pyclass] +enum SimpleTupleEnum { + Int(i32), + Str(String), +} + +#[pyclass] +pub enum TupleEnum { + #[pyo3(constructor = (_0 = 1, _1 = 1.0, _2 = true))] + FullWithDefault(i32, f64, bool), + Full(i32, f64, bool), + EmptyTuple(), +} + +#[pyfunction] +pub fn do_tuple_stuff(thing: &TupleEnum) -> TupleEnum { + match thing { + TupleEnum::FullWithDefault(a, b, c) => TupleEnum::FullWithDefault(*a, *b, *c), + TupleEnum::Full(a, b, c) => TupleEnum::Full(*a, *b, *c), + TupleEnum::EmptyTuple() => TupleEnum::EmptyTuple(), + } +} + +#[pyclass] +pub enum MixedComplexEnum { + Nothing {}, + Empty(), +} + +#[pyfunction] +pub fn do_mixed_complex_stuff(thing: &MixedComplexEnum) -> MixedComplexEnum { + match thing { + MixedComplexEnum::Nothing {} => MixedComplexEnum::Empty(), + MixedComplexEnum::Empty() => MixedComplexEnum::Nothing {}, + } +} diff --git a/pytests/tests/test_enums.py b/pytests/tests/test_enums.py index cd1d7aed..cd4f7e12 100644 --- a/pytests/tests/test_enums.py +++ b/pytests/tests/test_enums.py @@ -137,3 +137,67 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn assert y == "HELLO" else: assert False + + +def test_tuple_enum_variant_constructors(): + tuple_variant = enums.TupleEnum.Full(42, 3.14, False) + assert isinstance(tuple_variant, enums.TupleEnum.Full) + + empty_tuple_variant = enums.TupleEnum.EmptyTuple() + assert isinstance(empty_tuple_variant, enums.TupleEnum.EmptyTuple) + + +@pytest.mark.parametrize( + "variant", + [ + enums.TupleEnum.FullWithDefault(), + enums.TupleEnum.Full(42, 3.14, False), + enums.TupleEnum.EmptyTuple(), + ], +) +def test_tuple_enum_variant_subclasses(variant: enums.TupleEnum): + assert isinstance(variant, enums.TupleEnum) + + +def test_tuple_enum_defaults(): + variant = enums.TupleEnum.FullWithDefault() + assert variant._0 == 1 + assert variant._1 == 1.0 + assert variant._2 is True + + +def test_tuple_enum_field_getters(): + tuple_variant = enums.TupleEnum.Full(42, 3.14, False) + assert tuple_variant._0 == 42 + assert tuple_variant._1 == 3.14 + assert tuple_variant._2 is False + + +def test_tuple_enum_index_getter(): + tuple_variant = enums.TupleEnum.Full(42, 3.14, False) + assert len(tuple_variant) == 3 + assert tuple_variant[0] == 42 + + +@pytest.mark.parametrize( + "variant", + [enums.MixedComplexEnum.Nothing()], +) +def test_mixed_complex_enum_pyfunction_instance_nothing( + variant: enums.MixedComplexEnum, +): + assert isinstance(variant, enums.MixedComplexEnum.Nothing) + assert isinstance( + enums.do_mixed_complex_stuff(variant), enums.MixedComplexEnum.Empty + ) + + +@pytest.mark.parametrize( + "variant", + [enums.MixedComplexEnum.Empty()], +) +def test_mixed_complex_enum_pyfunction_instance_empty(variant: enums.MixedComplexEnum): + assert isinstance(variant, enums.MixedComplexEnum.Empty) + assert isinstance( + enums.do_mixed_complex_stuff(variant), enums.MixedComplexEnum.Nothing + ) diff --git a/pytests/tests/test_enums_match.py b/pytests/tests/test_enums_match.py index 4d55bbbe..6c4b5f6a 100644 --- a/pytests/tests/test_enums_match.py +++ b/pytests/tests/test_enums_match.py @@ -57,3 +57,102 @@ def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum): assert z is True case _: assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.ComplexEnum.MultiFieldStruct(42, 3.14, True), + ], +) +def test_complex_enum_partial_match(variant: enums.ComplexEnum): + match variant: + case enums.ComplexEnum.MultiFieldStruct(a): + assert a == 42 + case _: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.TupleEnum.Full(42, 3.14, True), + enums.TupleEnum.EmptyTuple(), + ], +) +def test_tuple_enum_match_statement(variant: enums.TupleEnum): + match variant: + case enums.TupleEnum.Full(_0=x, _1=y, _2=z): + assert x == 42 + assert y == 3.14 + assert z is True + case enums.TupleEnum.EmptyTuple(): + assert True + case _: + print(variant) + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.SimpleTupleEnum.Int(42), + enums.SimpleTupleEnum.Str("hello"), + ], +) +def test_simple_tuple_enum_match_statement(variant: enums.SimpleTupleEnum): + match variant: + case enums.SimpleTupleEnum.Int(x): + assert x == 42 + case enums.SimpleTupleEnum.Str(x): + assert x == "hello" + case _: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.TupleEnum.Full(42, 3.14, True), + ], +) +def test_tuple_enum_match_match_args(variant: enums.TupleEnum): + match variant: + case enums.TupleEnum.Full(x, y, z): + assert x == 42 + assert y == 3.14 + assert z is True + assert True + case _: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.TupleEnum.Full(42, 3.14, True), + ], +) +def test_tuple_enum_partial_match(variant: enums.TupleEnum): + match variant: + case enums.TupleEnum.Full(a): + assert a == 42 + case _: + assert False + + +@pytest.mark.parametrize( + "variant", + [ + enums.MixedComplexEnum.Nothing(), + enums.MixedComplexEnum.Empty(), + ], +) +def test_mixed_complex_enum_match_statement(variant: enums.MixedComplexEnum): + match variant: + case enums.MixedComplexEnum.Nothing(): + assert True + case enums.MixedComplexEnum.Empty(): + assert True + case _: + assert False diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs index 116b8968..e98010fe 100644 --- a/tests/ui/invalid_pyclass_enum.rs +++ b/tests/ui/invalid_pyclass_enum.rs @@ -21,12 +21,6 @@ enum NoUnitVariants { UnitVariant, } -#[pyclass] -enum NoTupleVariants { - StructVariant { field: i32 }, - TupleVariant(i32), -} - #[pyclass] enum SimpleNoSignature { #[pyo3(constructor = (a, b))] diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr index e9ba9806..7e3b6ffa 100644 --- a/tests/ui/invalid_pyclass_enum.stderr +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -17,23 +17,15 @@ error: #[pyclass] can't be used on enums without any variants | ^^ error: Unit variant `UnitVariant` is not yet supported in a complex enum - = help: change to a struct variant with no fields: `UnitVariant { }` + = help: change to an empty tuple variant instead: `UnitVariant()` = note: the enum is complex because of non-unit variant `StructVariant` --> tests/ui/invalid_pyclass_enum.rs:21:5 | 21 | UnitVariant, | ^^^^^^^^^^^ -error: Tuple variant `TupleVariant` is not yet supported in a complex enum - = help: change to a struct variant with named fields: `TupleVariant { /* fields */ }` - = note: the enum is complex because of non-unit variant `StructVariant` - --> tests/ui/invalid_pyclass_enum.rs:27:5 - | -27 | TupleVariant(i32), - | ^^^^^^^^^^^^ - error: `constructor` can't be used on a simple enum variant - --> tests/ui/invalid_pyclass_enum.rs:32:12 + --> tests/ui/invalid_pyclass_enum.rs:26:12 | -32 | #[pyo3(constructor = (a, b))] +26 | #[pyo3(constructor = (a, b))] | ^^^^^^^^^^^ diff --git a/tests/ui/invalid_pymethod_enum.rs b/tests/ui/invalid_pymethod_enum.rs index 9b596e08..5c41d19d 100644 --- a/tests/ui/invalid_pymethod_enum.rs +++ b/tests/ui/invalid_pymethod_enum.rs @@ -16,4 +16,20 @@ impl ComplexEnum { } } +#[pyclass] +enum TupleEnum { + Int(i32), + Str(String), +} + +#[pymethods] +impl TupleEnum { + fn mutate_in_place(&mut self) { + *self = match self { + TupleEnum::Int(int) => TupleEnum::Str(int.to_string()), + TupleEnum::Str(string) => TupleEnum::Int(string.len() as i32), + } + } +} + fn main() {} diff --git a/tests/ui/invalid_pymethod_enum.stderr b/tests/ui/invalid_pymethod_enum.stderr index 6cf6fe89..bc377d2a 100644 --- a/tests/ui/invalid_pymethod_enum.stderr +++ b/tests/ui/invalid_pymethod_enum.stderr @@ -22,3 +22,28 @@ note: required by a bound in `PyRefMut` | pub struct PyRefMut<'p, T: PyClass> { | ^^^^^^^^^^^^^^ required by this bound in `PyRefMut` = note: this error originates in the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0271]: type mismatch resolving `::Frozen == False` + --> tests/ui/invalid_pymethod_enum.rs:27:24 + | +27 | fn mutate_in_place(&mut self) { + | ^ expected `False`, found `True` + | +note: required by a bound in `extract_pyclass_ref_mut` + --> src/impl_/extract_argument.rs + | + | pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass>( + | ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut` + +error[E0271]: type mismatch resolving `::Frozen == False` + --> tests/ui/invalid_pymethod_enum.rs:25:1 + | +25 | #[pymethods] + | ^^^^^^^^^^^^ expected `False`, found `True` + | +note: required by a bound in `PyRefMut` + --> src/pycell.rs + | + | pub struct PyRefMut<'p, T: PyClass> { + | ^^^^^^^^^^^^^^ required by this bound in `PyRefMut` + = note: this error originates in the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info)