refs #4286 -- allow setting submodule on declarative pymodules (#4301)

This commit is contained in:
Alex Gaynor 2024-07-02 07:24:47 -04:00 committed by GitHub
parent f3603a0a48
commit ccd04475a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 77 additions and 19 deletions

View File

@ -154,6 +154,8 @@ The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pycl
For nested modules, the name of the parent module is automatically added. For nested modules, the name of the parent module is automatically added.
In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested
but the `Ext` class will have for `module` the default `builtins` because it not nested. but the `Ext` class will have for `module` the default `builtins` because it not nested.
You can provide the `submodule` argument to `pymodule()` for modules that are not top-level modules.
```rust ```rust
# mod declarative_module_module_attr_test { # mod declarative_module_module_attr_test {
use pyo3::prelude::*; use pyo3::prelude::*;
@ -168,7 +170,7 @@ mod my_extension {
#[pymodule_export] #[pymodule_export]
use super::Ext; use super::Ext;
#[pymodule] #[pymodule(submodule)]
mod submodule { mod submodule {
use super::*; use super::*;
// This is a submodule // This is a submodule

View File

@ -0,0 +1 @@
allow setting `submodule` on declarative `#[pymodule]`s

View File

@ -37,6 +37,7 @@ pub mod kw {
syn::custom_keyword!(set_all); syn::custom_keyword!(set_all);
syn::custom_keyword!(signature); syn::custom_keyword!(signature);
syn::custom_keyword!(subclass); syn::custom_keyword!(subclass);
syn::custom_keyword!(submodule);
syn::custom_keyword!(text_signature); syn::custom_keyword!(text_signature);
syn::custom_keyword!(transparent); syn::custom_keyword!(transparent);
syn::custom_keyword!(unsendable); syn::custom_keyword!(unsendable);
@ -178,6 +179,7 @@ pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;
pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>; pub type NameAttribute = KeywordAttribute<kw::name, NameLitStr>;
pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>; pub type RenameAllAttribute = KeywordAttribute<kw::rename_all, RenamingRuleLitStr>;
pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>; pub type TextSignatureAttribute = KeywordAttribute<kw::text_signature, TextSignatureAttributeValue>;
pub type SubmoduleAttribute = kw::submodule;
impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> { impl<K: Parse + std::fmt::Debug, V: Parse> Parse for KeywordAttribute<K, V> {
fn parse(input: ParseStream<'_>) -> Result<Self> { fn parse(input: ParseStream<'_>) -> Result<Self> {

View File

@ -3,6 +3,7 @@
use crate::{ use crate::{
attributes::{ attributes::{
self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute, self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute,
SubmoduleAttribute,
}, },
get_doc, get_doc,
pyclass::PyClassPyO3Option, pyclass::PyClassPyO3Option,
@ -27,6 +28,7 @@ pub struct PyModuleOptions {
krate: Option<CrateAttribute>, krate: Option<CrateAttribute>,
name: Option<syn::Ident>, name: Option<syn::Ident>,
module: Option<ModuleAttribute>, module: Option<ModuleAttribute>,
is_submodule: bool,
} }
impl PyModuleOptions { impl PyModuleOptions {
@ -38,6 +40,7 @@ impl PyModuleOptions {
PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?, PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?,
PyModulePyO3Option::Crate(path) => options.set_crate(path)?, PyModulePyO3Option::Crate(path) => options.set_crate(path)?,
PyModulePyO3Option::Module(module) => options.set_module(module)?, PyModulePyO3Option::Module(module) => options.set_module(module)?,
PyModulePyO3Option::Submodule(submod) => options.set_submodule(submod)?,
} }
} }
@ -73,9 +76,22 @@ impl PyModuleOptions {
self.module = Some(name); self.module = Some(name);
Ok(()) Ok(())
} }
fn set_submodule(&mut self, submod: SubmoduleAttribute) -> Result<()> {
ensure_spanned!(
!self.is_submodule,
submod.span() => "`submodule` may only be specified once"
);
self.is_submodule = true;
Ok(())
}
} }
pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> { pub fn pymodule_module_impl(
mut module: syn::ItemMod,
mut is_submodule: bool,
) -> Result<TokenStream> {
let syn::ItemMod { let syn::ItemMod {
attrs, attrs,
vis, vis,
@ -100,6 +116,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
} else { } else {
name.to_string() name.to_string()
}; };
is_submodule = is_submodule || options.is_submodule;
let mut module_items = Vec::new(); let mut module_items = Vec::new();
let mut module_items_cfg_attrs = Vec::new(); let mut module_items_cfg_attrs = Vec::new();
@ -297,7 +314,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
) )
} }
}}; }};
let initialization = module_initialization(&name, ctx, module_def); let initialization = module_initialization(&name, ctx, module_def, is_submodule);
Ok(quote!( Ok(quote!(
#(#attrs)* #(#attrs)*
#vis mod #ident { #vis mod #ident {
@ -331,7 +348,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let vis = &function.vis; let vis = &function.vis;
let doc = get_doc(&function.attrs, None, ctx); let doc = get_doc(&function.attrs, None, ctx);
let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }); let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false);
// Module function called with optional Python<'_> marker as first arg, followed by the module. // Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new(); let mut module_args = Vec::new();
@ -396,20 +413,27 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
}) })
} }
fn module_initialization(name: &syn::Ident, ctx: &Ctx, module_def: TokenStream) -> TokenStream { fn module_initialization(
name: &syn::Ident,
ctx: &Ctx,
module_def: TokenStream,
is_submodule: bool,
) -> TokenStream {
let Ctx { pyo3_path, .. } = ctx; let Ctx { pyo3_path, .. } = ctx;
let pyinit_symbol = format!("PyInit_{}", name); let pyinit_symbol = format!("PyInit_{}", name);
let name = name.to_string(); let name = name.to_string();
let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx); let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
quote! { let mut result = quote! {
#[doc(hidden)] #[doc(hidden)]
pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name; pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
pub(super) struct MakeDef; pub(super) struct MakeDef;
#[doc(hidden)] #[doc(hidden)]
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def; pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
};
if !is_submodule {
result.extend(quote! {
/// This autogenerated function is called by the python interpreter when importing /// This autogenerated function is called by the python interpreter when importing
/// the module. /// the module.
#[doc(hidden)] #[doc(hidden)]
@ -417,7 +441,9 @@ fn module_initialization(name: &syn::Ident, ctx: &Ctx, module_def: TokenStream)
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject { pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py)) #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
} }
});
} }
result
} }
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]` /// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
@ -557,6 +583,7 @@ fn has_pyo3_module_declared<T: Parse>(
} }
enum PyModulePyO3Option { enum PyModulePyO3Option {
Submodule(SubmoduleAttribute),
Crate(CrateAttribute), Crate(CrateAttribute),
Name(NameAttribute), Name(NameAttribute),
Module(ModuleAttribute), Module(ModuleAttribute),
@ -571,6 +598,8 @@ impl Parse for PyModulePyO3Option {
input.parse().map(PyModulePyO3Option::Crate) input.parse().map(PyModulePyO3Option::Crate)
} else if lookahead.peek(attributes::kw::module) { } else if lookahead.peek(attributes::kw::module) {
input.parse().map(PyModulePyO3Option::Module) input.parse().map(PyModulePyO3Option::Module)
} else if lookahead.peek(attributes::kw::submodule) {
input.parse().map(PyModulePyO3Option::Submodule)
} else { } else {
Err(lookahead.error()) Err(lookahead.error())
} }

View File

@ -3,7 +3,7 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Span, TokenStream as TokenStream2};
use pyo3_macros_backend::{ use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType, pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType,
@ -35,10 +35,26 @@ use syn::{parse::Nothing, parse_macro_input, Item};
/// [1]: https://pyo3.rs/latest/module.html /// [1]: https://pyo3.rs/latest/module.html
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream { pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(args as Nothing);
match parse_macro_input!(input as Item) { match parse_macro_input!(input as Item) {
Item::Mod(module) => pymodule_module_impl(module), Item::Mod(module) => {
Item::Fn(function) => pymodule_function_impl(function), let is_submodule = match parse_macro_input!(args as Option<syn::Ident>) {
Some(i) if i == "submodule" => true,
Some(_) => {
return syn::Error::new(
Span::call_site(),
"#[pymodule] only accepts submodule as an argument",
)
.into_compile_error()
.into();
}
None => false,
};
pymodule_module_impl(module, is_submodule)
}
Item::Fn(function) => {
parse_macro_input!(args as Nothing);
pymodule_function_impl(function)
}
unsupported => Err(syn::Error::new_spanned( unsupported => Err(syn::Error::new_spanned(
unsupported, unsupported,
"#[pymodule] only supports modules and functions.", "#[pymodule] only supports modules and functions.",

View File

@ -49,6 +49,10 @@ create_exception!(
"Some description." "Some description."
); );
#[pymodule]
#[pyo3(submodule)]
mod external_submodule {}
/// A module written using declarative syntax. /// A module written using declarative syntax.
#[pymodule] #[pymodule]
mod declarative_module { mod declarative_module {
@ -70,6 +74,9 @@ mod declarative_module {
#[pymodule_export] #[pymodule_export]
use super::some_module::SomeException; use super::some_module::SomeException;
#[pymodule_export]
use super::external_submodule;
#[pymodule] #[pymodule]
mod inner { mod inner {
use super::*; use super::*;
@ -108,7 +115,7 @@ mod declarative_module {
} }
} }
#[pymodule] #[pymodule(submodule)]
#[pyo3(module = "custom_root")] #[pyo3(module = "custom_root")]
mod inner_custom_root { mod inner_custom_root {
use super::*; use super::*;
@ -174,6 +181,7 @@ fn test_declarative_module() {
py_assert!(py, m, "hasattr(m, 'LocatedClass')"); py_assert!(py, m, "hasattr(m, 'LocatedClass')");
py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)"); py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)");
py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)"); py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)");
py_assert!(py, m, "hasattr(m, 'external_submodule')")
}) })
} }