Merge pull request #2657 from mejrs/decorator_fix

Update decorator to use Cell counter
This commit is contained in:
Bruno Kolenbrander 2022-10-10 19:55:36 +02:00 committed by GitHub
commit c9b26f57cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 8 deletions

View file

@ -1,14 +1,16 @@
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple}; use pyo3::types::{PyDict, PyTuple};
use std::cell::Cell;
/// A function decorator that keeps track how often it is called. /// A function decorator that keeps track how often it is called.
/// ///
/// It otherwise doesn't do anything special. /// It otherwise doesn't do anything special.
#[pyclass(name = "Counter")] #[pyclass(name = "Counter")]
pub struct PyCounter { pub struct PyCounter {
// We use `#[pyo3(get)]` so that python can read the count but not mutate it. // Keeps track of how many calls have gone through.
#[pyo3(get)] //
count: u64, // See the discussion at the end for why `Cell` is used.
count: Cell<u64>,
// This is the actual function being wrapped. // This is the actual function being wrapped.
wraps: Py<PyAny>, wraps: Py<PyAny>,
@ -23,20 +25,30 @@ impl PyCounter {
// 2. We still need to handle any exceptions that the function might raise // 2. We still need to handle any exceptions that the function might raise
#[new] #[new]
fn __new__(wraps: Py<PyAny>) -> Self { fn __new__(wraps: Py<PyAny>) -> Self {
PyCounter { count: 0, wraps } PyCounter {
count: Cell::new(0),
wraps,
}
}
#[getter]
fn count(&self) -> u64 {
self.count.get()
} }
#[args(args = "*", kwargs = "**")] #[args(args = "*", kwargs = "**")]
fn __call__( fn __call__(
&mut self, &self,
py: Python<'_>, py: Python<'_>,
args: &PyTuple, args: &PyTuple,
kwargs: Option<&PyDict>, kwargs: Option<&PyDict>,
) -> PyResult<Py<PyAny>> { ) -> PyResult<Py<PyAny>> {
self.count += 1; let old_count = self.count.get();
let new_count = old_count + 1;
self.count.set(new_count);
let name = self.wraps.getattr(py, "__name__")?; let name = self.wraps.getattr(py, "__name__")?;
println!("{} has been called {} time(s).", name, self.count); println!("{} has been called {} time(s).", name, new_count);
// After doing something, we finally forward the call to the wrapped function // After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?; let ret = self.wraps.call(py, args, kwargs)?;

View file

@ -38,3 +38,14 @@ def test_default_arg():
say_hello() say_hello()
assert say_hello.count == 4 assert say_hello.count == 4
# https://github.com/PyO3/pyo3/discussions/2598
def test_discussion_2598():
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")
say_hello()
say_hello()

View file

@ -37,7 +37,7 @@ say_hello has been called 4 time(s).
hello hello
``` ```
#### Pure Python implementation ### Pure Python implementation
A Python implementation of this looks similar to the Rust version: A Python implementation of this looks similar to the Rust version:
@ -65,3 +65,52 @@ def Counter(wraps):
return wraps(*args, **kwargs) return wraps(*args, **kwargs)
return call return call
``` ```
### What is the `Cell` for?
A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count:
```rust,ignore
#[args(args = "*", kwargs = "**")]
fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Py<PyAny>> {
self.count += 1;
let name = self.wraps.getattr(py, "__name__")?;
println!("{} has been called {} time(s).", name, self.count);
// After doing something, we finally forward the call to the wrapped function
let ret = self.wraps.call(py, args, kwargs)?;
// We could do something with the return value of
// the function before returning it
Ok(ret)
}
```
The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively,
and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code
which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception.
As a result, something innocent like this will raise an exception:
```py
@Counter
def say_hello():
if say_hello.count < 2:
print(f"hello from decorator")
say_hello()
# RuntimeError: Already borrowed
```
The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared counter and the easiest way to do that is to use [`Cell`], so that's what is used here.
This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above:
- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run.
- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods).
- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run.
This is especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things.
[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator <Help Wanted> · Discussion #2598 · PyO3/pyo3"
[`Cell`]: https://doc.rust-lang.org/std/cell/struct.Cell.html "Cell in std::cell - Rust"