diff --git a/guide/src/function.md b/guide/src/function.md index 226cfe5f..1f00d462 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -46,7 +46,7 @@ fn double(x: usize) -> usize { #[modinit(module_with_functions)] fn init_mod(py: Python, m: &PyModule) -> PyResult<()> { - add_function_to_module!(m, double, py); + m.add_function(wrap_function!(double)); Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index c6eacd58..8d6428a7 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -103,13 +103,13 @@ pub fn process_functions_in_module(ast: &mut syn::Item) { if let Some((module_name, python_name, pyfn_attrs)) = extract_pyfn_attrs(&mut item.attrs) { - let function_to_python = add_fn_to_module(item, python_name, pyfn_attrs); - let function_wrapper_ident = get_add_to_module_ident(&item.ident); + let function_to_python = add_fn_to_module(item, &python_name, pyfn_attrs); + let function_wrapper_ident = function_wrapper_ident(&item.ident); let tokens = quote! { fn block_wrapper() { #function_to_python - #function_wrapper_ident(#module_name, py); + #module_name.add_function(&#function_wrapper_ident); } }.to_string(); @@ -205,15 +205,16 @@ fn extract_pyfn_attrs( } /// Coordinates the naming of a the add-function-to-python-module function -fn get_add_to_module_ident(name: &syn::Ident) -> syn::Ident { - syn::Ident::new("__pyo3_add_to_module_".to_string() + &name.to_string()) +fn function_wrapper_ident(name: &syn::Ident) -> syn::Ident { + // Make sure this ident matches the one of wrap_function + syn::Ident::new("__pyo3_get_function_".to_string() + &name.to_string()) } /// Generates python wrapper over a function that allows adding it to a python module as a python /// function pub fn add_fn_to_module( item: &mut syn::Item, - python_name: syn::Ident, + python_name: &syn::Ident, pyfn_attrs: Vec, ) -> Tokens { let name = item.ident.clone(); @@ -241,13 +242,13 @@ pub fn add_fn_to_module( output: ty, }; - let add_to_module_ident = get_add_to_module_ident(&name); + let function_wrapper_ident = function_wrapper_ident(&name); let wrapper = function_c_wrapper(&name, &spec); let doc = utils::get_doc(&item.attrs, true); let tokens = quote! ( - fn #add_to_module_ident(module: &::pyo3::PyModule, py: ::pyo3::Python) { + fn #function_wrapper_ident(py: ::pyo3::Python) -> ::pyo3::PyObject { use std; use pyo3 as _pyo3; use pyo3::ObjectProtocol; @@ -271,7 +272,7 @@ pub fn add_fn_to_module( ) }; - module.add(stringify!(#python_name), function); + function } ); diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index c715ebfa..b3023dee 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -130,7 +130,7 @@ pub fn function(_: TokenStream, input: TokenStream) -> TokenStream { // Build the output let python_name = ast.ident.clone(); - let expanded = module::add_fn_to_module(&mut ast, python_name, Vec::new()); + let expanded = module::add_fn_to_module(&mut ast, &python_name, Vec::new()); // Return the generated impl as a TokenStream let mut tokens = Tokens::new(); diff --git a/src/lib.rs b/src/lib.rs index dc67a782..6f37764e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,13 +193,14 @@ macro_rules! cstr { ); } -/// Registers a function annotated with `#[function]` in module. -/// The first parameter is the module, the second the name of the function and the third is an -/// instance of `Python`. +/// Returns a function that takes a Python instance and returns a python function. +/// +/// Use this together with `#[function]` and [PyModule::add_function]. #[macro_export] -macro_rules! add_function_to_module( - ($modname:expr, $function_name:ident, $python:expr) => { - concat_idents!(__pyo3_add_to_module_, $function_name)($modname, $python); +macro_rules! wrap_function ( + ($function_name:ident) => { + // Make sure this ident matches the one in function_wrapper_ident + &concat_idents!(__pyo3_get_function_, $function_name) }; ); diff --git a/src/objects/module.rs b/src/objects/module.rs index 517f25cb..ae67c4cd 100644 --- a/src/objects/module.rs +++ b/src/objects/module.rs @@ -16,7 +16,6 @@ use objectprotocol::ObjectProtocol; use instance::PyObjectWithToken; use err::{PyResult, PyErr}; - /// Represents a Python `module` object. pub struct PyModule(PyObject); @@ -140,4 +139,23 @@ impl PyModule { self.setattr(T::NAME, ty) } + + /// Adds a function to a module, using the functions __name__ as name. + /// + /// Use this together with the`#[function]` and [wrap_function!] macro. + /// + /// ```rust,ignore + /// m.add_function(wrap_function!(double)); + /// ``` + /// + /// You can also add a function with a custom name using [add](PyModule::add): + /// + /// ```rust,ignore + /// m.add("also_double", wrap_function!(double)(py)); + /// ``` + pub fn add_function(&self, wrapper: &Fn(Python) -> PyObject) -> PyResult<()> { + let function = wrapper(self.py()); + let name = function.getattr(self.py(), "__name__").expect("A function must have a __name__"); + self.add(name.extract(self.py()).unwrap(), function) + } } diff --git a/tests/test_module.rs b/tests/test_module.rs index a69ab2e0..762021fb 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -37,7 +37,8 @@ fn init_mod(py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar"); - add_function_to_module!(m, double, py); + m.add_function(wrap_function!(double)); + m.add("also_double", wrap_function!(double)(py)); Ok(()) } @@ -56,4 +57,5 @@ fn test_module_with_functions() { py.run("assert module_with_functions.foo == 'bar'", None, Some(d)).unwrap(); py.run("assert module_with_functions.EmptyClass != None", None, Some(d)).unwrap(); py.run("assert module_with_functions.double(3) == 6", None, Some(d)).unwrap(); + py.run("assert module_with_functions.also_double(3) == 6", None, Some(d)).unwrap(); }