From f5f2e84f4b874c3050e8cb2adecdbd9b692f4479 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Tue, 28 Jul 2020 19:36:21 +0900 Subject: [PATCH] Enable &Self in #[pymethods] again --- CHANGELOG.md | 1 + pyo3-derive-backend/src/lib.rs | 2 +- pyo3-derive-backend/src/module.rs | 2 +- pyo3-derive-backend/src/pymethod.rs | 47 +++++++++++++++++++++++------ tests/test_methods.rs | 12 +++++--- 5 files changed, 48 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97d0332c..a182341c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Fix segfault with #[pyclass(dict, unsendable)] [#1058](https://github.com/PyO3/pyo3/pull/1058) - Don't rely on the order of structmembers to compute offsets in PyCell. Related to [#1058](https://github.com/PyO3/pyo3/pull/1058). [#1059](https://github.com/PyO3/pyo3/pull/1059) +- Allows `&Self` as a `#[pymethods]` argument again. [#1071](https://github.com/PyO3/pyo3/pull/1071) ## [0.11.1] - 2020-06-30 ### Added diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index cd1b4c3b..d5e4e940 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -18,6 +18,6 @@ mod utils; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionAttr}; -pub use pyimpl::{build_py_methods, impl_methods}; +pub use pyimpl::build_py_methods; pub use pyproto::build_py_proto; pub use utils::get_doc; diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 5720f9be..218ef3ca 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -211,7 +211,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream { #name(#(#names),*) }; - let body = pymethod::impl_arg_params(spec, cb); + let body = pymethod::impl_arg_params(spec, None, cb); quote! { unsafe extern "C" fn __wrap( diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index dd4f845b..fbda2efb 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -112,7 +112,7 @@ fn impl_wrap_common( } } } else { - let body = impl_arg_params(&spec, body); + let body = impl_arg_params(&spec, Some(cls), body); quote! { unsafe extern "C" fn __wrap( @@ -138,7 +138,7 @@ fn impl_wrap_common( pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>, self_ty: &SelfType) -> TokenStream { let python_name = &spec.python_name; let cb = impl_call(cls, &spec); - let body = impl_arg_params(&spec, cb); + let body = impl_arg_params(&spec, Some(cls), cb); let slf = self_ty.receiver(cls); quote! { @@ -166,7 +166,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let python_name = &spec.python_name; let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(#(#names),*) }; - let body = impl_arg_params(spec, cb); + let body = impl_arg_params(spec, Some(cls), cb); quote! { #[allow(unused_mut)] @@ -198,7 +198,7 @@ pub fn impl_wrap_class(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(&_cls, #(#names),*) }; - let body = impl_arg_params(spec, cb); + let body = impl_arg_params(spec, Some(cls), cb); quote! { #[allow(unused_mut)] @@ -226,7 +226,7 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let names: Vec = get_arg_names(&spec); let cb = quote! { #cls::#name(#(#names),*) }; - let body = impl_arg_params(spec, cb); + let body = impl_arg_params(spec, Some(cls), cb); quote! { #[allow(unused_mut)] @@ -383,7 +383,11 @@ fn impl_call(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { quote! { #cls::#fname(_slf, #(#names),*) } } -pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream { +pub fn impl_arg_params( + spec: &FnSpec<'_>, + self_: Option<&syn::Type>, + body: TokenStream, +) -> TokenStream { if spec.args.is_empty() { return quote! { #body @@ -412,7 +416,7 @@ pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream { let mut param_conversion = Vec::new(); let mut option_pos = 0; for (idx, arg) in spec.args.iter().enumerate() { - param_conversion.push(impl_arg_param(&arg, &spec, idx, &mut option_pos)); + param_conversion.push(impl_arg_param(&arg, &spec, idx, self_, &mut option_pos)); } let (mut accept_args, mut accept_kwargs) = (false, false); @@ -458,6 +462,7 @@ fn impl_arg_param( arg: &FnArg<'_>, spec: &FnSpec<'_>, idx: usize, + self_: Option<&syn::Type>, option_pos: &mut usize, ) -> TokenStream { let arg_name = syn::Ident::new(&format!("arg{}", idx), Span::call_site()); @@ -491,7 +496,7 @@ fn impl_arg_param( quote! { None } }; if let syn::Type::Reference(tref) = ty { - let (tref, mut_) = tref_preprocess(tref); + let (tref, mut_) = preprocess_tref(tref, self_); // To support Rustc 1.39.0, we don't use as_deref here... let tmp_as_deref = if mut_.is_some() { quote! { _tmp.as_mut().map(std::ops::DerefMut::deref_mut) } @@ -524,7 +529,7 @@ fn impl_arg_param( }; } } else if let syn::Type::Reference(tref) = arg.ty { - let (tref, mut_) = tref_preprocess(tref); + let (tref, mut_) = preprocess_tref(tref, self_); // Get &T from PyRef quote! { let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target @@ -537,12 +542,34 @@ fn impl_arg_param( } }; - fn tref_preprocess(tref: &syn::TypeReference) -> (syn::TypeReference, Option) { + /// Replace `Self`, remove lifetime and get mutability from the type + fn preprocess_tref( + tref: &syn::TypeReference, + self_: Option<&syn::Type>, + ) -> (syn::TypeReference, Option) { let mut tref = tref.to_owned(); + if let Some(syn::Type::Path(tpath)) = self_ { + replace_self(&mut tref, &tpath.path); + } tref.lifetime = None; let mut_ = tref.mutability; (tref, mut_) } + + /// Replace `Self` with the exact type name since it is used out of the impl block + fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) { + match &mut *tref.elem { + syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path), + syn::Type::Path(ref mut tpath) => { + if let Some(ident) = tpath.path.get_ident() { + if ident == "Self" { + tpath.path = self_path.to_owned(); + } + } + } + _ => {} + } + } } pub fn impl_py_method_def(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream { diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 20292fda..40db57ae 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -16,6 +16,11 @@ impl InstanceMethod { fn method(&self) -> PyResult { Ok(self.member) } + + // Checks that &Self works + fn add_other(&self, other: &Self) -> i32 { + self.member + other.member + } } #[test] @@ -26,10 +31,9 @@ fn instance_method() { let obj = PyCell::new(py, InstanceMethod { member: 42 }).unwrap(); let obj_ref = obj.borrow(); assert_eq!(obj_ref.method().unwrap(), 42); - let d = [("obj", obj)].into_py_dict(py); - py.run("assert obj.method() == 42", None, Some(d)).unwrap(); - py.run("assert obj.method.__doc__ == 'Test method'", None, Some(d)) - .unwrap(); + py_assert!(py, obj, "obj.method() == 42"); + py_assert!(py, obj, "obj.add_other(obj) == 84"); + py_assert!(py, obj, "obj.method.__doc__ == 'Test method'"); } #[pyclass]