From 53582177fd558885dfc9c80fe74507c811b0c45e Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 29 Sep 2021 19:29:32 +0100 Subject: [PATCH] exceptions: add test coverage for all exceptions --- src/exceptions.rs | 131 ++++++++++++++++++++++++++++++++++++++++++++-- src/types/mod.rs | 4 +- 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/src/exceptions.rs b/src/exceptions.rs index 67b1a751..71c051c5 100644 --- a/src/exceptions.rs +++ b/src/exceptions.rs @@ -19,6 +19,7 @@ use std::os::raw::c_char; macro_rules! impl_exception_boilerplate { ($name: ident) => { impl ::std::convert::From<&$name> for $crate::PyErr { + #[inline] fn from(err: &$name) -> $crate::PyErr { $crate::PyErr::from_instance(err) } @@ -28,6 +29,7 @@ macro_rules! impl_exception_boilerplate { /// Creates a new [`PyErr`] of this type. /// /// [`PyErr`]: https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3" + #[inline] pub fn new_err(args: A) -> $crate::PyErr where A: $crate::PyErrArguments + ::std::marker::Send + ::std::marker::Sync + 'static, @@ -268,20 +270,20 @@ except ", $name, " as e: ``` # Example: Catching ", $name, " in Rust - + ``` use pyo3::prelude::*; use pyo3::exceptions::Py", $name, "; - + Python::with_gil(|py| { let result: PyResult<()> = py.run(\"raise ", $name, "\", None, None); - + let error_type = match result { Ok(_) => \"Not an error\", Err(error) if error.is_instance::(py) => \"" , $name, "\", Err(_) => \"Some other error\", }; - + assert_eq!(error_type, \"", $name, "\"); }); ``` @@ -578,6 +580,38 @@ impl PyUnicodeDecodeError { } } +#[cfg(test)] +macro_rules! test_exception { + ($exc_ty:ident $(, $constructor:expr)?) => { + #[allow(non_snake_case)] + #[test] + fn $exc_ty () { + use super::$exc_ty; + + $crate::Python::with_gil(|py| { + use std::error::Error; + let err: $crate::PyErr = { + None + $( + .or(Some($constructor(py))) + )? + .unwrap_or($exc_ty::new_err("a test exception")) + }; + + assert!(err.is_instance::<$exc_ty>(py)); + + let value: &$exc_ty = err.instance(py).downcast().unwrap(); + assert!(value.source().is_none()); + + err.set_cause(py, Some($crate::exceptions::PyValueError::new_err("a cause"))); + assert!(value.source().is_some()); + + assert!($crate::PyErr::from(value).is_instance::<$exc_ty>(py)); + }) + } + }; +} + /// Exceptions defined in Python's [`asyncio`](https://docs.python.org/3/library/asyncio.html) /// module. pub mod asyncio { @@ -588,6 +622,21 @@ pub mod asyncio { import_exception!(asyncio, LimitOverrunError); import_exception!(asyncio, QueueEmpty); import_exception!(asyncio, QueueFull); + + #[cfg(test)] + mod tests { + test_exception!(CancelledError); + test_exception!(InvalidStateError); + test_exception!(TimeoutError); + test_exception!(IncompleteReadError, |_| { + IncompleteReadError::new_err(("partial", "expected")) + }); + test_exception!(LimitOverrunError, |_| { + LimitOverrunError::new_err(("message", "consumed")) + }); + test_exception!(QueueEmpty); + test_exception!(QueueFull); + } } /// Exceptions defined in Python's [`socket`](https://docs.python.org/3/library/socket.html) @@ -596,11 +645,18 @@ pub mod socket { import_exception!(socket, herror); import_exception!(socket, gaierror); import_exception!(socket, timeout); + + #[cfg(test)] + mod tests { + test_exception!(herror); + test_exception!(gaierror); + test_exception!(timeout); + } } #[cfg(test)] mod tests { - use super::{PyException, PyUnicodeDecodeError}; + use super::*; use crate::types::{IntoPyDict, PyDict}; use crate::{PyErr, Python}; @@ -765,4 +821,69 @@ mod tests { ); }); } + + test_exception!(PyBaseException); + test_exception!(PyException); + test_exception!(PyStopAsyncIteration); + test_exception!(PyStopIteration); + test_exception!(PyGeneratorExit); + test_exception!(PyArithmeticError); + test_exception!(PyLookupError); + test_exception!(PyAssertionError); + test_exception!(PyAttributeError); + test_exception!(PyBufferError); + test_exception!(PyEOFError); + test_exception!(PyFloatingPointError); + test_exception!(PyOSError); + test_exception!(PyImportError); + test_exception!(PyModuleNotFoundError); + test_exception!(PyIndexError); + test_exception!(PyKeyError); + test_exception!(PyKeyboardInterrupt); + test_exception!(PyMemoryError); + test_exception!(PyNameError); + test_exception!(PyOverflowError); + test_exception!(PyRuntimeError); + test_exception!(PyRecursionError); + test_exception!(PyNotImplementedError); + test_exception!(PySyntaxError); + test_exception!(PyReferenceError); + test_exception!(PySystemError); + test_exception!(PySystemExit); + test_exception!(PyTypeError); + test_exception!(PyUnboundLocalError); + test_exception!(PyUnicodeError); + test_exception!(PyUnicodeDecodeError, |py| { + let invalid_utf8 = b"fo\xd8o"; + let err = std::str::from_utf8(invalid_utf8).expect_err("should be invalid utf8"); + PyErr::from_instance(PyUnicodeDecodeError::new_utf8(py, invalid_utf8, err).unwrap()) + }); + test_exception!(PyUnicodeEncodeError, |py: Python<'_>| { + py.eval("chr(40960).encode('ascii')", None, None) + .unwrap_err() + }); + test_exception!(PyUnicodeTranslateError, |_| { + PyUnicodeTranslateError::new_err(("\u{3042}", 0, 1, "ouch")) + }); + test_exception!(PyValueError); + test_exception!(PyZeroDivisionError); + test_exception!(PyBlockingIOError); + test_exception!(PyBrokenPipeError); + test_exception!(PyChildProcessError); + test_exception!(PyConnectionError); + test_exception!(PyConnectionAbortedError); + test_exception!(PyConnectionRefusedError); + test_exception!(PyConnectionResetError); + test_exception!(PyFileExistsError); + test_exception!(PyFileNotFoundError); + test_exception!(PyInterruptedError); + test_exception!(PyIsADirectoryError); + test_exception!(PyNotADirectoryError); + test_exception!(PyPermissionError); + test_exception!(PyProcessLookupError); + test_exception!(PyTimeoutError); + test_exception!(PyEnvironmentError); + test_exception!(PyIOError); + #[cfg(windows)] + test_exception!(PyWindowsError); } diff --git a/src/types/mod.rs b/src/types/mod.rs index cb1d91fa..817cc175 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -41,7 +41,7 @@ macro_rules! pyobject_native_type_base( fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { - let s = self.repr().map_err(|_| ::std::fmt::Error)?; + let s = self.repr().or(::std::result::Result::Err(::std::fmt::Error))?; f.write_str(&s.to_string_lossy()) } } @@ -50,7 +50,7 @@ macro_rules! pyobject_native_type_base( fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { - let s = self.str().map_err(|_| ::std::fmt::Error)?; + let s = self.str().or(::std::result::Result::Err(::std::fmt::Error))?; f.write_str(&s.to_string_lossy()) } }