Merge pull request #1901 from LaurentMazare/closures

Support for wrapping rust closures as python functions
This commit is contained in:
David Hewitt 2021-10-17 09:46:26 +01:00 committed by GitHub
commit fbb5e3cd91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 13 deletions

View File

@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849) - Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849)
- Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860) - Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860)
- Add `abi3-py310` feature. [#1889](https://github.com/PyO3/pyo3/pull/1889) - Add `abi3-py310` feature. [#1889](https://github.com/PyO3/pyo3/pull/1889)
- Add `PyCFunction::new_closure` to create a Python function from a Rust closure. [#1901](https://github.com/PyO3/pyo3/pull/1901)
### Changed ### Changed

View File

@ -3,8 +3,9 @@ use crate::exceptions::PyValueError;
use crate::prelude::*; use crate::prelude::*;
use crate::{ use crate::{
class::methods::{self, PyMethodDef}, class::methods::{self, PyMethodDef},
ffi, AsPyPointer, ffi, types, AsPyPointer,
}; };
use std::os::raw::c_void;
/// Represents a builtin Python function object. /// Represents a builtin Python function object.
#[repr(transparent)] #[repr(transparent)]
@ -12,6 +13,48 @@ pub struct PyCFunction(PyAny);
pyobject_native_type_core!(PyCFunction, ffi::PyCFunction_Type, #checkfunction=ffi::PyCFunction_Check); pyobject_native_type_core!(PyCFunction, ffi::PyCFunction_Type, #checkfunction=ffi::PyCFunction_Check);
const CLOSURE_CAPSULE_NAME: &[u8] = b"pyo3-closure\0";
unsafe extern "C" fn run_closure<F, R>(
capsule_ptr: *mut ffi::PyObject,
args: *mut ffi::PyObject,
kwargs: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
crate::callback_body!(py, {
let boxed_fn: &F =
&*(ffi::PyCapsule_GetPointer(capsule_ptr, CLOSURE_CAPSULE_NAME.as_ptr() as *const _)
as *mut F);
let args = py.from_borrowed_ptr::<types::PyTuple>(args);
let kwargs = py.from_borrowed_ptr_or_opt::<types::PyDict>(kwargs);
boxed_fn(args, kwargs)
})
}
unsafe extern "C" fn drop_closure<F, R>(capsule_ptr: *mut ffi::PyObject)
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let result = std::panic::catch_unwind(|| {
let boxed_fn: Box<F> = Box::from_raw(ffi::PyCapsule_GetPointer(
capsule_ptr,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
) as *mut F);
drop(boxed_fn)
});
if let Err(err) = result {
// This second layer of catch_unwind is useful as eprintln! can also panic.
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
eprintln!("--- PyO3 intercepted a panic when dropping a closure");
eprintln!("{:?}", err);
}));
}
}
impl PyCFunction { impl PyCFunction {
/// Create a new built-in function with keywords. /// Create a new built-in function with keywords.
pub fn new_with_keywords<'a>( pub fn new_with_keywords<'a>(
@ -39,23 +82,57 @@ impl PyCFunction {
) )
} }
/// Create a new function from a closure.
///
/// # Examples
///
/// ```
/// # use pyo3::prelude::*;
/// # use pyo3::{py_run, types};
///
/// Python::with_gil(|py| {
/// let add_one = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
/// let i = args.extract::<(i64,)>()?.0;
/// Ok(i+1)
/// };
/// let add_one = types::PyCFunction::new_closure(add_one, py).unwrap();
/// py_run!(py, add_one, "assert add_one(42) == 43");
/// });
/// ```
pub fn new_closure<F, R>(f: F, py: Python) -> PyResult<&PyCFunction>
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let function_ptr = Box::into_raw(Box::new(f));
let capsule = unsafe {
PyObject::from_owned_ptr_or_err(
py,
ffi::PyCapsule_New(
function_ptr as *mut c_void,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
Some(drop_closure::<F, R>),
),
)?
};
let method_def = methods::PyMethodDef::cfunction_with_keywords(
"pyo3-closure",
methods::PyCFunctionWithKeywords(run_closure::<F, R>),
"",
);
Self::internal_new_from_pointers(method_def, py, capsule.as_ptr(), std::ptr::null_mut())
}
#[doc(hidden)] #[doc(hidden)]
pub fn internal_new( fn internal_new_from_pointers(
method_def: PyMethodDef, method_def: PyMethodDef,
py_or_module: PyFunctionArguments, py: Python,
mod_ptr: *mut ffi::PyObject,
module_name: *mut ffi::PyObject,
) -> PyResult<&Self> { ) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let def = method_def let def = method_def
.as_method_def() .as_method_def()
.map_err(|err| PyValueError::new_err(err.0))?; .map_err(|err| PyValueError::new_err(err.0))?;
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
unsafe { unsafe {
py.from_owned_ptr_or_err::<PyCFunction>(ffi::PyCFunction_NewEx( py.from_owned_ptr_or_err::<PyCFunction>(ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(def)), Box::into_raw(Box::new(def)),
@ -64,6 +141,22 @@ impl PyCFunction {
)) ))
} }
} }
#[doc(hidden)]
pub fn internal_new(
method_def: PyMethodDef,
py_or_module: PyFunctionArguments,
) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
Self::internal_new_from_pointers(method_def, py, mod_ptr, module_name)
}
} }
/// Represents a Python function object. /// Represents a Python function object.

