feat: support pyclass on complex enums

This commit is contained in:
Mate Kovacs 2024-01-16 00:57:42 +09:00
parent 48e74b7829
commit 3ed5ddb0ec
14 changed files with 1031 additions and 91 deletions

View File

@ -2,7 +2,7 @@
PyO3 exposes a group of attributes powered by Rust's proc macro system for defining Python classes as Rust structs.
The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or a fieldless `enum` (a.k.a. C-like enum) to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`.
The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or `enum` to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`.
This chapter will discuss the functionality and configuration these attributes offer. Below is a list of links to the relevant section of this chapter for each:
@ -21,13 +21,13 @@ This chapter will discuss the functionality and configuration these attributes o
## Defining a new class
To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or a fieldless enum.
To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or enum.
```rust
# #![allow(dead_code)]
use pyo3::prelude::*;
#[pyclass]
struct Integer {
struct MyClass {
inner: i32,
}
@ -35,7 +35,15 @@ struct Integer {
#[pyclass]
struct Number(i32);
// PyO3 supports custom discriminants in enums
// PyO3 supports unit-only enums (which contain only unit variants)
// These simple enums behave similarly to Python's enumerations (enum.Enum)
#[pyclass]
enum MyEnum {
Variant,
OtherVariant = 30, // PyO3 supports custom discriminants.
}
// PyO3 supports custom discriminants in unit-only enums
#[pyclass]
enum HttpResponse {
Ok = 200,
@ -44,14 +52,19 @@ enum HttpResponse {
// ...
}
// PyO3 also supports enums with non-unit variants
// These complex enums have sligtly different behavior from the simple enums above
// They are meant to work with instance checks and match statement patterns
#[pyclass]
enum MyEnum {
Variant,
OtherVariant = 30, // PyO3 supports custom discriminants.
enum Shape {
Circle { radius: f64 },
Rectangle { width: f64, height: f64 },
RegularPolygon { side_count: u32, radius: f64 },
Nothing { },
}
```
The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass` and `MyEnum`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter.
The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass`, `Number`, `MyEnum`, `HttpResponse`, and `Shape`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter.
### Restrictions
@ -964,7 +977,13 @@ Note that `text_signature` on `#[new]` is not compatible with compilation in
## #[pyclass] enums
Currently PyO3 only supports fieldless enums. PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`:
Enum support in PyO3 comes in two flavors, depending on what kind of variants the enum has: simple and complex.
### Simple enums
A simple enum (a.k.a. C-like enum) has only unit variants.
PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`:
```rust
# use pyo3::prelude::*;
@ -986,7 +1005,7 @@ Python::with_gil(|py| {
})
```
You can also convert your enums into `int`:
You can also convert your simple enums into `int`:
```rust
# use pyo3::prelude::*;
@ -1094,6 +1113,90 @@ enum BadSubclass {
`#[pyclass]` enums are currently not interoperable with `IntEnum` in Python.
### Complex enums
An enum is complex if it has any non-unit (struct or tuple) variants.
Currently PyO3 supports only struct variants in a complex enum. Support for unit and tuple variants is planned.
PyO3 adds a class attribute for each variant, which may be used to construct values and in match patterns. PyO3 also provides getter methods for all fields of each variant.
```rust
# use pyo3::prelude::*;
#[pyclass]
enum Shape {
Circle { radius: f64 },
Rectangle { width: f64, height: f64 },
RegularPolygon { side_count: u32, radius: f64 },
Nothing { },
}
Python::with_gil(|py| {
let def_count_vertices = if py.version_info() >= (3, 10) { r#"
def count_vertices(cls, shape):
match shape:
case cls.Circle():
return 0
case cls.Rectangle():
return 4
case cls.RegularPolygon(side_count=n):
return n
case cls.Nothing():
return 0
"# } else { r#"
def count_vertices(cls, shape):
if isinstance(shape, cls.Circle):
return 0
elif isinstance(shape, cls.Rectangle):
return 4
elif isinstance(shape, cls.RegularPolygon):
n = shape.side_count
return n
elif isinstance(shape, cls.Nothing):
return 0
"# };
let circle = Shape::Circle { radius: 10.0 }.into_py(py);
let square = Shape::RegularPolygon { side_count: 4, radius: 10.0 }.into_py(py);
let cls = py.get_type::<Shape>();
pyo3::py_run!(py, circle square cls, &format!(r#"
assert isinstance(circle, cls)
assert isinstance(circle, cls.Circle)
assert circle.radius == 10.0
assert isinstance(square, cls)
assert isinstance(square, cls.RegularPolygon)
assert square.side_count == 4
assert square.radius == 10.0
{}
assert count_vertices(cls, circle) == 0
assert count_vertices(cls, square) == 4
"#, def_count_vertices))
})
```
WARNING: `Py::new` and `.into_py` are currently inconsistent. Note how the constructed value is _not_ an instance of the specific variant. For this reason, constructing values is only recommended using `.into_py`.
```rust
# use pyo3::prelude::*;
#[pyclass]
enum MyEnum {
Variant { i: i32 },
}
Python::with_gil(|py| {
let x = Py::new(py, MyEnum::Variant { i: 42 }).unwrap();
let cls = py.get_type::<MyEnum>();
pyo3::py_run!(py, x cls, r#"
assert isinstance(x, cls)
assert not isinstance(x, cls.Variant)
"#)
})
```
## Implementation details
The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.

