diff --git a/CHANGELOG.md b/CHANGELOG.md index c74c2a71..43aa8457 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849) - Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860) - Add `abi3-py310` feature. [#1889](https://github.com/PyO3/pyo3/pull/1889) +- Add `PyCFunction::new_closure` to create a Python function from a Rust closure. [#1901](https://github.com/PyO3/pyo3/pull/1901) ### Changed diff --git a/src/types/function.rs b/src/types/function.rs index fd315bc7..fb31b087 100644 --- a/src/types/function.rs +++ b/src/types/function.rs @@ -3,8 +3,9 @@ use crate::exceptions::PyValueError; use crate::prelude::*; use crate::{ class::methods::{self, PyMethodDef}, - ffi, AsPyPointer, + ffi, types, AsPyPointer, }; +use std::os::raw::c_void; /// Represents a builtin Python function object. #[repr(transparent)] @@ -12,6 +13,48 @@ pub struct PyCFunction(PyAny); pyobject_native_type_core!(PyCFunction, ffi::PyCFunction_Type, #checkfunction=ffi::PyCFunction_Check); +const CLOSURE_CAPSULE_NAME: &[u8] = b"pyo3-closure\0"; + +unsafe extern "C" fn run_closure( + capsule_ptr: *mut ffi::PyObject, + args: *mut ffi::PyObject, + kwargs: *mut ffi::PyObject, +) -> *mut ffi::PyObject +where + F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static, + R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>, +{ + crate::callback_body!(py, { + let boxed_fn: &F = + &*(ffi::PyCapsule_GetPointer(capsule_ptr, CLOSURE_CAPSULE_NAME.as_ptr() as *const _) + as *mut F); + let args = py.from_borrowed_ptr::(args); + let kwargs = py.from_borrowed_ptr_or_opt::(kwargs); + boxed_fn(args, kwargs) + }) +} + +unsafe extern "C" fn drop_closure(capsule_ptr: *mut ffi::PyObject) +where + F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static, + R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>, +{ + let result = std::panic::catch_unwind(|| { + let boxed_fn: Box = Box::from_raw(ffi::PyCapsule_GetPointer( + capsule_ptr, + CLOSURE_CAPSULE_NAME.as_ptr() as *const _, + ) as *mut F); + drop(boxed_fn) + }); + if let Err(err) = result { + // This second layer of catch_unwind is useful as eprintln! can also panic. + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + eprintln!("--- PyO3 intercepted a panic when dropping a closure"); + eprintln!("{:?}", err); + })); + } +} + impl PyCFunction { /// Create a new built-in function with keywords. pub fn new_with_keywords<'a>( @@ -39,23 +82,57 @@ impl PyCFunction { ) } + /// Create a new function from a closure. + /// + /// # Examples + /// + /// ``` + /// # use pyo3::prelude::*; + /// # use pyo3::{py_run, types}; + /// + /// Python::with_gil(|py| { + /// let add_one = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> { + /// let i = args.extract::<(i64,)>()?.0; + /// Ok(i+1) + /// }; + /// let add_one = types::PyCFunction::new_closure(add_one, py).unwrap(); + /// py_run!(py, add_one, "assert add_one(42) == 43"); + /// }); + /// ``` + pub fn new_closure(f: F, py: Python) -> PyResult<&PyCFunction> + where + F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static, + R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>, + { + let function_ptr = Box::into_raw(Box::new(f)); + let capsule = unsafe { + PyObject::from_owned_ptr_or_err( + py, + ffi::PyCapsule_New( + function_ptr as *mut c_void, + CLOSURE_CAPSULE_NAME.as_ptr() as *const _, + Some(drop_closure::), + ), + )? + }; + let method_def = methods::PyMethodDef::cfunction_with_keywords( + "pyo3-closure", + methods::PyCFunctionWithKeywords(run_closure::), + "", + ); + Self::internal_new_from_pointers(method_def, py, capsule.as_ptr(), std::ptr::null_mut()) + } + #[doc(hidden)] - pub fn internal_new( + fn internal_new_from_pointers( method_def: PyMethodDef, - py_or_module: PyFunctionArguments, + py: Python, + mod_ptr: *mut ffi::PyObject, + module_name: *mut ffi::PyObject, ) -> PyResult<&Self> { - let (py, module) = py_or_module.into_py_and_maybe_module(); let def = method_def .as_method_def() .map_err(|err| PyValueError::new_err(err.0))?; - let (mod_ptr, module_name) = if let Some(m) = module { - let mod_ptr = m.as_ptr(); - let name = m.name()?.into_py(py); - (mod_ptr, name.as_ptr()) - } else { - (std::ptr::null_mut(), std::ptr::null_mut()) - }; - unsafe { py.from_owned_ptr_or_err::(ffi::PyCFunction_NewEx( Box::into_raw(Box::new(def)), @@ -64,6 +141,22 @@ impl PyCFunction { )) } } + + #[doc(hidden)] + pub fn internal_new( + method_def: PyMethodDef, + py_or_module: PyFunctionArguments, + ) -> PyResult<&Self> { + let (py, module) = py_or_module.into_py_and_maybe_module(); + let (mod_ptr, module_name) = if let Some(m) = module { + let mod_ptr = m.as_ptr(); + let name = m.name()?.into_py(py); + (mod_ptr, name.as_ptr()) + } else { + (std::ptr::null_mut(), std::ptr::null_mut()) + }; + Self::internal_new_from_pointers(method_def, py, mod_ptr, module_name) + } } /// Represents a Python function object. diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index d534b006..5782cdf5 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -23,6 +23,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethods.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/invalid_argument_attributes.rs"); + t.compile_fail("tests/ui/invalid_closure.rs"); t.compile_fail("tests/ui/reject_generics.rs"); tests_rust_1_48(&t); diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index d3cecdb4..2fe675b4 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -1,7 +1,7 @@ #[cfg(not(Py_LIMITED_API))] use pyo3::buffer::PyBuffer; use pyo3::prelude::*; -use pyo3::types::PyCFunction; +use pyo3::types::{self, PyCFunction}; #[cfg(not(Py_LIMITED_API))] use pyo3::types::{PyDateTime, PyFunction}; @@ -213,3 +213,57 @@ fn test_conversion_error() { "argument 'option_arg': 'str' object cannot be interpreted as an integer" ); } + +#[test] +fn test_closure() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let f = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> { + let gil = Python::acquire_gil(); + let py = gil.python(); + let res: Vec<_> = args + .iter() + .map(|elem| { + if let Ok(i) = elem.extract::() { + (i + 1).into_py(py) + } else if let Ok(f) = elem.extract::() { + (2. * f).into_py(py) + } else if let Ok(mut s) = elem.extract::() { + s.push_str("-py"); + s.into_py(py) + } else { + panic!("unexpected argument type for {:?}", elem) + } + }) + .collect(); + Ok(res) + }; + let closure_py = PyCFunction::new_closure(f, py).unwrap(); + + py_assert!(py, closure_py, "closure_py(42) == [43]"); + py_assert!( + py, + closure_py, + "closure_py(42, 3.14, 'foo') == [43, 6.28, 'foo-py']" + ); +} + +#[test] +fn test_closure_counter() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let counter = std::cell::RefCell::new(0); + let counter_fn = + move |_args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult { + let mut counter = counter.borrow_mut(); + *counter += 1; + Ok(*counter) + }; + let counter_py = PyCFunction::new_closure(counter_fn, py).unwrap(); + + py_assert!(py, counter_py, "counter_py() == 1"); + py_assert!(py, counter_py, "counter_py() == 2"); + py_assert!(py, counter_py, "counter_py() == 3"); +} diff --git a/tests/ui/invalid_closure.rs b/tests/ui/invalid_closure.rs new file mode 100644 index 00000000..58f7148c --- /dev/null +++ b/tests/ui/invalid_closure.rs @@ -0,0 +1,19 @@ +use pyo3::prelude::*; +use pyo3::types::{PyCFunction, PyDict, PyTuple}; + +fn main() { + let fun: Py = Python::with_gil(|py| { + let local_data = vec![0, 1, 2, 3, 4]; + let ref_: &[u8] = &local_data; + + let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> { + println!("This is five: {:?}", ref_.len()); + Ok(()) + }; + PyCFunction::new_closure(closure_fn, py).unwrap().into() + }); + + Python::with_gil(|py| { + fun.call0(py).unwrap(); + }); +} diff --git a/tests/ui/invalid_closure.stderr b/tests/ui/invalid_closure.stderr new file mode 100644 index 00000000..87db62f6 --- /dev/null +++ b/tests/ui/invalid_closure.stderr @@ -0,0 +1,28 @@ +error[E0597]: `local_data` does not live long enough + --> tests/ui/invalid_closure.rs:7:27 + | +7 | let ref_: &[u8] = &local_data; + | ^^^^^^^^^^^ borrowed value does not live long enough +... +13 | PyCFunction::new_closure(closure_fn, py).unwrap().into() + | ---------------------------------------- argument requires that `local_data` is borrowed for `'static` +14 | }); + | - `local_data` dropped here while still borrowed + +error[E0373]: closure may outlive the current function, but it borrows `ref_`, which is owned by the current function + --> tests/ui/invalid_closure.rs:9:26 + | +9 | let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `ref_` +10 | println!("This is five: {:?}", ref_.len()); + | ---- `ref_` is borrowed here + | +note: function requires argument type to outlive `'static` + --> tests/ui/invalid_closure.rs:13:9 + | +13 | PyCFunction::new_closure(closure_fn, py).unwrap().into() + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +help: to force the closure to take ownership of `ref_` (and any other referenced variables), use the `move` keyword + | +9 | let closure_fn = move |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^