diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 3a369c6e..c84d6784 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError; use pyo3::class::PyVisit; use pyo3::prelude::*; use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto}; +use std::cell::Cell; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -368,6 +369,144 @@ fn tries_gil_in_traverse() { }) } +#[pyclass] +struct HijackedTraverse { + traversed: Cell, + hijacked: Cell, +} + +impl HijackedTraverse { + fn new() -> Self { + Self { + traversed: Cell::new(false), + hijacked: Cell::new(false), + } + } + + fn traversed_and_hijacked(&self) -> (bool, bool) { + (self.traversed.get(), self.hijacked.get()) + } +} + +#[pymethods] +impl HijackedTraverse { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.traversed.set(true); + Ok(()) + } +} + +trait Traversable { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>; +} + +impl<'a> Traversable for PyRef<'a, HijackedTraverse> { + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.hijacked.set(true); + Ok(()) + } +} + +#[test] +fn traverse_cannot_be_hijacked() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + let cell = PyCell::new(py, HijackedTraverse::new()).unwrap(); + let obj = cell.to_object(py); + assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false)); + traverse(obj.as_ptr(), novisit, std::ptr::null_mut()); + assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false)); + }) +} + +#[allow(dead_code)] +#[pyclass] +struct DropDuringTraversal { + cycle: Cell>>, + dropped: TestDropCall, +} + +#[pymethods] +impl DropDuringTraversal { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.cycle.take(); + Ok(()) + } + + fn __clear__(&mut self) { + self.cycle.take(); + } +} + +#[test] +fn drop_during_traversal_with_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + drop(inst); + }); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + +#[test] +fn drop_during_traversal_without_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + let inst = Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + inst + }); + + drop(inst); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + // Manual traversal utilities unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { @@ -389,4 +528,3 @@ extern "C" fn visit_error( ) -> std::os::raw::c_int { -1 } -