Merge pull request #3689 from PyO3/unsendable-threadsafe-traverse
Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread
This commit is contained in:
commit
8bef6e3398
|
@ -16,7 +16,7 @@
|
|||
| `set_all` | Generates setters for all fields of the pyclass. |
|
||||
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
|
||||
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
|
||||
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
|
||||
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
|
||||
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
|
||||
|
||||
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
|
||||
|
|
1
newsfragments/3689.changed.md
Normal file
1
newsfragments/3689.changed.md
Normal file
|
@ -0,0 +1 @@
|
|||
Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.
|
|
@ -1013,6 +1013,7 @@ impl<T> PyClassNewTextSignature<T> for &'_ PyClassImplCollector<T> {
|
|||
#[doc(hidden)]
|
||||
pub trait PyClassThreadChecker<T>: Sized {
|
||||
fn ensure(&self);
|
||||
fn check(&self) -> bool;
|
||||
fn can_drop(&self, py: Python<'_>) -> bool;
|
||||
fn new() -> Self;
|
||||
private_decl! {}
|
||||
|
@ -1028,6 +1029,9 @@ pub struct SendablePyClass<T: Send>(PhantomData<T>);
|
|||
|
||||
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
|
||||
fn ensure(&self) {}
|
||||
fn check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn can_drop(&self, _py: Python<'_>) -> bool {
|
||||
true
|
||||
}
|
||||
|
@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
|
|||
);
|
||||
}
|
||||
|
||||
fn check(&self) -> bool {
|
||||
thread::current().id() == self.0
|
||||
}
|
||||
|
||||
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
|
||||
if thread::current().id() != self.0 {
|
||||
PyRuntimeError::new_err(format!(
|
||||
|
@ -1071,6 +1079,9 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
|
|||
fn ensure(&self) {
|
||||
self.ensure(std::any::type_name::<T>());
|
||||
}
|
||||
fn check(&self) -> bool {
|
||||
self.check()
|
||||
}
|
||||
fn can_drop(&self, py: Python<'_>) -> bool {
|
||||
self.can_drop(py, std::any::type_name::<T>())
|
||||
}
|
||||
|
|
|
@ -269,7 +269,7 @@ where
|
|||
|
||||
let py = Python::assume_gil_acquired();
|
||||
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
|
||||
let borrow = slf.try_borrow();
|
||||
let borrow = slf.try_borrow_threadsafe();
|
||||
let visit = PyVisit::from_raw(visit, arg, py);
|
||||
|
||||
let retval = if let Ok(borrow) = borrow {
|
||||
|
|
|
@ -351,6 +351,14 @@ impl<T: PyClass> PyCell<T> {
|
|||
.map(|_| PyRef { inner: self })
|
||||
}
|
||||
|
||||
/// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
|
||||
pub(crate) fn try_borrow_threadsafe(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
|
||||
self.check_threadsafe()?;
|
||||
self.borrow_checker()
|
||||
.try_borrow()
|
||||
.map(|_| PyRef { inner: self })
|
||||
}
|
||||
|
||||
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
|
||||
/// This borrow lasts as long as the returned `PyRefMut` exists.
|
||||
///
|
||||
|
@ -975,6 +983,7 @@ impl From<PyBorrowMutError> for PyErr {
|
|||
#[doc(hidden)]
|
||||
pub trait PyCellLayout<T>: PyLayout<T> {
|
||||
fn ensure_threadsafe(&self);
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
|
||||
/// Implementation of tp_dealloc.
|
||||
/// # Safety
|
||||
/// - slf must be a valid pointer to an instance of a T or a subclass.
|
||||
|
@ -988,6 +997,9 @@ where
|
|||
T: PyTypeInfo,
|
||||
{
|
||||
fn ensure_threadsafe(&self) {}
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||
Ok(())
|
||||
}
|
||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||
let type_obj = T::type_object_raw(py);
|
||||
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
|
||||
|
@ -1025,6 +1037,12 @@ where
|
|||
self.contents.thread_checker.ensure();
|
||||
self.ob_base.ensure_threadsafe();
|
||||
}
|
||||
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||
if !self.contents.thread_checker.check() {
|
||||
return Err(PyBorrowError { _private: () });
|
||||
}
|
||||
self.ob_base.check_threadsafe()
|
||||
}
|
||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||
// Safety: Python only calls tp_dealloc when no references to the object remain.
|
||||
let cell = &mut *(slf as *mut PyCell<T>);
|
||||
|
|
|
@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
|
|||
assert!(drop_called.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[pyclass(unsendable)]
|
||||
struct UnsendableTraversal {
|
||||
traversed: Cell<bool>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl UnsendableTraversal {
|
||||
fn __clear__(&mut self) {}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
self.traversed.set(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
|
||||
fn unsendable_are_not_traversed_on_foreign_thread() {
|
||||
Python::with_gil(|py| unsafe {
|
||||
let ty = py.get_type::<UnsendableTraversal>().as_type_ptr();
|
||||
let traverse = get_type_traverse(ty).unwrap();
|
||||
|
||||
let obj = Py::new(
|
||||
py,
|
||||
UnsendableTraversal {
|
||||
traversed: Cell::new(false),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ptr = SendablePtr(obj.as_ptr());
|
||||
|
||||
std::thread::spawn(move || {
|
||||
// traversal on foreign thread is a no-op
|
||||
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||
})
|
||||
.join()
|
||||
.unwrap();
|
||||
|
||||
assert!(!obj.borrow(py).traversed.get());
|
||||
|
||||
// traversal on home thread still works
|
||||
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||
|
||||
assert!(obj.borrow(py).traversed.get());
|
||||
});
|
||||
}
|
||||
|
||||
// Manual traversal utilities
|
||||
|
||||
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
|
||||
|
@ -533,3 +582,8 @@ extern "C" fn visit_error(
|
|||
) -> std::os::raw::c_int {
|
||||
-1
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct SendablePtr(*mut pyo3::ffi::PyObject);
|
||||
|
||||
unsafe impl Send for SendablePtr {}
|
||||
|
|
Loading…
Reference in a new issue