View File

@ -0,0 +1 @@
Support `#[pyclass]` on enums that have non-unit variants.

View File

@ -7,7 +7,7 @@ use crate::attributes::{
};
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::FnSpec;
use crate::method::{FnArg, FnSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pymethod::{
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
@ -16,7 +16,7 @@ use crate::pymethod::{
use crate::utils::{self, apply_renaming_rule, get_pyo3_crate, PythonDoc};
use crate::PyFunctionOptions;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use quote::{format_ident, quote};
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
@ -30,6 +30,7 @@ pub enum PyClassKind {
}
/// The parsed arguments of the pyclass macro
#[derive(Clone)]
pub struct PyClassArgs {
pub class_kind: PyClassKind,
pub options: PyClassPyO3Options,
@ -52,7 +53,7 @@ impl PyClassArgs {
}
}
#[derive(Default)]
#[derive(Clone, Default)]
pub struct PyClassPyO3Options {
pub krate: Option<CrateAttribute>,
pub dict: Option<kw::dict>,
@ -128,7 +129,7 @@ impl Parse for PyClassPyO3Option {
}
}
impl PyClassPyO3Options {
impl Parse for PyClassPyO3Options {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut options: PyClassPyO3Options = Default::default();
@ -138,7 +139,9 @@ impl PyClassPyO3Options {
Ok(options)
}
}
impl PyClassPyO3Options {
pub fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> syn::Result<()> {
take_pyo3_options(attrs)?
.into_iter()
@ -369,73 +372,24 @@ fn impl_class(
})
}
struct PyClassEnumVariant<'a> {
ident: &'a syn::Ident,
options: EnumVariantPyO3Options,
}
impl<'a> PyClassEnumVariant<'a> {
fn python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> {
self.options
.name
.as_ref()
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
.unwrap_or_else(|| {
let name = self.ident.unraw();
if let Some(attr) = &args.options.rename_all {
let new_name = apply_renaming_rule(attr.value.rule, &name.to_string());
Cow::Owned(Ident::new(&new_name, Span::call_site()))
} else {
Cow::Owned(name)
}
})
}
}
struct PyClassEnum<'a> {
ident: &'a syn::Ident,
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
// This matters when the underlying representation may not fit in `isize`.
repr_type: syn::Ident,
variants: Vec<PyClassEnumVariant<'a>>,
enum PyClassEnum<'a> {
Simple(PyClassSimpleEnum<'a>),
Complex(PyClassComplexEnum<'a>),
}
impl<'a> PyClassEnum<'a> {
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
fn is_numeric_type(t: &syn::Ident) -> bool {
[
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
"isize",
]
.iter()
.any(|&s| t == s)
}
let ident = &enum_.ident;
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
// "Under the default representation, the specified discriminant is interpreted as an isize
// value", so `isize` should be enough by default.
let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) {
let args =
attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
if let Some(ident) = args
.into_iter()
.filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
.find(is_numeric_type)
{
repr_type = ident;
}
}
let variants = enum_
let has_only_unit_variants = enum_
.variants
.iter_mut()
.map(extract_variant_data)
.collect::<syn::Result<_>>()?;
Ok(Self {
ident,
repr_type,
variants,
.iter()
.all(|variant| matches!(variant.fields, syn::Fields::Unit));
Ok(if has_only_unit_variants {
let simple_enum = PyClassSimpleEnum::new(enum_)?;
Self::Simple(simple_enum)
} else {
let complex_enum = PyClassComplexEnum::new(enum_)?;
Self::Complex(complex_enum)
})
}
}
@ -460,6 +414,208 @@ pub fn build_py_enum(
impl_enum(enum_, &args, doc, method_type)
}
struct PyClassSimpleEnum<'a> {
ident: &'a syn::Ident,
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
// This matters when the underlying representation may not fit in `isize`.
repr_type: syn::Ident,
variants: Vec<PyClassEnumUnitVariant<'a>>,
}
impl<'a> PyClassSimpleEnum<'a> {
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
fn is_numeric_type(t: &syn::Ident) -> bool {
[
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
"isize",
]
.iter()
.any(|&s| t == s)
}
fn extract_unit_variant_data(
variant: &mut syn::Variant,
) -> syn::Result<PyClassEnumUnitVariant<'_>> {
use syn::Fields;
let ident = match &variant.fields {
Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Must be a unit variant."),
};
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok(PyClassEnumUnitVariant { ident, options })
}
let ident = &enum_.ident;
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
// "Under the default representation, the specified discriminant is interpreted as an isize
// value", so `isize` should be enough by default.
let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) {
let args =
attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
if let Some(ident) = args
.into_iter()
.filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
.find(is_numeric_type)
{
repr_type = ident;
}
}
let variants: Vec<_> = enum_
.variants
.iter_mut()
.map(extract_unit_variant_data)
.collect::<syn::Result<_>>()?;
Ok(Self {
ident,
repr_type,
variants,
})
}
}
struct PyClassComplexEnum<'a> {
ident: &'a syn::Ident,
variants: Vec<PyClassEnumVariant<'a>>,
}
impl<'a> PyClassComplexEnum<'a> {
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
let witness = enum_
.variants
.iter()
.find(|variant| !matches!(variant.fields, syn::Fields::Unit))
.expect("complex enum has a non-unit variant")
.ident
.to_owned();
let extract_variant_data =
|variant: &'a mut syn::Variant| -> syn::Result<PyClassEnumVariant<'a>> {
use syn::Fields;
let ident = &variant.ident;
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
let variant = match &variant.fields {
Fields::Unit => {
bail_spanned!(variant.span() => format!(
"Unit variant `{ident}` is not yet supported in a complex enum\n\
= help: change to a struct variant with no fields: `{ident} {{ }}`\n\
= note: the enum is complex because of non-unit variant `{witness}`",
ident=ident, witness=witness))
}
Fields::Named(fields) => {
let fields = fields
.named
.iter()
.map(|field| PyClassEnumVariantNamedField {
ident: field.ident.as_ref().expect("named field has an identifier"),
ty: &field.ty,
span: field.span(),
})
.collect();
PyClassEnumVariant::Struct(PyClassEnumStructVariant {
ident,
fields,
options,
})
}
Fields::Unnamed(_) => {
bail_spanned!(variant.span() => format!(
"Tuple variant `{ident}` is not yet supported in a complex enum\n\
= help: change to a struct variant with named fields: `{ident} {{ /* fields */ }}`\n\
= note: the enum is complex because of non-unit variant `{witness}`",
ident=ident, witness=witness))
}
};
Ok(variant)
};
let ident = &enum_.ident;
let variants: Vec<_> = enum_
.variants
.iter_mut()
.map(extract_variant_data)
.collect::<syn::Result<_>>()?;
Ok(Self { ident, variants })
}
}
enum PyClassEnumVariant<'a> {
// TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>),
Struct(PyClassEnumStructVariant<'a>),
// TODO(mkovaxx): Tuple(PyClassEnumTupleVariant<'a>),
}
trait EnumVariant {
fn get_ident(&self) -> &syn::Ident;
fn get_options(&self) -> &EnumVariantPyO3Options;
fn get_python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> {
self.get_options()
.name
.as_ref()
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
.unwrap_or_else(|| {
let name = self.get_ident().unraw();
if let Some(attr) = &args.options.rename_all {
let new_name = apply_renaming_rule(attr.value.rule, &name.to_string());
Cow::Owned(Ident::new(&new_name, Span::call_site()))
} else {
Cow::Owned(name)
}
})
}
}
impl<'a> EnumVariant for PyClassEnumVariant<'a> {
fn get_ident(&self) -> &syn::Ident {
match self {
PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident,
}
}
fn get_options(&self) -> &EnumVariantPyO3Options {
match self {
PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options,
}
}
}
/// A unit variant has no fields
struct PyClassEnumUnitVariant<'a> {
ident: &'a syn::Ident,
options: EnumVariantPyO3Options,
}
impl<'a> EnumVariant for PyClassEnumUnitVariant<'a> {
fn get_ident(&self) -> &syn::Ident {
self.ident
}
fn get_options(&self) -> &EnumVariantPyO3Options {
&self.options
}
}
/// A struct variant has named fields
struct PyClassEnumStructVariant<'a> {
ident: &'a syn::Ident,
fields: Vec<PyClassEnumVariantNamedField<'a>>,
options: EnumVariantPyO3Options,
}
struct PyClassEnumVariantNamedField<'a> {
ident: &'a syn::Ident,
ty: &'a syn::Type,
span: Span,
}
/// `#[pyo3()]` options for pyclass enum variants
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
@ -505,11 +661,25 @@ fn impl_enum(
args: &PyClassArgs,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> Result<TokenStream> {
match enum_ {
PyClassEnum::Simple(simple_enum) => impl_simple_enum(simple_enum, args, doc, methods_type),
PyClassEnum::Complex(complex_enum) => {
impl_complex_enum(complex_enum, args, doc, methods_type)
}
}
}
fn impl_simple_enum(
simple_enum: PyClassSimpleEnum<'_>,
args: &PyClassArgs,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> Result<TokenStream> {
let krate = get_pyo3_crate(&args.options.krate);
let cls = enum_.ident;
let cls = simple_enum.ident;
let ty: syn::Type = syn::parse_quote!(#cls);
let variants = enum_.variants;
let variants = simple_enum.variants;
let pytypeinfo = impl_pytypeinfo(cls, args, None);
let (default_repr, default_repr_slot) = {
@ -519,7 +689,7 @@ fn impl_enum(
let repr = format!(
"{}.{}",
get_class_python_name(cls, args),
variant.python_name(args),
variant.get_python_name(args),
);
quote! { #cls::#variant_name => #repr, }
});
@ -534,7 +704,7 @@ fn impl_enum(
(repr_impl, repr_slot)
};
let repr_type = &enum_.repr_type;
let repr_type = &simple_enum.repr_type;
let (default_int, default_int_slot) = {
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
@ -601,7 +771,10 @@ fn impl_enum(
cls,
args,
methods_type,
enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name(args)))),
simple_enum_default_methods(
cls,
variants.iter().map(|v| (v.ident, v.get_python_name(args))),
),
default_slots,
)
.doc(doc)
@ -626,6 +799,214 @@ fn impl_enum(
})
}
fn impl_complex_enum(
complex_enum: PyClassComplexEnum<'_>,
args: &PyClassArgs,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> Result<TokenStream> {
// Need to rig the enum PyClass options
let args = {
let mut rigged_args = args.clone();
// Needs to be frozen to disallow `&mut self` methods, which could break a runtime invariant
rigged_args.options.frozen = parse_quote!(frozen);
// Needs to be subclassable by the variant PyClasses
rigged_args.options.subclass = parse_quote!(subclass);
rigged_args
};
let krate = get_pyo3_crate(&args.options.krate);
let cls = complex_enum.ident;
let variants = complex_enum.variants;
let pytypeinfo = impl_pytypeinfo(cls, &args, None);
let default_slots = vec![];
let impl_builder = PyClassImplsBuilder::new(
cls,
&args,
methods_type,
complex_enum_default_methods(
cls,
variants
.iter()
.map(|v| (v.get_ident(), v.get_python_name(&args))),
),
default_slots,
)
.doc(doc);
// Need to customize the into_py impl so that it returns the variant PyClass
let enum_into_py_impl = {
let match_arms: Vec<TokenStream> = variants
.iter()
.map(|variant| {
let variant_ident = variant.get_ident();
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
quote! {
#cls::#variant_ident { .. } => {
let pyclass_init = _pyo3::PyClassInitializer::from(self).add_subclass(#variant_cls);
let variant_value = _pyo3::Py::new(py, pyclass_init).unwrap();
_pyo3::IntoPy::into_py(variant_value, py)
}
}
})
.collect();
quote! {
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
match self {
#(#match_arms)*
}
}
}
}
};
let pyclass_impls: TokenStream = vec![
impl_builder.impl_pyclass(),
impl_builder.impl_extractext(),
enum_into_py_impl,
impl_builder.impl_pyclassimpl()?,
impl_builder.impl_freelist(),
]
.into_iter()
.collect();
let mut variant_cls_zsts = vec![];
let mut variant_cls_pytypeinfos = vec![];
let mut variant_cls_pyclass_impls = vec![];
let mut variant_cls_impls = vec![];
for variant in &variants {
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
let variant_cls_zst = quote! {
#[doc(hidden)]
#[allow(non_camel_case_types)]
struct #variant_cls;
};
variant_cls_zsts.push(variant_cls_zst);
let variant_args = PyClassArgs {
class_kind: PyClassKind::Struct,
// TODO(mkovaxx): propagate variant.options
options: parse_quote!(extends = #cls, frozen),
};
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None);
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
let variant_new = complex_enum_variant_new(cls, variant)?;
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant)?;
variant_cls_impls.push(variant_cls_impl);
let pyclass_impl = PyClassImplsBuilder::new(
&variant_cls,
&variant_args,
methods_type,
field_getters,
vec![variant_new],
)
.impl_all()?;
variant_cls_pyclass_impls.push(pyclass_impl);
}
Ok(quote! {
const _: () = {
use #krate as _pyo3;
#pytypeinfo
#pyclass_impls
#[doc(hidden)]
#[allow(non_snake_case)]
impl #cls {}
#(#variant_cls_zsts)*
#(#variant_cls_pytypeinfos)*
#(#variant_cls_pyclass_impls)*
#(#variant_cls_impls)*
};
})
}
fn impl_complex_enum_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumVariant<'_>,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
match variant {
PyClassEnumVariant::Struct(struct_variant) => {
impl_complex_enum_struct_variant_cls(enum_name, struct_variant)
}
}
}
fn impl_complex_enum_struct_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumStructVariant<'_>,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
let variant_ident = &variant.ident;
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
let variant_cls_type = parse_quote!(#variant_cls);
let mut field_names: Vec<Ident> = vec![];
let mut fields_with_types: Vec<TokenStream> = vec![];
let mut field_getters = vec![];
let mut field_getter_impls: Vec<TokenStream> = vec![];
for field in &variant.fields {
let field_name = field.ident;
let field_type = field.ty;
let field_with_type = quote! { #field_name: #field_type };
let field_getter = complex_enum_variant_field_getter(
&variant_cls_type,
field_name,
field_type,
field.span,
)?;
let field_getter_impl = quote! {
fn #field_name(slf: _pyo3::PyRef<Self>) -> _pyo3::PyResult<#field_type> {
match &*slf.into_super() {
#enum_name::#variant_ident { #field_name, .. } => Ok(#field_name.clone()),
_ => unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
}
}
};
field_names.push(field_name.clone());
fields_with_types.push(field_with_type);
field_getters.push(field_getter);
field_getter_impls.push(field_getter_impl);
}
let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
impl #variant_cls {
fn __pymethod_constructor__(py: _pyo3::Python<'_>, #(#fields_with_types,)*) -> _pyo3::PyClassInitializer<#variant_cls> {
let base_value = #enum_name::#variant_ident { #(#field_names,)* };
_pyo3::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
}
#(#field_getter_impls)*
}
};
Ok((cls_impl, field_getters))
}
fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident {
format_ident!("{}_{}", enum_, variant)
}
fn generate_default_protocol_slot(
cls: &syn::Type,
method: &mut syn::ImplItemFn,
@ -645,7 +1026,7 @@ fn generate_default_protocol_slot(
)
}
fn enum_default_methods<'a>(
fn simple_enum_default_methods<'a>(
cls: &'a syn::Ident,
unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
) -> Vec<MethodAndMethodDef> {
@ -667,14 +1048,167 @@ fn enum_default_methods<'a>(
.collect()
}
fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
use syn::Fields;
let ident = match variant.fields {
Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
fn complex_enum_default_methods<'a>(
cls: &'a syn::Ident,
variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
) -> Vec<MethodAndMethodDef> {
let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
rust_ident: var_ident.clone(),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute {
kw: syn::parse_quote! { name },
value: NameLitStr(py_ident.clone()),
}),
deprecations: Default::default(),
},
};
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok(PyClassEnumVariant { ident, options })
variant_names
.into_iter()
.map(|(var, py_name)| {
gen_complex_enum_variant_attr(cls, &cls_type, &variant_to_attribute(var, &py_name))
})
.collect()
}
pub fn gen_complex_enum_variant_attr(
cls: &syn::Ident,
cls_type: &syn::Type,
spec: &ConstSpec,
) -> MethodAndMethodDef {
let member = &spec.rust_ident;
let wrapper_ident = format_ident!("__pymethod_variant_cls_{}__", member);
let deprecations = &spec.attributes.deprecations;
let python_name = &spec.null_terminated_python_name();
let variant_cls = format_ident!("{}_{}", cls, member);
let associated_method = quote! {
fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
#deprecations
::std::result::Result::Ok(py.get_type::<#variant_cls>().into())
}
};
let method_def = quote! {
_pyo3::class::PyMethodDefType::ClassAttribute({
_pyo3::class::PyClassAttributeDef::new(
#python_name,
_pyo3::impl_::pymethods::PyClassAttributeFactory(#cls_type::#wrapper_ident)
)
})
};
MethodAndMethodDef {
associated_method,
method_def,
}
}
fn complex_enum_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumVariant<'a>,
) -> Result<MethodAndSlotDef> {
match variant {
PyClassEnumVariant::Struct(struct_variant) => {
complex_enum_struct_variant_new(cls, struct_variant)
}
}
}
fn complex_enum_struct_variant_new<'a>(
cls: &'a syn::Ident,
variant: &'a PyClassEnumStructVariant<'a>,
) -> Result<MethodAndSlotDef> {
let variant_cls = format_ident!("{}_{}", cls, variant.ident);
let variant_cls_type: syn::Type = parse_quote!(#variant_cls);
let arg_py_ident: syn::Ident = parse_quote!(py);
let arg_py_type: syn::Type = parse_quote!(_pyo3::Python<'_>);
let args = {
let mut no_pyo3_attrs = vec![];
let attrs = crate::pyfunction::PyFunctionArgPyO3Attributes::from_attrs(&mut no_pyo3_attrs)?;
let mut args = vec![
// py: Python<'_>
FnArg {
name: &arg_py_ident,
ty: &arg_py_type,
optional: None,
default: None,
py: true,
attrs: attrs.clone(),
is_varargs: false,
is_kwargs: false,
is_cancel_handle: false,
},
];
for field in &variant.fields {
args.push(FnArg {
name: field.ident,
ty: field.ty,
optional: None,
default: None,
py: false,
attrs: attrs.clone(),
is_varargs: false,
is_kwargs: false,
is_cancel_handle: false,
});
}
args
};
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
let spec = FnSpec {
tp: crate::method::FnType::FnNew,
name: &format_ident!("__pymethod_constructor__"),
python_name: format_ident!("__new__"),
signature,
output: variant_cls_type.clone(),
convention: crate::method::CallingConvention::TpNew,
text_signature: None,
asyncness: None,
unsafety: None,
deprecations: Deprecations::default(),
};
crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec)
}
fn complex_enum_variant_field_getter<'a>(
variant_cls_type: &'a syn::Type,
field_name: &'a syn::Ident,
field_type: &'a syn::Type,
field_span: Span,
) -> Result<MethodAndMethodDef> {
let signature = crate::pyfunction::FunctionSignature::from_arguments(vec![])?;
let self_type = crate::method::SelfType::TryFromPyCell(field_span);
let spec = FnSpec {
tp: crate::method::FnType::Getter(self_type.clone()),
name: field_name,
python_name: field_name.clone(),
signature,
output: field_type.clone(),
convention: crate::method::CallingConvention::Noargs,
text_signature: None,
asyncness: None,
unsafety: None,
deprecations: Deprecations::default(),
};
let property_type = crate::pymethod::PropertyType::Function {
self_type: &self_type,
spec: &spec,
doc: crate::get_doc(&[], None),
};
let getter = crate::pymethod::impl_py_getter_def(variant_cls_type, property_type)?;
Ok(getter)
}
fn descriptors_to_items(

View File

@ -324,7 +324,8 @@ pub fn impl_py_method_def(
})
}
fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<MethodAndSlotDef> {
/// Also used by pyclass.
pub fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<MethodAndSlotDef> {
let wrapper_ident = syn::Ident::new("__pymethod___new____", Span::call_site());
let associated_method = spec.get_wrapper_function(&wrapper_ident, Some(cls))?;
// Use just the text_signature_call_signature() because the class' Python name

View File

@ -1,4 +1,5 @@
import nox
import sys
from nox.command import CommandFailed
nox.options.sessions = ["test"]
@ -13,7 +14,12 @@ def test(session: nox.Session):
except CommandFailed:
# No binary wheel for numpy available on this platform
pass
session.run("pytest", *session.posargs)
ignored_paths = []
if sys.version_info < (3, 10):
# Match syntax is only available in Python >= 3.10
ignored_paths.append("tests/test_enums_match.py")
ignore_args = [f"--ignore={path}" for path in ignored_paths]
session.run("pytest", *ignore_args, *session.posargs)
@nox.session

58
pytests/src/enums.rs Normal file
View File

@ -0,0 +1,58 @@
use pyo3::{pyclass, pyfunction, pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
#[pymodule]
pub fn enums(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<SimpleEnum>()?;
m.add_class::<ComplexEnum>()?;
m.add_wrapped(wrap_pyfunction!(do_simple_stuff))?;
m.add_wrapped(wrap_pyfunction!(do_complex_stuff))?;
Ok(())
}
#[pyclass]
pub enum SimpleEnum {
Sunday,
Monday,
Tuesday,
Wednesday,
Thursday,
Friday,
Saturday,
}
#[pyfunction]
pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
match thing {
SimpleEnum::Sunday => SimpleEnum::Monday,
SimpleEnum::Monday => SimpleEnum::Tuesday,
SimpleEnum::Tuesday => SimpleEnum::Wednesday,
SimpleEnum::Wednesday => SimpleEnum::Thursday,
SimpleEnum::Thursday => SimpleEnum::Friday,
SimpleEnum::Friday => SimpleEnum::Saturday,
SimpleEnum::Saturday => SimpleEnum::Sunday,
}
}
#[pyclass]
pub enum ComplexEnum {
Int { i: i32 },
Float { f: f64 },
Str { s: String },
EmptyStruct {},
MultiFieldStruct { a: i32, b: f64, c: bool },
}
#[pyfunction]
pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
match thing {
ComplexEnum::Int { i } => ComplexEnum::Str { s: i.to_string() },
ComplexEnum::Float { f } => ComplexEnum::Float { f: f * f },
ComplexEnum::Str { s } => ComplexEnum::Int { i: s.len() as i32 },
ComplexEnum::EmptyStruct {} => ComplexEnum::EmptyStruct {},
ComplexEnum::MultiFieldStruct { a, b, c } => ComplexEnum::MultiFieldStruct {
a: *a,
b: *b,
c: *c,
},
}
}

View File

@ -7,6 +7,7 @@ pub mod buf_and_str;
pub mod comparisons;
pub mod datetime;
pub mod dict_iter;
pub mod enums;
pub mod misc;
pub mod objstore;
pub mod othermod;
@ -25,6 +26,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;
m.add_wrapped(wrap_pymodule!(enums::enums))?;
m.add_wrapped(wrap_pymodule!(misc::misc))?;
m.add_wrapped(wrap_pymodule!(objstore::objstore))?;
m.add_wrapped(wrap_pymodule!(othermod::othermod))?;
@ -44,6 +46,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
sys_modules.set_item("pyo3_pytests.enums", m.getattr("enums")?)?;
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;

116
pytests/tests/test_enums.py Normal file
View File

@ -0,0 +1,116 @@
import pytest
from pyo3_pytests import enums
def test_complex_enum_variant_constructors():
int_variant = enums.ComplexEnum.Int(42)
assert isinstance(int_variant, enums.ComplexEnum.Int)
float_variant = enums.ComplexEnum.Float(3.14)
assert isinstance(float_variant, enums.ComplexEnum.Float)
str_variant = enums.ComplexEnum.Str("hello")
assert isinstance(str_variant, enums.ComplexEnum.Str)
empty_struct_variant = enums.ComplexEnum.EmptyStruct()
assert isinstance(empty_struct_variant, enums.ComplexEnum.EmptyStruct)
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.Int(42),
enums.ComplexEnum.Float(3.14),
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
assert isinstance(variant, enums.ComplexEnum)
def test_complex_enum_field_getters():
int_variant = enums.ComplexEnum.Int(42)
assert int_variant.i == 42
float_variant = enums.ComplexEnum.Float(3.14)
assert float_variant.f == 3.14
str_variant = enums.ComplexEnum.Str("hello")
assert str_variant.s == "hello"
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
assert multi_field_struct_variant.a == 42
assert multi_field_struct_variant.b == 3.14
assert multi_field_struct_variant.c is True
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.Int(42),
enums.ComplexEnum.Float(3.14),
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
if isinstance(variant, enums.ComplexEnum.Int):
x = variant.i
assert x == 42
elif isinstance(variant, enums.ComplexEnum.Float):
x = variant.f
assert x == 3.14
elif isinstance(variant, enums.ComplexEnum.Str):
x = variant.s
assert x == "hello"
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
assert True
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
x = variant.a
y = variant.b
z = variant.c
assert x == 42
assert y == 3.14
assert z is True
else:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.Int(42),
enums.ComplexEnum.Float(3.14),
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
variant = enums.do_complex_stuff(variant)
if isinstance(variant, enums.ComplexEnum.Int):
x = variant.i
assert x == 5
elif isinstance(variant, enums.ComplexEnum.Float):
x = variant.f
assert x == 9.8596
elif isinstance(variant, enums.ComplexEnum.Str):
x = variant.s
assert x == "42"
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
assert True
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
x = variant.a
y = variant.b
z = variant.c
assert x == 42
assert y == 3.14
assert z is True
else:
assert False

View File

@ -0,0 +1,59 @@
# This file is only collected when Python >= 3.10, because it tests match syntax.
import pytest
from pyo3_pytests import enums
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.Int(42),
enums.ComplexEnum.Float(3.14),
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_match_statement(variant: enums.ComplexEnum):
match variant:
case enums.ComplexEnum.Int(i=x):
assert x == 42
case enums.ComplexEnum.Float(f=x):
assert x == 3.14
case enums.ComplexEnum.Str(s=x):
assert x == "hello"
case enums.ComplexEnum.EmptyStruct():
assert True
case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z):
assert x == 42
assert y == 3.14
assert z is True
case _:
assert False
@pytest.mark.parametrize(
"variant",
[
enums.ComplexEnum.Int(42),
enums.ComplexEnum.Float(3.14),
enums.ComplexEnum.Str("hello"),
enums.ComplexEnum.EmptyStruct(),
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
],
)
def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum):
match enums.do_complex_stuff(variant):
case enums.ComplexEnum.Int(i=x):
assert x == 5
case enums.ComplexEnum.Float(f=x):
assert x == 9.8596
case enums.ComplexEnum.Str(s=x):
assert x == "42"
case enums.ComplexEnum.EmptyStruct():
assert True
case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z):
assert x == 42
assert y == 3.14
assert z is True
case _:
assert False

