diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index a6947b26..4e4607c2 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -4,6 +4,7 @@ use crate::{ attributes::FromPyWithAttribute, method::{FnArg, FnSpec}, pyfunction::Argument, + utils::unwrap_ty_group, }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; @@ -258,7 +259,7 @@ fn impl_arg_param( } }; - return if let syn::Type::Reference(tref) = arg.optional.as_ref().unwrap_or(&ty) { + return if let syn::Type::Reference(tref) = unwrap_ty_group(arg.optional.unwrap_or(&ty)) { let (tref, mut_) = preprocess_tref(tref, self_); let (target_ty, borrow_tmp) = if arg.optional.is_some() { // Get Option<&T> from Option> diff --git a/pyo3-macros-backend/src/utils.rs b/pyo3-macros-backend/src/utils.rs index cc35c108..101a9b05 100644 --- a/pyo3-macros-backend/src/utils.rs +++ b/pyo3-macros-backend/src/utils.rs @@ -29,12 +29,8 @@ macro_rules! ensure_spanned { } /// Check if the given type `ty` is `pyo3::Python`. -pub fn is_python(mut ty: &syn::Type) -> bool { - while let syn::Type::Group(group) = ty { - // Macros can create invisible delimiters around types. - ty = &*group.elem; - } - match ty { +pub fn is_python(ty: &syn::Type) -> bool { + match unwrap_ty_group(ty) { syn::Type::Path(typath) => typath .path .segments @@ -124,3 +120,10 @@ pub fn unwrap_group(mut expr: &syn::Expr) -> &syn::Expr { } expr } + +pub fn unwrap_ty_group(mut ty: &syn::Type) -> &syn::Type { + while let syn::Type::Group(g) = ty { + ty = &*g.elem; + } + ty +}