Merge pull request #2657 from mejrs/decorator_fix
Update decorator to use Cell counter
This commit is contained in:
commit
c9b26f57cd
|
@ -1,14 +1,16 @@
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::{PyDict, PyTuple};
|
use pyo3::types::{PyDict, PyTuple};
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
/// A function decorator that keeps track how often it is called.
|
/// A function decorator that keeps track how often it is called.
|
||||||
///
|
///
|
||||||
/// It otherwise doesn't do anything special.
|
/// It otherwise doesn't do anything special.
|
||||||
#[pyclass(name = "Counter")]
|
#[pyclass(name = "Counter")]
|
||||||
pub struct PyCounter {
|
pub struct PyCounter {
|
||||||
// We use `#[pyo3(get)]` so that python can read the count but not mutate it.
|
// Keeps track of how many calls have gone through.
|
||||||
#[pyo3(get)]
|
//
|
||||||
count: u64,
|
// See the discussion at the end for why `Cell` is used.
|
||||||
|
count: Cell<u64>,
|
||||||
|
|
||||||
// This is the actual function being wrapped.
|
// This is the actual function being wrapped.
|
||||||
wraps: Py<PyAny>,
|
wraps: Py<PyAny>,
|
||||||
|
@ -23,20 +25,30 @@ impl PyCounter {
|
||||||
// 2. We still need to handle any exceptions that the function might raise
|
// 2. We still need to handle any exceptions that the function might raise
|
||||||
#[new]
|
#[new]
|
||||||
fn __new__(wraps: Py<PyAny>) -> Self {
|
fn __new__(wraps: Py<PyAny>) -> Self {
|
||||||
PyCounter { count: 0, wraps }
|
PyCounter {
|
||||||
|
count: Cell::new(0),
|
||||||
|
wraps,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn count(&self) -> u64 {
|
||||||
|
self.count.get()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[args(args = "*", kwargs = "**")]
|
#[args(args = "*", kwargs = "**")]
|
||||||
fn __call__(
|
fn __call__(
|
||||||
&mut self,
|
&self,
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
args: &PyTuple,
|
args: &PyTuple,
|
||||||
kwargs: Option<&PyDict>,
|
kwargs: Option<&PyDict>,
|
||||||
) -> PyResult<Py<PyAny>> {
|
) -> PyResult<Py<PyAny>> {
|
||||||
self.count += 1;
|
let old_count = self.count.get();
|
||||||
|
let new_count = old_count + 1;
|
||||||
|
self.count.set(new_count);
|
||||||
let name = self.wraps.getattr(py, "__name__")?;
|
let name = self.wraps.getattr(py, "__name__")?;
|
||||||
|
|
||||||
println!("{} has been called {} time(s).", name, self.count);
|
println!("{} has been called {} time(s).", name, new_count);
|
||||||
|
|
||||||
// After doing something, we finally forward the call to the wrapped function
|
// After doing something, we finally forward the call to the wrapped function
|
||||||
let ret = self.wraps.call(py, args, kwargs)?;
|
let ret = self.wraps.call(py, args, kwargs)?;
|
||||||
|
|
|
@ -38,3 +38,14 @@ def test_default_arg():
|
||||||
say_hello()
|
say_hello()
|
||||||
|
|
||||||
assert say_hello.count == 4
|
assert say_hello.count == 4
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/PyO3/pyo3/discussions/2598
|
||||||
|
def test_discussion_2598():
|
||||||
|
@Counter
|
||||||
|
def say_hello():
|
||||||
|
if say_hello.count < 2:
|
||||||
|
print(f"hello from decorator")
|
||||||
|
|
||||||
|
say_hello()
|
||||||
|
say_hello()
|
||||||
|
|
|
@ -37,7 +37,7 @@ say_hello has been called 4 time(s).
|
||||||
hello
|
hello
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Pure Python implementation
|
### Pure Python implementation
|
||||||
|
|
||||||
A Python implementation of this looks similar to the Rust version:
|
A Python implementation of this looks similar to the Rust version:
|
||||||
|
|
||||||
|
@ -65,3 +65,52 @@ def Counter(wraps):
|
||||||
return wraps(*args, **kwargs)
|
return wraps(*args, **kwargs)
|
||||||
return call
|
return call
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### What is the `Cell` for?
|
||||||
|
|
||||||
|
A [previous implementation] used a normal `u64`, which meant it required a `&mut self` receiver to update the count:
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
#[args(args = "*", kwargs = "**")]
|
||||||
|
fn __call__(&mut self, py: Python<'_>, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Py<PyAny>> {
|
||||||
|
self.count += 1;
|
||||||
|
let name = self.wraps.getattr(py, "__name__")?;
|
||||||
|
|
||||||
|
println!("{} has been called {} time(s).", name, self.count);
|
||||||
|
|
||||||
|
// After doing something, we finally forward the call to the wrapped function
|
||||||
|
let ret = self.wraps.call(py, args, kwargs)?;
|
||||||
|
|
||||||
|
// We could do something with the return value of
|
||||||
|
// the function before returning it
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The problem with this is that the `&mut self` receiver means PyO3 has to borrow it exclusively,
|
||||||
|
and hold this borrow across the`self.wraps.call(py, args, kwargs)` call. This call returns control to the user's Python code
|
||||||
|
which is free to call arbitrary things, *including* the decorated function. If that happens PyO3 is unable to create a second unique borrow and will be forced to raise an exception.
|
||||||
|
|
||||||
|
As a result, something innocent like this will raise an exception:
|
||||||
|
|
||||||
|
```py
|
||||||
|
@Counter
|
||||||
|
def say_hello():
|
||||||
|
if say_hello.count < 2:
|
||||||
|
print(f"hello from decorator")
|
||||||
|
|
||||||
|
say_hello()
|
||||||
|
# RuntimeError: Already borrowed
|
||||||
|
```
|
||||||
|
|
||||||
|
The implementation in this chapter fixes that by never borrowing exclusively; all the methods take `&self` as receivers, of which multiple may exist simultaneously. This requires a shared counter and the easiest way to do that is to use [`Cell`], so that's what is used here.
|
||||||
|
|
||||||
|
This shows the dangers of running arbitrary Python code - note that "running arbitrary Python code" can be far more subtle than the example above:
|
||||||
|
- Python's asynchronous executor may park the current thread in the middle of Python code, even in Python code that *you* control, and let other Python code run.
|
||||||
|
- Dropping arbitrary Python objects may invoke destructors defined in Python (`__del__` methods).
|
||||||
|
- Calling Python's C-api (most PyO3 apis call C-api functions internally) may raise exceptions, which may allow Python code in signal handlers to run.
|
||||||
|
|
||||||
|
This is especially important if you are writing unsafe code; Python code must never be able to cause undefined behavior. You must ensure that your Rust code is in a consistent state before doing any of the above things.
|
||||||
|
|
||||||
|
[previous implementation]: https://github.com/PyO3/pyo3/discussions/2598 "Thread Safe Decorator <Help Wanted> · Discussion #2598 · PyO3/pyo3"
|
||||||
|
[`Cell`]: https://doc.rust-lang.org/std/cell/struct.Cell.html "Cell in std::cell - Rust"
|
Loading…
Reference in a new issue