From 611ea4db49b3cbc60bada4b622c556264a0b238c Mon Sep 17 00:00:00 2001 From: mejrs <> Date: Tue, 4 Oct 2022 17:59:46 +0200 Subject: [PATCH] Update decorator to use atomics --- examples/decorator/src/lib.rs | 24 ++++++++++----- examples/decorator/tests/test_.py | 11 +++++++ guide/src/class/call.md | 50 ++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/examples/decorator/src/lib.rs b/examples/decorator/src/lib.rs index 078b33f0..ca6db532 100644 --- a/examples/decorator/src/lib.rs +++ b/examples/decorator/src/lib.rs @@ -1,14 +1,16 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; +use std::sync::atomic::{AtomicU64, Ordering}; /// A function decorator that keeps track how often it is called. /// /// It otherwise doesn't do anything special. #[pyclass(name = "Counter")] pub struct PyCounter { - // We use `#[pyo3(get)]` so that python can read the count but not mutate it. - #[pyo3(get)] - count: u64, + // Keeps track of how many calls have gone through. + // + // See the discussion at the end for why an atomic is used. + count: AtomicU64, // This is the actual function being wrapped. wraps: Py, @@ -23,20 +25,28 @@ impl PyCounter { // 2. We still need to handle any exceptions that the function might raise #[new] fn __new__(wraps: Py) -> Self { - PyCounter { count: 0, wraps } + PyCounter { + count: AtomicU64::new(0), + wraps, + } + } + + #[getter] + fn count(&self) -> u64 { + self.count.load(Ordering::Relaxed) } #[args(args = "*", kwargs = "**")] fn __call__( - &mut self, + &self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>, ) -> PyResult> { - self.count += 1; + let old_count = self.count.fetch_add(1, Ordering::Relaxed); let name = self.wraps.getattr(py, "__name__")?; - println!("{} has been called {} time(s).", name, self.count); + println!("{} has been called {} time(s).", name, old_count + 1); // After doing something, we finally forward the call to the wrapped function let ret = self.wraps.call(py, args, kwargs)?; diff --git a/examples/decorator/tests/test_.py b/examples/decorator/tests/test_.py index 1031c442..26056713 100644 --- a/examples/decorator/tests/test_.py +++ b/examples/decorator/tests/test_.py @@ -38,3 +38,14 @@ def test_default_arg(): say_hello() assert say_hello.count == 4 + + +# https://github.com/PyO3/pyo3/discussions/2598 +def test_discussion_2598(): + @Counter + def say_hello(): + if say_hello.count < 2: + print(f"hello from decorator") + + say_hello() + say_hello() diff --git a/guide/src/class/call.md b/guide/src/class/call.md index 9a470373..be544e70 100644 --- a/guide/src/class/call.md +++ b/guide/src/class/call.md @@ -37,7 +37,7 @@ say_hello has been called 4 time(s). hello ``` -#### Pure Python implementation +### Pure Python implementation A Python implementation of this looks similar to the Rust version: @@ -65,3 +65,51 @@ def Counter(wraps): return wraps(*args, **kwargs) return call ``` + +### What are the atomics for? + +A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count: + +```rust,ignore +#[args(args = "*", kwargs = "**")] +fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult> { + self.count += 1; + let name = self.wraps.getattr(py, "__name__")?; + + println!("{} has been called {} time(s).", name, self.count); + + // After doing something, we finally forward the call to the wrapped function + let ret = self.wraps.call(py, args, kwargs)?; + + // We could do something with the return value of + // the function before returning it + Ok(ret) +} +``` + +The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively, + and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code + which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception. + +As a result, something innocent like this will raise an exception: + +```py +@Counter +def say_hello(): + if say_hello.count < 2: + print(f"hello from decorator") + +say_hello() +# RuntimeError: Already borrowed +``` + +The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared and threadsafe counter and the easiest way to do that is to use atomics, so that's what is used here. + +This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above: +- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run. +- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods). +- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run. + +This especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things. + +[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator · Discussion #2598 · PyO3/pyo3" \ No newline at end of file