View File

@ -23,6 +23,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_pymethods.rs"); t.compile_fail("tests/ui/invalid_pymethods.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_argument_attributes.rs"); t.compile_fail("tests/ui/invalid_argument_attributes.rs");
t.compile_fail("tests/ui/invalid_closure.rs");
t.compile_fail("tests/ui/reject_generics.rs"); t.compile_fail("tests/ui/reject_generics.rs");
tests_rust_1_48(&t); tests_rust_1_48(&t);

View File

@ -1,7 +1,7 @@
#[cfg(not(Py_LIMITED_API))] #[cfg(not(Py_LIMITED_API))]
use pyo3::buffer::PyBuffer; use pyo3::buffer::PyBuffer;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyCFunction; use pyo3::types::{self, PyCFunction};
#[cfg(not(Py_LIMITED_API))] #[cfg(not(Py_LIMITED_API))]
use pyo3::types::{PyDateTime, PyFunction}; use pyo3::types::{PyDateTime, PyFunction};
@ -213,3 +213,57 @@ fn test_conversion_error() {
"argument 'option_arg': 'str' object cannot be interpreted as an integer" "argument 'option_arg': 'str' object cannot be interpreted as an integer"
); );
} }
#[test]
fn test_closure() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
let gil = Python::acquire_gil();
let py = gil.python();
let res: Vec<_> = args
.iter()
.map(|elem| {
if let Ok(i) = elem.extract::<i64>() {
(i + 1).into_py(py)
} else if let Ok(f) = elem.extract::<f64>() {
(2. * f).into_py(py)
} else if let Ok(mut s) = elem.extract::<String>() {
s.push_str("-py");
s.into_py(py)
} else {
panic!("unexpected argument type for {:?}", elem)
}
})
.collect();
Ok(res)
};
let closure_py = PyCFunction::new_closure(f, py).unwrap();
py_assert!(py, closure_py, "closure_py(42) == [43]");
py_assert!(
py,
closure_py,
"closure_py(42, 3.14, 'foo') == [43, 6.28, 'foo-py']"
);
}
#[test]
fn test_closure_counter() {
let gil = Python::acquire_gil();
let py = gil.python();
let counter = std::cell::RefCell::new(0);
let counter_fn =
move |_args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<i32> {
let mut counter = counter.borrow_mut();
*counter += 1;
Ok(*counter)
};
let counter_py = PyCFunction::new_closure(counter_fn, py).unwrap();
py_assert!(py, counter_py, "counter_py() == 1");
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}

View File

@ -0,0 +1,19 @@
use pyo3::prelude::*;
use pyo3::types::{PyCFunction, PyDict, PyTuple};
fn main() {
let fun: Py<PyCFunction> = Python::with_gil(|py| {
let local_data = vec![0, 1, 2, 3, 4];
let ref_: &[u8] = &local_data;
let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
println!("This is five: {:?}", ref_.len());
Ok(())
};
PyCFunction::new_closure(closure_fn, py).unwrap().into()
});
Python::with_gil(|py| {
fun.call0(py).unwrap();
});
}

View File

@ -0,0 +1,28 @@
error[E0597]: `local_data` does not live long enough
--> tests/ui/invalid_closure.rs:7:27
|
7 | let ref_: &[u8] = &local_data;
| ^^^^^^^^^^^ borrowed value does not live long enough
...
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ---------------------------------------- argument requires that `local_data` is borrowed for `'static`
14 | });
| - `local_data` dropped here while still borrowed
error[E0373]: closure may outlive the current function, but it borrows `ref_`, which is owned by the current function
--> tests/ui/invalid_closure.rs:9:26
|
9 | let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `ref_`
10 | println!("This is five: {:?}", ref_.len());
| ---- `ref_` is borrowed here
|
note: function requires argument type to outlive `'static`
--> tests/ui/invalid_closure.rs:13:9
|
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
help: to force the closure to take ownership of `ref_` (and any other referenced variables), use the `move` keyword
|
9 | let closure_fn = move |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^