Make python function wrapper creation fallible.

Wrapping a function can fail if we can't get the module name.

Based on suggestion by @kngwyu
This commit is contained in:
Sebastian Pütz 2020-09-03 15:48:32 +02:00
parent 1f017b66fb
commit 3214249010
13 changed files with 34 additions and 31 deletions

View File

@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
/// A Python module implemented in Rust.
#[pymodule]
fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
m.add_function(wrap_pyfunction!(sum_as_string))?;
Ok(())
}

View File

@ -55,7 +55,7 @@ fn count_line(line: &str, needle: &str) -> usize {
#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

View File

@ -35,7 +35,7 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// A good place to install the Rust -> Python logger.
pyo3_log::init();
m.add_wrapped(wrap_pyfunction!(log_something))?;
m.add_function(wrap_pyfunction!(log_something))?;
Ok(())
}
```

View File

@ -488,7 +488,7 @@ pub struct UserModel {
#[pymodule]
fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<UserModel>()?;
m.add_wrapped(wrap_pyfunction!(solve_wrapper))?;
m.add_function(wrap_pyfunction!(solve_wrapper))?;
Ok(())
}

View File

@ -192,7 +192,7 @@ pub fn add_fn_to_module(
Ok(quote! {
fn #function_wrapper_ident<'a>(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyObject {
) -> pyo3::PyResult<pyo3::PyObject> {
let arg = args.into();
let (py, maybe_module) = arg.into_py_and_maybe_module();
#wrapper
@ -206,12 +206,8 @@ pub fn add_fn_to_module(
let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = match m.name() {
Ok(name) => <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py),
Err(err) => {
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
}
};
let name = m.name()?;
let name = <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py);
(mod_ptr, <PyObject as pyo3::AsPyPointer>::as_ptr(&name))
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
@ -228,7 +224,7 @@ pub fn add_fn_to_module(
)
};
function
Ok(function)
}
})
}

View File

@ -2,6 +2,7 @@
//
// based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython
use crate::callback::IntoPyCallbackOutput;
use crate::err::{PyErr, PyResult};
use crate::exceptions;
use crate::ffi;
@ -197,8 +198,11 @@ impl PyModule {
///
/// **This function will be deprecated in the next release. Please use the specific
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
pub fn add_wrapped<'a, T>(&'a self, wrapper: &impl Fn(Python<'a>) -> T) -> PyResult<()>
where
T: IntoPyCallbackOutput<PyObject>,
{
let function = wrapper(self.py()).convert(self.py())?;
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}
@ -211,9 +215,9 @@ impl PyModule {
/// m.add_module(wrap_pymodule!(utils));
/// ```
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
let module = wrapper(self.py());
let name = module.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, module)
}
/// Adds a function to a module, using the functions __name__ as name.
@ -229,8 +233,11 @@ impl PyModule {
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
pub fn add_function<'a>(
&'a self,
wrapper: &impl Fn(&'a Self) -> PyResult<PyObject>,
) -> PyResult<()> {
let function = wrapper(self)?;
let name = function.getattr(self.py(), "__name__")?;
self.add(name.extract(self.py())?, function)
}

View File

@ -14,7 +14,7 @@ fn test_pybytes_bytes_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(bytes_pybytes_conversion)(py);
let f = wrap_pyfunction!(bytes_pybytes_conversion)(py).unwrap();
py_assert!(py, f, "f(b'Hello World') == b'Hello World'");
}
@ -28,7 +28,7 @@ fn test_pybytes_vec_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(bytes_vec_conversion)(py);
let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap();
py_assert!(py, f, "f(b'Hello World') == b'Hello World'");
}
@ -37,6 +37,6 @@ fn test_bytearray_vec_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(bytes_vec_conversion)(py);
let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap();
py_assert!(py, f, "f(bytearray(b'Hello World')) == b'Hello World'");
}

View File

@ -19,7 +19,7 @@ fn fail_to_open_file() -> PyResult<()> {
fn test_filenotfounderror() {
let gil = Python::acquire_gil();
let py = gil.python();
let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py);
let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py).unwrap();
py_run!(
py,
@ -64,7 +64,7 @@ fn call_fail_with_custom_error() -> PyResult<()> {
fn test_custom_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py);
let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py).unwrap();
py_run!(
py,

View File

@ -61,7 +61,7 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
m.add("foo", "bar").unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)?).unwrap();
Ok(())
}

View File

@ -14,7 +14,7 @@ fn test_optional_bool() {
// Regression test for issue #932
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(optional_bool)(py);
let f = wrap_pyfunction!(optional_bool)(py).unwrap();
py_assert!(py, f, "f() == 'Some(true)'");
py_assert!(py, f, "f(True) == 'Some(true)'");
@ -36,7 +36,7 @@ fn buffer_inplace_add(py: Python, x: PyBuffer<i32>, y: PyBuffer<i32>) {
fn test_buffer_add() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(buffer_inplace_add)(py);
let f = wrap_pyfunction!(buffer_inplace_add)(py).unwrap();
py_expect_exception!(
py,

View File

@ -14,7 +14,7 @@ fn test_unicode_encode_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let take_str = wrap_pyfunction!(take_str)(py);
let take_str = wrap_pyfunction!(take_str)(py).unwrap();
py_run!(
py,
take_str,

View File

@ -104,7 +104,7 @@ fn test_function() {
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(my_function)(py);
let f = wrap_pyfunction!(my_function)(py).unwrap();
py_assert!(py, f, "f.__text_signature__ == '(a, b=None, *, c=42)'");
}

View File

@ -59,7 +59,7 @@ fn return_custom_class() {
assert_eq!(get_zero().unwrap().value, 0);
// Using from python
let get_zero = wrap_pyfunction!(get_zero)(py);
let get_zero = wrap_pyfunction!(get_zero)(py).unwrap();
py_assert!(py, get_zero, "get_zero().value == 0");
}
@ -206,5 +206,5 @@ fn result_conversion_function() -> Result<(), MyError> {
fn test_result_conversion() {
let gil = Python::acquire_gil();
let py = gil.python();
wrap_pyfunction!(result_conversion_function)(py);
wrap_pyfunction!(result_conversion_function)(py).unwrap();
}