diff --git a/src/callback.rs b/src/callback.rs index d81fdccc..b5252cab 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -5,11 +5,13 @@ use crate::err::{PyErr, PyResult}; use crate::exceptions::PyOverflowError; use crate::ffi::{self, Py_hash_t}; -use crate::IntoPyPointer; +use crate::panic::PanicException; +use crate::{GILPool, IntoPyPointer}; use crate::{IntoPy, PyObject, Python}; -use std::isize; +use std::any::Any; use std::os::raw::c_int; -use std::panic::UnwindSafe; +use std::panic::{AssertUnwindSafe, UnwindSafe}; +use std::{isize, panic}; /// A type which can be the return type of a python C-API callback pub trait PyCallbackOutput: Copy { @@ -234,30 +236,33 @@ macro_rules! callback_body { #[doc(hidden)] pub unsafe fn handle_panic(body: F) -> R where - F: FnOnce(Python) -> crate::PyResult + UnwindSafe, + F: FnOnce(Python) -> PyResult + UnwindSafe, R: PyCallbackOutput, { - let pool = crate::GILPool::new(); - let unwind_safe_py = std::panic::AssertUnwindSafe(pool.python()); - let result = - match std::panic::catch_unwind(move || -> crate::PyResult<_> { body(*unwind_safe_py) }) { - Ok(result) => result, - Err(e) => { - // Try to format the error in the same way panic does - if let Some(string) = e.downcast_ref::() { - Err(crate::panic::PanicException::new_err((string.clone(),))) - } else if let Some(s) = e.downcast_ref::<&str>() { - Err(crate::panic::PanicException::new_err((s.to_string(),))) - } else { - Err(crate::panic::PanicException::new_err(( - "panic from Rust code", - ))) - } - } - }; + let pool = GILPool::new(); + let unwind_safe_py = AssertUnwindSafe(pool.python()); + let panic_result = panic::catch_unwind(move || -> PyResult<_> { + let py = *unwind_safe_py; + body(py) + }); - result.unwrap_or_else(|e| { - e.restore(pool.python()); - crate::callback::callback_error() + panic_result_into_callback_output(pool.python(), panic_result) +} + +fn panic_result_into_callback_output( + py: Python, + panic_result: Result, Box>, +) -> R +where + R: PyCallbackOutput, +{ + let py_result = match panic_result { + Ok(py_result) => py_result, + Err(panic_err) => Err(PanicException::from_panic(panic_err)), + }; + + py_result.unwrap_or_else(|py_err| { + py_err.restore(py); + R::ERR_VALUE }) } diff --git a/src/panic.rs b/src/panic.rs index 4884278d..366e886a 100644 --- a/src/panic.rs +++ b/src/panic.rs @@ -1,4 +1,6 @@ use crate::exceptions::PyBaseException; +use crate::PyErr; +use std::any::Any; pyo3_exception!( " @@ -11,3 +13,16 @@ pyo3_exception!( PanicException, PyBaseException ); + +impl PanicException { + // Try to format the error in the same way panic does + pub(crate) fn from_panic(e: Box) -> PyErr { + if let Some(string) = e.downcast_ref::() { + Self::new_err((string.clone(),)) + } else if let Some(s) = e.downcast_ref::<&str>() { + Self::new_err((s.to_string(),)) + } else { + Self::new_err(("panic from Rust code",)) + } + } +}