diff --git a/CHANGELOG.md b/CHANGELOG.md index cc44503d..62a1e0f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Don't rely on the order of structmembers to compute offsets in PyCell. Related to [#1058](https://github.com/PyO3/pyo3/pull/1058). [#1059](https://github.com/PyO3/pyo3/pull/1059) - Allows `&Self` as a `#[pymethods]` argument again. [#1071](https://github.com/PyO3/pyo3/pull/1071) +- Improve lifetime insertion in `#[pyproto]`. [#1093](https://github.com/PyO3/pyo3/pull/1093) ## [0.11.1] - 2020-06-30 ### Added diff --git a/guide/src/class.md b/guide/src/class.md index b9f7947b..bcd21741 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -923,8 +923,8 @@ struct MyIterator { #[pyproto] impl PyIterProtocol for MyIterator { - fn __iter__(slf: PyRef) -> Py { - slf.into() + fn __iter__(slf: PyRef) -> PyRef { + slf } fn __next__(mut slf: PyRefMut) -> Option { slf.iter.next() @@ -948,8 +948,8 @@ struct Iter { #[pyproto] impl PyIterProtocol for Iter { - fn __iter__(slf: PyRefMut) -> Py { - slf.into() + fn __iter__(slf: PyRef) -> PyRef { + slf } fn __next__(mut slf: PyRefMut) -> Option { @@ -964,7 +964,7 @@ struct Container { #[pyproto] impl PyIterProtocol for Container { - fn __iter__(slf: PyRefMut) -> PyResult> { + fn __iter__(slf: PyRef) -> PyResult> { let iter = Iter { inner: slf.iter.clone().into_iter(), }; diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 13b15cd9..6004a368 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -1,5 +1,5 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use crate::func::MethodProto; +use crate::proto_method::MethodProto; /// Predicates for `#[pyproto]`. pub struct Proto { diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index d5e4e940..5695adf4 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -4,10 +4,10 @@ #![recursion_limit = "1024"] mod defs; -mod func; mod konst; mod method; mod module; +mod proto_method; mod pyclass; mod pyfunction; mod pyimpl; diff --git a/pyo3-derive-backend/src/func.rs b/pyo3-derive-backend/src/proto_method.rs similarity index 84% rename from pyo3-derive-backend/src/func.rs rename to pyo3-derive-backend/src/proto_method.rs index 0a88e4dd..a4c49071 100644 --- a/pyo3-derive-backend/src/func.rs +++ b/pyo3-derive-backend/src/proto_method.rs @@ -6,7 +6,6 @@ use syn::Token; // TODO: // Add lifetime support for args with Rptr - #[derive(Debug)] pub enum MethodProto { Free { @@ -77,7 +76,11 @@ pub(crate) fn impl_method_proto( ) -> TokenStream { let ret_ty = match &sig.output { syn::ReturnType::Default => quote! { () }, - syn::ReturnType::Type(_, ty) => ty.to_token_stream(), + syn::ReturnType::Type(_, ty) => { + let mut ty = ty.clone(); + insert_lifetime(&mut ty); + ty.to_token_stream() + } }; match *meth { @@ -106,22 +109,7 @@ pub(crate) fn impl_method_proto( let p: syn::Path = syn::parse_str(proto).unwrap(); let slf_name = syn::Ident::new(arg, Span::call_site()); - let mut slf_ty = get_arg_ty(sig, 0); - - // update the type if no lifetime was given: - // PyRef --> PyRef<'p, Self> - if let syn::Type::Path(ref mut path) = slf_ty { - if let syn::PathArguments::AngleBracketed(ref mut args) = - path.path.segments[0].arguments - { - if let syn::GenericArgument::Lifetime(_) = args.args[0] { - } else { - let lt = syn::parse_quote! {'p}; - args.args.insert(0, lt); - } - } - } - + let slf_ty = get_arg_ty(sig, 0); let tmp: syn::ItemFn = syn::parse_quote! { fn test(&self) -> <#cls as #p<'p>>::Result {} }; @@ -336,40 +324,64 @@ pub(crate) fn impl_method_proto( } } -// TODO: better arg ty detection +/// Some hacks for arguments: get `T` from `Option` and insert lifetime fn get_arg_ty(sig: &syn::Signature, idx: usize) -> syn::Type { - let mut ty = match sig.inputs[idx] { - syn::FnArg::Typed(ref cap) => { - match *cap.ty { - syn::Type::Path(ref ty) => { - // use only last path segment for Option<> - let seg = ty.path.segments.last().unwrap().clone(); - if seg.ident == "Option" { - if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments { - if let Some(pair) = data.args.last() { - match pair { - syn::GenericArgument::Type(ref ty) => return ty.clone(), - _ => panic!("Option only accepted for concrete types"), - } - }; - } - } - *cap.ty.clone() + fn get_option_ty(path: &syn::Path) -> Option { + let seg = path.segments.last()?; + if seg.ident == "Option" { + if let syn::PathArguments::AngleBracketed(ref data) = seg.arguments { + if let Some(syn::GenericArgument::Type(ref ty)) = data.args.last() { + return Some(ty.to_owned()); } - _ => *cap.ty.clone(), } } - _ => panic!("fn arg type is not supported"), - }; - - // Add a lifetime if there is none - if let syn::Type::Reference(ref mut r) = ty { - r.lifetime.get_or_insert(syn::parse_quote! {'p}); + None } + let mut ty = match &sig.inputs[idx] { + syn::FnArg::Typed(ref cap) => match &*cap.ty { + // For `Option`, we use `T` as an associated type for the protocol. + syn::Type::Path(ref ty) => get_option_ty(&ty.path).unwrap_or_else(|| *cap.ty.clone()), + _ => *cap.ty.clone(), + }, + ty => panic!("Unsupported argument type: {:?}", ty), + }; + insert_lifetime(&mut ty); ty } +/// Insert lifetime `'p` to `PyRef` or references (e.g., `&PyType`). +fn insert_lifetime(ty: &mut syn::Type) { + fn insert_lifetime_for_path(path: &mut syn::TypePath) { + if let Some(seg) = path.path.segments.last_mut() { + if let syn::PathArguments::AngleBracketed(ref mut args) = seg.arguments { + let mut has_lifetime = false; + for arg in &mut args.args { + match arg { + // Insert `'p` recursively for `Option>` or so. + syn::GenericArgument::Type(ref mut ty) => insert_lifetime(ty), + syn::GenericArgument::Lifetime(_) => has_lifetime = true, + _ => {} + } + } + // Insert lifetime to PyRef (i.e., PyRef -> PyRef<'p, Self>) + if !has_lifetime && (seg.ident == "PyRef" || seg.ident == "PyRefMut") { + args.args.insert(0, syn::parse_quote! {'p}); + } + } + } + } + + match ty { + syn::Type::Reference(ref mut r) => { + r.lifetime.get_or_insert(syn::parse_quote! {'p}); + insert_lifetime(&mut *r.elem); + } + syn::Type::Path(ref mut path) => insert_lifetime_for_path(path), + _ => {} + } +} + fn extract_decl(spec: syn::Item) -> syn::Signature { match spec { syn::Item::Fn(f) => f.sig, diff --git a/pyo3-derive-backend/src/pyproto.rs b/pyo3-derive-backend/src/pyproto.rs index 308bcbe3..44fc2065 100644 --- a/pyo3-derive-backend/src/pyproto.rs +++ b/pyo3-derive-backend/src/pyproto.rs @@ -1,8 +1,8 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use crate::defs; -use crate::func::impl_method_proto; use crate::method::{FnSpec, FnType}; +use crate::proto_method::impl_method_proto; use crate::pymethod; use proc_macro2::{Span, TokenStream}; use quote::quote; diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 34828ca1..5cc56f61 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -51,12 +51,12 @@ struct Iterator { } #[pyproto] -impl<'p> PyIterProtocol for Iterator { - fn __iter__(slf: PyRef<'p, Self>) -> Py { - slf.into() +impl PyIterProtocol for Iterator { + fn __iter__(slf: PyRef) -> PyRef { + slf } - fn __next__(mut slf: PyRefMut<'p, Self>) -> Option { + fn __next__(mut slf: PyRefMut) -> Option { slf.iter.next() } } @@ -81,7 +81,7 @@ fn iterator() { struct StringMethods {} #[pyproto] -impl<'p> PyObjectProtocol<'p> for StringMethods { +impl PyObjectProtocol for StringMethods { fn __str__(&self) -> &'static str { "str" } @@ -236,7 +236,7 @@ struct SetItem { } #[pyproto] -impl PyMappingProtocol<'a> for SetItem { +impl PyMappingProtocol for SetItem { fn __setitem__(&mut self, key: i32, val: i32) { self.key = key; self.val = val; @@ -362,16 +362,16 @@ struct ContextManager { } #[pyproto] -impl<'p> PyContextProtocol<'p> for ContextManager { +impl PyContextProtocol for ContextManager { fn __enter__(&mut self) -> i32 { 42 } fn __exit__( &mut self, - ty: Option<&'p PyType>, - _value: Option<&'p PyAny>, - _traceback: Option<&'p PyAny>, + ty: Option<&PyType>, + _value: Option<&PyAny>, + _traceback: Option<&PyAny>, ) -> bool { let gil = Python::acquire_gil(); self.exit_called = true; @@ -564,14 +564,14 @@ impl OnceFuture { #[pyproto] impl PyAsyncProtocol for OnceFuture { - fn __await__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> { + fn __await__(slf: PyRef) -> PyRef { slf } } #[pyproto] impl PyIterProtocol for OnceFuture { - fn __iter__(slf: PyRef<'p, Self>) -> PyRef<'p, Self> { + fn __iter__(slf: PyRef) -> PyRef { slf } fn __next__(mut slf: PyRefMut) -> Option { @@ -632,14 +632,14 @@ impl DescrCounter { #[pyproto] impl PyDescrProtocol for DescrCounter { fn __get__( - mut slf: PyRefMut<'p, Self>, + mut slf: PyRefMut, _instance: &PyAny, - _owner: Option<&'p PyType>, - ) -> PyRefMut<'p, Self> { + _owner: Option<&PyType>, + ) -> PyRefMut { slf.count += 1; slf } - fn __set__(_slf: PyRef<'p, Self>, _instance: &PyAny, mut new_value: PyRefMut<'p, Self>) { + fn __set__(_slf: PyRef, _instance: &PyAny, mut new_value: PyRefMut) { new_value.count = _slf.count; } }