Merge pull request #2014 from b05902132/default_impl

Support default method implementation
This commit is contained in:
David Hewitt 2021-11-29 23:21:11 +00:00 committed by GitHub
commit 8a03778ca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 1 deletions

View File

@ -3,7 +3,7 @@
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
@ -425,6 +425,27 @@ fn impl_enum_class(
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
let default_repr_impl = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", cls, variant_name);
quote! { #cls::#variant_name => #repr, }
});
quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
#[pyo3(name = "__repr__")]
fn __pyo3__repr__(&self) -> &'static str {
match self {
#(#variants_repr)*
_ => unreachable!("Unsupported variant type."),
}
}
}
};
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
Ok(quote! {
#pytypeinfo
@ -433,6 +454,8 @@ fn impl_enum_class(
#descriptors
#default_impls
})
}
@ -758,6 +781,9 @@ impl<'a> PyClassImplsBuilder<'a> {
// Implementation which uses dtolnay specialization to load all slots.
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
// This depends on Python implementation detail;
// an old slot entry will be overriden by newer ones.
visitor(collector.py_class_default_slots());
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());

View File

@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
}
}
pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
// all succeed.
let ty: syn::Type = syn::parse_quote!(#cls);
let mut method_defs: Vec<_> = method_defs
.into_iter()
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
.collect();
let mut proto_impls = Vec::new();
for meth in &mut method_defs {
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
GeneratedPyMethod::Proto(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
proto_impls.push(quote!(#(#attrs)* #token_stream))
}
GeneratedPyMethod::SlotTraitImpl(..) => {
panic!("SlotFragment methods cannot have default implementation!")
}
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
panic!("Only protocol methods can have default implementation!")
}
}
}
quote! {
impl #cls {
#(#method_defs)*
}
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[#(#proto_impls),*]
}
}
}
}
fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
quote! {
impl ::pyo3::class::impl_::PyMethods<#ty>

View File

@ -657,6 +657,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots);
slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots);
slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots);
// slots that PyO3 implements by default, but can be overidden by the users.
slots_trait!(PyClassDefaultSlots, py_class_default_slots);
// Protocol slots from #[pymethods] if not using inventory.
#[cfg(not(feature = "multiple-pymethods"))]
slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots);

View File

@ -0,0 +1,41 @@
use pyo3::prelude::*;
mod common;
// Test default generated __repr__.
#[pyclass]
enum TestDefaultRepr {
Var,
}
#[test]
fn test_default_slot_exists() {
Python::with_gil(|py| {
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
py_assert!(
py,
test_object,
"repr(test_object) == 'TestDefaultRepr.Var'"
);
})
}
#[pyclass]
enum OverrideSlot {
Var,
}
#[pymethods]
impl OverrideSlot {
fn __repr__(&self) -> &str {
"overriden"
}
}
#[test]
fn test_override_slot() {
Python::with_gil(|py| {
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
})
}

View File

@ -51,3 +51,13 @@ fn test_enum_arg() {
py_run!(py, f mynum, "f(mynum.Variant)")
}
#[test]
fn test_default_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
})
}