From fbd531195aeb18eb23a679bc30946285120376ce Mon Sep 17 00:00:00 2001 From: Thomas Tanon Date: Wed, 6 Mar 2024 19:20:02 +0100 Subject: [PATCH] PyAddToModule: Properly propagate initialization error (#3919) Better than panics --- pyo3-macros-backend/src/pyclass.rs | 19 ++++++++++++-- src/impl_/pymodule.rs | 9 +------ src/types/mod.rs | 12 +++++++++ tests/test_declarative_module.rs | 41 ++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 94700457..3eca8086 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -881,11 +881,12 @@ fn impl_complex_enum( } }; - let pyclass_impls: TokenStream = vec![ + let pyclass_impls: TokenStream = [ impl_builder.impl_pyclass(ctx), impl_builder.impl_extractext(ctx), enum_into_py_impl, impl_builder.impl_pyclassimpl(ctx)?, + impl_builder.impl_add_to_module(ctx), impl_builder.impl_freelist(ctx), ] .into_iter() @@ -1372,11 +1373,12 @@ impl<'a> PyClassImplsBuilder<'a> { } fn impl_all(&self, ctx: &Ctx) -> Result { - let tokens = vec![ + let tokens = [ self.impl_pyclass(ctx), self.impl_extractext(ctx), self.impl_into_py(ctx), self.impl_pyclassimpl(ctx)?, + self.impl_add_to_module(ctx), self.impl_freelist(ctx), ] .into_iter() @@ -1625,6 +1627,19 @@ impl<'a> PyClassImplsBuilder<'a> { }) } + fn impl_add_to_module(&self, ctx: &Ctx) -> TokenStream { + let Ctx { pyo3_path } = ctx; + let cls = self.cls; + quote! { + impl #pyo3_path::impl_::pymodule::PyAddToModule for #cls { + fn add_to_module(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> { + use #pyo3_path::types::PyModuleMethods; + module.add_class::() + } + } + } + } + fn impl_freelist(&self, ctx: &Ctx) -> TokenStream { let cls = self.cls; let Ctx { pyo3_path } = ctx; diff --git a/src/impl_/pymodule.rs b/src/impl_/pymodule.rs index 9fff799c..a19cbda5 100644 --- a/src/impl_/pymodule.rs +++ b/src/impl_/pymodule.rs @@ -7,8 +7,7 @@ use portable_atomic::{AtomicI64, Ordering}; #[cfg(not(PyPy))] use crate::exceptions::PyImportError; -use crate::types::module::PyModuleMethods; -use crate::{ffi, sync::GILOnceCell, types::PyModule, Bound, Py, PyResult, PyTypeInfo, Python}; +use crate::{ffi, sync::GILOnceCell, types::PyModule, Bound, Py, PyResult, Python}; /// `Sync` wrapper of `ffi::PyModuleDef`. pub struct ModuleDef { @@ -141,12 +140,6 @@ pub trait PyAddToModule { fn add_to_module(module: &Bound<'_, PyModule>) -> PyResult<()>; } -impl PyAddToModule for T { - fn add_to_module(module: &Bound<'_, PyModule>) -> PyResult<()> { - module.add(Self::NAME, Self::type_object_bound(module.py())) - } -} - #[cfg(test)] mod tests { use std::sync::atomic::{AtomicBool, Ordering}; diff --git a/src/types/mod.rs b/src/types/mod.rs index fc74b03d..cee45e86 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -249,6 +249,18 @@ macro_rules! pyobject_native_type_info( } )? } + + impl<$($generics,)*> $crate::impl_::pymodule::PyAddToModule for $name { + fn add_to_module( + module: &$crate::Bound<'_, $crate::types::PyModule>, + ) -> $crate::PyResult<()> { + use $crate::types::PyModuleMethods; + module.add( + ::NAME, + ::type_object_bound(module.py()), + ) + } + } }; ); diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 3cb93765..9eea8e2d 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -3,6 +3,8 @@ use pyo3::create_exception; use pyo3::exceptions::PyException; use pyo3::prelude::*; +#[cfg(not(Py_LIMITED_API))] +use pyo3::types::PyBool; #[path = "../src/tests/common.rs"] mod common; @@ -127,3 +129,42 @@ fn test_declarative_module() { py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)"); }) } + +#[cfg(not(Py_LIMITED_API))] +#[pyclass(extends = PyBool)] +struct ExtendsBool; + +#[cfg(not(Py_LIMITED_API))] +#[pymodule] +mod class_initialization_module { + #[pymodule_export] + use super::ExtendsBool; +} + +#[test] +#[cfg(not(Py_LIMITED_API))] +fn test_class_initialization_fails() { + Python::with_gil(|py| { + let err = class_initialization_module::DEF + .make_module(py) + .unwrap_err(); + assert_eq!( + err.to_string(), + "RuntimeError: An error occurred while initializing class ExtendsBool" + ); + }) +} + +#[pymodule] +mod r#type { + #[pymodule_export] + use super::double; +} + +#[test] +fn test_raw_ident_module() { + Python::with_gil(|py| { + let m = pyo3::wrap_pymodule!(r#type)(py).into_bound(py); + py_assert!(py, m, "m.double(2) == 4"); + }) +}