Merge pull request #3689 from PyO3/unsendable-threadsafe-traverse

Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread
This commit is contained in:
Adam Reichold 2023-12-23 14:13:46 +00:00 committed by GitHub
commit 8bef6e3398
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 2 deletions

View file

@ -16,7 +16,7 @@
| `set_all` | Generates setters for all fields of the pyclass. |
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or

View file

@ -0,0 +1 @@
Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.

View file

@ -1013,6 +1013,7 @@ impl<T> PyClassNewTextSignature<T> for &'_ PyClassImplCollector<T> {
#[doc(hidden)]
pub trait PyClassThreadChecker<T>: Sized {
fn ensure(&self);
fn check(&self) -> bool;
fn can_drop(&self, py: Python<'_>) -> bool;
fn new() -> Self;
private_decl! {}
@ -1028,6 +1029,9 @@ pub struct SendablePyClass<T: Send>(PhantomData<T>);
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
fn ensure(&self) {}
fn check(&self) -> bool {
true
}
fn can_drop(&self, _py: Python<'_>) -> bool {
true
}
@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
);
}
fn check(&self) -> bool {
thread::current().id() == self.0
}
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
if thread::current().id() != self.0 {
PyRuntimeError::new_err(format!(
@ -1071,6 +1079,9 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
fn ensure(&self) {
self.ensure(std::any::type_name::<T>());
}
fn check(&self) -> bool {
self.check()
}
fn can_drop(&self, py: Python<'_>) -> bool {
self.can_drop(py, std::any::type_name::<T>())
}

View file

@ -269,7 +269,7 @@ where
let py = Python::assume_gil_acquired();
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
let borrow = slf.try_borrow();
let borrow = slf.try_borrow_threadsafe();
let visit = PyVisit::from_raw(visit, arg, py);
let retval = if let Ok(borrow) = borrow {

View file

@ -351,6 +351,14 @@ impl<T: PyClass> PyCell<T> {
.map(|_| PyRef { inner: self })
}
/// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
pub(crate) fn try_borrow_threadsafe(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
self.check_threadsafe()?;
self.borrow_checker()
.try_borrow()
.map(|_| PyRef { inner: self })
}
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
/// This borrow lasts as long as the returned `PyRefMut` exists.
///
@ -975,6 +983,7 @@ impl From<PyBorrowMutError> for PyErr {
#[doc(hidden)]
pub trait PyCellLayout<T>: PyLayout<T> {
fn ensure_threadsafe(&self);
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
/// Implementation of tp_dealloc.
/// # Safety
/// - slf must be a valid pointer to an instance of a T or a subclass.
@ -988,6 +997,9 @@ where
T: PyTypeInfo,
{
fn ensure_threadsafe(&self) {}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
Ok(())
}
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
let type_obj = T::type_object_raw(py);
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
@ -1025,6 +1037,12 @@ where
self.contents.thread_checker.ensure();
self.ob_base.ensure_threadsafe();
}
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
if !self.contents.thread_checker.check() {
return Err(PyBorrowError { _private: () });
}
self.ob_base.check_threadsafe()
}
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
// Safety: Python only calls tp_dealloc when no references to the object remain.
let cell = &mut *(slf as *mut PyCell<T>);

View file

@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
assert!(drop_called.load(Ordering::Relaxed));
}
#[pyclass(unsendable)]
struct UnsendableTraversal {
traversed: Cell<bool>,
}
#[pymethods]
impl UnsendableTraversal {
fn __clear__(&mut self) {}
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.traversed.set(true);
Ok(())
}
}
#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn unsendable_are_not_traversed_on_foreign_thread() {
Python::with_gil(|py| unsafe {
let ty = py.get_type::<UnsendableTraversal>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();
let obj = Py::new(
py,
UnsendableTraversal {
traversed: Cell::new(false),
},
)
.unwrap();
let ptr = SendablePtr(obj.as_ptr());
std::thread::spawn(move || {
// traversal on foreign thread is a no-op
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
})
.join()
.unwrap();
assert!(!obj.borrow(py).traversed.get());
// traversal on home thread still works
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
assert!(obj.borrow(py).traversed.get());
});
}
// Manual traversal utilities
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
@ -533,3 +582,8 @@ extern "C" fn visit_error(
) -> std::os::raw::c_int {
-1
}
#[derive(Clone, Copy)]
struct SendablePtr(*mut pyo3::ffi::PyObject);
unsafe impl Send for SendablePtr {}