Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread
Adds a "threadsafe" variant of `PyCell::try_borrow` which will fail instead of panicking if called on the wrong thread and use it in `call_traverse` to turn GC traversals of unsendable pyclasses into no-ops if on the wrong thread. This can imply leaking the underlying resource if the originator thread has already exited so that the GC will never run there again, but it does avoid hard aborts as we cannot raise an exception from within `call_traverse`.
This commit is contained in:
parent
65f25d4133
commit
4dc6c1643e
|
@ -16,7 +16,7 @@
|
|||
| `set_all` | Generates setters for all fields of the pyclass. |
|
||||
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
|
||||
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
|
||||
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
|
||||
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
|
||||
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
|
||||
|
||||
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.
|
|
@ -1013,6 +1013,7 @@ impl<T> PyClassNewTextSignature<T> for &'_ PyClassImplCollector<T> {
|
|||
#[doc(hidden)]
|
||||
pub trait PyClassThreadChecker<T>: Sized {
|
||||
fn ensure(&self);
|
||||
fn check(&self) -> bool;
|
||||
fn can_drop(&self, py: Python<'_>) -> bool;
|
||||
fn new() -> Self;
|
||||
private_decl! {}
|
||||
|
@ -1028,6 +1029,9 @@ pub struct SendablePyClass<T: Send>(PhantomData<T>);
|
|||
|
||||
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
|
||||
fn ensure(&self) {}
|
||||
fn check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn can_drop(&self, _py: Python<'_>) -> bool {
|
||||
true
|
||||
}
|
||||
|
@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
|
|||
);
|
||||
}
|
||||
|
||||
fn check(&self) -> bool {
|
||||
thread::current().id() == self.0
|
||||
}
|
||||
|
||||
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
|
||||
if thread::current().id() != self.0 {
|
||||
PyRuntimeError::new_err(format!(
|
||||
|
@ -1071,6 +1079,9 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
|
|||
fn ensure(&self) {
|
||||
self.ensure(std::any::type_name::<T>());
|
||||
}
|
||||
fn check(&self) -> bool {
|
||||
self.check()
|
||||
}
|
||||
fn can_drop(&self, py: Python<'_>) -> bool {
|
||||
self.can_drop(py, std::any::type_name::<T>())
|
||||
}
|
||||
|
|
|
@ -269,7 +269,7 @@ where
|
|||
|
||||
let py = Python::assume_gil_acquired();
|
||||
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
|
||||
let borrow = slf.try_borrow();
|
||||
let borrow = slf.try_borrow_threadsafe();
|
||||
let visit = PyVisit::from_raw(visit, arg, py);
|
||||
|
||||
let retval = if let Ok(borrow) = borrow {
|
||||
|
|
|
@ -351,6 +351,14 @@ impl<T: PyClass> PyCell<T> {
|
|||
.map(|_| PyRef { inner: self })
|
||||
}
|
||||
|
||||
/// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
|
||||
pub(crate) fn try_borrow_threadsafe(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
|
||||
self.check_threadsafe()?;
|
||||
self.borrow_checker()
|
||||
.try_borrow()
|
||||
.map(|_| PyRef { inner: self })
|
||||
}
|
||||
|
||||
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
|
||||
/// This borrow lasts as long as the returned `PyRefMut` exists.
|
||||
///
|
||||
|
@ -975,6 +983,7 @@ impl From<PyBorrowMutError> for PyErr {
|
|||
#[doc(hidden)]
|
||||
pub trait PyCellLayout<T>: PyLayout<T> {
|
||||
fn ensure_threadsafe(&self);
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
|
||||
/// Implementation of tp_dealloc.
|
||||
/// # Safety
|
||||
/// - slf must be a valid pointer to an instance of a T or a subclass.
|
||||
|
@ -988,6 +997,9 @@ where
|
|||
T: PyTypeInfo,
|
||||
{
|
||||
fn ensure_threadsafe(&self) {}
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||
Ok(())
|
||||
}
|
||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||
let type_obj = T::type_object_raw(py);
|
||||
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
|
||||
|
@ -1025,6 +1037,12 @@ where
|
|||
self.contents.thread_checker.ensure();
|
||||
self.ob_base.ensure_threadsafe();
|
||||
}
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||
if !self.contents.thread_checker.check() {
|
||||
return Err(PyBorrowError { _private: () });
|
||||
}
|
||||
self.ob_base.check_threadsafe()
|
||||
}
|
||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||
// Safety: Python only calls tp_dealloc when no references to the object remain.
|
||||
let cell = &mut *(slf as *mut PyCell<T>);
|
||||
|
|
|
@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
|
|||
assert!(drop_called.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[pyclass(unsendable)]
|
||||
struct UnsendableTraversal {
|
||||
traversed: Cell<bool>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl UnsendableTraversal {
|
||||
fn __clear__(&mut self) {}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
self.traversed.set(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
|
||||
fn unsendable_are_not_traversed_on_foreign_thread() {
|
||||
Python::with_gil(|py| unsafe {
|
||||
let ty = py.get_type::<UnsendableTraversal>().as_type_ptr();
|
||||
let traverse = get_type_traverse(ty).unwrap();
|
||||
|
||||
let obj = Py::new(
|
||||
py,
|
||||
UnsendableTraversal {
|
||||
traversed: Cell::new(false),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ptr = SendablePtr(obj.as_ptr());
|
||||
|
||||
std::thread::spawn(move || {
|
||||
// traversal on foreign thread is a no-op
|
||||
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||
})
|
||||
.join()
|
||||
.unwrap();
|
||||
|
||||
assert!(!obj.borrow(py).traversed.get());
|
||||
|
||||
// traversal on home thread still works
|
||||
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||
|
||||
assert!(obj.borrow(py).traversed.get());
|
||||
});
|
||||
}
|
||||
|
||||
// Manual traversal utilities
|
||||
|
||||
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
|
||||
|
@ -533,3 +582,8 @@ extern "C" fn visit_error(
|
|||
) -> std::os::raw::c_int {
|
||||
-1
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct SendablePtr(*mut pyo3::ffi::PyObject);
|
||||
|
||||
unsafe impl Send for SendablePtr {}
|
||||
|
|
Loading…
Reference in New Issue