Update decorator to use atomics

This commit is contained in:
mejrs 2022-10-04 17:59:46 +02:00
parent 7b3ad2b718
commit 611ea4db49
3 changed files with 77 additions and 8 deletions

View file

@ -1,14 +1,16 @@
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use std::sync::atomic::{AtomicU64, Ordering};
/// A function decorator that keeps track how often it is called.
///
/// It otherwise doesn't do anything special.
#[pyclass(name = "Counter")]
pub struct PyCounter {
// We use `#[pyo3(get)]` so that python can read the count but not mutate it.
#[pyo3(get)]
count: u64,
// Keeps track of how many calls have gone through.
//
// See the discussion at the end for why an atomic is used.
count: AtomicU64,
// This is the actual function being wrapped.
wraps: Py<PyAny>,
@ -23,20 +25,28 @@ impl PyCounter {
// 2. We still need to handle any exceptions that the function might raise
#[new]
fn __new__(wraps: Py<PyAny>) -> Self {
PyCounter { count: 0, wraps }
PyCounter {
count: AtomicU64::new(0),
wraps,
}
}
#[getter]
fn count(&self) -> u64 {
self.count.load(Ordering::Relaxed)
}
#[args(args = "*", kwargs = "**")]
fn __call__(
&mut self,
&self,
py: Python<'_>,
args: &PyTuple,
kwargs: Option<&PyDict>,
) -> PyResult<Py<PyAny>> {
self.count += 1;
let old_count = self.count.fetch_add(1, Ordering::Relaxed);
let name = self.wraps.getattr(py, "__name__")?;
println!("{} has been called {} time(s).", name, self.count);
println!("{} has been called {} time(s).", name, old_count + 1);
// After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?;

View file

@ -38,3 +38,14 @@ def test_default_arg():
say_hello()
assert say_hello.count == 4
# https://github.com/PyO3/pyo3/discussions/2598
def test_discussion_2598():
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")
say_hello()
say_hello()

View file

@ -37,7 +37,7 @@ say_hello has been called 4 time(s).
hello
```
#### Pure Python implementation
### Pure Python implementation
A Python implementation of this looks similar to the Rust version:
@ -65,3 +65,51 @@ def Counter(wraps):
return wraps(*args, **kwargs)
return call
```
### What are the atomics for?
A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count:
```rust,ignore
#[args(args = "*", kwargs = "**")]
fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Py<PyAny>> {
self.count += 1;
let name = self.wraps.getattr(py, "__name__")?;
println!("{} has been called {} time(s).", name, self.count);
// After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?;
// We could do something with the return value of
// the function before returning it
Ok(ret)
}
```
The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively,
and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code
which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception.
As a result, something innocent like this will raise an exception:
```py
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")
say_hello()
# RuntimeError: Already borrowed
```
The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared and threadsafe counter and the easiest way to do that is to use atomics, so that's what is used here.
This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above:
- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run.
- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods).
- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run.
This especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things.
[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator <Help Wanted> · Discussion #2598 · PyO3/pyo3"