From f5fee94afcaf11edceceed45192aabd71aeb9415 Mon Sep 17 00:00:00 2001 From: Bruno Kolenbrander <59372212+mejrs@users.noreply.github.com> Date: Tue, 23 Apr 2024 20:01:41 +0200 Subject: [PATCH] Scope macro imports more tightly (#4088) --- pyo3-macros-backend/src/module.rs | 11 ++++++----- pytests/src/enums.rs | 4 +++- tests/test_no_imports.rs | 5 ++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 3153279a..626cde12 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -382,10 +382,7 @@ fn module_initialization(options: PyModuleOptions, ident: &syn::Ident) -> TokenS fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> { let ctx = &Ctx::new(&options.krate); let Ctx { pyo3_path } = ctx; - let mut stmts: Vec = vec![syn::parse_quote!( - #[allow(unknown_lints, unused_imports, redundant_imports)] - use #pyo3_path::{PyNativeType, types::PyModuleMethods}; - )]; + let mut stmts: Vec = Vec::new(); for mut stmt in func.block.stmts.drain(..) { if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt { @@ -395,7 +392,11 @@ fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn let name = &func.sig.ident; let statements: Vec = syn::parse_quote! { #wrapped_function - #module_name.as_borrowed().add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?; + { + #[allow(unknown_lints, unused_imports, redundant_imports)] + use #pyo3_path::{PyNativeType, types::PyModuleMethods}; + #module_name.as_borrowed().add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?; + } }; stmts.extend(statements); } diff --git a/pytests/src/enums.rs b/pytests/src/enums.rs index 4bb269fb..0a1bc49b 100644 --- a/pytests/src/enums.rs +++ b/pytests/src/enums.rs @@ -1,5 +1,7 @@ use pyo3::{ - pyclass, pyfunction, pymodule, types::PyModule, wrap_pyfunction_bound, Bound, PyResult, + pyclass, pyfunction, pymodule, + types::{PyModule, PyModuleMethods}, + wrap_pyfunction_bound, Bound, PyResult, }; #[pymodule] diff --git a/tests/test_no_imports.rs b/tests/test_no_imports.rs index 35c978b0..022d61e0 100644 --- a/tests/test_no_imports.rs +++ b/tests/test_no_imports.rs @@ -30,7 +30,10 @@ fn basic_module_bound(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyRes 42 } - m.add_function(pyo3::wrap_pyfunction_bound!(basic_function, m)?)?; + pyo3::types::PyModuleMethods::add_function( + m, + pyo3::wrap_pyfunction_bound!(basic_function, m)?, + )?; Ok(()) }