diff --git a/newsfragments/3905.changed.md b/newsfragments/3905.changed.md new file mode 100644 index 00000000..917584eb --- /dev/null +++ b/newsfragments/3905.changed.md @@ -0,0 +1 @@ +The `#[pymodule]` macro now supports module functions that take a single argument as a `&Bound<'_, PyModule>`. diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 40cf34f2..0bdb1cbc 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -201,6 +201,14 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result let doc = get_doc(&function.attrs, None); let initialization = module_initialization(options, ident); + + // Module function called with optional Python<'_> marker as first arg, followed by the module. + let mut module_args = Vec::new(); + if function.sig.inputs.len() == 2 { + module_args.push(quote!(module.py())); + } + module_args.push(quote!(::std::convert::Into::into(BoundRef(module)))); + Ok(quote! { #function #vis mod #ident { @@ -218,7 +226,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result use #krate::impl_::pymethods::BoundRef; fn __pyo3_pymodule(module: &#krate::Bound<'_, #krate::types::PyModule>) -> #krate::PyResult<()> { - #ident(module.py(), ::std::convert::Into::into(BoundRef(module))) + #ident(#(#module_args),*) } impl #ident::MakeDef { diff --git a/src/tests/hygiene/pyfunction.rs b/src/tests/hygiene/pyfunction.rs index 19fe2739..9cfad0db 100644 --- a/src/tests/hygiene/pyfunction.rs +++ b/src/tests/hygiene/pyfunction.rs @@ -14,3 +14,11 @@ fn invoke_wrap_pyfunction() { crate::py_run!(py, func, r#"func(5)"#); }); } + +#[test] +fn invoke_wrap_pyfunction_bound() { + crate::Python::with_gil(|py| { + let func = crate::wrap_pyfunction_bound!(do_something, py).unwrap(); + crate::py_run!(py, func, r#"func(5)"#); + }); +} diff --git a/src/tests/hygiene/pymodule.rs b/src/tests/hygiene/pymodule.rs index 0b37c440..bb49d382 100644 --- a/src/tests/hygiene/pymodule.rs +++ b/src/tests/hygiene/pymodule.rs @@ -21,3 +21,18 @@ fn my_module(_py: crate::Python<'_>, m: &crate::types::PyModule) -> crate::PyRes ::std::result::Result::Ok(()) } + +#[crate::pymodule] +#[pyo3(crate = "crate")] +fn my_module_bound(m: &crate::Bound<'_, crate::types::PyModule>) -> crate::PyResult<()> { + as crate::types::PyModuleMethods>::add_function( + m, + crate::wrap_pyfunction_bound!(do_something, m)?, + )?; + as crate::types::PyModuleMethods>::add_wrapped( + m, + crate::wrap_pymodule!(foo), + )?; + + ::std::result::Result::Ok(()) +} diff --git a/tests/test_no_imports.rs b/tests/test_no_imports.rs index 88932ed2..69f4b6e4 100644 --- a/tests/test_no_imports.rs +++ b/tests/test_no_imports.rs @@ -22,6 +22,18 @@ fn basic_module(_py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyRes Ok(()) } +#[pyo3::pymodule] +fn basic_module_bound(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> { + #[pyfn(m)] + fn answer() -> usize { + 42 + } + + m.add_function(pyo3::wrap_pyfunction_bound!(basic_function, m)?)?; + + Ok(()) +} + #[pyo3::pyclass] struct BasicClass { #[pyo3(get)]