diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b98face..50bb2505 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `#[pyclass(unsendable)]`. [#1009](https://github.com/PyO3/pyo3/pull/1009) ## [0.11.0] - 2020-06-28 ### Added diff --git a/guide/src/class.md b/guide/src/class.md index 679b61c9..4724184f 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -124,6 +124,8 @@ If a custom class contains references to other Python objects that can be collec * `extends=BaseType` - Use a custom base class. The base `BaseType` must implement `PyTypeInfo`. * `subclass` - Allows Python classes to inherit from this class. * `dict` - Adds `__dict__` support, so that the instances of this type have a dictionary containing arbitrary instance variables. +* `unsendable` - Making it safe to expose `!Send` structs to Python, where all object can be accessed + by multiple threads. A class marked with `unsendable` panics when accessed by another thread. * `module="XXX"` - Set the name of the module the class will be shown as defined in. If not given, the class will be a virtual member of the `builtins` module. @@ -974,6 +976,10 @@ impl pyo3::class::proto_methods::HasProtoRegistry for MyClass { ®ISTRY } } + +impl pyo3::pyclass::PyClassSend for MyClass { + type ThreadChecker = pyo3::pyclass::ThreadCheckerStub; +} # let gil = Python::acquire_gil(); # let py = gil.python(); # let cls = py.get_type::(); diff --git a/guide/src/migration.md b/guide/src/migration.md index 5033cc11..004e20ce 100644 --- a/guide/src/migration.md +++ b/guide/src/migration.md @@ -8,43 +8,71 @@ For a detailed list of all changes, see [CHANGELOG.md](https://github.com/PyO3/p ### Stable Rust PyO3 now supports the stable Rust toolchain. The minimum required version is 1.39.0. -### `#[pyclass]` structs must now be `Send` +### `#[pyclass]` structs must now be `Send` or `unsendable` Because `#[pyclass]` structs can be sent between threads by the Python interpreter, they must implement -`Send` to guarantee thread safety. This bound was added in PyO3 `0.11.0`. +`Send` or declared as `unsendable` (by `#[pyclass(unsendable)]`). +Note that `unsendable` is added in PyO3 `0.11.1` and `Send` is always required in PyO3 `0.11.0`. -This may "break" some code which previously was accepted, even though it was unsound. To resolve this, -consider using types like `Arc` instead of `Rc`, `Mutex` instead of `RefCell`, and add `Send` to any -boxed closures stored inside the `#[pyclass]`. +This may "break" some code which previously was accepted, even though it could be unsound. +There can be two fixes: -Before: -```rust,compile_fail -use pyo3::prelude::*; -use std::rc::Rc; -use std::cell::RefCell; +1. If you think that your `#[pyclass]` actually must be `Send`able, then let's implement `Send`. + A common, safer way is using thread-safe types. E.g., `Arc` instead of `Rc`, `Mutex` instead of + `RefCell`, and `Box` instead of `Box`. -#[pyclass] -struct NotThreadSafe { - shared_bools: Rc>>, - closure: Box -} -``` + Before: + ```rust,compile_fail + use pyo3::prelude::*; + use std::rc::Rc; + use std::cell::RefCell; -After: -```rust -use pyo3::prelude::*; -use std::sync::{Arc, Mutex}; + #[pyclass] + struct NotThreadSafe { + shared_bools: Rc>>, + closure: Box + } + ``` -#[pyclass] -struct ThreadSafe { - shared_bools: Arc>>, - closure: Box -} -``` + After: + ```rust + use pyo3::prelude::*; + use std::sync::{Arc, Mutex}; -Or in situations where you cannot change your `#[pyclass]` to automatically implement `Send` -(e.g., when it contains a raw pointer), you can use `unsafe impl Send`. -In such cases, care should be taken to ensure the struct is actually thread safe. -See [the Rustnomicon](ttps://doc.rust-lang.org/nomicon/send-and-sync.html) for more. + #[pyclass] + struct ThreadSafe { + shared_bools: Arc>>, + closure: Box + } + ``` + + In situations where you cannot change your `#[pyclass]` to automatically implement `Send` + (e.g., when it contains a raw pointer), you can use `unsafe impl Send`. + In such cases, care should be taken to ensure the struct is actually thread safe. + See [the Rustnomicon](https://doc.rust-lang.org/nomicon/send-and-sync.html) for more. + +2. If you think that your `#[pyclass]` should not be accessed by another thread, you can use + `unsendable` flag. A class marked with `unsendable` panics when accessed by another thread, + making it thread-safe to expose an unsendable object to the Python interpreter. + + Before: + ```rust,compile_fail + use pyo3::prelude::*; + + #[pyclass] + struct Unsendable { + pointers: Vec<*mut std::os::raw::c_char>, + } + ``` + + After: + ```rust + use pyo3::prelude::*; + + #[pyclass(unsendable)] + struct Unsendable { + pointers: Vec<*mut std::os::raw::c_char>, + } + ``` ### All `PyObject` and `Py` methods now take `Python` as an argument Previously, a few methods such as `Object::get_refcnt` did not take `Python` as an argument (to diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index cf52f334..dffb346a 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -19,6 +19,7 @@ pub struct PyClassArgs { pub flags: Vec, pub base: syn::TypePath, pub has_extends: bool, + pub has_unsendable: bool, pub module: Option, } @@ -45,6 +46,7 @@ impl Default for PyClassArgs { flags: vec![parse_quote! { 0 }], base: parse_quote! { pyo3::PyAny }, has_extends: false, + has_unsendable: false, } } } @@ -60,7 +62,7 @@ impl PyClassArgs { } } - /// Match a single flag + /// Match a key/value flag fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> { let syn::ExprAssign { left, right, .. } = assign; let key = match &**left { @@ -120,31 +122,27 @@ impl PyClassArgs { Ok(()) } - /// Match a key/value flag + /// Match a single flag fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> { let flag = exp.path.segments.first().unwrap().ident.to_string(); - let path = match flag.as_str() { - "gc" => { - parse_quote! {pyo3::type_flags::GC} - } - "weakref" => { - parse_quote! {pyo3::type_flags::WEAKREF} - } - "subclass" => { - parse_quote! {pyo3::type_flags::BASETYPE} - } - "dict" => { - parse_quote! {pyo3::type_flags::DICT} + let mut push_flag = |flag| { + self.flags.push(syn::Expr::Path(flag)); + }; + match flag.as_str() { + "gc" => push_flag(parse_quote! {pyo3::type_flags::GC}), + "weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}), + "subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}), + "dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}), + "unsendable" => { + self.has_unsendable = true; } _ => { return Err(syn::Error::new_spanned( &exp.path, - "Expected one of gc/weakref/subclass/dict", + "Expected one of gc/weakref/subclass/dict/unsendable", )) } }; - - self.flags.push(syn::Expr::Path(path)); Ok(()) } } @@ -386,6 +384,16 @@ fn impl_class( quote! {} }; + let thread_checker = if attr.has_unsendable { + quote! { pyo3::pyclass::ThreadCheckerImpl<#cls> } + } else if attr.has_extends { + quote! { + pyo3::pyclass::ThreadCheckerInherited<#cls, <#cls as pyo3::type_object::PyTypeInfo>::BaseType> + } + } else { + quote! { pyo3::pyclass::ThreadCheckerStub<#cls> } + }; + Ok(quote! { unsafe impl pyo3::type_object::PyTypeInfo for #cls { type Type = #cls; @@ -424,6 +432,10 @@ fn impl_class( type Target = pyo3::PyRefMut<'a, #cls>; } + impl pyo3::pyclass::PyClassSend for #cls { + type ThreadChecker = #thread_checker; + } + #into_pyobject #impl_inventory @@ -433,7 +445,6 @@ fn impl_class( #extra #gc_impl - }) } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 528d6a2a..c10c172c 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -7,7 +7,7 @@ use crate::err::{PyErr, PyResult}; use crate::exceptions::TypeError; use crate::instance::PyNativeType; -use crate::pyclass::PyClass; +use crate::pyclass::{PyClass, PyClassThreadChecker}; use crate::types::{PyAny, PyDict, PyModule, PyTuple}; use crate::{ffi, GILPool, IntoPy, PyCell, Python}; use std::cell::UnsafeCell; @@ -157,11 +157,12 @@ impl ModuleDef { /// Utilities for basetype #[doc(hidden)] -pub trait PyBaseTypeUtils { +pub trait PyBaseTypeUtils: Sized { type Dict; type WeakRef; type LayoutAsBase; type BaseNativeType; + type ThreadChecker: PyClassThreadChecker; } impl PyBaseTypeUtils for T { @@ -169,6 +170,7 @@ impl PyBaseTypeUtils for T { type WeakRef = T::WeakRef; type LayoutAsBase = crate::pycell::PyCellInner; type BaseNativeType = T::BaseNativeType; + type ThreadChecker = T::ThreadChecker; } /// Utility trait to enable &PyClass as a pymethod/function argument diff --git a/src/pycell.rs b/src/pycell.rs index 2fa67288..19f0d199 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -1,10 +1,11 @@ //! Includes `PyCell` implementation. use crate::conversion::{AsPyPointer, FromPyPointer, ToPyObject}; +use crate::pyclass::{PyClass, PyClassThreadChecker}; use crate::pyclass_init::PyClassInitializer; use crate::pyclass_slots::{PyClassDict, PyClassWeakRef}; use crate::type_object::{PyBorrowFlagLayout, PyLayout, PySizedLayout, PyTypeInfo}; use crate::types::PyAny; -use crate::{ffi, FromPy, PyClass, PyErr, PyNativeType, PyObject, PyResult, Python}; +use crate::{ffi, FromPy, PyErr, PyNativeType, PyObject, PyResult, Python}; use std::cell::{Cell, UnsafeCell}; use std::fmt; use std::mem::ManuallyDrop; @@ -161,6 +162,7 @@ pub struct PyCell { inner: PyCellInner, dict: T::Dict, weakref: T::WeakRef, + thread_checker: T::ThreadChecker, } unsafe impl PyNativeType for PyCell {} @@ -227,6 +229,7 @@ impl PyCell { /// } /// ``` pub fn try_borrow(&self) -> Result, PyBorrowError> { + self.thread_checker.ensure(); let flag = self.inner.get_borrow_flag(); if flag == BorrowFlag::HAS_MUTABLE_BORROW { Err(PyBorrowError { _private: () }) @@ -258,6 +261,7 @@ impl PyCell { /// assert!(c.try_borrow_mut().is_ok()); /// ``` pub fn try_borrow_mut(&self) -> Result, PyBorrowMutError> { + self.thread_checker.ensure(); if self.inner.get_borrow_flag() != BorrowFlag::UNUSED { Err(PyBorrowMutError { _private: () }) } else { @@ -296,6 +300,7 @@ impl PyCell { /// } /// ``` pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, PyBorrowError> { + self.thread_checker.ensure(); if self.inner.get_borrow_flag() == BorrowFlag::HAS_MUTABLE_BORROW { Err(PyBorrowError { _private: () }) } else { @@ -352,6 +357,7 @@ impl PyCell { let self_ = base as *mut Self; (*self_).dict = T::Dict::new(); (*self_).weakref = T::WeakRef::new(); + (*self_).thread_checker = T::ThreadChecker::new(); Ok(self_) } } diff --git a/src/pyclass.rs b/src/pyclass.rs index 28cadef6..b6b05f86 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -1,14 +1,16 @@ -//! `PyClass` trait +//! `PyClass` and related traits. use crate::class::methods::{PyClassAttributeDef, PyMethodDefType, PyMethods}; use crate::class::proto_methods::PyProtoMethods; use crate::conversion::{AsPyPointer, FromPyPointer}; +use crate::derive_utils::PyBaseTypeUtils; use crate::pyclass_slots::{PyClassDict, PyClassWeakRef}; use crate::type_object::{type_flags, PyLayout}; use crate::types::PyAny; use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python}; use std::ffi::CString; +use std::marker::PhantomData; use std::os::raw::c_void; -use std::ptr; +use std::{ptr, thread}; #[inline] pub(crate) unsafe fn default_new( @@ -91,10 +93,10 @@ pub(crate) unsafe fn tp_free_fallback(obj: *mut ffi::PyObject) { pub trait PyClass: PyTypeInfo, AsRefTarget = PyCell> + Sized + + PyClassSend + PyClassAlloc + PyMethods + PyProtoMethods - + Send { /// Specify this class has `#[pyclass(dict)]` or not. type Dict: PyClassDict; @@ -308,3 +310,76 @@ fn py_class_properties() -> Vec { defs.values().cloned().collect() } + +/// This trait is implemented for `#[pyclass]` and handles following two situations: +/// 1. In case `T` is `Send`, stub `ThreadChecker` is used and does nothing. +/// This implementation is used by default. Compile fails if `T: !Send`. +/// 2. In case `T` is `!Send`, `ThreadChecker` panics when `T` is accessed by another thread. +/// This implementation is used when `#[pyclass(unsendable)]` is given. +/// Panicking makes it safe to expose `T: !Send` to the Python interpreter, where all objects +/// can be accessed by multiple threads by `threading` module. +pub trait PyClassSend: Sized { + type ThreadChecker: PyClassThreadChecker; +} + +#[doc(hidden)] +pub trait PyClassThreadChecker: Sized { + fn ensure(&self); + fn new() -> Self; + private_decl! {} +} + +/// Stub checker for `Send` types. +#[doc(hidden)] +pub struct ThreadCheckerStub(PhantomData); + +impl PyClassThreadChecker for ThreadCheckerStub { + fn ensure(&self) {} + fn new() -> Self { + ThreadCheckerStub(PhantomData) + } + private_impl! {} +} + +impl PyClassThreadChecker for ThreadCheckerStub { + fn ensure(&self) {} + fn new() -> Self { + ThreadCheckerStub(PhantomData) + } + private_impl! {} +} + +/// Thread checker for unsendable types. +/// Panics when the value is accessed by another thread. +#[doc(hidden)] +pub struct ThreadCheckerImpl(thread::ThreadId, PhantomData); + +impl PyClassThreadChecker for ThreadCheckerImpl { + fn ensure(&self) { + if thread::current().id() != self.0 { + panic!( + "{} is unsendable, but sent to another thread!", + std::any::type_name::() + ); + } + } + fn new() -> Self { + ThreadCheckerImpl(thread::current().id(), PhantomData) + } + private_impl! {} +} + +/// Thread checker for types that have `Send` and `extends=...`. +/// Ensures that `T: Send` and the parent is not accessed by another thread. +#[doc(hidden)] +pub struct ThreadCheckerInherited(PhantomData, U::ThreadChecker); + +impl PyClassThreadChecker for ThreadCheckerInherited { + fn ensure(&self) { + self.1.ensure(); + } + fn new() -> Self { + ThreadCheckerInherited(PhantomData, U::ThreadChecker::new()) + } + private_impl! {} +} diff --git a/src/types/mod.rs b/src/types/mod.rs index c238e72a..fafc3580 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -75,6 +75,7 @@ macro_rules! pyobject_native_type { type WeakRef = $crate::pyclass_slots::PyClassDummySlot; type LayoutAsBase = $crate::pycell::PyCellBase<$name>; type BaseNativeType = $name; + type ThreadChecker = $crate::pyclass::ThreadCheckerStub<$crate::PyObject>; } pyobject_native_type_named!($name $(,$type_param)*); pyobject_native_type_convert!($name, $layout, $typeobject, $module, $checkfunction $(,$type_param)*); diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 82b4810c..03a13103 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -163,3 +163,55 @@ fn class_with_object_field() { py_assert!(py, ty, "ty(5).value == 5"); py_assert!(py, ty, "ty(None).value == None"); } + +#[pyclass(unsendable)] +struct UnsendableBase { + rc: std::rc::Rc, +} + +#[pymethods] +impl UnsendableBase { + fn value(&self) -> usize { + *self.rc.as_ref() + } +} + +#[pyclass(extends=UnsendableBase)] +struct UnsendableChild {} + +/// If a class is marked as `unsendable`, it panics when accessed by another thread. +#[test] +fn panic_unsendable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let base = || UnsendableBase { + rc: std::rc::Rc::new(0), + }; + let unsendable_base = PyCell::new(py, base()).unwrap(); + let unsendable_child = PyCell::new(py, (UnsendableChild {}, base())).unwrap(); + + let source = pyo3::indoc::indoc!( + r#" +def value(): + return unsendable.value() + +import concurrent.futures +executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) +future = executor.submit(value) +try: + result = future.result() + assert False, 'future must panic' +except BaseException as e: + assert str(e) == 'test_class_basics::UnsendableBase is unsendable, but sent to another thread!' +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + let test = |unsendable| { + globals.set_item("unsendable", unsendable).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); + }; + test(unsendable_base.as_ref()); + test(unsendable_child.as_ref()); +} diff --git a/tests/ui/invalid_pyclass_args.stderr b/tests/ui/invalid_pyclass_args.stderr index 72373cd6..42b2c460 100644 --- a/tests/ui/invalid_pyclass_args.stderr +++ b/tests/ui/invalid_pyclass_args.stderr @@ -22,7 +22,7 @@ error: Expected string literal (e.g., "my_mod") 12 | #[pyclass(module = my_module)] | ^^^^^^^^^ -error: Expected one of gc/weakref/subclass/dict +error: Expected one of gc/weakref/subclass/dict/unsendable --> $DIR/invalid_pyclass_args.rs:15:11 | 15 | #[pyclass(weakrev)]