diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 034d079b..b9c78ef6 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -81,7 +81,7 @@ pub enum FnType { FnNewClass, FnClass, FnStatic, - FnModule, + FnModule(Span), ClassAttribute, } @@ -93,7 +93,7 @@ impl FnType { | FnType::Fn(_) | FnType::FnClass | FnType::FnNewClass - | FnType::FnModule => true, + | FnType::FnModule(_) => true, FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => false, } } @@ -117,8 +117,8 @@ impl FnType { ::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject)), } } - FnType::FnModule => { - quote! { + FnType::FnModule(span) => { + quote_spanned! { *span => #[allow(clippy::useless_conversion)] ::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)), } @@ -633,7 +633,7 @@ impl<'a> FnSpec<'a> { // Getters / Setters / ClassAttribute are not callables on the Python side FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None, FnType::Fn(_) => Some("self"), - FnType::FnModule => Some("module"), + FnType::FnModule(_) => Some("module"), FnType::FnClass | FnType::FnNewClass => Some("cls"), FnType::FnStatic | FnType::FnNew => None, }; diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 12f25411..5aedc410 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -189,30 +189,30 @@ pub fn impl_wrap_pyfunction( let python_name = name.map_or_else(|| func.sig.ident.unraw(), |name| name.value.0); - let mut arguments = func - .sig - .inputs - .iter_mut() - .map(FnArg::parse) - .collect::>>()?; - let tp = if pass_module.is_some() { - const PASS_MODULE_ERR: &str = - "expected &PyModule or Py as first argument with `pass_module`"; - ensure_spanned!( - !arguments.is_empty(), - func.span() => PASS_MODULE_ERR - ); - let arg = arguments.remove(0); - ensure_spanned!( - type_is_pymodule(arg.ty), - arg.ty.span() => PASS_MODULE_ERR - ); - method::FnType::FnModule + let span = match func.sig.inputs.first() { + Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(), + Some(syn::FnArg::Receiver(_)) | None => bail_spanned!( + func.span() => "expected `&PyModule` or `Py` as first argument with `pass_module`" + ), + }; + method::FnType::FnModule(span) } else { method::FnType::FnStatic }; + let arguments = func + .sig + .inputs + .iter_mut() + .skip(if tp.skip_first_rust_argument_in_python_signature() { + 1 + } else { + 0 + }) + .map(FnArg::parse) + .collect::>>()?; + let signature = if let Some(signature) = signature { FunctionSignature::from_arguments_and_attribute(arguments, signature)? } else { @@ -269,34 +269,3 @@ pub fn impl_wrap_pyfunction( }; Ok(wrapped_pyfunction) } - -fn type_is_pymodule(ty: &syn::Type) -> bool { - let is_pymodule = |typath: &syn::TypePath| { - typath - .path - .segments - .last() - .map_or(false, |seg| seg.ident == "PyModule") - }; - match ty { - syn::Type::Reference(tyref) => { - if let syn::Type::Path(typath) = tyref.elem.as_ref() { - return is_pymodule(typath); - } - } - syn::Type::Path(typath) => { - if let Some(syn::PathSegment { - arguments: syn::PathArguments::AngleBracketed(args), - .. - }) = typath.path.segments.last() - { - if args.args.len() != 1 { - return false; - } - return matches!(args.args.first().unwrap(), syn::GenericArgument::Type(syn::Type::Path(typath)) if is_pymodule(typath)); - } - } - _ => {} - } - false -} diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index a8fd3b41..e058e2f2 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -257,7 +257,7 @@ pub fn gen_py_method( doc: spec.get_doc(meth_attrs), }, )?), - (_, FnType::FnModule) => { + (_, FnType::FnModule(_)) => { unreachable!("methods cannot be FnModule") } }) diff --git a/tests/ui/invalid_need_module_arg_position.rs b/tests/ui/invalid_need_module_arg_position.rs index b3722ae4..2d45f35b 100644 --- a/tests/ui/invalid_need_module_arg_position.rs +++ b/tests/ui/invalid_need_module_arg_position.rs @@ -3,10 +3,10 @@ use pyo3::prelude::*; #[pymodule] fn module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { #[pyfn(m, pass_module)] - fn fail(string: &str, module: &PyModule) -> PyResult<&str> { + fn fail<'py>(string: &str, module: &'py PyModule) -> PyResult<&'py str> { module.name() } Ok(()) } -fn main(){} +fn main() {} diff --git a/tests/ui/invalid_need_module_arg_position.stderr b/tests/ui/invalid_need_module_arg_position.stderr index 65ab4b16..b9231c30 100644 --- a/tests/ui/invalid_need_module_arg_position.stderr +++ b/tests/ui/invalid_need_module_arg_position.stderr @@ -1,5 +1,14 @@ -error: expected &PyModule or Py as first argument with `pass_module` - --> tests/ui/invalid_need_module_arg_position.rs:6:21 +error[E0277]: the trait bound `&str: From<&pyo3::prelude::PyModule>` is not satisfied + --> tests/ui/invalid_need_module_arg_position.rs:6:26 | -6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> { - | ^ +6 | fn fail<'py>(string: &str, module: &'py PyModule) -> PyResult<&'py str> { + | ^ the trait `From<&pyo3::prelude::PyModule>` is not implemented for `&str` + | + = help: the following other types implement trait `From`: + > + >> + >> + > + > + > + = note: required for `&pyo3::prelude::PyModule` to implement `Into<&str>`