Update decorator to use atomics
This commit is contained in:
parent
7b3ad2b718
commit
611ea4db49
|
@ -1,14 +1,16 @@
|
|||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyDict, PyTuple};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// A function decorator that keeps track how often it is called.
|
||||
///
|
||||
/// It otherwise doesn't do anything special.
|
||||
#[pyclass(name = "Counter")]
|
||||
pub struct PyCounter {
|
||||
// We use `#[pyo3(get)]` so that python can read the count but not mutate it.
|
||||
#[pyo3(get)]
|
||||
count: u64,
|
||||
// Keeps track of how many calls have gone through.
|
||||
//
|
||||
// See the discussion at the end for why an atomic is used.
|
||||
count: AtomicU64,
|
||||
|
||||
// This is the actual function being wrapped.
|
||||
wraps: Py<PyAny>,
|
||||
|
@ -23,20 +25,28 @@ impl PyCounter {
|
|||
// 2. We still need to handle any exceptions that the function might raise
|
||||
#[new]
|
||||
fn __new__(wraps: Py<PyAny>) -> Self {
|
||||
PyCounter { count: 0, wraps }
|
||||
PyCounter {
|
||||
count: AtomicU64::new(0),
|
||||
wraps,
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn count(&self) -> u64 {
|
||||
self.count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[args(args = "*", kwargs = "**")]
|
||||
fn __call__(
|
||||
&mut self,
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
args: &PyTuple,
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<Py<PyAny>> {
|
||||
self.count += 1;
|
||||
let old_count = self.count.fetch_add(1, Ordering::Relaxed);
|
||||
let name = self.wraps.getattr(py, "__name__")?;
|
||||
|
||||
println!("{} has been called {} time(s).", name, self.count);
|
||||
println!("{} has been called {} time(s).", name, old_count + 1);
|
||||
|
||||
// After doing something, we finally forward the call to the wrapped function
|
||||
let ret = self.wraps.call(py, args, kwargs)?;
|
||||
|
|
|
@ -38,3 +38,14 @@ def test_default_arg():
|
|||
say_hello()
|
||||
|
||||
assert say_hello.count == 4
|
||||
|
||||
|
||||
# https://github.com/PyO3/pyo3/discussions/2598
|
||||
def test_discussion_2598():
|
||||
@Counter
|
||||
def say_hello():
|
||||
if say_hello.count < 2:
|
||||
print(f"hello from decorator")
|
||||
|
||||
say_hello()
|
||||
say_hello()
|
||||
|
|
|
@ -37,7 +37,7 @@ say_hello has been called 4 time(s).
|
|||
hello
|
||||
```
|
||||
|
||||
#### Pure Python implementation
|
||||
### Pure Python implementation
|
||||
|
||||
A Python implementation of this looks similar to the Rust version:
|
||||
|
||||
|
@ -65,3 +65,51 @@ def Counter(wraps):
|
|||
return wraps(*args, **kwargs)
|
||||
return call
|
||||
```
|
||||
|
||||
### What are the atomics for?
|
||||
|
||||
A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count:
|
||||
|
||||
```rust,ignore
|
||||
#[args(args = "*", kwargs = "**")]
|
||||
fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Py<PyAny>> {
|
||||
self.count += 1;
|
||||
let name = self.wraps.getattr(py, "__name__")?;
|
||||
|
||||
println!("{} has been called {} time(s).", name, self.count);
|
||||
|
||||
// After doing something, we finally forward the call to the wrapped function
|
||||
let ret = self.wraps.call(py, args, kwargs)?;
|
||||
|
||||
// We could do something with the return value of
|
||||
// the function before returning it
|
||||
Ok(ret)
|
||||
}
|
||||
```
|
||||
|
||||
The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively,
|
||||
and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code
|
||||
which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception.
|
||||
|
||||
As a result, something innocent like this will raise an exception:
|
||||
|
||||
```py
|
||||
@Counter
|
||||
def say_hello():
|
||||
if say_hello.count < 2:
|
||||
print(f"hello from decorator")
|
||||
|
||||
say_hello()
|
||||
# RuntimeError: Already borrowed
|
||||
```
|
||||
|
||||
The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared and threadsafe counter and the easiest way to do that is to use atomics, so that's what is used here.
|
||||
|
||||
This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above:
|
||||
- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run.
|
||||
- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods).
|
||||
- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run.
|
||||
|
||||
This especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things.
|
||||
|
||||
[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator <Help Wanted> · Discussion #2598 · PyO3/pyo3"
|
Loading…
Reference in a new issue