From 92b724f64f65c6dc651e687274b5af9343cea785 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 11 Jul 2023 22:13:21 +0100 Subject: [PATCH] normalize exception in `PyErr::matches` and `PyErr::get_type` --- newsfragments/3313.fixed.md | 1 + src/err/err_state.rs | 16 +++++++++---- src/err/mod.rs | 47 ++++++++++++++++++------------------- 3 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 newsfragments/3313.fixed.md diff --git a/newsfragments/3313.fixed.md b/newsfragments/3313.fixed.md new file mode 100644 index 00000000..e5fc5c0e --- /dev/null +++ b/newsfragments/3313.fixed.md @@ -0,0 +1 @@ +Fix case where `PyErr::matches` and `PyErr::is_instance` returned results inconsistent with `PyErr::get_type`. diff --git a/src/err/err_state.rs b/src/err/err_state.rs index bf4fb3fd..50a17fda 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -68,11 +68,17 @@ impl PyErrState { ) } } - PyErrState::LazyValue { ptype, pvalue } => ( - ptype.into_ptr(), - pvalue(py).into_ptr(), - std::ptr::null_mut(), - ), + PyErrState::LazyValue { ptype, pvalue } => { + if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 { + Self::exceptions_must_derive_from_base_exception(py).into_ffi_tuple(py) + } else { + ( + ptype.into_ptr(), + pvalue(py).into_ptr(), + std::ptr::null_mut(), + ) + } + } PyErrState::FfiTuple { ptype, pvalue, diff --git a/src/err/mod.rs b/src/err/mod.rs index 3af7b92b..f90b9995 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -138,10 +138,6 @@ impl PyErr { where A: PyErrArguments + Send + Sync + 'static, { - if unsafe { ffi::PyExceptionClass_Check(ty.as_ptr()) } == 0 { - return exceptions_must_derive_from_base_exception(ty.py()); - } - PyErr::from_state(PyErrState::LazyValue { ptype: ty.into(), pvalue: boxed_args(args), @@ -438,16 +434,13 @@ impl PyErr { where T: ToPyObject, { - fn inner(err: &PyErr, py: Python<'_>, exc: PyObject) -> bool { - (unsafe { ffi::PyErr_GivenExceptionMatches(err.type_ptr(py), exc.as_ptr()) }) != 0 - } - inner(self, py, exc.to_object(py)) + self.is_instance(py, exc.to_object(py).as_ref(py)) } /// Returns true if the current exception is instance of `T`. #[inline] pub fn is_instance(&self, py: Python<'_>, ty: &PyAny) -> bool { - (unsafe { ffi::PyErr_GivenExceptionMatches(self.type_ptr(py), ty.as_ptr()) }) != 0 + (unsafe { ffi::PyErr_GivenExceptionMatches(self.get_type(py).as_ptr(), ty.as_ptr()) }) != 0 } /// Returns true if the current exception is instance of `T`. @@ -630,19 +623,6 @@ impl PyErr { } } - /// Returns borrowed reference to this Err's type - fn type_ptr(&self, py: Python<'_>) -> *mut ffi::PyObject { - match unsafe { &*self.state.get() } { - // In lazy type case, normalize before returning ptype in case the type is not a valid - // exception type. - Some(PyErrState::LazyTypeAndValue { .. }) => self.normalized(py).ptype.as_ptr(), - Some(PyErrState::LazyValue { ptype, .. }) => ptype.as_ptr(), - Some(PyErrState::FfiTuple { ptype, .. }) => ptype.as_ptr(), - Some(PyErrState::Normalized(n)) => n.ptype.as_ptr(), - None => panic!("Cannot access exception type while normalizing"), - } - } - #[inline] fn normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { if let Some(PyErrState::Normalized(n)) = unsafe { @@ -822,8 +802,8 @@ fn exceptions_must_derive_from_base_exception(py: Python<'_>) -> PyErr { #[cfg(test)] mod tests { use super::PyErrState; - use crate::exceptions; - use crate::{PyErr, Python}; + use crate::exceptions::{self, PyTypeError, PyValueError}; + use crate::{PyErr, PyTypeInfo, Python}; #[test] fn no_error() { @@ -937,6 +917,25 @@ mod tests { is_sync::(); } + #[test] + fn test_pyerr_matches() { + Python::with_gil(|py| { + let err = PyErr::new::("foo"); + assert!(err.matches(py, PyValueError::type_object(py))); + + assert!(err.matches( + py, + (PyValueError::type_object(py), PyTypeError::type_object(py)) + )); + + assert!(!err.matches(py, PyTypeError::type_object(py))); + + // String is not a valid exception class, so we should get a TypeError + let err: PyErr = PyErr::from_type(crate::types::PyString::type_object(py), "foo"); + assert!(err.matches(py, PyTypeError::type_object(py))); + }) + } + #[test] fn test_pyerr_cause() { Python::with_gil(|py| {