macros: Support #[pyo3(name)] on enum variants

This commit is contained in:
Gabriel Smith 2022-06-16 15:18:55 -04:00
parent 2122faa547
commit 75656949f9
3 changed files with 83 additions and 12 deletions

View File

@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383) - Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383)
- Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a - Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a
Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460) Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460)
- Supprt `#[pyo3(name)]` on enum variants [#2457](https://github.com/PyO3/pyo3/pull/2457)
### Changed ### Changed

View File

@ -330,7 +330,17 @@ fn impl_class(
struct PyClassEnumVariant<'a> { struct PyClassEnumVariant<'a> {
ident: &'a syn::Ident, ident: &'a syn::Ident,
/* currently have no more options */ options: EnumVariantPyO3Options,
}
impl<'a> PyClassEnumVariant<'a> {
fn python_name(&self) -> Cow<'_, syn::Ident> {
self.options
.name
.as_ref()
.map(|name_attr| Cow::Borrowed(&name_attr.value.0))
.unwrap_or_else(|| Cow::Owned(self.ident.unraw()))
}
} }
struct PyClassEnum<'a> { struct PyClassEnum<'a> {
@ -342,7 +352,7 @@ struct PyClassEnum<'a> {
} }
impl<'a> PyClassEnum<'a> { impl<'a> PyClassEnum<'a> {
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> { fn new(enum_: &'a mut syn::ItemEnum) -> syn::Result<Self> {
fn is_numeric_type(t: &syn::Ident) -> bool { fn is_numeric_type(t: &syn::Ident) -> bool {
[ [
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize", "u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
@ -370,7 +380,7 @@ impl<'a> PyClassEnum<'a> {
let variants = enum_ let variants = enum_
.variants .variants
.iter() .iter_mut()
.map(extract_variant_data) .map(extract_variant_data)
.collect::<syn::Result<_>>()?; .collect::<syn::Result<_>>()?;
Ok(Self { Ok(Self {
@ -407,6 +417,46 @@ pub fn build_py_enum(
Ok(impl_enum(enum_, &args, doc, method_type)) Ok(impl_enum(enum_, &args, doc, method_type))
} }
/// `#[pyo3()]` options for pyclass enum variants
struct EnumVariantPyO3Options {
name: Option<NameAttribute>,
}
enum EnumVariantPyO3Option {
Name(NameAttribute),
}
impl Parse for EnumVariantPyO3Option {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
input.parse().map(EnumVariantPyO3Option::Name)
} else {
Err(lookahead.error())
}
}
}
impl EnumVariantPyO3Options {
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
let mut options = EnumVariantPyO3Options { name: None };
for option in take_pyo3_options(attrs)? {
match option {
EnumVariantPyO3Option::Name(name) => {
ensure_spanned!(
options.name.is_none(),
name.span() => "`name` may only be specified once"
);
options.name = Some(name);
}
}
}
Ok(options)
}
}
fn impl_enum( fn impl_enum(
enum_: PyClassEnum<'_>, enum_: PyClassEnum<'_>,
args: &PyClassArgs, args: &PyClassArgs,
@ -433,7 +483,11 @@ fn impl_enum_class(
let variants_repr = variants.iter().map(|variant| { let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident; let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support. // Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", get_class_python_name(cls, args), variant_name); let repr = format!(
"{}.{}",
get_class_python_name(cls, args),
variant.python_name(),
);
quote! { #cls::#variant_name => #repr, } quote! { #cls::#variant_name => #repr, }
}); });
let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! { let mut repr_impl: syn::ImplItemMethod = syn::parse_quote! {
@ -511,7 +565,7 @@ fn impl_enum_class(
cls, cls,
args, args,
methods_type, methods_type,
enum_default_methods(cls, variants.iter().map(|v| v.ident)), enum_default_methods(cls, variants.iter().map(|v| (v.ident, v.python_name()))),
default_slots, default_slots,
) )
.doc(doc) .doc(doc)
@ -557,33 +611,34 @@ fn generate_default_protocol_slot(
fn enum_default_methods<'a>( fn enum_default_methods<'a>(
cls: &'a syn::Ident, cls: &'a syn::Ident,
unit_variant_names: impl IntoIterator<Item = &'a syn::Ident>, unit_variant_names: impl IntoIterator<Item = (&'a syn::Ident, Cow<'a, syn::Ident>)>,
) -> Vec<MethodAndMethodDef> { ) -> Vec<MethodAndMethodDef> {
let cls_type = syn::parse_quote!(#cls); let cls_type = syn::parse_quote!(#cls);
let variant_to_attribute = |ident: &syn::Ident| ConstSpec { let variant_to_attribute = |var_ident: &syn::Ident, py_ident: &syn::Ident| ConstSpec {
rust_ident: ident.clone(), rust_ident: var_ident.clone(),
attributes: ConstAttributes { attributes: ConstAttributes {
is_class_attr: true, is_class_attr: true,
name: Some(NameAttribute { name: Some(NameAttribute {
kw: syn::parse_quote! { name }, kw: syn::parse_quote! { name },
value: NameLitStr(ident.clone()), value: NameLitStr(py_ident.clone()),
}), }),
deprecations: Default::default(), deprecations: Default::default(),
}, },
}; };
unit_variant_names unit_variant_names
.into_iter() .into_iter()
.map(|var| gen_py_const(&cls_type, &variant_to_attribute(var))) .map(|(var, py_name)| gen_py_const(&cls_type, &variant_to_attribute(var, &py_name)))
.collect() .collect()
} }
fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> { fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result<PyClassEnumVariant<'_>> {
use syn::Fields; use syn::Fields;
let ident = match variant.fields { let ident = match variant.fields {
Fields::Unit => &variant.ident, Fields::Unit => &variant.ident,
_ => bail_spanned!(variant.span() => "Currently only support unit variants."), _ => bail_spanned!(variant.span() => "Currently only support unit variants."),
}; };
Ok(PyClassEnumVariant { ident }) let options = EnumVariantPyO3Options::take_pyo3_options(&mut variant.attrs)?;
Ok(PyClassEnumVariant { ident, options })
} }
fn descriptors_to_items( fn descriptors_to_items(

View File

@ -175,3 +175,18 @@ fn test_rename_enum_repr_correct() {
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'"); py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
}) })
} }
#[pyclass]
#[derive(Debug, PartialEq, Clone)]
pub enum RenameVariantEnum {
#[pyo3(name = "VARIANT")]
Variant,
}
#[test]
fn test_rename_variant_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, RenameVariantEnum::Variant).unwrap();
py_assert!(py, var1, "repr(var1) == 'RenameVariantEnum.VARIANT'");
})
}