From ad76a8a5ce7ae72c7b17dd2121be5feb8b62947e Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Fri, 7 Aug 2020 13:31:17 +0100 Subject: [PATCH] Change unsendable test to use Rust thread --- tests/test_class_basics.rs | 104 ++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 7d19039f..649944a1 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; -use pyo3::py_run; +use pyo3::types::PyType; +use pyo3::{py_run, PyClass}; mod common; @@ -166,59 +167,76 @@ fn class_with_object_field() { #[pyclass(unsendable)] struct UnsendableBase { - rc: std::rc::Rc, + value: std::rc::Rc, } #[pymethods] impl UnsendableBase { + #[new] + fn new(value: usize) -> UnsendableBase { + Self { + value: std::rc::Rc::new(value), + } + } + + #[getter] fn value(&self) -> usize { - *self.rc.as_ref() + *self.value } } #[pyclass(extends=UnsendableBase)] struct UnsendableChild {} +#[pymethods] +impl UnsendableChild { + #[new] + fn new(value: usize) -> (UnsendableChild, UnsendableBase) { + (UnsendableChild {}, UnsendableBase::new(value)) + } +} + +fn test_unsendable() -> PyResult<()> { + let obj = std::thread::spawn(|| -> PyResult<_> { + Python::with_gil(|py| { + let obj: Py = PyType::new::(py).call1((5,))?.extract()?; + + // Accessing the value inside this thread should not panic + let caught_panic = + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| -> PyResult<_> { + assert_eq!(obj.as_ref(py).getattr("value")?.extract::()?, 5); + Ok(()) + })) + .is_err(); + + assert_eq!(caught_panic, false); + Ok(obj) + }) + }) + .join() + .unwrap()?; + + // This access must panic + Python::with_gil(|py| { + obj.borrow(py); + }); + + panic!("Borrowing unsendable from receiving thread did not panic."); +} + /// If a class is marked as `unsendable`, it panics when accessed by another thread. #[test] -fn panic_unsendable() { - if option_env!("RUSTFLAGS") - .map(|s| s.contains("-Cpanic=abort")) - .unwrap_or(false) - { - return; - } - - let gil = Python::acquire_gil(); - let py = gil.python(); - let base = || UnsendableBase { - rc: std::rc::Rc::new(0), - }; - let unsendable_base = PyCell::new(py, base()).unwrap(); - let unsendable_child = PyCell::new(py, (UnsendableChild {}, base())).unwrap(); - - let source = pyo3::indoc::indoc!( - r#" -def value(): - return unsendable.value() - -import concurrent.futures -executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) -future = executor.submit(value) -try: - result = future.result() - assert False, 'future must panic' -except BaseException as e: - assert str(e) == 'test_class_basics::UnsendableBase is unsendable, but sent to another thread!' -"# - ); - let globals = PyModule::import(py, "__main__").unwrap().dict(); - let test = |unsendable| { - globals.set_item("unsendable", unsendable).unwrap(); - py.run(source, Some(globals), None) - .map_err(|e| e.print(py)) - .unwrap(); - }; - test(unsendable_base.as_ref()); - test(unsendable_child.as_ref()); +#[should_panic( + expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread!" +)] +fn panic_unsendable_base() { + test_unsendable::().unwrap(); +} + +#[test] +#[should_panic( + expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread!" +)] +fn panic_unsendable_child() { + test_unsendable::().unwrap(); }