diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index c974e860..5697df95 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -17,7 +17,6 @@ pub struct FnArg<'a> { pub ty: &'a syn::Type, pub optional: Option<&'a syn::Type>, pub py: bool, - pub reference: bool, } #[derive(Clone, PartialEq, Debug, Copy, Eq)] @@ -214,17 +213,13 @@ impl<'a> FnSpec<'a> { } }; - let py = crate::utils::if_type_is_python(ty); - - let opt = check_ty_optional(ty); arguments.push(FnArg { name: ident, by_ref, mutability, ty, - optional: opt, - py, - reference: is_ref(name, ty)?, + optional: utils::option_type_argument(ty), + py: utils::is_python(ty), }); } } @@ -323,57 +318,6 @@ impl<'a> FnSpec<'a> { } } -pub fn is_ref(name: &syn::Ident, ty: &syn::Type) -> syn::Result { - match ty { - syn::Type::Reference(_) => return Ok(true), - syn::Type::Path(syn::TypePath { ref path, .. }) => { - if let Some(segment) = path.segments.last() { - if "Option" == segment.ident.to_string().as_str() { - match segment.arguments { - syn::PathArguments::AngleBracketed(ref params) => { - if params.args.len() != 1 { - let msg = format!("argument type is not supported by python method: {:?} ({:?}) {:?}", - name, - ty, - path); - syn::Error::new_spanned(segment, msg); - } - let last = ¶ms.args[params.args.len() - 1]; - if let syn::GenericArgument::Type(syn::Type::Reference(_)) = last { - return Ok(true); - } - } - _ => { - let msg = format!( - "argument type is not supported by python method: {:?} ({:?}) {:?}", - name, ty, path - ); - syn::Error::new_spanned(segment, msg); - } - } - } - } - } - _ => (), - } - Ok(false) -} - -pub(crate) fn check_ty_optional(ty: &syn::Type) -> Option<&syn::Type> { - let path = match ty { - syn::Type::Path(syn::TypePath { ref path, .. }) => path, - _ => return None, - }; - let seg = path.segments.last().filter(|s| s.ident == "Option")?; - match seg.arguments { - syn::PathArguments::AngleBracketed(ref params) => match params.args.first() { - Some(syn::GenericArgument::Type(ref ty)) => Some(ty), - _ => None, - }, - _ => None, - } -} - #[derive(Clone, PartialEq, Debug)] struct MethodAttributes { ty: Option, diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 9a51a1e5..bd6e4182 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -59,22 +59,19 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { } /// Transforms a rust fn arg parsed with syn into a method::FnArg -fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> syn::Result> { +fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result> { let (mutability, by_ref, ident) = match *cap.pat { syn::Pat::Ident(ref patid) => (&patid.mutability, &patid.by_ref, &patid.ident), _ => return Err(syn::Error::new_spanned(&cap.pat, "Unsupported argument")), }; - let py = crate::utils::if_type_is_python(&cap.ty); - let opt = method::check_ty_optional(&cap.ty); Ok(method::FnArg { name: ident, mutability, by_ref, ty: &cap.ty, - optional: opt, - py, - reference: method::is_ref(&name, &cap.ty)?, + optional: utils::option_type_argument(&cap.ty), + py: utils::is_python(&cap.ty), }) } @@ -164,7 +161,7 @@ pub fn add_fn_to_module( )) } syn::FnArg::Typed(ref cap) => { - arguments.push(wrap_fn_argument(cap, &func.sig.ident)?); + arguments.push(wrap_fn_argument(cap)?); } } } diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index fbda2efb..62ced447 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -488,58 +488,48 @@ fn impl_arg_param( let arg_value = quote!(output[#option_pos]); *option_pos += 1; - return if let Some(ty) = arg.optional.as_ref() { - let default = if let Some(d) = spec.default_value(name).filter(|d| d.to_string() != "None") - { - quote! { Some(#d) } - } else { - quote! { None } - }; - if let syn::Type::Reference(tref) = ty { - let (tref, mut_) = preprocess_tref(tref, self_); - // To support Rustc 1.39.0, we don't use as_deref here... - let tmp_as_deref = if mut_.is_some() { - quote! { _tmp.as_mut().map(std::ops::DerefMut::deref_mut) } - } else { - quote! { _tmp.as_ref().map(std::ops::Deref::deref) } - }; + let default = match (spec.default_value(name), arg.optional.is_some()) { + (Some(default), true) if default.to_string() != "None" => quote! { Some(#default) }, + (Some(default), _) => quote! { #default }, + (None, true) => quote! { None }, + (None, false) => quote! { panic!("Failed to extract required method argument") }, + }; + + return if let syn::Type::Reference(tref) = arg.optional.as_ref().unwrap_or(&ty) { + let (tref, mut_) = preprocess_tref(tref, self_); + let (target_ty, borrow_tmp) = if arg.optional.is_some() { // Get Option<&T> from Option> - quote! { - let #mut_ _tmp = match #arg_value { - Some(_obj) => { - _obj.extract::::Target>>()? - }, - None => #default, - }; - let #arg_name = #tmp_as_deref; - } + ( + quote! { Option<<#tref as pyo3::derive_utils::ExtractExt>::Target> }, + // To support Rustc 1.39.0, we don't use as_deref here... + if mut_.is_some() { + quote! { _tmp.as_mut().map(std::ops::DerefMut::deref_mut) } + } else { + quote! { _tmp.as_ref().map(std::ops::Deref::deref) } + }, + ) } else { - quote! { - let #arg_name = match #arg_value { - Some(_obj) => _obj.extract()?, - None => #default, - }; - } + // Get &T from PyRef + ( + quote! { <#tref as pyo3::derive_utils::ExtractExt>::Target }, + quote! { &#mut_ *_tmp }, + ) + }; + + quote! { + let #mut_ _tmp: #target_ty = match #arg_value { + Some(_obj) => _obj.extract()?, + None => #default, + }; + let #arg_name = #borrow_tmp; } - } else if let Some(default) = spec.default_value(name) { + } else { quote! { let #arg_name = match #arg_value { Some(_obj) => _obj.extract()?, None => #default, }; } - } else if let syn::Type::Reference(tref) = arg.ty { - let (tref, mut_) = preprocess_tref(tref, self_); - // Get &T from PyRef - quote! { - let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target - = #arg_value.unwrap().extract()?; - let #arg_name = &#mut_ *_tmp; - } - } else { - quote! { - let #arg_name = #arg_value.unwrap().extract()?; - } }; /// Replace `Self`, remove lifetime and get mutability from the type @@ -739,7 +729,11 @@ pub(crate) fn impl_py_getter_def( /// Split an argument of pyo3::Python from the front of the arg list, if present fn split_off_python_arg<'a>(args: &'a [FnArg<'a>]) -> (Option<&FnArg>, &[FnArg]) { - if args.get(0).map(|py| utils::if_type_is_python(&py.ty)) == Some(true) { + if args + .get(0) + .map(|py| utils::is_python(&py.ty)) + .unwrap_or(false) + { (Some(&args[0]), &args[1..]) } else { (None, args) diff --git a/pyo3-derive-backend/src/utils.rs b/pyo3-derive-backend/src/utils.rs index 6bd78570..1bfc1a05 100644 --- a/pyo3-derive-backend/src/utils.rs +++ b/pyo3-derive-backend/src/utils.rs @@ -8,7 +8,7 @@ pub fn print_err(msg: String, t: TokenStream) { } /// Check if the given type `ty` is `pyo3::Python`. -pub fn if_type_is_python(ty: &syn::Type) -> bool { +pub fn is_python(ty: &syn::Type) -> bool { match ty { syn::Type::Path(ref typath) => typath .path @@ -20,6 +20,19 @@ pub fn if_type_is_python(ty: &syn::Type) -> bool { } } +/// If `ty` is Option, return `Some(T)`, else None. +pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> { + if let syn::Type::Path(syn::TypePath { path, .. }) = ty { + let seg = path.segments.last().filter(|s| s.ident == "Option")?; + if let syn::PathArguments::AngleBracketed(params) = &seg.arguments { + if let syn::GenericArgument::Type(ty) = params.args.first()? { + return Some(ty); + } + } + } + None +} + pub fn is_text_signature_attr(attr: &syn::Attribute) -> bool { attr.path.is_ident("text_signature") }