Merge pull request #3689 from PyO3/unsendable-threadsafe-traverse
Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread
This commit is contained in:
commit
8bef6e3398
|
@ -16,7 +16,7 @@
|
||||||
| `set_all` | Generates setters for all fields of the pyclass. |
|
| `set_all` | Generates setters for all fields of the pyclass. |
|
||||||
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
|
| `subclass` | Allows other Python classes and `#[pyclass]` to inherit from this class. Enums cannot be subclassed. |
|
||||||
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
|
| <span style="white-space: pre">`text_signature = "(arg1, arg2, ...)"`</span> | Sets the text signature for the Python class' `__new__` method. |
|
||||||
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread.|
|
| `unsendable` | Required if your struct is not [`Send`][params-3]. Rather than using `unsendable`, consider implementing your struct in a threadsafe way by e.g. substituting [`Rc`][params-4] with [`Arc`][params-5]. By using `unsendable`, your class will panic when accessed by another thread. Also note the Python's GC is multi-threaded and while unsendable classes will not be traversed on foreign threads to avoid UB, this can lead to memory leaks. |
|
||||||
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
|
| `weakref` | Allows this class to be [weakly referenceable][params-6]. |
|
||||||
|
|
||||||
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
|
All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation, or as one or
|
||||||
|
|
1
newsfragments/3689.changed.md
Normal file
1
newsfragments/3689.changed.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoiding hard aborts at the cost of potential leakage.
|
|
@ -1013,6 +1013,7 @@ impl<T> PyClassNewTextSignature<T> for &'_ PyClassImplCollector<T> {
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub trait PyClassThreadChecker<T>: Sized {
|
pub trait PyClassThreadChecker<T>: Sized {
|
||||||
fn ensure(&self);
|
fn ensure(&self);
|
||||||
|
fn check(&self) -> bool;
|
||||||
fn can_drop(&self, py: Python<'_>) -> bool;
|
fn can_drop(&self, py: Python<'_>) -> bool;
|
||||||
fn new() -> Self;
|
fn new() -> Self;
|
||||||
private_decl! {}
|
private_decl! {}
|
||||||
|
@ -1028,6 +1029,9 @@ pub struct SendablePyClass<T: Send>(PhantomData<T>);
|
||||||
|
|
||||||
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
|
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
|
||||||
fn ensure(&self) {}
|
fn ensure(&self) {}
|
||||||
|
fn check(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
fn can_drop(&self, _py: Python<'_>) -> bool {
|
fn can_drop(&self, _py: Python<'_>) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check(&self) -> bool {
|
||||||
|
thread::current().id() == self.0
|
||||||
|
}
|
||||||
|
|
||||||
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
|
fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
|
||||||
if thread::current().id() != self.0 {
|
if thread::current().id() != self.0 {
|
||||||
PyRuntimeError::new_err(format!(
|
PyRuntimeError::new_err(format!(
|
||||||
|
@ -1071,6 +1079,9 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
|
||||||
fn ensure(&self) {
|
fn ensure(&self) {
|
||||||
self.ensure(std::any::type_name::<T>());
|
self.ensure(std::any::type_name::<T>());
|
||||||
}
|
}
|
||||||
|
fn check(&self) -> bool {
|
||||||
|
self.check()
|
||||||
|
}
|
||||||
fn can_drop(&self, py: Python<'_>) -> bool {
|
fn can_drop(&self, py: Python<'_>) -> bool {
|
||||||
self.can_drop(py, std::any::type_name::<T>())
|
self.can_drop(py, std::any::type_name::<T>())
|
||||||
}
|
}
|
||||||
|
|
|
@ -269,7 +269,7 @@ where
|
||||||
|
|
||||||
let py = Python::assume_gil_acquired();
|
let py = Python::assume_gil_acquired();
|
||||||
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
|
let slf = py.from_borrowed_ptr::<PyCell<T>>(slf);
|
||||||
let borrow = slf.try_borrow();
|
let borrow = slf.try_borrow_threadsafe();
|
||||||
let visit = PyVisit::from_raw(visit, arg, py);
|
let visit = PyVisit::from_raw(visit, arg, py);
|
||||||
|
|
||||||
let retval = if let Ok(borrow) = borrow {
|
let retval = if let Ok(borrow) = borrow {
|
||||||
|
|
|
@ -351,6 +351,14 @@ impl<T: PyClass> PyCell<T> {
|
||||||
.map(|_| PyRef { inner: self })
|
.map(|_| PyRef { inner: self })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread
|
||||||
|
pub(crate) fn try_borrow_threadsafe(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
|
||||||
|
self.check_threadsafe()?;
|
||||||
|
self.borrow_checker()
|
||||||
|
.try_borrow()
|
||||||
|
.map(|_| PyRef { inner: self })
|
||||||
|
}
|
||||||
|
|
||||||
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
|
/// Mutably borrows the value `T`, returning an error if the value is currently borrowed.
|
||||||
/// This borrow lasts as long as the returned `PyRefMut` exists.
|
/// This borrow lasts as long as the returned `PyRefMut` exists.
|
||||||
///
|
///
|
||||||
|
@ -975,6 +983,7 @@ impl From<PyBorrowMutError> for PyErr {
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub trait PyCellLayout<T>: PyLayout<T> {
|
pub trait PyCellLayout<T>: PyLayout<T> {
|
||||||
fn ensure_threadsafe(&self);
|
fn ensure_threadsafe(&self);
|
||||||
|
fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
|
||||||
/// Implementation of tp_dealloc.
|
/// Implementation of tp_dealloc.
|
||||||
/// # Safety
|
/// # Safety
|
||||||
/// - slf must be a valid pointer to an instance of a T or a subclass.
|
/// - slf must be a valid pointer to an instance of a T or a subclass.
|
||||||
|
@ -988,6 +997,9 @@ where
|
||||||
T: PyTypeInfo,
|
T: PyTypeInfo,
|
||||||
{
|
{
|
||||||
fn ensure_threadsafe(&self) {}
|
fn ensure_threadsafe(&self) {}
|
||||||
|
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||||
let type_obj = T::type_object_raw(py);
|
let type_obj = T::type_object_raw(py);
|
||||||
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
|
// For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
|
||||||
|
@ -1025,6 +1037,12 @@ where
|
||||||
self.contents.thread_checker.ensure();
|
self.contents.thread_checker.ensure();
|
||||||
self.ob_base.ensure_threadsafe();
|
self.ob_base.ensure_threadsafe();
|
||||||
}
|
}
|
||||||
|
fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
|
||||||
|
if !self.contents.thread_checker.check() {
|
||||||
|
return Err(PyBorrowError { _private: () });
|
||||||
|
}
|
||||||
|
self.ob_base.check_threadsafe()
|
||||||
|
}
|
||||||
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
|
||||||
// Safety: Python only calls tp_dealloc when no references to the object remain.
|
// Safety: Python only calls tp_dealloc when no references to the object remain.
|
||||||
let cell = &mut *(slf as *mut PyCell<T>);
|
let cell = &mut *(slf as *mut PyCell<T>);
|
||||||
|
|
|
@ -512,6 +512,55 @@ fn drop_during_traversal_without_gil() {
|
||||||
assert!(drop_called.load(Ordering::Relaxed));
|
assert!(drop_called.load(Ordering::Relaxed));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass(unsendable)]
|
||||||
|
struct UnsendableTraversal {
|
||||||
|
traversed: Cell<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl UnsendableTraversal {
|
||||||
|
fn __clear__(&mut self) {}
|
||||||
|
|
||||||
|
#[allow(clippy::unnecessary_wraps)]
|
||||||
|
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||||
|
self.traversed.set(true);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
|
||||||
|
fn unsendable_are_not_traversed_on_foreign_thread() {
|
||||||
|
Python::with_gil(|py| unsafe {
|
||||||
|
let ty = py.get_type::<UnsendableTraversal>().as_type_ptr();
|
||||||
|
let traverse = get_type_traverse(ty).unwrap();
|
||||||
|
|
||||||
|
let obj = Py::new(
|
||||||
|
py,
|
||||||
|
UnsendableTraversal {
|
||||||
|
traversed: Cell::new(false),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ptr = SendablePtr(obj.as_ptr());
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
// traversal on foreign thread is a no-op
|
||||||
|
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||||
|
})
|
||||||
|
.join()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!obj.borrow(py).traversed.get());
|
||||||
|
|
||||||
|
// traversal on home thread still works
|
||||||
|
assert_eq!(traverse({ ptr }.0, novisit, std::ptr::null_mut()), 0);
|
||||||
|
|
||||||
|
assert!(obj.borrow(py).traversed.get());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Manual traversal utilities
|
// Manual traversal utilities
|
||||||
|
|
||||||
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
|
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
|
||||||
|
@ -533,3 +582,8 @@ extern "C" fn visit_error(
|
||||||
) -> std::os::raw::c_int {
|
) -> std::os::raw::c_int {
|
||||||
-1
|
-1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct SendablePtr(*mut pyo3::ffi::PyObject);
|
||||||
|
|
||||||
|
unsafe impl Send for SendablePtr {}
|
||||||
|
|
Loading…
Reference in a new issue