View File

@ -14,6 +14,7 @@ fn test_compile_errors() {
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
t.compile_fail("tests/ui/invalid_pymethods_buffer.rs");
t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs");
t.compile_fail("tests/ui/invalid_pymethod_enum.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_pymodule_args.rs");
t.compile_fail("tests/ui/reject_generics.rs");

View File

@ -15,4 +15,16 @@ enum NotDrivedClass {
#[pyclass]
enum NoEmptyEnum {}
#[pyclass]
enum NoUnitVariants {
StructVariant { field: i32 },
UnitVariant,
}
#[pyclass]
enum NoTupleVariants {
StructVariant { field: i32 },
TupleVariant(i32),
}
fn main() {}

View File

@ -15,3 +15,19 @@ error: #[pyclass] can't be used on enums without any variants
|
16 | enum NoEmptyEnum {}
| ^^
error: Unit variant `UnitVariant` is not yet supported in a complex enum
= help: change to a struct variant with no fields: `UnitVariant { }`
= note: the enum is complex because of non-unit variant `StructVariant`
--> tests/ui/invalid_pyclass_enum.rs:21:5
|
21 | UnitVariant,
| ^^^^^^^^^^^
error: Tuple variant `TupleVariant` is not yet supported in a complex enum
= help: change to a struct variant with named fields: `TupleVariant { /* fields */ }`
= note: the enum is complex because of non-unit variant `StructVariant`
--> tests/ui/invalid_pyclass_enum.rs:27:5
|
27 | TupleVariant(i32),
| ^^^^^^^^^^^^

View File

@ -0,0 +1,19 @@
use pyo3::prelude::*;
#[pyclass]
enum ComplexEnum {
Int { int: i32 },
Str { string: String },
}
#[pymethods]
impl ComplexEnum {
fn mutate_in_place(&mut self) {
*self = match self {
ComplexEnum::Int { int } => ComplexEnum::Str { string: int.to_string() },
ComplexEnum::Str { string } => ComplexEnum::Int { int: string.len() as i32 },
}
}
}
fn main() {}

View File

@ -0,0 +1,11 @@
error[E0271]: type mismatch resolving `<ComplexEnum as PyClass>::Frozen == False`
--> tests/ui/invalid_pymethod_enum.rs:11:24
|
11 | fn mutate_in_place(&mut self) {
| ^ expected `False`, found `True`
|
note: required by a bound in `extract_pyclass_ref_mut`
--> src/impl_/extract_argument.rs
|
| pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass<Frozen = False>>(
| ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut`