diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index f95fe820..9efe22d2 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -27,7 +27,7 @@ pub enum FnType { FnCall, FnClass, FnStatic, - PySelf(syn::TypePath), + PySelfNew(syn::TypeReference), } #[derive(Clone, PartialEq, Debug)] @@ -103,13 +103,16 @@ impl<'a> FnSpec<'a> { if fn_type == FnType::Fn && !has_self { if arguments.is_empty() { - panic!("Static method needs #[staticmethod] attribute"); + return Err(syn::Error::new_spanned( + name, + "Static method needs #[staticmethod] attribute", + )); } let tp = match arguments.remove(0).ty { - syn::Type::Path(p) => replace_self(p), - _ => panic!("Invalid type as self"), + syn::Type::Reference(r) => replace_self(r)?, + x => return Err(syn::Error::new_spanned(x, "Invalid type as custom self")), }; - fn_type = FnType::PySelf(tp); + fn_type = FnType::PySelfNew(tp); } Ok(FnSpec { @@ -386,15 +389,19 @@ fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec } } -// Replace A with A<_> -fn replace_self(path: &syn::TypePath) -> syn::TypePath { +// Replace &A with &A<_> +fn replace_self(refn: &syn::TypeReference) -> syn::Result { fn infer(span: proc_macro2::Span) -> syn::GenericArgument { syn::GenericArgument::Type(syn::Type::Infer(syn::TypeInfer { underscore_token: syn::token::Underscore { spans: [span] }, })) } - let mut res = path.to_owned(); - for seg in &mut res.path.segments { + let mut res = refn.to_owned(); + let tp = match &mut *res.elem { + syn::Type::Path(p) => p, + _ => return Err(syn::Error::new_spanned(refn, "unsupported argument")), + }; + for seg in &mut tp.path.segments { if let syn::PathArguments::AngleBracketed(ref mut g) = seg.arguments { let mut args = syn::punctuated::Punctuated::new(); for arg in &g.args { @@ -415,5 +422,6 @@ fn replace_self(path: &syn::TypePath) -> syn::TypePath { g.args = args; } } - res + res.lifetime = None; + Ok(res) } diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 4d01f99c..e47edc04 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -34,7 +34,7 @@ pub fn gen_py_method( }; let text_signature = match &spec.tp { - FnType::Fn | FnType::PySelf(_) | FnType::FnClass | FnType::FnStatic => { + FnType::Fn | FnType::PySelfNew(_) | FnType::FnClass | FnType::FnStatic => { utils::parse_text_signature_attrs(&mut *meth_attrs, name)? } FnType::FnNew => parse_erroneous_text_signature( @@ -59,7 +59,7 @@ pub fn gen_py_method( Ok(match spec.tp { FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)), - FnType::PySelf(ref self_ty) => impl_py_method_def( + FnType::PySelfNew(ref self_ty) => impl_py_method_def( name, doc, &spec, @@ -127,7 +127,7 @@ pub fn impl_wrap_pyslf( cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>, - self_ty: &syn::TypePath, + self_ty: &syn::TypeReference, noargs: bool, ) -> TokenStream { let names = get_arg_names(spec); @@ -221,8 +221,7 @@ pub fn impl_proto_wrap(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> /// Generate class method wrapper (PyCFunction, PyCFunctionWithKeywords) pub fn impl_wrap_new(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> TokenStream { let names: Vec = get_arg_names(&spec); - let cb = quote! { #cls::#name(&_obj, #(#names),*) }; - + let cb = quote! { #cls::#name(#(#names),*) }; let body = impl_arg_params(spec, cb); quote! { @@ -240,11 +239,11 @@ pub fn impl_wrap_new(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) -> T let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); - #body + # body - match <<#cls as pyo3::PyTypeInfo>::ConcreteLayout as pyo3::pyclass::PyClassNew>::new(_py, _result) { - Ok(_slf) => _slf as _, - Err(e) => e.restore_and_null(), + match _result.and_then(|slf| pyo3::PyClassShell::new(_py, slf)) { + Ok(slf) => slf as _, + Err(e) => e.restore_and_null(_py), } } }