From 9afc38ae416bb750efd227e4f8b4302392c0d303 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 5 Jul 2024 05:16:06 -0400 Subject: [PATCH] fixes #4285 -- allow full-path to pymodule with nested declarative modules (#4288) --- newsfragments/4288.fixed.md | 1 + pyo3-macros-backend/src/module.rs | 88 ++++++++++++++++++++++++++++--- tests/test_declarative_module.rs | 11 ++++ 3 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 newsfragments/4288.fixed.md diff --git a/newsfragments/4288.fixed.md b/newsfragments/4288.fixed.md new file mode 100644 index 00000000..105bb042 --- /dev/null +++ b/newsfragments/4288.fixed.md @@ -0,0 +1 @@ +allow `#[pyo3::prelude::pymodule]` with nested declarative modules diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index faa7032d..2ca084a6 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -8,7 +8,7 @@ use crate::{ get_doc, pyclass::PyClassPyO3Option, pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, - utils::{Ctx, LitCStr}, + utils::{Ctx, LitCStr, PyO3CratePath}, }; use proc_macro2::{Span, TokenStream}; use quote::quote; @@ -183,7 +183,18 @@ pub fn pymodule_module_impl( ); ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified"); pymodule_init = Some(quote! { #ident(module)?; }); - } else if has_attribute(&item_fn.attrs, "pyfunction") { + } else if has_attribute(&item_fn.attrs, "pyfunction") + || has_attribute_with_namespace( + &item_fn.attrs, + Some(pyo3_path), + &["pyfunction"], + ) + || has_attribute_with_namespace( + &item_fn.attrs, + Some(pyo3_path), + &["prelude", "pyfunction"], + ) + { module_items.push(ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs)); } @@ -193,7 +204,18 @@ pub fn pymodule_module_impl( !has_attribute(&item_struct.attrs, "pymodule_export"), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_struct.attrs, "pyclass") { + if has_attribute(&item_struct.attrs, "pyclass") + || has_attribute_with_namespace( + &item_struct.attrs, + Some(pyo3_path), + &["pyclass"], + ) + || has_attribute_with_namespace( + &item_struct.attrs, + Some(pyo3_path), + &["prelude", "pyclass"], + ) + { module_items.push(item_struct.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs)); if !has_pyo3_module_declared::( @@ -210,7 +232,14 @@ pub fn pymodule_module_impl( !has_attribute(&item_enum.attrs, "pymodule_export"), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_enum.attrs, "pyclass") { + if has_attribute(&item_enum.attrs, "pyclass") + || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"]) + || has_attribute_with_namespace( + &item_enum.attrs, + Some(pyo3_path), + &["prelude", "pyclass"], + ) + { module_items.push(item_enum.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs)); if !has_pyo3_module_declared::( @@ -227,7 +256,14 @@ pub fn pymodule_module_impl( !has_attribute(&item_mod.attrs, "pymodule_export"), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_mod.attrs, "pymodule") { + if has_attribute(&item_mod.attrs, "pymodule") + || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"]) + || has_attribute_with_namespace( + &item_mod.attrs, + Some(pyo3_path), + &["prelude", "pymodule"], + ) + { module_items.push(item_mod.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs)); if !has_pyo3_module_declared::( @@ -555,8 +591,48 @@ fn find_and_remove_attribute(attrs: &mut Vec, ident: &str) -> bo found } +enum IdentOrStr<'a> { + Str(&'a str), + Ident(syn::Ident), +} + +impl<'a> PartialEq for IdentOrStr<'a> { + fn eq(&self, other: &syn::Ident) -> bool { + match self { + IdentOrStr::Str(s) => other == s, + IdentOrStr::Ident(i) => other == i, + } + } +} fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { - attrs.iter().any(|attr| attr.path().is_ident(ident)) + has_attribute_with_namespace(attrs, None, &[ident]) +} + +fn has_attribute_with_namespace( + attrs: &[syn::Attribute], + crate_path: Option<&PyO3CratePath>, + idents: &[&str], +) -> bool { + let mut segments = vec![]; + if let Some(c) = crate_path { + match c { + PyO3CratePath::Given(paths) => { + for p in &paths.segments { + segments.push(IdentOrStr::Ident(p.ident.clone())); + } + } + PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")), + } + }; + for i in idents { + segments.push(IdentOrStr::Str(i)); + } + + attrs.iter().any(|attr| { + segments + .iter() + .eq(attr.path().segments.iter().map(|v| &v.ident)) + }) } fn set_module_attribute(attrs: &mut Vec, module_name: &str) { diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 0bf426a5..f62d5182 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -124,6 +124,9 @@ mod declarative_module { struct Struct; } + #[pyo3::prelude::pymodule] + mod full_path_inner {} + #[pymodule_init] fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("double2", m.getattr("double")?) @@ -247,3 +250,11 @@ fn test_module_names() { ); }) } + +#[test] +fn test_inner_module_full_path() { + Python::with_gil(|py| { + let m = declarative_module(py); + py_assert!(py, m, "m.full_path_inner"); + }) +}