pyfunction: allow wrap_pyfunction to work on imports (even cross-crate)

This commit is contained in:
David Hewitt 2022-01-06 23:21:06 +00:00
parent 2cee7feaaf
commit de8174684f
9 changed files with 89 additions and 37 deletions

View File

@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067)
- Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081)
- Add check for correct number of arguments on magic methods. [#2083](https://github.com/PyO3/pyo3/pull/2083)
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)
### Changed

View File

@ -99,10 +99,11 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt {
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
let module_name = pyfn_args.modname;
let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?;
let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
let name = &func.sig.ident;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
#module_name.add_function(#ident(#module_name)?)?;
#module_name.add_function(#name::wrap(#name::DEF, #module_name)?)?;
};
stmts.extend(statements);
}

View File

@ -9,12 +9,11 @@ use crate::{
method::{self, CallingConvention, FnArg},
pymethod::check_generic,
utils::{self, ensure_not_async_fn, get_pyo3_crate},
wrap::function_wrapper_ident,
};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{ext::IdentExt, spanned::Spanned, Ident, NestedMeta, Path, Result};
use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result};
use syn::{
parse::{Parse, ParseBuffer, ParseStream},
token::Comma,
@ -364,7 +363,7 @@ pub fn build_py_function(
mut options: PyFunctionOptions,
) -> syn::Result<TokenStream> {
options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
Ok(impl_wrap_pyfunction(ast, options)?.1)
impl_wrap_pyfunction(ast, options)
}
/// Generates python wrapper over a function that allows adding it to a python module as a python
@ -372,7 +371,7 @@ pub fn build_py_function(
pub fn impl_wrap_pyfunction(
func: &mut syn::ItemFn,
options: PyFunctionOptions,
) -> syn::Result<(Ident, TokenStream)> {
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
ensure_not_async_fn(&func.sig)?;
@ -412,7 +411,6 @@ pub fn impl_wrap_pyfunction(
.map(|attr| (&python_name, attr)),
);
let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
let krate = get_pyo3_crate(&options.krate);
let spec = method::FnSpec {
@ -434,21 +432,40 @@ pub fn impl_wrap_pyfunction(
unsafety: func.sig.unsafety,
};
let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name);
let vis = &func.vis;
let name = &func.sig.ident;
let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?;
let methoddef = spec.get_methoddef(wrapper_ident);
let wrapped_pyfunction = quote! {
#wrapper
pub(crate) fn #function_wrapper_ident<'a>(
args: impl ::std::convert::Into<#krate::derive_utils::PyFunctionArguments<'a>>
) -> #krate::PyResult<&'a #krate::types::PyCFunction> {
// Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
// will actually bring both the module and the function into scope.
#[doc(hidden)]
#vis mod #name {
use #krate as _pyo3;
_pyo3::types::PyCFunction::internal_new(#methoddef, args.into())
pub(crate) struct PyO3Def;
// Exported for `wrap_pyfunction!`
pub use _pyo3::impl_::pyfunction::wrap_pyfunction as wrap;
pub const DEF: _pyo3::PyMethodDef = <PyO3Def as _pyo3::impl_::pyfunction::PyFunctionDef>::DEF;
}
// Generate the definition inside an anonymous function in the same scope as the original function -
// this avoids complications around the fact that the generated module has a different scope
// (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
// inside a function body)
const _: () = {
use #krate as _pyo3;
impl _pyo3::impl_::pyfunction::PyFunctionDef for #name::PyO3Def {
const DEF: _pyo3::PyMethodDef = #methoddef;
}
};
};
Ok((function_wrapper_ident, wrapped_pyfunction))
Ok(wrapped_pyfunction)
}
fn type_is_pymodule(ty: &syn::Type) -> bool {

View File

@ -22,25 +22,16 @@ impl Parse for WrapPyFunctionArgs {
}
}
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> syn::Result<TokenStream> {
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> TokenStream {
let WrapPyFunctionArgs {
mut function,
function,
comma_and_arg,
} = args;
let span = function.span();
let last_segment = function
.segments
.last_mut()
.ok_or_else(|| err_spanned!(span => "expected non-empty path"))?;
last_segment.ident = function_wrapper_ident(&last_segment.ident);
let output = if let Some((_, arg)) = comma_and_arg {
quote! { #function(#arg) }
if let Some((_, arg)) = comma_and_arg {
quote! { #function::wrap(#function::DEF, #arg) }
} else {
quote! { &|arg| #function(arg) }
};
Ok(output)
quote! { &|arg| #function::wrap(#function::DEF, arg) }
}
}
pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream> {
@ -58,10 +49,6 @@ pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream
})
}
pub(crate) fn function_wrapper_ident(name: &Ident) -> Ident {
format_ident!("__pyo3_get_function_{}", name)
}
pub(crate) fn module_def_ident(name: &Ident) -> Ident {
format_ident!("__PYO3_PYMODULE_DEF_{}", name.to_string().to_uppercase())
}

View File

@ -166,7 +166,8 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {
/// A proc macro used to expose Rust functions to Python.
///
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]` options:
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]`
/// options:
///
/// | Annotation | Description |
/// | :- | :- |
@ -176,6 +177,11 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {
///
/// For more on exposing functions see the [function section of the guide][1].
///
/// Due to technical limitations on how `#[pyfunction]` is implemented, a function marked
/// `#[pyfunction]` cannot have a module with the same name in the same scope. (The
/// `#[pyfunction]` implementation generates a hidden module with the same name containing
/// metadata about the function, which is used by `wrap_pyfunction!`).
///
/// [1]: https://pyo3.rs/latest/function.html
#[proc_macro_attribute]
pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
@ -208,7 +214,7 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
#[proc_macro]
pub fn wrap_pyfunction(input: TokenStream) -> TokenStream {
let args = parse_macro_input!(input as WrapPyFunctionArgs);
wrap_pyfunction_impl(args).unwrap_or_compile_error().into()
wrap_pyfunction_impl(args).into()
}
/// Returns a function that takes a `Python` instance and returns a Python module.

View File

@ -4,9 +4,7 @@
//! Functionality for the code generated by the derive backend
use crate::err::PyErr;
use crate::types::PyModule;
use crate::{PyCell, PyClass, Python};
use crate::{types::PyModule, PyCell, PyClass, PyErr, Python};
/// Utility trait to enable &PyClass as a pymethod/function argument
#[doc(hidden)]

View File

@ -13,5 +13,7 @@ pub(crate) mod not_send;
#[doc(hidden)]
pub mod pyclass;
#[doc(hidden)]
pub mod pyfunction;
#[doc(hidden)]
pub mod pymethods;
pub mod pymodule;

14
src/impl_/pyfunction.rs Normal file
View File

@ -0,0 +1,14 @@
use crate::{
class::methods::PyMethodDef, derive_utils::PyFunctionArguments, types::PyCFunction, PyResult,
};
pub trait PyFunctionDef {
const DEF: crate::PyMethodDef;
}
pub fn wrap_pyfunction<'a>(
method_def: PyMethodDef,
args: impl Into<PyFunctionArguments<'a>>,
) -> PyResult<&'a PyCFunction> {
PyCFunction::internal_new(method_def, args.into())
}

View File

@ -269,3 +269,29 @@ fn test_closure_counter() {
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}
#[test]
fn use_pyfunction() {
mod function_in_module {
use pyo3::prelude::*;
#[pyfunction]
pub fn foo(x: i32) -> i32 {
x
}
}
Python::with_gil(|py| {
use function_in_module::foo;
// check imported name can be wrapped
let f = wrap_pyfunction!(foo, py).unwrap();
assert_eq!(f.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
// check path import can be wrapped
let f2 = wrap_pyfunction!(function_in_module::foo, py).unwrap();
assert_eq!(f2.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f2.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
})
}