diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 7528ef81..7173fa86 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -115,19 +115,10 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { for item in &mut *items { match item { Item::Use(item_use) => { - let mut is_pyo3 = false; - item_use.attrs.retain(|attr| { - let found = attr.path().is_ident("pymodule_export"); - is_pyo3 |= found; - !found - }); - if is_pyo3 { - let cfg_attrs = item_use - .attrs - .iter() - .filter(|attr| attr.path().is_ident("cfg")) - .cloned() - .collect::>(); + let is_pymodule_export = + find_and_remove_attribute(&mut item_use.attrs, "pymodule_export"); + if is_pymodule_export { + let cfg_attrs = get_cfg_attributes(&item_use.attrs); extract_use_items( &item_use.tree, &cfg_attrs, @@ -137,23 +128,116 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } } Item::Fn(item_fn) => { - let mut is_module_init = false; - item_fn.attrs.retain(|attr| { - let found = attr.path().is_ident("pymodule_init"); - is_module_init |= found; - !found - }); - if is_module_init { - ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified"); - let ident = &item_fn.sig.ident; + ensure_spanned!( + !has_attribute(&item_fn.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + let is_pymodule_init = + find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init"); + let ident = &item_fn.sig.ident; + if is_pymodule_init { + ensure_spanned!( + !has_attribute(&item_fn.attrs, "pyfunction"), + item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`" + ); + ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified"); pymodule_init = Some(quote! { #ident(module)?; }); - } else { - bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]") + } else if has_attribute(&item_fn.attrs, "pyfunction") { + module_items.push(ident.clone()); + module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs)); } } - item => { - bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]") + Item::Struct(item_struct) => { + ensure_spanned!( + !has_attribute(&item_struct.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + if has_attribute(&item_struct.attrs, "pyclass") { + module_items.push(item_struct.ident.clone()); + module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs)); + } } + Item::Enum(item_enum) => { + ensure_spanned!( + !has_attribute(&item_enum.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + if has_attribute(&item_enum.attrs, "pyclass") { + module_items.push(item_enum.ident.clone()); + module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs)); + } + } + Item::Mod(item_mod) => { + ensure_spanned!( + !has_attribute(&item_mod.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + if has_attribute(&item_mod.attrs, "pymodule") { + module_items.push(item_mod.ident.clone()); + module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs)); + } + } + Item::ForeignMod(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Trait(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Const(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Static(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Macro(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::ExternCrate(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Impl(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::TraitAlias(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Type(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + Item::Union(item) => { + ensure_spanned!( + !has_attribute(&item.attrs, "pymodule_export"), + item.span() => "`#[pymodule_export]` may only be used on `use` statements" + ); + } + _ => (), } } @@ -355,6 +439,31 @@ fn get_pyfn_attr(attrs: &mut Vec) -> syn::Result Vec { + attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .cloned() + .collect() +} + +fn find_and_remove_attribute(attrs: &mut Vec, ident: &str) -> bool { + let mut found = false; + attrs.retain(|attr| { + if attr.path().is_ident(ident) { + found = true; + false + } else { + true + } + }); + found +} + +fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { + attrs.iter().any(|attr| attr.path().is_ident(ident)) +} + enum PyModulePyO3Option { Crate(CrateAttribute), Name(NameAttribute), diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 86913d9b..3cb93765 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -15,8 +15,8 @@ struct ValueClass { #[pymethods] impl ValueClass { #[new] - fn new(value: usize) -> ValueClass { - ValueClass { value } + fn new(value: usize) -> Self { + Self { value } } } @@ -48,6 +48,33 @@ mod declarative_module { #[pymodule_export] use super::{declarative_module2, double, MyError, ValueClass as Value}; + #[pymodule] + mod inner { + use super::*; + + #[pyfunction] + fn triple(x: usize) -> usize { + x * 3 + } + + #[pyclass] + struct Struct; + + #[pymethods] + impl Struct { + #[new] + fn new() -> Self { + Self + } + } + + #[pyclass] + enum Enum { + A, + B, + } + } + #[pymodule_init] fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("double2", m.getattr("double")?) @@ -65,7 +92,6 @@ mod declarative_submodule { use super::{double, double_value}; } -/// A module written using declarative syntax. #[pymodule] #[pyo3(name = "declarative_module_renamed")] mod declarative_module2 { @@ -84,7 +110,7 @@ fn test_declarative_module() { ); py_assert!(py, m, "m.double(2) == 4"); - py_assert!(py, m, "m.double2(3) == 6"); + py_assert!(py, m, "m.inner.triple(3) == 9"); py_assert!(py, m, "m.declarative_submodule.double(4) == 8"); py_assert!( py, @@ -97,5 +123,7 @@ fn test_declarative_module() { py_assert!(py, m, "not hasattr(m, 'LocatedClass')"); #[cfg(not(Py_LIMITED_API))] py_assert!(py, m, "hasattr(m, 'LocatedClass')"); + py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)"); + py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)"); }) } diff --git a/tests/ui/invalid_pymodule_trait.stderr b/tests/ui/invalid_pymodule_trait.stderr index 3ed12861..4b02f14a 100644 --- a/tests/ui/invalid_pymodule_trait.stderr +++ b/tests/ui/invalid_pymodule_trait.stderr @@ -1,4 +1,4 @@ -error: only 'use' statements and and pymodule_init functions are allowed in #[pymodule] +error: `#[pymodule_export]` may only be used on `use` statements --> tests/ui/invalid_pymodule_trait.rs:5:5 | 5 | #[pymodule_export] diff --git a/tests/ui/invalid_pymodule_two_pymodule_init.stderr b/tests/ui/invalid_pymodule_two_pymodule_init.stderr index 9f0900f9..c117ebd5 100644 --- a/tests/ui/invalid_pymodule_two_pymodule_init.stderr +++ b/tests/ui/invalid_pymodule_two_pymodule_init.stderr @@ -1,4 +1,4 @@ -error: only one pymodule_init may be specified +error: only one `#[pymodule_init]` may be specified --> tests/ui/invalid_pymodule_two_pymodule_init.rs:11:5 | 11 | fn init2(m: &PyModule) -> PyResult<()> {