Merge pull request #3582 from mkovaxx/pyclass_complex_enum
Full ADT support with pyclass for complex enums
This commit is contained in:
commit
d1b072222a
|
@ -2,7 +2,7 @@
|
|||
|
||||
PyO3 exposes a group of attributes powered by Rust's proc macro system for defining Python classes as Rust structs.
|
||||
|
||||
The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or a fieldless `enum` (a.k.a. C-like enum) to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`.
|
||||
The main attribute is `#[pyclass]`, which is placed upon a Rust `struct` or `enum` to generate a Python type for it. They will usually also have *one* `#[pymethods]`-annotated `impl` block for the struct, which is used to define Python methods and constants for the generated Python type. (If the [`multiple-pymethods`] feature is enabled, each `#[pyclass]` is allowed to have multiple `#[pymethods]` blocks.) `#[pymethods]` may also have implementations for Python magic methods such as `__str__`.
|
||||
|
||||
This chapter will discuss the functionality and configuration these attributes offer. Below is a list of links to the relevant section of this chapter for each:
|
||||
|
||||
|
@ -21,13 +21,13 @@ This chapter will discuss the functionality and configuration these attributes o
|
|||
|
||||
## Defining a new class
|
||||
|
||||
To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or a fieldless enum.
|
||||
To define a custom Python class, add the `#[pyclass]` attribute to a Rust struct or enum.
|
||||
```rust
|
||||
# #![allow(dead_code)]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[pyclass]
|
||||
struct Integer {
|
||||
struct MyClass {
|
||||
inner: i32,
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,15 @@ struct Integer {
|
|||
#[pyclass]
|
||||
struct Number(i32);
|
||||
|
||||
// PyO3 supports custom discriminants in enums
|
||||
// PyO3 supports unit-only enums (which contain only unit variants)
|
||||
// These simple enums behave similarly to Python's enumerations (enum.Enum)
|
||||
#[pyclass]
|
||||
enum MyEnum {
|
||||
Variant,
|
||||
OtherVariant = 30, // PyO3 supports custom discriminants.
|
||||
}
|
||||
|
||||
// PyO3 supports custom discriminants in unit-only enums
|
||||
#[pyclass]
|
||||
enum HttpResponse {
|
||||
Ok = 200,
|
||||
|
@ -44,14 +52,19 @@ enum HttpResponse {
|
|||
// ...
|
||||
}
|
||||
|
||||
// PyO3 also supports enums with non-unit variants
|
||||
// These complex enums have sligtly different behavior from the simple enums above
|
||||
// They are meant to work with instance checks and match statement patterns
|
||||
#[pyclass]
|
||||
enum MyEnum {
|
||||
Variant,
|
||||
OtherVariant = 30, // PyO3 supports custom discriminants.
|
||||
enum Shape {
|
||||
Circle { radius: f64 },
|
||||
Rectangle { width: f64, height: f64 },
|
||||
RegularPolygon { side_count: u32, radius: f64 },
|
||||
Nothing { },
|
||||
}
|
||||
```
|
||||
|
||||
The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass` and `MyEnum`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter.
|
||||
The above example generates implementations for [`PyTypeInfo`] and [`PyClass`] for `MyClass`, `Number`, `MyEnum`, `HttpResponse`, and `Shape`. To see these generated implementations, refer to the [implementation details](#implementation-details) at the end of this chapter.
|
||||
|
||||
### Restrictions
|
||||
|
||||
|
@ -964,7 +977,13 @@ Note that `text_signature` on `#[new]` is not compatible with compilation in
|
|||
|
||||
## #[pyclass] enums
|
||||
|
||||
Currently PyO3 only supports fieldless enums. PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`:
|
||||
Enum support in PyO3 comes in two flavors, depending on what kind of variants the enum has: simple and complex.
|
||||
|
||||
### Simple enums
|
||||
|
||||
A simple enum (a.k.a. C-like enum) has only unit variants.
|
||||
|
||||
PyO3 adds a class attribute for each variant, so you can access them in Python without defining `#[new]`. PyO3 also provides default implementations of `__richcmp__` and `__int__`, so they can be compared using `==`:
|
||||
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
|
@ -986,7 +1005,7 @@ Python::with_gil(|py| {
|
|||
})
|
||||
```
|
||||
|
||||
You can also convert your enums into `int`:
|
||||
You can also convert your simple enums into `int`:
|
||||
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
|
@ -1094,6 +1113,90 @@ enum BadSubclass {
|
|||
|
||||
`#[pyclass]` enums are currently not interoperable with `IntEnum` in Python.
|
||||
|
||||
### Complex enums
|
||||
|
||||
An enum is complex if it has any non-unit (struct or tuple) variants.
|
||||
|
||||
Currently PyO3 supports only struct variants in a complex enum. Support for unit and tuple variants is planned.
|
||||
|
||||
PyO3 adds a class attribute for each variant, which may be used to construct values and in match patterns. PyO3 also provides getter methods for all fields of each variant.
|
||||
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
#[pyclass]
|
||||
enum Shape {
|
||||
Circle { radius: f64 },
|
||||
Rectangle { width: f64, height: f64 },
|
||||
RegularPolygon { side_count: u32, radius: f64 },
|
||||
Nothing { },
|
||||
}
|
||||
|
||||
Python::with_gil(|py| {
|
||||
let def_count_vertices = if py.version_info() >= (3, 10) { r#"
|
||||
def count_vertices(cls, shape):
|
||||
match shape:
|
||||
case cls.Circle():
|
||||
return 0
|
||||
case cls.Rectangle():
|
||||
return 4
|
||||
case cls.RegularPolygon(side_count=n):
|
||||
return n
|
||||
case cls.Nothing():
|
||||
return 0
|
||||
"# } else { r#"
|
||||
def count_vertices(cls, shape):
|
||||
if isinstance(shape, cls.Circle):
|
||||
return 0
|
||||
elif isinstance(shape, cls.Rectangle):
|
||||
return 4
|
||||
elif isinstance(shape, cls.RegularPolygon):
|
||||
n = shape.side_count
|
||||
return n
|
||||
elif isinstance(shape, cls.Nothing):
|
||||
return 0
|
||||
"# };
|
||||
|
||||
let circle = Shape::Circle { radius: 10.0 }.into_py(py);
|
||||
let square = Shape::RegularPolygon { side_count: 4, radius: 10.0 }.into_py(py);
|
||||
let cls = py.get_type::<Shape>();
|
||||
|
||||
pyo3::py_run!(py, circle square cls, &format!(r#"
|
||||
assert isinstance(circle, cls)
|
||||
assert isinstance(circle, cls.Circle)
|
||||
assert circle.radius == 10.0
|
||||
|
||||
assert isinstance(square, cls)
|
||||
assert isinstance(square, cls.RegularPolygon)
|
||||
assert square.side_count == 4
|
||||
assert square.radius == 10.0
|
||||
|
||||
{}
|
||||
|
||||
assert count_vertices(cls, circle) == 0
|
||||
assert count_vertices(cls, square) == 4
|
||||
"#, def_count_vertices))
|
||||
})
|
||||
```
|
||||
|
||||
WARNING: `Py::new` and `.into_py` are currently inconsistent. Note how the constructed value is _not_ an instance of the specific variant. For this reason, constructing values is only recommended using `.into_py`.
|
||||
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
#[pyclass]
|
||||
enum MyEnum {
|
||||
Variant { i: i32 },
|
||||
}
|
||||
|
||||
Python::with_gil(|py| {
|
||||
let x = Py::new(py, MyEnum::Variant { i: 42 }).unwrap();
|
||||
let cls = py.get_type::<MyEnum>();
|
||||
pyo3::py_run!(py, x cls, r#"
|
||||
assert isinstance(x, cls)
|
||||
assert not isinstance(x, cls.Variant)
|
||||
"#)
|
||||
})
|
||||
```
|
||||
|
||||
## Implementation details
|
||||
|
||||
The `#[pyclass]` macros rely on a lot of conditional code generation: each `#[pyclass]` can optionally have a `#[pymethods]` block.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Support `#[pyclass]` on enums that have non-unit variants.
|
|
@ -7,7 +7,7 @@ use crate::attributes::{
|
|||
};
|
||||
use crate::deprecations::Deprecations;
|
||||
use crate::konst::{ConstAttributes, ConstSpec};
|
||||
use crate::method::FnSpec;
|
||||
use crate::method::{FnArg, FnSpec};
|
||||
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
|
||||
use crate::pymethod::{
|
||||
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
|
||||
|
@ -16,7 +16,7 @@ use crate::pymethod::{
|
|||
use crate::utils::{self, apply_renaming_rule, get_pyo3_crate, PythonDoc};
|
||||
use crate::PyFunctionOptions;
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::ext::IdentExt;
|
||||
use syn::parse::{Parse, ParseStream};
|
||||
use syn::punctuated::Punctuated;
|
||||
|
@ -30,6 +30,7 @@ pub enum PyClassKind {
|
|||
}
|
||||
|
||||
/// The parsed arguments of the pyclass macro
|
||||
#[derive(Clone)]
|
||||
pub struct PyClassArgs {
|
||||
pub class_kind: PyClassKind,
|
||||
pub options: PyClassPyO3Options,
|
||||
|
@ -52,7 +53,7 @@ impl PyClassArgs {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct PyClassPyO3Options {
|
||||
pub krate: Option<CrateAttribute>,
|
||||
pub dict: Option<kw::dict>,
|
||||
|
@ -128,7 +129,7 @@ impl Parse for PyClassPyO3Option {
|
|||
}
|
||||
}
|
||||
|
||||
impl PyClassPyO3Options {
|
||||
impl Parse for PyClassPyO3Options {
|
||||
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
|
||||
let mut options: PyClassPyO3Options = Default::default();
|
||||
|
||||
|
@ -138,7 +139,9 @@ impl PyClassPyO3Options {
|
|||
|
||||
Ok(options)
|
||||
}
|
||||
}
|
||||
|
||||
impl PyClassPyO3Options {
|
||||
pub fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> syn::Result<()> {
|
||||
take_pyo3_options(attrs)?
|
||||
.into_iter()
|
||||
|
@ -369,73 +372,24 @@ fn impl_class(
|
|||
})
|
||||
}
|
||||
|
||||
struct PyClassEnumVariant<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
options: EnumVariantPyO3Options,
|
||||
}
|
||||
|
||||
impl<'a> PyClassEnumVariant<'a> {
|
||||
fn python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> {
|
||||
self.options
|
||||
.name
|
||||
.as_ref()
|
||||
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
|
||||
.unwrap_or_else(|| {
|
||||
let name = self.ident.unraw();
|
||||
if let Some(attr) = &args.options.rename_all {
|
||||
let new_name = apply_renaming_rule(attr.value.rule, &name.to_string());
|
||||
Cow::Owned(Ident::new(&new_name, Span::call_site()))
|
||||
} else {
|
||||
Cow::Owned(name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PyClassEnum<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
|
||||
// This matters when the underlying representation may not fit in `isize`.
|
||||
repr_type: syn::Ident,
|
||||
variants: Vec<PyClassEnumVariant<'a>>,
|
||||
enum PyClassEnum<'a> {
|
||||
Simple(PyClassSimpleEnum<'a>),
|
||||
Complex(PyClassComplexEnum<'a>),
|
||||
}
|
||||
|
||||
impl<'a> PyClassEnum<'a> {
|
||||
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
|
||||
fn is_numeric_type(t: &syn::Ident) -> bool {
|
||||
[
|
||||
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
|
||||
"isize",
|
||||
]
|
||||
.iter()
|
||||
.any(|&s| t == s)
|
||||
}
|
||||
let ident = &enum_.ident;
|
||||
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
|
||||
// "Under the default representation, the specified discriminant is interpreted as an isize
|
||||
// value", so `isize` should be enough by default.
|
||||
let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
|
||||
if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) {
|
||||
let args =
|
||||
attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
|
||||
if let Some(ident) = args
|
||||
.into_iter()
|
||||
.filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
|
||||
.find(is_numeric_type)
|
||||
{
|
||||
repr_type = ident;
|
||||
}
|
||||
}
|
||||
|
||||
let variants = enum_
|
||||
let has_only_unit_variants = enum_
|
||||
.variants
|
||||
.iter_mut()
|
||||
.map(extract_variant_data)
|
||||
.collect::<syn::Result<_>>()?;
|
||||
Ok(Self {
|
||||
ident,
|
||||
repr_type,
|
||||
variants,
|
||||
.iter()
|
||||
.all(|variant| matches!(variant.fields, syn::Fields::Unit));
|
||||
|
||||
Ok(if has_only_unit_variants {
|
||||
let simple_enum = PyClassSimpleEnum::new(enum_)?;
|
||||
Self::Simple(simple_enum)
|
||||
} else {
|
||||
let complex_enum = PyClassComplexEnum::new(enum_)?;
|
||||
Self::Complex(complex_enum)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -460,6 +414,208 @@ pub fn build_py_enum(
|
|||
impl_enum(enum_, &args, doc, method_type)
|
||||
}
|
||||
|
||||
struct PyClassSimpleEnum<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
|
||||
// This matters when the underlying representation may not fit in `isize`.
|
||||
repr_type: syn::Ident,
|
||||
variants: Vec<PyClassEnumUnitVariant<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> PyClassSimpleEnum<'a> {
|
||||
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
|
||||
fn is_numeric_type(t: &syn::Ident) -> bool {
|
||||
[
|
||||
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
|
||||
"isize",
|
||||
]
|
||||
.iter()
|
||||
.any(|&s| t == s)
|
||||
}
|
||||
|
||||
fn extract_unit_variant_data(
|
||||
variant: &mut syn::Variant,
|
||||
) -> syn::Result<PyClassEnumUnitVariant<'_>> {
|
||||
use syn::Fields;
|
||||
let ident = match &variant.fields {
|
||||
Fields::Unit => &variant.ident,
|
||||
_ => bail_spanned!(variant.span() => "Must be a unit variant."),
|
||||
};
|
||||
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
|
||||
Ok(PyClassEnumUnitVariant { ident, options })
|
||||
}
|
||||
|
||||
let ident = &enum_.ident;
|
||||
|
||||
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
|
||||
// "Under the default representation, the specified discriminant is interpreted as an isize
|
||||
// value", so `isize` should be enough by default.
|
||||
let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
|
||||
if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path().is_ident("repr")) {
|
||||
let args =
|
||||
attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
|
||||
if let Some(ident) = args
|
||||
.into_iter()
|
||||
.filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
|
||||
.find(is_numeric_type)
|
||||
{
|
||||
repr_type = ident;
|
||||
}
|
||||
}
|
||||
|
||||
let variants: Vec<_> = enum_
|
||||
.variants
|
||||
.iter_mut()
|
||||
.map(extract_unit_variant_data)
|
||||
.collect::<syn::Result<_>>()?;
|
||||
Ok(Self {
|
||||
ident,
|
||||
repr_type,
|
||||
variants,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PyClassComplexEnum<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
variants: Vec<PyClassEnumVariant<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> PyClassComplexEnum<'a> {
|
||||
fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
|
||||
let witness = enum_
|
||||
.variants
|
||||
.iter()
|
||||
.find(|variant| !matches!(variant.fields, syn::Fields::Unit))
|
||||
.expect("complex enum has a non-unit variant")
|
||||
.ident
|
||||
.to_owned();
|
||||
|
||||
let extract_variant_data =
|
||||
|variant: &'a mut syn::Variant| -> syn::Result<PyClassEnumVariant<'a>> {
|
||||
use syn::Fields;
|
||||
let ident = &variant.ident;
|
||||
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
|
||||
|
||||
let variant = match &variant.fields {
|
||||
Fields::Unit => {
|
||||
bail_spanned!(variant.span() => format!(
|
||||
"Unit variant `{ident}` is not yet supported in a complex enum\n\
|
||||
= help: change to a struct variant with no fields: `{ident} {{ }}`\n\
|
||||
= note: the enum is complex because of non-unit variant `{witness}`",
|
||||
ident=ident, witness=witness))
|
||||
}
|
||||
Fields::Named(fields) => {
|
||||
let fields = fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|field| PyClassEnumVariantNamedField {
|
||||
ident: field.ident.as_ref().expect("named field has an identifier"),
|
||||
ty: &field.ty,
|
||||
span: field.span(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
PyClassEnumVariant::Struct(PyClassEnumStructVariant {
|
||||
ident,
|
||||
fields,
|
||||
options,
|
||||
})
|
||||
}
|
||||
Fields::Unnamed(_) => {
|
||||
bail_spanned!(variant.span() => format!(
|
||||
"Tuple variant `{ident}` is not yet supported in a complex enum\n\
|
||||
= help: change to a struct variant with named fields: `{ident} {{ /* fields */ }}`\n\
|
||||
= note: the enum is complex because of non-unit variant `{witness}`",
|
||||
ident=ident, witness=witness))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(variant)
|
||||
};
|
||||
|
||||
let ident = &enum_.ident;
|
||||
|
||||
let variants: Vec<_> = enum_
|
||||
.variants
|
||||
.iter_mut()
|
||||
.map(extract_variant_data)
|
||||
.collect::<syn::Result<_>>()?;
|
||||
|
||||
Ok(Self { ident, variants })
|
||||
}
|
||||
}
|
||||
|
||||
enum PyClassEnumVariant<'a> {
|
||||
// TODO(mkovaxx): Unit(PyClassEnumUnitVariant<'a>),
|
||||
Struct(PyClassEnumStructVariant<'a>),
|
||||
// TODO(mkovaxx): Tuple(PyClassEnumTupleVariant<'a>),
|
||||
}
|
||||
|
||||
trait EnumVariant {
|
||||
fn get_ident(&self) -> &syn::Ident;
|
||||
fn get_options(&self) -> &EnumVariantPyO3Options;
|
||||
|
||||
fn get_python_name(&self, args: &PyClassArgs) -> Cow<'_, syn::Ident> {
|
||||
self.get_options()
|
||||
.name
|
||||
.as_ref()
|
||||
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
|
||||
.unwrap_or_else(|| {
|
||||
let name = self.get_ident().unraw();
|
||||
if let Some(attr) = &args.options.rename_all {
|
||||
let new_name = apply_renaming_rule(attr.value.rule, &name.to_string());
|
||||
Cow::Owned(Ident::new(&new_name, Span::call_site()))
|
||||
} else {
|
||||
Cow::Owned(name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EnumVariant for PyClassEnumVariant<'a> {
|
||||
fn get_ident(&self) -> &syn::Ident {
|
||||
match self {
|
||||
PyClassEnumVariant::Struct(struct_variant) => struct_variant.ident,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_options(&self) -> &EnumVariantPyO3Options {
|
||||
match self {
|
||||
PyClassEnumVariant::Struct(struct_variant) => &struct_variant.options,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A unit variant has no fields
|
||||
struct PyClassEnumUnitVariant<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
options: EnumVariantPyO3Options,
|
||||
}
|
||||
|
||||
impl<'a> EnumVariant for PyClassEnumUnitVariant<'a> {
|
||||
fn get_ident(&self) -> &syn::Ident {
|
||||
self.ident
|
||||
}
|
||||
|
||||
fn get_options(&self) -> &EnumVariantPyO3Options {
|
||||
&self.options
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct variant has named fields
|
||||
struct PyClassEnumStructVariant<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
fields: Vec<PyClassEnumVariantNamedField<'a>>,
|
||||
options: EnumVariantPyO3Options,
|
||||
}
|
||||
|
||||
struct PyClassEnumVariantNamedField<'a> {
|
||||
ident: &'a syn::Ident,
|
||||
ty: &'a syn::Type,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
/// `#[pyo3()]` options for pyclass enum variants
|
||||
struct EnumVariantPyO3Options {
|
||||
name: Option<NameAttribute>,
|
||||
|
@ -505,11 +661,25 @@ fn impl_enum(
|
|||
args: &PyClassArgs,
|
||||
doc: PythonDoc,
|
||||
methods_type: PyClassMethodsType,
|
||||
) -> Result<TokenStream> {
|
||||
match enum_ {
|
||||
PyClassEnum::Simple(simple_enum) => impl_simple_enum(simple_enum, args, doc, methods_type),
|
||||
PyClassEnum::Complex(complex_enum) => {
|
||||
impl_complex_enum(complex_enum, args, doc, methods_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_simple_enum(
|
||||
simple_enum: PyClassSimpleEnum<'_>,
|
||||
args: &PyClassArgs,
|
||||
doc: PythonDoc,
|
||||
methods_type: PyClassMethodsType,
|
||||
) -> Result<TokenStream> {
|
||||
let krate = get_pyo3_crate(&args.options.krate);
|
||||
let cls = enum_.ident;
|
||||
let cls = simple_enum.ident;
|
||||
let ty: syn::Type = syn::parse_quote!(#cls);
|
||||
let variants = enum_.variants;
|
||||
let variants = simple_enum.variants;
|
||||
let pytypeinfo = impl_pytypeinfo(cls, args, None);
|
||||
|
||||
let (default_repr, default_repr_slot) = {
|
||||
|
@ -519,7 +689,7 @@ fn impl_enum(
|
|||
let repr = format!(
|
||||
"{}.{}",
|
||||
get_class_python_name(cls, args),
|
||||
variant.python_name(args),
|
||||
variant.get_python_name(args),
|
||||
);
|
||||
quote! { #cls::#variant_name => #repr, }
|
||||
});
|
||||
|
@ -534,7 +704,7 @@ fn impl_enum(
|
|||
(repr_impl, repr_slot)
|
||||
};
|
||||
|
||||
let repr_type = &enum_.repr_type;
|
||||
let repr_type = &simple_enum.repr_type;
|
||||
|
||||
let (default_int, default_int_slot) = {
|
||||
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
|
||||
|
@ -601,7 +771,10 @@ fn impl_enum(
|
|||
cls,
|
||||
args,
|
||||
methods_type,
|
||||
enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name(args)))),
|
||||
simple_enum_default_methods(
|
||||
cls,
|
||||
variants.iter().map(|v| (v.ident, v.get_python_name(args))),
|
||||
),
|
||||
default_slots,
|
||||
)
|
||||
.doc(doc)
|
||||
|
@ -626,6 +799,214 @@ fn impl_enum(
|
|||
})
|
||||
}
|
||||
|
||||
fn impl_complex_enum(
|
||||
complex_enum: PyClassComplexEnum<'_>,
|
||||
args: &PyClassArgs,
|
||||
doc: PythonDoc,
|
||||
methods_type: PyClassMethodsType,
|
||||
) -> Result<TokenStream> {
|
||||
// Need to rig the enum PyClass options
|
||||
let args = {
|
||||
let mut rigged_args = args.clone();
|
||||
// Needs to be frozen to disallow `&mut self` methods, which could break a runtime invariant
|
||||
rigged_args.options.frozen = parse_quote!(frozen);
|
||||
// Needs to be subclassable by the variant PyClasses
|
||||
rigged_args.options.subclass = parse_quote!(subclass);
|
||||
rigged_args
|
||||
};
|
||||
|
||||
let krate = get_pyo3_crate(&args.options.krate);
|
||||
let cls = complex_enum.ident;
|
||||
let variants = complex_enum.variants;
|
||||
let pytypeinfo = impl_pytypeinfo(cls, &args, None);
|
||||
|
||||
let default_slots = vec![];
|
||||
|
||||
let impl_builder = PyClassImplsBuilder::new(
|
||||
cls,
|
||||
&args,
|
||||
methods_type,
|
||||
complex_enum_default_methods(
|
||||
cls,
|
||||
variants
|
||||
.iter()
|
||||
.map(|v| (v.get_ident(), v.get_python_name(&args))),
|
||||
),
|
||||
default_slots,
|
||||
)
|
||||
.doc(doc);
|
||||
|
||||
// Need to customize the into_py impl so that it returns the variant PyClass
|
||||
let enum_into_py_impl = {
|
||||
let match_arms: Vec<TokenStream> = variants
|
||||
.iter()
|
||||
.map(|variant| {
|
||||
let variant_ident = variant.get_ident();
|
||||
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
|
||||
quote! {
|
||||
#cls::#variant_ident { .. } => {
|
||||
let pyclass_init = _pyo3::PyClassInitializer::from(self).add_subclass(#variant_cls);
|
||||
let variant_value = _pyo3::Py::new(py, pyclass_init).unwrap();
|
||||
_pyo3::IntoPy::into_py(variant_value, py)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
impl _pyo3::IntoPy<_pyo3::PyObject> for #cls {
|
||||
fn into_py(self, py: _pyo3::Python) -> _pyo3::PyObject {
|
||||
match self {
|
||||
#(#match_arms)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let pyclass_impls: TokenStream = vec![
|
||||
impl_builder.impl_pyclass(),
|
||||
impl_builder.impl_extractext(),
|
||||
enum_into_py_impl,
|
||||
impl_builder.impl_pyclassimpl()?,
|
||||
impl_builder.impl_freelist(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let mut variant_cls_zsts = vec![];
|
||||
let mut variant_cls_pytypeinfos = vec![];
|
||||
let mut variant_cls_pyclass_impls = vec![];
|
||||
let mut variant_cls_impls = vec![];
|
||||
for variant in &variants {
|
||||
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
|
||||
|
||||
let variant_cls_zst = quote! {
|
||||
#[doc(hidden)]
|
||||
#[allow(non_camel_case_types)]
|
||||
struct #variant_cls;
|
||||
};
|
||||
variant_cls_zsts.push(variant_cls_zst);
|
||||
|
||||
let variant_args = PyClassArgs {
|
||||
class_kind: PyClassKind::Struct,
|
||||
// TODO(mkovaxx): propagate variant.options
|
||||
options: parse_quote!(extends = #cls, frozen),
|
||||
};
|
||||
|
||||
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None);
|
||||
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
|
||||
|
||||
let variant_new = complex_enum_variant_new(cls, variant)?;
|
||||
|
||||
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant)?;
|
||||
variant_cls_impls.push(variant_cls_impl);
|
||||
|
||||
let pyclass_impl = PyClassImplsBuilder::new(
|
||||
&variant_cls,
|
||||
&variant_args,
|
||||
methods_type,
|
||||
field_getters,
|
||||
vec![variant_new],
|
||||
)
|
||||
.impl_all()?;
|
||||
|
||||
variant_cls_pyclass_impls.push(pyclass_impl);
|
||||
}
|
||||
|
||||
Ok(quote! {
|
||||
const _: () = {
|
||||
use #krate as _pyo3;
|
||||
|
||||
#pytypeinfo
|
||||
|
||||
#pyclass_impls
|
||||
|
||||
#[doc(hidden)]
|
||||
#[allow(non_snake_case)]
|
||||
impl #cls {}
|
||||
|
||||
#(#variant_cls_zsts)*
|
||||
|
||||
#(#variant_cls_pytypeinfos)*
|
||||
|
||||
#(#variant_cls_pyclass_impls)*
|
||||
|
||||
#(#variant_cls_impls)*
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
fn impl_complex_enum_variant_cls(
|
||||
enum_name: &syn::Ident,
|
||||
variant: &PyClassEnumVariant<'_>,
|
||||
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
|
||||
match variant {
|
||||
PyClassEnumVariant::Struct(struct_variant) => {
|
||||
impl_complex_enum_struct_variant_cls(enum_name, struct_variant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn impl_complex_enum_struct_variant_cls(
|
||||
enum_name: &syn::Ident,
|
||||
variant: &PyClassEnumStructVariant<'_>,
|
||||
) -> Result<(TokenStream, Vec<MethodAndMethodDef>)> {
|
||||
let variant_ident = &variant.ident;
|
||||
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
|
||||
let variant_cls_type = parse_quote!(#variant_cls);
|
||||
|
||||
let mut field_names: Vec<Ident> = vec![];
|
||||
let mut fields_with_types: Vec<TokenStream> = vec![];
|
||||
let mut field_getters = vec![];
|
||||
let mut field_getter_impls: Vec<TokenStream> = vec![];
|
||||
for field in &variant.fields {
|
||||
let field_name = field.ident;
|
||||
let field_type = field.ty;
|
||||
let field_with_type = quote! { #field_name: #field_type };
|
||||
|
||||
let field_getter = complex_enum_variant_field_getter(
|
||||
&variant_cls_type,
|
||||
field_name,
|
||||
field_type,
|
||||
field.span,
|
||||
)?;
|
||||
|
||||
let field_getter_impl = quote! {
|
||||
fn #field_name(slf: _pyo3::PyRef<Self>) -> _pyo3::PyResult<#field_type> {
|
||||
match &*slf.into_super() {
|
||||
#enum_name::#variant_ident { #field_name, .. } => Ok(#field_name.clone()),
|
||||
_ => unreachable!("Wrong complex enum variant found in variant wrapper PyClass"),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
field_names.push(field_name.clone());
|
||||
fields_with_types.push(field_with_type);
|
||||
field_getters.push(field_getter);
|
||||
field_getter_impls.push(field_getter_impl);
|
||||
}
|
||||
|
||||
let cls_impl = quote! {
|
||||
#[doc(hidden)]
|
||||
#[allow(non_snake_case)]
|
||||
impl #variant_cls {
|
||||
fn __pymethod_constructor__(py: _pyo3::Python<'_>, #(#fields_with_types,)*) -> _pyo3::PyClassInitializer<#variant_cls> {
|
||||
let base_value = #enum_name::#variant_ident { #(#field_names,)* };
|
||||
_pyo3::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
|
||||
}
|
||||
|
||||
#(#field_getter_impls)*
|
||||
}
|
||||
};
|
||||
|
||||
Ok((cls_impl, field_getters))
|
||||
}
|
||||
|
||||
fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident {
|
||||
format_ident!("{}_{}", enum_, variant)
|
||||
}
|
||||
|
||||
fn generate_default_protocol_slot(
|
||||
cls: &syn::Type,
|
||||
method: &mut syn::ImplItemFn,
|
||||
|
@ -645,7 +1026,7 @@ fn generate_default_protocol_slot(
|
|||
)
|
||||
}
|
||||
|
||||
fn enum_default_methods<'a>(
|
||||
fn simple_enum_default_methods<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
|
||||
) -> Vec<MethodAndMethodDef> {
|
||||
|
@ -667,14 +1048,167 @@ fn enum_default_methods<'a>(
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
|
||||
use syn::Fields;
|
||||
let ident = match variant.fields {
|
||||
Fields::Unit => &variant.ident,
|
||||
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
|
||||
fn complex_enum_default_methods<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
|
||||
) -> Vec<MethodAndMethodDef> {
|
||||
let cls_type = syn::parse_quote!(#cls);
|
||||
let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
|
||||
rust_ident: var_ident.clone(),
|
||||
attributes: ConstAttributes {
|
||||
is_class_attr: true,
|
||||
name: Some(NameAttribute {
|
||||
kw: syn::parse_quote! { name },
|
||||
value: NameLitStr(py_ident.clone()),
|
||||
}),
|
||||
deprecations: Default::default(),
|
||||
},
|
||||
};
|
||||
let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
|
||||
Ok(PyClassEnumVariant { ident, options })
|
||||
variant_names
|
||||
.into_iter()
|
||||
.map(|(var, py_name)| {
|
||||
gen_complex_enum_variant_attr(cls, &cls_type, &variant_to_attribute(var, &py_name))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn gen_complex_enum_variant_attr(
|
||||
cls: &syn::Ident,
|
||||
cls_type: &syn::Type,
|
||||
spec: &ConstSpec,
|
||||
) -> MethodAndMethodDef {
|
||||
let member = &spec.rust_ident;
|
||||
let wrapper_ident = format_ident!("__pymethod_variant_cls_{}__", member);
|
||||
let deprecations = &spec.attributes.deprecations;
|
||||
let python_name = &spec.null_terminated_python_name();
|
||||
|
||||
let variant_cls = format_ident!("{}_{}", cls, member);
|
||||
let associated_method = quote! {
|
||||
fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
|
||||
#deprecations
|
||||
::std::result::Result::Ok(py.get_type::<#variant_cls>().into())
|
||||
}
|
||||
};
|
||||
|
||||
let method_def = quote! {
|
||||
_pyo3::class::PyMethodDefType::ClassAttribute({
|
||||
_pyo3::class::PyClassAttributeDef::new(
|
||||
#python_name,
|
||||
_pyo3::impl_::pymethods::PyClassAttributeFactory(#cls_type::#wrapper_ident)
|
||||
)
|
||||
})
|
||||
};
|
||||
|
||||
MethodAndMethodDef {
|
||||
associated_method,
|
||||
method_def,
|
||||
}
|
||||
}
|
||||
|
||||
fn complex_enum_variant_new<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
variant: &'a PyClassEnumVariant<'a>,
|
||||
) -> Result<MethodAndSlotDef> {
|
||||
match variant {
|
||||
PyClassEnumVariant::Struct(struct_variant) => {
|
||||
complex_enum_struct_variant_new(cls, struct_variant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn complex_enum_struct_variant_new<'a>(
|
||||
cls: &'a syn::Ident,
|
||||
variant: &'a PyClassEnumStructVariant<'a>,
|
||||
) -> Result<MethodAndSlotDef> {
|
||||
let variant_cls = format_ident!("{}_{}", cls, variant.ident);
|
||||
let variant_cls_type: syn::Type = parse_quote!(#variant_cls);
|
||||
|
||||
let arg_py_ident: syn::Ident = parse_quote!(py);
|
||||
let arg_py_type: syn::Type = parse_quote!(_pyo3::Python<'_>);
|
||||
|
||||
let args = {
|
||||
let mut no_pyo3_attrs = vec![];
|
||||
let attrs = crate::pyfunction::PyFunctionArgPyO3Attributes::from_attrs(&mut no_pyo3_attrs)?;
|
||||
|
||||
let mut args = vec![
|
||||
// py: Python<'_>
|
||||
FnArg {
|
||||
name: &arg_py_ident,
|
||||
ty: &arg_py_type,
|
||||
optional: None,
|
||||
default: None,
|
||||
py: true,
|
||||
attrs: attrs.clone(),
|
||||
is_varargs: false,
|
||||
is_kwargs: false,
|
||||
is_cancel_handle: false,
|
||||
},
|
||||
];
|
||||
|
||||
for field in &variant.fields {
|
||||
args.push(FnArg {
|
||||
name: field.ident,
|
||||
ty: field.ty,
|
||||
optional: None,
|
||||
default: None,
|
||||
py: false,
|
||||
attrs: attrs.clone(),
|
||||
is_varargs: false,
|
||||
is_kwargs: false,
|
||||
is_cancel_handle: false,
|
||||
});
|
||||
}
|
||||
args
|
||||
};
|
||||
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
|
||||
|
||||
let spec = FnSpec {
|
||||
tp: crate::method::FnType::FnNew,
|
||||
name: &format_ident!("__pymethod_constructor__"),
|
||||
python_name: format_ident!("__new__"),
|
||||
signature,
|
||||
output: variant_cls_type.clone(),
|
||||
convention: crate::method::CallingConvention::TpNew,
|
||||
text_signature: None,
|
||||
asyncness: None,
|
||||
unsafety: None,
|
||||
deprecations: Deprecations::default(),
|
||||
};
|
||||
|
||||
crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec)
|
||||
}
|
||||
|
||||
fn complex_enum_variant_field_getter<'a>(
|
||||
variant_cls_type: &'a syn::Type,
|
||||
field_name: &'a syn::Ident,
|
||||
field_type: &'a syn::Type,
|
||||
field_span: Span,
|
||||
) -> Result<MethodAndMethodDef> {
|
||||
let signature = crate::pyfunction::FunctionSignature::from_arguments(vec![])?;
|
||||
|
||||
let self_type = crate::method::SelfType::TryFromPyCell(field_span);
|
||||
|
||||
let spec = FnSpec {
|
||||
tp: crate::method::FnType::Getter(self_type.clone()),
|
||||
name: field_name,
|
||||
python_name: field_name.clone(),
|
||||
signature,
|
||||
output: field_type.clone(),
|
||||
convention: crate::method::CallingConvention::Noargs,
|
||||
text_signature: None,
|
||||
asyncness: None,
|
||||
unsafety: None,
|
||||
deprecations: Deprecations::default(),
|
||||
};
|
||||
|
||||
let property_type = crate::pymethod::PropertyType::Function {
|
||||
self_type: &self_type,
|
||||
spec: &spec,
|
||||
doc: crate::get_doc(&[], None),
|
||||
};
|
||||
|
||||
let getter = crate::pymethod::impl_py_getter_def(variant_cls_type, property_type)?;
|
||||
Ok(getter)
|
||||
}
|
||||
|
||||
fn descriptors_to_items(
|
||||
|
|
|
@ -324,7 +324,8 @@ pub fn impl_py_method_def(
|
|||
})
|
||||
}
|
||||
|
||||
fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<MethodAndSlotDef> {
|
||||
/// Also used by pyclass.
|
||||
pub fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<MethodAndSlotDef> {
|
||||
let wrapper_ident = syn::Ident::new("__pymethod___new____", Span::call_site());
|
||||
let associated_method = spec.get_wrapper_function(&wrapper_ident, Some(cls))?;
|
||||
// Use just the text_signature_call_signature() because the class' Python name
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import nox
|
||||
import sys
|
||||
from nox.command import CommandFailed
|
||||
|
||||
nox.options.sessions = ["test"]
|
||||
|
@ -13,7 +14,12 @@ def test(session: nox.Session):
|
|||
except CommandFailed:
|
||||
# No binary wheel for numpy available on this platform
|
||||
pass
|
||||
session.run("pytest", *session.posargs)
|
||||
ignored_paths = []
|
||||
if sys.version_info < (3, 10):
|
||||
# Match syntax is only available in Python >= 3.10
|
||||
ignored_paths.append("tests/test_enums_match.py")
|
||||
ignore_args = [f"--ignore={path}" for path in ignored_paths]
|
||||
session.run("pytest", *ignore_args, *session.posargs)
|
||||
|
||||
|
||||
@nox.session
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
use pyo3::{pyclass, pyfunction, pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
|
||||
#[pymodule]
|
||||
pub fn enums(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<SimpleEnum>()?;
|
||||
m.add_class::<ComplexEnum>()?;
|
||||
m.add_wrapped(wrap_pyfunction!(do_simple_stuff))?;
|
||||
m.add_wrapped(wrap_pyfunction!(do_complex_stuff))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub enum SimpleEnum {
|
||||
Sunday,
|
||||
Monday,
|
||||
Tuesday,
|
||||
Wednesday,
|
||||
Thursday,
|
||||
Friday,
|
||||
Saturday,
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
|
||||
match thing {
|
||||
SimpleEnum::Sunday => SimpleEnum::Monday,
|
||||
SimpleEnum::Monday => SimpleEnum::Tuesday,
|
||||
SimpleEnum::Tuesday => SimpleEnum::Wednesday,
|
||||
SimpleEnum::Wednesday => SimpleEnum::Thursday,
|
||||
SimpleEnum::Thursday => SimpleEnum::Friday,
|
||||
SimpleEnum::Friday => SimpleEnum::Saturday,
|
||||
SimpleEnum::Saturday => SimpleEnum::Sunday,
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub enum ComplexEnum {
|
||||
Int { i: i32 },
|
||||
Float { f: f64 },
|
||||
Str { s: String },
|
||||
EmptyStruct {},
|
||||
MultiFieldStruct { a: i32, b: f64, c: bool },
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
|
||||
match thing {
|
||||
ComplexEnum::Int { i } => ComplexEnum::Str { s: i.to_string() },
|
||||
ComplexEnum::Float { f } => ComplexEnum::Float { f: f * f },
|
||||
ComplexEnum::Str { s } => ComplexEnum::Int { i: s.len() as i32 },
|
||||
ComplexEnum::EmptyStruct {} => ComplexEnum::EmptyStruct {},
|
||||
ComplexEnum::MultiFieldStruct { a, b, c } => ComplexEnum::MultiFieldStruct {
|
||||
a: *a,
|
||||
b: *b,
|
||||
c: *c,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ pub mod buf_and_str;
|
|||
pub mod comparisons;
|
||||
pub mod datetime;
|
||||
pub mod dict_iter;
|
||||
pub mod enums;
|
||||
pub mod misc;
|
||||
pub mod objstore;
|
||||
pub mod othermod;
|
||||
|
@ -25,6 +26,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
#[cfg(not(Py_LIMITED_API))]
|
||||
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
|
||||
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;
|
||||
m.add_wrapped(wrap_pymodule!(enums::enums))?;
|
||||
m.add_wrapped(wrap_pymodule!(misc::misc))?;
|
||||
m.add_wrapped(wrap_pymodule!(objstore::objstore))?;
|
||||
m.add_wrapped(wrap_pymodule!(othermod::othermod))?;
|
||||
|
@ -44,6 +46,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.enums", m.getattr("enums")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
|
||||
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
from pyo3_pytests import enums
|
||||
|
||||
|
||||
def test_complex_enum_variant_constructors():
|
||||
int_variant = enums.ComplexEnum.Int(42)
|
||||
assert isinstance(int_variant, enums.ComplexEnum.Int)
|
||||
|
||||
float_variant = enums.ComplexEnum.Float(3.14)
|
||||
assert isinstance(float_variant, enums.ComplexEnum.Float)
|
||||
|
||||
str_variant = enums.ComplexEnum.Str("hello")
|
||||
assert isinstance(str_variant, enums.ComplexEnum.Str)
|
||||
|
||||
empty_struct_variant = enums.ComplexEnum.EmptyStruct()
|
||||
assert isinstance(empty_struct_variant, enums.ComplexEnum.EmptyStruct)
|
||||
|
||||
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
|
||||
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
enums.ComplexEnum.Int(42),
|
||||
enums.ComplexEnum.Float(3.14),
|
||||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
|
||||
assert isinstance(variant, enums.ComplexEnum)
|
||||
|
||||
|
||||
def test_complex_enum_field_getters():
|
||||
int_variant = enums.ComplexEnum.Int(42)
|
||||
assert int_variant.i == 42
|
||||
|
||||
float_variant = enums.ComplexEnum.Float(3.14)
|
||||
assert float_variant.f == 3.14
|
||||
|
||||
str_variant = enums.ComplexEnum.Str("hello")
|
||||
assert str_variant.s == "hello"
|
||||
|
||||
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
|
||||
assert multi_field_struct_variant.a == 42
|
||||
assert multi_field_struct_variant.b == 3.14
|
||||
assert multi_field_struct_variant.c is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
enums.ComplexEnum.Int(42),
|
||||
enums.ComplexEnum.Float(3.14),
|
||||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
|
||||
if isinstance(variant, enums.ComplexEnum.Int):
|
||||
x = variant.i
|
||||
assert x == 42
|
||||
elif isinstance(variant, enums.ComplexEnum.Float):
|
||||
x = variant.f
|
||||
assert x == 3.14
|
||||
elif isinstance(variant, enums.ComplexEnum.Str):
|
||||
x = variant.s
|
||||
assert x == "hello"
|
||||
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
|
||||
assert True
|
||||
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
|
||||
x = variant.a
|
||||
y = variant.b
|
||||
z = variant.c
|
||||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
enums.ComplexEnum.Int(42),
|
||||
enums.ComplexEnum.Float(3.14),
|
||||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
|
||||
variant = enums.do_complex_stuff(variant)
|
||||
if isinstance(variant, enums.ComplexEnum.Int):
|
||||
x = variant.i
|
||||
assert x == 5
|
||||
elif isinstance(variant, enums.ComplexEnum.Float):
|
||||
x = variant.f
|
||||
assert x == 9.8596
|
||||
elif isinstance(variant, enums.ComplexEnum.Str):
|
||||
x = variant.s
|
||||
assert x == "42"
|
||||
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
|
||||
assert True
|
||||
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
|
||||
x = variant.a
|
||||
y = variant.b
|
||||
z = variant.c
|
||||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
else:
|
||||
assert False
|
|
@ -0,0 +1,59 @@
|
|||
# This file is only collected when Python >= 3.10, because it tests match syntax.
|
||||
import pytest
|
||||
from pyo3_pytests import enums
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
enums.ComplexEnum.Int(42),
|
||||
enums.ComplexEnum.Float(3.14),
|
||||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_match_statement(variant: enums.ComplexEnum):
|
||||
match variant:
|
||||
case enums.ComplexEnum.Int(i=x):
|
||||
assert x == 42
|
||||
case enums.ComplexEnum.Float(f=x):
|
||||
assert x == 3.14
|
||||
case enums.ComplexEnum.Str(s=x):
|
||||
assert x == "hello"
|
||||
case enums.ComplexEnum.EmptyStruct():
|
||||
assert True
|
||||
case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z):
|
||||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
case _:
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
enums.ComplexEnum.Int(42),
|
||||
enums.ComplexEnum.Float(3.14),
|
||||
enums.ComplexEnum.Str("hello"),
|
||||
enums.ComplexEnum.EmptyStruct(),
|
||||
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
|
||||
],
|
||||
)
|
||||
def test_complex_enum_pyfunction_in_out(variant: enums.ComplexEnum):
|
||||
match enums.do_complex_stuff(variant):
|
||||
case enums.ComplexEnum.Int(i=x):
|
||||
assert x == 5
|
||||
case enums.ComplexEnum.Float(f=x):
|
||||
assert x == 9.8596
|
||||
case enums.ComplexEnum.Str(s=x):
|
||||
assert x == "42"
|
||||
case enums.ComplexEnum.EmptyStruct():
|
||||
assert True
|
||||
case enums.ComplexEnum.MultiFieldStruct(a=x, b=y, c=z):
|
||||
assert x == 42
|
||||
assert y == 3.14
|
||||
assert z is True
|
||||
case _:
|
||||
assert False
|
|
@ -14,6 +14,7 @@ fn test_compile_errors() {
|
|||
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
|
||||
t.compile_fail("tests/ui/invalid_pymethods_buffer.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethod_enum.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymodule_args.rs");
|
||||
t.compile_fail("tests/ui/reject_generics.rs");
|
||||
|
|
|
@ -15,4 +15,16 @@ enum NotDrivedClass {
|
|||
#[pyclass]
|
||||
enum NoEmptyEnum {}
|
||||
|
||||
#[pyclass]
|
||||
enum NoUnitVariants {
|
||||
StructVariant { field: i32 },
|
||||
UnitVariant,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
enum NoTupleVariants {
|
||||
StructVariant { field: i32 },
|
||||
TupleVariant(i32),
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
@ -15,3 +15,19 @@ error: #[pyclass] can't be used on enums without any variants
|
|||
|
|
||||
16 | enum NoEmptyEnum {}
|
||||
| ^^
|
||||
|
||||
error: Unit variant `UnitVariant` is not yet supported in a complex enum
|
||||
= help: change to a struct variant with no fields: `UnitVariant { }`
|
||||
= note: the enum is complex because of non-unit variant `StructVariant`
|
||||
--> tests/ui/invalid_pyclass_enum.rs:21:5
|
||||
|
|
||||
21 | UnitVariant,
|
||||
| ^^^^^^^^^^^
|
||||
|
||||
error: Tuple variant `TupleVariant` is not yet supported in a complex enum
|
||||
= help: change to a struct variant with named fields: `TupleVariant { /* fields */ }`
|
||||
= note: the enum is complex because of non-unit variant `StructVariant`
|
||||
--> tests/ui/invalid_pyclass_enum.rs:27:5
|
||||
|
|
||||
27 | TupleVariant(i32),
|
||||
| ^^^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
#[pyclass]
|
||||
enum ComplexEnum {
|
||||
Int { int: i32 },
|
||||
Str { string: String },
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl ComplexEnum {
|
||||
fn mutate_in_place(&mut self) {
|
||||
*self = match self {
|
||||
ComplexEnum::Int { int } => ComplexEnum::Str { string: int.to_string() },
|
||||
ComplexEnum::Str { string } => ComplexEnum::Int { int: string.len() as i32 },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,11 @@
|
|||
error[E0271]: type mismatch resolving `<ComplexEnum as PyClass>::Frozen == False`
|
||||
--> tests/ui/invalid_pymethod_enum.rs:11:24
|
||||
|
|
||||
11 | fn mutate_in_place(&mut self) {
|
||||
| ^ expected `False`, found `True`
|
||||
|
|
||||
note: required by a bound in `extract_pyclass_ref_mut`
|
||||
--> src/impl_/extract_argument.rs
|
||||
|
|
||||
| pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass<Frozen = False>>(
|
||||
| ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut`
|
Loading…
Reference in New Issue