diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c838ce1..f660d7ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add support for `#[pyclass(extends=Exception)]`. [#1591](https://github.com/PyO3/pyo3/pull/1591) - Add support for extracting `PathBuf` from `pathlib.Path`. [#1654](https://github.com/PyO3/pyo3/pull/1654) - Add `#[pyo3(text_signature = "...")]` syntax for setting text signature. [#1658](https://github.com/PyO3/pyo3/pull/1658) +- Add support for setting and retrieving exception cause. [#1679](https://github.com/PyO3/pyo3/pull/1679) ### Changed diff --git a/src/err/mod.rs b/src/err/mod.rs index 367f9064..1634ee2f 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -390,6 +390,28 @@ impl PyErr { PyErr::from_state(PyErrState::Normalized(self.normalized(py).clone())) } + /// Return the cause (either an exception instance, or None, set by `raise ... from ...`) + /// associated with the exception, as accessible from Python through `__cause__`. + pub fn cause(&self, py: Python) -> Option { + let ptr = unsafe { ffi::PyException_GetCause(self.pvalue(py).as_ptr()) }; + let obj = unsafe { py.from_owned_ptr_or_opt::(ptr) }; + obj.map(|x| Self::from_instance(x)) + } + + /// Set the cause associated with the exception, pass `None` to clear it. + pub fn set_cause(&self, py: Python, cause: Option) { + if let Some(cause) = cause { + let cause = cause.into_instance(py); + unsafe { + ffi::PyException_SetCause(self.pvalue(py).as_ptr(), cause.as_ptr()); + } + } else { + unsafe { + ffi::PyException_SetCause(self.pvalue(py).as_ptr(), std::ptr::null_mut()); + } + } + } + fn from_state(state: PyErrState) -> PyErr { PyErr { state: UnsafeCell::new(Some(state)), @@ -627,4 +649,36 @@ mod tests { is_send::(); is_sync::(); } + + #[test] + fn test_pyerr_cause() { + Python::with_gil(|py| { + let err = py + .run("raise Exception('banana')", None, None) + .expect_err("raising should have given us an error"); + assert!(err.cause(py).is_none()); + + let err = py + .run( + "raise Exception('banana') from Exception('apple')", + None, + None, + ) + .expect_err("raising should have given us an error"); + let cause = err + .cause(py) + .expect("raising from should have given us a cause"); + assert_eq!(cause.to_string(), "Exception: apple"); + + err.set_cause(py, None); + assert!(err.cause(py).is_none()); + + let new_cause = exceptions::PyValueError::new_err("orange"); + err.set_cause(py, Some(new_cause)); + let cause = err + .cause(py) + .expect("set_cause should have given us a cause"); + assert_eq!(cause.to_string(), "ValueError: orange"); + }); + } }