diff --git a/newsfragments/3587.added.md b/newsfragments/3587.added.md new file mode 100644 index 00000000..f8ea280d --- /dev/null +++ b/newsfragments/3587.added.md @@ -0,0 +1,2 @@ +- Classmethods can now receive `Py` as their first argument +- Function annotated with `pass_module` can now receive `Py` as their first argument \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 428efc95..2cff0a79 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -113,12 +113,14 @@ impl FnType { } FnType::FnClass | FnType::FnNewClass => { quote! { - _pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject), + #[allow(clippy::useless_conversion)] + ::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject)), } } FnType::FnModule => { quote! { - py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf), + #[allow(clippy::useless_conversion)] + ::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)), } } } diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index b1a2bcc7..f1985047 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -199,7 +199,8 @@ pub fn impl_wrap_pyfunction( .collect::>>()?; let tp = if pass_module.is_some() { - const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`"; + const PASS_MODULE_ERR: &str = + "expected &PyModule or Py as first argument with `pass_module`"; ensure_spanned!( !arguments.is_empty(), func.span() => PASS_MODULE_ERR @@ -271,18 +272,32 @@ pub fn impl_wrap_pyfunction( } fn type_is_pymodule(ty: &syn::Type) -> bool { - if let syn::Type::Reference(tyref) = ty { - if let syn::Type::Path(typath) = tyref.elem.as_ref() { - if typath - .path - .segments - .last() - .map(|seg| seg.ident == "PyModule") - .unwrap_or(false) - { - return true; + let is_pymodule = |typath: &syn::TypePath| { + typath + .path + .segments + .last() + .map_or(false, |seg| seg.ident == "PyModule") + }; + match ty { + syn::Type::Reference(tyref) => { + if let syn::Type::Path(typath) = tyref.elem.as_ref() { + return is_pymodule(typath); } } + syn::Type::Path(typath) => { + if let Some(syn::PathSegment { + arguments: syn::PathArguments::AngleBracketed(args), + .. + }) = typath.path.segments.last() + { + if args.args.len() != 1 { + return false; + } + return matches!(args.args.first().unwrap(), syn::GenericArgument::Type(syn::Type::Path(typath)) if is_pymodule(typath)); + } + } + _ => {} } false } diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 8de5f556..7919ac0c 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -76,6 +76,14 @@ impl ClassMethod { fn method(cls: &PyType) -> PyResult { Ok(format!("{}.method()!", cls.name()?)) } + + #[classmethod] + fn method_owned(cls: Py) -> PyResult { + Ok(format!( + "{}.method_owned()!", + Python::with_gil(|gil| cls.as_ref(gil).name().map(ToString::to_string))? + )) + } } #[test] @@ -84,6 +92,11 @@ fn class_method() { let d = [("C", py.get_type::())].into_py_dict(py); py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'"); py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'"); + py_assert!( + py, + *d, + "C().method_owned() == 'ClassMethod.method_owned()!'" + ); py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'"); py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'"); }); diff --git a/tests/test_module.rs b/tests/test_module.rs index aef6995c..2de23b38 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -348,6 +348,12 @@ fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { module.name() } +#[pyfunction] +#[pyo3(pass_module)] +fn pyfunction_with_module_owned(module: Py) -> PyResult { + Python::with_gil(|gil| module.as_ref(gil).name().map(Into::into)) +} + #[pyfunction] #[pyo3(pass_module)] fn pyfunction_with_module_and_py<'a>( @@ -393,6 +399,7 @@ fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str> #[pymodule] fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?; + m.add_function(wrap_pyfunction!(pyfunction_with_module_owned, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?; @@ -401,6 +408,7 @@ fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult< pyfunction_with_pass_module_in_attribute, m )?)?; + m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?; Ok(()) } @@ -413,6 +421,11 @@ fn test_module_functions_with_module() { m, "m.pyfunction_with_module() == 'module_with_functions_with_module'" ); + py_assert!( + py, + m, + "m.pyfunction_with_module_owned() == 'module_with_functions_with_module'" + ); py_assert!( py, m, diff --git a/tests/ui/invalid_need_module_arg_position.stderr b/tests/ui/invalid_need_module_arg_position.stderr index 8fce151f..65ab4b16 100644 --- a/tests/ui/invalid_need_module_arg_position.stderr +++ b/tests/ui/invalid_need_module_arg_position.stderr @@ -1,4 +1,4 @@ -error: expected &PyModule as first argument with `pass_module` +error: expected &PyModule or Py as first argument with `pass_module` --> tests/ui/invalid_need_module_arg_position.rs:6:21 | 6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {