Merge pull request #597 from kngwyu/err-nosegv

Reguire GIL before constructing PyErr from Rust value
This commit is contained in:
Yuji Kanagawa 2019-09-28 15:11:23 +09:00 committed by GitHub
commit d860ee3f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 21 deletions

View File

@ -16,6 +16,12 @@ use std::io;
use std::os::raw::c_char;
/// Represents a `PyErr` value
///
/// **CAUTION**
///
/// When you construct an instance of `PyErrValue`, we highly recommend to use `from_err_args` method.
/// If you want to to construct `PyErrValue::ToArgs` directly, please do not forget calling
/// `Python::acquire_gil`.
pub enum PyErrValue {
None,
Value(PyObject),
@ -23,6 +29,13 @@ pub enum PyErrValue {
ToObject(Box<dyn ToPyObject>),
}
impl PyErrValue {
pub fn from_err_args<T: 'static + PyErrArguments>(value: T) -> Self {
let _ = Python::acquire_gil();
PyErrValue::ToArgs(Box::new(value))
}
}
/// Represents a Python exception that was raised.
pub struct PyErr {
/// The type of the exception. This should be either a `PyClass` or a `PyType`.
@ -417,7 +430,7 @@ macro_rules! impl_to_pyerr {
impl std::convert::From<$err> for PyErr {
fn from(err: $err) -> PyErr {
PyErr::from_value::<$pyexc>(PyErrValue::ToArgs(Box::new(err)))
PyErr::from_value::<$pyexc>(PyErrValue::from_err_args(err))
}
}
};
@ -426,34 +439,35 @@ macro_rules! impl_to_pyerr {
/// Create `OSError` from `io::Error`
impl std::convert::From<io::Error> for PyErr {
fn from(err: io::Error) -> PyErr {
macro_rules! err_value {
() => {
PyErrValue::from_err_args(err)
};
}
match err.kind() {
io::ErrorKind::BrokenPipe => {
PyErr::from_value::<exceptions::BrokenPipeError>(PyErrValue::ToArgs(Box::new(err)))
PyErr::from_value::<exceptions::BrokenPipeError>(err_value!())
}
io::ErrorKind::ConnectionRefused => {
PyErr::from_value::<exceptions::ConnectionRefusedError>(err_value!())
}
io::ErrorKind::ConnectionAborted => {
PyErr::from_value::<exceptions::ConnectionAbortedError>(err_value!())
}
io::ErrorKind::ConnectionRefused => PyErr::from_value::<
exceptions::ConnectionRefusedError,
>(PyErrValue::ToArgs(Box::new(err))),
io::ErrorKind::ConnectionAborted => PyErr::from_value::<
exceptions::ConnectionAbortedError,
>(PyErrValue::ToArgs(Box::new(err))),
io::ErrorKind::ConnectionReset => {
PyErr::from_value::<exceptions::ConnectionResetError>(PyErrValue::ToArgs(Box::new(
err,
)))
PyErr::from_value::<exceptions::ConnectionResetError>(err_value!())
}
io::ErrorKind::Interrupted => {
PyErr::from_value::<exceptions::InterruptedError>(PyErrValue::ToArgs(Box::new(err)))
PyErr::from_value::<exceptions::InterruptedError>(err_value!())
}
io::ErrorKind::NotFound => {
PyErr::from_value::<exceptions::FileNotFoundError>(err_value!())
}
io::ErrorKind::NotFound => PyErr::from_value::<exceptions::FileNotFoundError>(
PyErrValue::ToArgs(Box::new(err)),
),
io::ErrorKind::WouldBlock => {
PyErr::from_value::<exceptions::BlockingIOError>(PyErrValue::ToArgs(Box::new(err)))
PyErr::from_value::<exceptions::BlockingIOError>(err_value!())
}
io::ErrorKind::TimedOut => {
PyErr::from_value::<exceptions::TimeoutError>(PyErrValue::ToArgs(Box::new(err)))
}
_ => PyErr::from_value::<exceptions::OSError>(PyErrValue::ToArgs(Box::new(err))),
io::ErrorKind::TimedOut => PyErr::from_value::<exceptions::TimeoutError>(err_value!()),
_ => PyErr::from_value::<exceptions::OSError>(err_value!()),
}
}
}
@ -466,7 +480,7 @@ impl PyErrArguments for io::Error {
impl<W: 'static + Send + std::fmt::Debug> std::convert::From<std::io::IntoInnerError<W>> for PyErr {
fn from(err: std::io::IntoInnerError<W>) -> PyErr {
PyErr::from_value::<exceptions::OSError>(PyErrValue::ToArgs(Box::new(err)))
PyErr::from_value::<exceptions::OSError>(PyErrValue::from_err_args(err))
}
}

View File

@ -75,3 +75,18 @@ fn test_custom_error() {
"#
);
}
#[test]
fn test_exception_nosegfault() {
use std::{net::TcpListener, panic};
fn io_err() -> PyResult<()> {
TcpListener::bind("no:address")?;
Ok(())
}
fn parse_int() -> PyResult<()> {
"@_@".parse::<i64>()?;
Ok(())
}
assert!(io_err().is_err());
assert!(parse_int().is_err());
}