diff --git a/CHANGELOG.md b/CHANGELOG.md index f237cfbf..209d369b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067) - Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081) - Add check for correct number of arguments on magic methods. [#2083](https://github.com/PyO3/pyo3/pull/2083) +- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091) ### Changed diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 2dea4d79..ecfcc9f4 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -99,10 +99,11 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt { if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? { let module_name = pyfn_args.modname; - let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?; + let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?; + let name = &func.sig.ident; let statements: Vec = syn::parse_quote! { #wrapped_function - #module_name.add_function(#ident(#module_name)?)?; + #module_name.add_function(#name::wrap(#name::DEF, #module_name)?)?; }; stmts.extend(statements); } diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index df4e12ae..a45c6a9a 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -9,12 +9,11 @@ use crate::{ method::{self, CallingConvention, FnArg}, pymethod::check_generic, utils::{self, ensure_not_async_fn, get_pyo3_crate}, - wrap::function_wrapper_ident, }; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::punctuated::Punctuated; -use syn::{ext::IdentExt, spanned::Spanned, Ident, NestedMeta, Path, Result}; +use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result}; use syn::{ parse::{Parse, ParseBuffer, ParseStream}, token::Comma, @@ -364,7 +363,7 @@ pub fn build_py_function( mut options: PyFunctionOptions, ) -> syn::Result { options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?; - Ok(impl_wrap_pyfunction(ast, options)?.1) + impl_wrap_pyfunction(ast, options) } /// Generates python wrapper over a function that allows adding it to a python module as a python @@ -372,7 +371,7 @@ pub fn build_py_function( pub fn impl_wrap_pyfunction( func: &mut syn::ItemFn, options: PyFunctionOptions, -) -> syn::Result<(Ident, TokenStream)> { +) -> syn::Result { check_generic(&func.sig)?; ensure_not_async_fn(&func.sig)?; @@ -412,7 +411,6 @@ pub fn impl_wrap_pyfunction( .map(|attr| (&python_name, attr)), ); - let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); let krate = get_pyo3_crate(&options.krate); let spec = method::FnSpec { @@ -434,21 +432,40 @@ pub fn impl_wrap_pyfunction( unsafety: func.sig.unsafety, }; - let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name); + let vis = &func.vis; + let name = &func.sig.ident; + + let wrapper_ident = format_ident!("__pyfunction_{}", spec.name); let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?; let methoddef = spec.get_methoddef(wrapper_ident); let wrapped_pyfunction = quote! { #wrapper - pub(crate) fn #function_wrapper_ident<'a>( - args: impl ::std::convert::Into<#krate::derive_utils::PyFunctionArguments<'a>> - ) -> #krate::PyResult<&'a #krate::types::PyCFunction> { + // Create a module with the same name as the `#[pyfunction]` - this way `use ` + // will actually bring both the module and the function into scope. + #[doc(hidden)] + #vis mod #name { use #krate as _pyo3; - _pyo3::types::PyCFunction::internal_new(#methoddef, args.into()) + pub(crate) struct PyO3Def; + + // Exported for `wrap_pyfunction!` + pub use _pyo3::impl_::pyfunction::wrap_pyfunction as wrap; + pub const DEF: _pyo3::PyMethodDef = ::DEF; } + + // Generate the definition inside an anonymous function in the same scope as the original function - + // this avoids complications around the fact that the generated module has a different scope + // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is + // inside a function body) + const _: () = { + use #krate as _pyo3; + impl _pyo3::impl_::pyfunction::PyFunctionDef for #name::PyO3Def { + const DEF: _pyo3::PyMethodDef = #methoddef; + } + }; }; - Ok((function_wrapper_ident, wrapped_pyfunction)) + Ok(wrapped_pyfunction) } fn type_is_pymodule(ty: &syn::Type) -> bool { diff --git a/pyo3-macros-backend/src/wrap.rs b/pyo3-macros-backend/src/wrap.rs index e76d9cb6..881d65be 100644 --- a/pyo3-macros-backend/src/wrap.rs +++ b/pyo3-macros-backend/src/wrap.rs @@ -22,25 +22,16 @@ impl Parse for WrapPyFunctionArgs { } } -pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> syn::Result { +pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> TokenStream { let WrapPyFunctionArgs { - mut function, + function, comma_and_arg, } = args; - let span = function.span(); - let last_segment = function - .segments - .last_mut() - .ok_or_else(|| err_spanned!(span => "expected non-empty path"))?; - - last_segment.ident = function_wrapper_ident(&last_segment.ident); - - let output = if let Some((_, arg)) = comma_and_arg { - quote! { #function(#arg) } + if let Some((_, arg)) = comma_and_arg { + quote! { #function::wrap(#function::DEF, #arg) } } else { - quote! { &|arg| #function(arg) } - }; - Ok(output) + quote! { &|arg| #function::wrap(#function::DEF, arg) } + } } pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result { @@ -58,10 +49,6 @@ pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result Ident { - format_ident!("__pyo3_get_function_{}", name) -} - pub(crate) fn module_def_ident(name: &Ident) -> Ident { format_ident!("__PYO3_PYMODULE_DEF_{}", name.to_string().to_uppercase()) } diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 43de9a98..c03d3c86 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -166,7 +166,8 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream { /// A proc macro used to expose Rust functions to Python. /// -/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]` options: +/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]` +/// options: /// /// | Annotation | Description | /// | :- | :- | @@ -176,6 +177,11 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream { /// /// For more on exposing functions see the [function section of the guide][1]. /// +/// Due to technical limitations on how `#[pyfunction]` is implemented, a function marked +/// `#[pyfunction]` cannot have a module with the same name in the same scope. (The +/// `#[pyfunction]` implementation generates a hidden module with the same name containing +/// metadata about the function, which is used by `wrap_pyfunction!`). +/// /// [1]: https://pyo3.rs/latest/function.html #[proc_macro_attribute] pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { @@ -208,7 +214,7 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream { #[proc_macro] pub fn wrap_pyfunction(input: TokenStream) -> TokenStream { let args = parse_macro_input!(input as WrapPyFunctionArgs); - wrap_pyfunction_impl(args).unwrap_or_compile_error().into() + wrap_pyfunction_impl(args).into() } /// Returns a function that takes a `Python` instance and returns a Python module. diff --git a/src/derive_utils.rs b/src/derive_utils.rs index e949e255..2b2666bb 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -4,9 +4,7 @@ //! Functionality for the code generated by the derive backend -use crate::err::PyErr; -use crate::types::PyModule; -use crate::{PyCell, PyClass, Python}; +use crate::{types::PyModule, PyCell, PyClass, PyErr, Python}; /// Utility trait to enable &PyClass as a pymethod/function argument #[doc(hidden)] diff --git a/src/impl_.rs b/src/impl_.rs index 3a35a878..5f27a1fc 100644 --- a/src/impl_.rs +++ b/src/impl_.rs @@ -13,5 +13,7 @@ pub(crate) mod not_send; #[doc(hidden)] pub mod pyclass; #[doc(hidden)] +pub mod pyfunction; +#[doc(hidden)] pub mod pymethods; pub mod pymodule; diff --git a/src/impl_/pyfunction.rs b/src/impl_/pyfunction.rs new file mode 100644 index 00000000..f852e63b --- /dev/null +++ b/src/impl_/pyfunction.rs @@ -0,0 +1,14 @@ +use crate::{ + class::methods::PyMethodDef, derive_utils::PyFunctionArguments, types::PyCFunction, PyResult, +}; + +pub trait PyFunctionDef { + const DEF: crate::PyMethodDef; +} + +pub fn wrap_pyfunction<'a>( + method_def: PyMethodDef, + args: impl Into>, +) -> PyResult<&'a PyCFunction> { + PyCFunction::internal_new(method_def, args.into()) +} diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 3f0113c4..85bde39b 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -269,3 +269,29 @@ fn test_closure_counter() { py_assert!(py, counter_py, "counter_py() == 2"); py_assert!(py, counter_py, "counter_py() == 3"); } + +#[test] +fn use_pyfunction() { + mod function_in_module { + use pyo3::prelude::*; + + #[pyfunction] + pub fn foo(x: i32) -> i32 { + x + } + } + + Python::with_gil(|py| { + use function_in_module::foo; + + // check imported name can be wrapped + let f = wrap_pyfunction!(foo, py).unwrap(); + assert_eq!(f.call1((5,)).unwrap().extract::().unwrap(), 5); + assert_eq!(f.call1((42,)).unwrap().extract::().unwrap(), 42); + + // check path import can be wrapped + let f2 = wrap_pyfunction!(function_in_module::foo, py).unwrap(); + assert_eq!(f2.call1((5,)).unwrap().extract::().unwrap(), 5); + assert_eq!(f2.call1((42,)).unwrap().extract::().unwrap(), 42); + }) +}