diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d77293..ef99487e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add macro attribute to `#[pyfn]` and `#[pyfunction]` to pass the module of a Python function to the function body. [#1143](https://github.com/PyO3/pyo3/pull/1143) - Add `add_function()` and `add_submodule()` functions to `PyModule` [#1143](https://github.com/PyO3/pyo3/pull/1143) +- Add native `PyCFunction` and `PyFunction` types, change `add_function` to take a wrapper returning + a `&PyCFunction`instead of `PyObject`. [#1163](https://github.com/PyO3/pyo3/pull/1163) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) diff --git a/README.md b/README.md index ff041bf6..2e57a358 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// A Python module implemented in Rust. #[pymodule] fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string))?; + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; Ok(()) } diff --git a/examples/rustapi_module/src/datetime.rs b/examples/rustapi_module/src/datetime.rs index 3ccb7c69..062dc958 100644 --- a/examples/rustapi_module/src/datetime.rs +++ b/examples/rustapi_module/src/datetime.rs @@ -215,29 +215,29 @@ impl TzClass { #[pymodule] fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(make_date))?; - m.add_function(wrap_pyfunction!(get_date_tuple))?; - m.add_function(wrap_pyfunction!(date_from_timestamp))?; - m.add_function(wrap_pyfunction!(make_time))?; - m.add_function(wrap_pyfunction!(get_time_tuple))?; - m.add_function(wrap_pyfunction!(make_delta))?; - m.add_function(wrap_pyfunction!(get_delta_tuple))?; - m.add_function(wrap_pyfunction!(make_datetime))?; - m.add_function(wrap_pyfunction!(get_datetime_tuple))?; - m.add_function(wrap_pyfunction!(datetime_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_date, m)?)?; + m.add_function(wrap_pyfunction!(get_date_tuple, m)?)?; + m.add_function(wrap_pyfunction!(date_from_timestamp, m)?)?; + m.add_function(wrap_pyfunction!(make_time, m)?)?; + m.add_function(wrap_pyfunction!(get_time_tuple, m)?)?; + m.add_function(wrap_pyfunction!(make_delta, m)?)?; + m.add_function(wrap_pyfunction!(get_delta_tuple, m)?)?; + m.add_function(wrap_pyfunction!(make_datetime, m)?)?; + m.add_function(wrap_pyfunction!(get_datetime_tuple, m)?)?; + m.add_function(wrap_pyfunction!(datetime_from_timestamp, m)?)?; // Python 3.6+ functions #[cfg(Py_3_6)] { - m.add_function(wrap_pyfunction!(time_with_fold))?; + m.add_function(wrap_pyfunction!(time_with_fold, m)?)?; #[cfg(not(PyPy))] { - m.add_function(wrap_pyfunction!(get_time_tuple_fold))?; - m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_time_tuple_fold, m)?)?; + m.add_function(wrap_pyfunction!(get_datetime_tuple_fold, m)?)?; } } - m.add_function(wrap_pyfunction!(issue_219))?; + m.add_function(wrap_pyfunction!(issue_219, m)?)?; m.add_class::()?; Ok(()) diff --git a/examples/rustapi_module/src/othermod.rs b/examples/rustapi_module/src/othermod.rs index b9955806..5c149d29 100644 --- a/examples/rustapi_module/src/othermod.rs +++ b/examples/rustapi_module/src/othermod.rs @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 { #[pymodule] fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(double))?; + m.add_function(wrap_pyfunction!(double, m)?)?; m.add_class::()?; diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index 50a00780..78213102 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -56,8 +56,8 @@ fn count_line(line: &str, needle: &str) -> usize { #[pymodule] fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(search))?; - m.add_function(wrap_pyfunction!(search_sequential))?; - m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?; + m.add_function(wrap_pyfunction!(search_sequential, m)?)?; + m.add_function(wrap_pyfunction!(search_sequential_allow_threads, m)?)?; Ok(()) } diff --git a/guide/src/function.md b/guide/src/function.md index b33221c9..0e0277f3 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -36,7 +36,7 @@ fn double(x: usize) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(double)).unwrap(); + m.add_function(wrap_pyfunction!(double, m)?).unwrap(); Ok(()) } @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(num_kwds)).unwrap(); + m.add_function(wrap_pyfunction!(num_kwds, m)?).unwrap(); Ok(()) } @@ -206,7 +206,7 @@ fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { #[pymodule] fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(pyfunction_with_module)) + m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?) } # fn main() {} diff --git a/guide/src/module.md b/guide/src/module.md index 6b1d4581..aabba50b 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -73,7 +73,7 @@ fn subfunction() -> String { } fn init_submodule(module: &PyModule) -> PyResult<()> { - module.add_function(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction, module)?)?; Ok(()) } diff --git a/guide/src/trait_bounds.md b/guide/src/trait_bounds.md index 65e173cd..f845b5c8 100644 --- a/guide/src/trait_bounds.md +++ b/guide/src/trait_bounds.md @@ -488,7 +488,7 @@ pub struct UserModel { #[pymodule] fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_function(wrap_pyfunction!(solve_wrapper))?; + m.add_function(wrap_pyfunction!(solve_wrapper, m)?)?; Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index a706100e..657d1134 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -44,7 +44,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { #function_to_python - #module_name.add_function(&#function_wrapper_ident)?; + #module_name.add_function(#function_wrapper_ident(#module_name)?)?; } }; stmts.extend(item.block.stmts.into_iter()); @@ -204,50 +204,26 @@ pub fn add_fn_to_module( let python_name = &spec.python_name; - let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.pass_module); - + let name = &func.sig.ident; + let wrapper_ident = format_ident!("__pyo3_raw_{}", name); + let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, pyfn_attrs.pass_module); Ok(quote! { + #wrapper fn #function_wrapper_ident<'a>( - args: impl Into> - ) -> pyo3::PyResult { - let arg = args.into(); - let (py, maybe_module) = arg.into_py_and_maybe_module(); - #wrapper - - let _def = pyo3::class::PyMethodDef { - ml_name: stringify!(#python_name), - ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS, - ml_doc: #doc, - }; - - let (mod_ptr, name) = if let Some(m) = maybe_module { - let mod_ptr = ::as_ptr(m); - let name = m.name()?; - let name = <&str as pyo3::conversion::IntoPy>::into_py(name, py); - (mod_ptr, ::as_ptr(&name)) - } else { - (std::ptr::null_mut(), std::ptr::null_mut()) - }; - - let function = unsafe { - pyo3::PyObject::from_owned_ptr( - py, - pyo3::ffi::PyCFunction_NewEx( - Box::into_raw(Box::new(_def.as_method_def())), - mod_ptr, - name - ) - ) - }; - - Ok(function) + args: impl Into> + ) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> { + pyo3::types::PyCFunction::new_with_keywords(#wrapper_ident, stringify!(#python_name), #doc, args.into()) } }) } /// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords) -fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool) -> TokenStream { +fn function_c_wrapper( + name: &Ident, + wrapper_ident: &Ident, + spec: &method::FnSpec<'_>, + pass_module: bool, +) -> TokenStream { let names: Vec = get_arg_names(&spec); let cb; let slf_module; @@ -265,9 +241,8 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool slf_module = None; }; let body = pymethod::impl_arg_params(spec, None, cb); - quote! { - unsafe extern "C" fn __wrap( + unsafe extern "C" fn #wrapper_ident( _slf: *mut pyo3::ffi::PyObject, _args: *mut pyo3::ffi::PyObject, _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject diff --git a/src/derive_utils.rs b/src/derive_utils.rs index cef90fd1..819599e9 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -209,17 +209,16 @@ where } /// Enum to abstract over the arguments of Python function wrappers. -#[doc(hidden)] -pub enum WrapPyFunctionArguments<'a> { +pub enum PyFunctionArguments<'a> { Python(Python<'a>), PyModule(&'a PyModule), } -impl<'a> WrapPyFunctionArguments<'a> { +impl<'a> PyFunctionArguments<'a> { pub fn into_py_and_maybe_module(self) -> (Python<'a>, Option<&'a PyModule>) { match self { - WrapPyFunctionArguments::Python(py) => (py, None), - WrapPyFunctionArguments::PyModule(module) => { + PyFunctionArguments::Python(py) => (py, None), + PyFunctionArguments::PyModule(module) => { let py = module.py(); (py, Some(module)) } @@ -227,14 +226,14 @@ impl<'a> WrapPyFunctionArguments<'a> { } } -impl<'a> From> for WrapPyFunctionArguments<'a> { - fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> { - WrapPyFunctionArguments::Python(py) +impl<'a> From> for PyFunctionArguments<'a> { + fn from(py: Python<'a>) -> PyFunctionArguments<'a> { + PyFunctionArguments::Python(py) } } -impl<'a> From<&'a PyModule> for WrapPyFunctionArguments<'a> { - fn from(module: &'a PyModule) -> WrapPyFunctionArguments<'a> { - WrapPyFunctionArguments::PyModule(module) +impl<'a> From<&'a PyModule> for PyFunctionArguments<'a> { + fn from(module: &'a PyModule) -> PyFunctionArguments<'a> { + PyFunctionArguments::PyModule(module) } } diff --git a/src/ffi/funcobject.rs b/src/ffi/funcobject.rs new file mode 100644 index 00000000..10b15345 --- /dev/null +++ b/src/ffi/funcobject.rs @@ -0,0 +1,34 @@ +use std::os::raw::c_int; + +use crate::ffi::object::{PyObject, PyTypeObject, Py_TYPE}; + +#[cfg_attr(windows, link(name = "pythonXY"))] +extern "C" { + #[cfg_attr(PyPy, link_name = "PyPyFunction_Type")] + pub static mut PyFunction_Type: PyTypeObject; +} + +#[inline] +pub unsafe fn PyFunction_Check(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == &mut PyFunction_Type) as c_int +} + +extern "C" { + pub fn PyFunction_NewWithQualName( + code: *mut PyObject, + globals: *mut PyObject, + qualname: *mut PyObject, + ) -> *mut PyObject; + pub fn PyFunction_New(code: *mut PyObject, globals: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_Code(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_GetGlobals(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_GetModule(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_GetDefaults(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_SetDefaults(op: *mut PyObject, defaults: *mut PyObject) -> c_int; + pub fn PyFunction_GetKwDefaults(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_SetKwDefaults(op: *mut PyObject, defaults: *mut PyObject) -> c_int; + pub fn PyFunction_GetClosure(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_SetClosure(op: *mut PyObject, closure: *mut PyObject) -> c_int; + pub fn PyFunction_GetAnnotations(op: *mut PyObject) -> *mut PyObject; + pub fn PyFunction_SetAnnotations(op: *mut PyObject, annotations: *mut PyObject) -> c_int; +} diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 3c5e9e94..4015b2cd 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -19,6 +19,7 @@ pub use self::eval::*; pub use self::fileobject::*; pub use self::floatobject::*; pub use self::frameobject::PyFrameObject; +pub use self::funcobject::*; pub use self::genobject::*; pub use self::import::*; pub use self::intrcheck::*; @@ -157,3 +158,5 @@ pub mod frameobject { pub(crate) mod datetime; pub(crate) mod marshal; + +pub(crate) mod funcobject; diff --git a/src/lib.rs b/src/lib.rs index 4c2313e3..2935c390 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ //! #[pymodule] //! /// A Python module implemented in Rust. //! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { -//! m.add_function(wrap_pyfunction!(sum_as_string))?; +//! m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; //! //! Ok(()) //! } @@ -218,6 +218,20 @@ macro_rules! wrap_pyfunction { ($function_name: ident) => {{ &pyo3::paste::expr! { [<__pyo3_get_function_ $function_name>] } }}; + + ($function_name: ident, $arg: expr) => { + pyo3::wrap_pyfunction!($function_name)(pyo3::derive_utils::PyFunctionArguments::from($arg)) + }; +} + +/// Returns the function that is called in the C-FFI. +/// +/// Use this together with `#[pyfunction]` and [types::PyCFunction]. +#[macro_export] +macro_rules! raw_pycfunction { + ($function_name: ident) => {{ + pyo3::paste::expr! { [<__pyo3_raw_ $function_name>] } + }}; } /// Returns a function that takes a [Python] instance and returns a Python module. diff --git a/src/python.rs b/src/python.rs index 9b71ebf0..403737fe 100644 --- a/src/python.rs +++ b/src/python.rs @@ -134,7 +134,7 @@ impl<'p> Python<'p> { /// let gil = Python::acquire_gil(); /// let py = gil.python(); /// let m = PyModule::new(py, "pcount").unwrap(); - /// m.add_function(wrap_pyfunction!(parallel_count)).unwrap(); + /// m.add_function(wrap_pyfunction!(parallel_count, m).unwrap()).unwrap(); /// let locals = [("pcount", m)].into_py_dict(py); /// py.run(r#" /// s = ["Flow", "my", "tears", "the", "Policeman", "Said"] diff --git a/src/types/function.rs b/src/types/function.rs new file mode 100644 index 00000000..73aa04b7 --- /dev/null +++ b/src/types/function.rs @@ -0,0 +1,89 @@ +use std::ffi::{CStr, CString}; + +use crate::derive_utils::PyFunctionArguments; +use crate::exceptions::PyValueError; +use crate::prelude::*; +use crate::{class, ffi, AsPyPointer, PyMethodType}; + +/// Represents a builtin Python function object. +#[repr(transparent)] +pub struct PyCFunction(PyAny); + +pyobject_native_var_type!(PyCFunction, ffi::PyCFunction_Type, ffi::PyCFunction_Check); + +impl PyCFunction { + /// Create a new built-in function with keywords. + pub fn new_with_keywords<'a>( + fun: ffi::PyCFunctionWithKeywords, + name: &str, + doc: &'static str, + py_or_module: PyFunctionArguments<'a>, + ) -> PyResult<&'a PyCFunction> { + let fun = PyMethodType::PyCFunctionWithKeywords(fun); + Self::new_(fun, name, doc, py_or_module) + } + + /// Create a new built-in function without keywords. + pub fn new<'a>( + fun: ffi::PyCFunction, + name: &str, + doc: &'static str, + py_or_module: PyFunctionArguments<'a>, + ) -> PyResult<&'a PyCFunction> { + let fun = PyMethodType::PyCFunction(fun); + Self::new_(fun, name, doc, py_or_module) + } + + fn new_<'a>( + fun: class::PyMethodType, + name: &str, + doc: &'static str, + py_or_module: PyFunctionArguments<'a>, + ) -> PyResult<&'a PyCFunction> { + let (py, module) = py_or_module.into_py_and_maybe_module(); + let doc: &'static CStr = CStr::from_bytes_with_nul(doc.as_bytes()) + .map_err(|_| PyValueError::py_err("docstring must end with NULL byte."))?; + let name = CString::new(name.as_bytes()) + .map_err(|_| PyValueError::py_err("Function name cannot contain contain NULL byte."))?; + let def = match fun { + PyMethodType::PyCFunction(fun) => ffi::PyMethodDef { + ml_name: name.into_raw() as _, + ml_meth: Some(fun), + ml_flags: ffi::METH_VARARGS, + ml_doc: doc.as_ptr() as _, + }, + PyMethodType::PyCFunctionWithKeywords(fun) => ffi::PyMethodDef { + ml_name: name.into_raw() as _, + ml_meth: Some(unsafe { std::mem::transmute(fun) }), + ml_flags: ffi::METH_VARARGS | ffi::METH_KEYWORDS, + ml_doc: doc.as_ptr() as _, + }, + _ => { + return Err(PyValueError::py_err( + "Only PyCFunction and PyCFunctionWithKeywords are valid.", + )) + } + }; + let (mod_ptr, module_name) = if let Some(m) = module { + let mod_ptr = m.as_ptr(); + let name = m.name()?.into_py(py); + (mod_ptr, name.as_ptr()) + } else { + (std::ptr::null_mut(), std::ptr::null_mut()) + }; + + unsafe { + py.from_owned_ptr_or_err::(ffi::PyCFunction_NewEx( + Box::into_raw(Box::new(def)), + mod_ptr, + module_name, + )) + } + } +} + +/// Represents a Python function object. +#[repr(transparent)] +pub struct PyFunction(PyAny); + +pyobject_native_var_type!(PyFunction, ffi::PyFunction_Type, ffi::PyFunction_Check); diff --git a/src/types/mod.rs b/src/types/mod.rs index abdcdd00..b3915a8f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -13,6 +13,7 @@ pub use self::datetime::{ }; pub use self::dict::{IntoPyDict, PyDict}; pub use self::floatob::PyFloat; +pub use self::function::{PyCFunction, PyFunction}; pub use self::iterator::PyIterator; pub use self::list::PyList; pub use self::module::PyModule; @@ -226,6 +227,7 @@ mod complex; mod datetime; mod dict; mod floatob; +mod function; mod iterator; mod list; mod module; diff --git a/src/types/module.rs b/src/types/module.rs index b77a6fa4..baa343f6 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -9,8 +9,8 @@ use crate::ffi; use crate::instance::PyNativeType; use crate::pyclass::PyClass; use crate::type_object::PyTypeObject; -use crate::types::PyTuple; use crate::types::{PyAny, PyDict, PyList}; +use crate::types::{PyCFunction, PyTuple}; use crate::{AsPyPointer, IntoPy, Py, PyObject, Python}; use std::ffi::{CStr, CString}; use std::os::raw::c_char; @@ -258,7 +258,7 @@ impl PyModule { /// } /// #[pymodule] /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { - /// module.add_function(pyo3::wrap_pyfunction!(double)) + /// module.add_function(pyo3::wrap_pyfunction!(double, module)?) /// } /// ``` /// @@ -272,17 +272,11 @@ impl PyModule { /// } /// #[pymodule] /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { - /// module.add("also_double", pyo3::wrap_pyfunction!(double)(module)?) + /// module.add("also_double", pyo3::wrap_pyfunction!(double, module)?) /// } /// ``` - pub fn add_function<'a>( - &'a self, - wrapper: &impl Fn(&'a Self) -> PyResult, - ) -> PyResult<()> { - let py = self.py(); - let function = wrapper(self)?; - let name = function.getattr(py, "__name__")?; - let name = name.extract(py)?; - self.add(name, function) + pub fn add_function<'a>(&'a self, fun: &'a PyCFunction) -> PyResult<()> { + let name = fun.getattr("__name__")?.extract()?; + self.add(name, fun) } } diff --git a/tests/test_module.rs b/tests/test_module.rs index 7c278bdc..5ed75bf4 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -65,8 +65,8 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar").unwrap(); - m.add_function(wrap_pyfunction!(double)).unwrap(); - m.add("also_double", wrap_pyfunction!(double)(m)?).unwrap(); + m.add_function(wrap_pyfunction!(double, m)?).unwrap(); + m.add("also_double", wrap_pyfunction!(double, m)?).unwrap(); Ok(()) } @@ -163,7 +163,7 @@ fn r#move() -> usize { fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_function(wrap_pyfunction!(r#move)) + module.add_function(wrap_pyfunction!(r#move, module)?) } #[test] @@ -188,7 +188,7 @@ fn custom_named_fn() -> usize { fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - m.add_function(wrap_pyfunction!(custom_named_fn))?; + m.add_function(wrap_pyfunction!(custom_named_fn, m)?)?; m.dict().set_item("yay", "me")?; Ok(()) } @@ -221,7 +221,7 @@ fn subfunction() -> String { fn submodule(module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_function(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction, module)?)?; Ok(()) } @@ -229,7 +229,7 @@ fn submodule(module: &PyModule) -> PyResult<()> { fn submodule_with_init_fn(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_function(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction, module)?)?; Ok(()) } @@ -242,7 +242,7 @@ fn superfunction() -> String { fn supermodule(py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_function(wrap_pyfunction!(superfunction))?; + module.add_function(wrap_pyfunction!(superfunction, module)?)?; let module_to_add = PyModule::new(py, "submodule")?; submodule(module_to_add)?; module.add_submodule(module_to_add)?; @@ -291,7 +291,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> { ext_vararg_fn(py, a, vararg) } - m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn)) + m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn, m)?) .unwrap(); Ok(()) } @@ -368,15 +368,17 @@ fn pyfunction_with_module_and_args_kwargs<'a>( #[pymodule] fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module))?; - m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py))?; - m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg))?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?; m.add_function(pyo3::wrap_pyfunction!( - pyfunction_with_module_and_default_arg - ))?; + pyfunction_with_module_and_default_arg, + m + )?)?; m.add_function(pyo3::wrap_pyfunction!( - pyfunction_with_module_and_args_kwargs - )) + pyfunction_with_module_and_args_kwargs, + m + )?) } #[test] diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 0d8500a3..89b4dbd0 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -1,6 +1,7 @@ use pyo3::buffer::PyBuffer; use pyo3::prelude::*; -use pyo3::wrap_pyfunction; +use pyo3::types::{PyCFunction, PyFunction}; +use pyo3::{raw_pycfunction, wrap_pyfunction}; mod common; @@ -62,3 +63,59 @@ assert a, array.array("i", [2, 4, 6, 8]) "# ); } + +#[pyfunction] +fn function_with_pyfunction_arg(fun: &PyFunction) -> PyResult<&PyAny> { + fun.call((), None) +} + +#[pyfunction] +fn function_with_pycfunction_arg(fun: &PyCFunction) -> PyResult<&PyAny> { + fun.call((), None) +} + +#[test] +fn test_functions_with_function_args() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let py_func_arg = wrap_pyfunction!(function_with_pyfunction_arg)(py).unwrap(); + let py_cfunc_arg = wrap_pyfunction!(function_with_pycfunction_arg)(py).unwrap(); + let bool_to_string = wrap_pyfunction!(optional_bool)(py).unwrap(); + + pyo3::py_run!( + py, + py_func_arg + py_cfunc_arg + bool_to_string, + r#" + def foo(): return "bar" + assert py_func_arg(foo) == "bar" + assert py_cfunc_arg(bool_to_string) == "Some(true)" + "# + ) +} + +#[test] +fn test_raw_function() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let raw_func = raw_pycfunction!(optional_bool); + let fun = PyCFunction::new_with_keywords(raw_func, "fun", "\0", py.into()).unwrap(); + let res = fun.call((), None).unwrap().extract::<&str>().unwrap(); + assert_eq!(res, "Some(true)"); + let res = fun.call((false,), None).unwrap().extract::<&str>().unwrap(); + assert_eq!(res, "Some(false)"); + let no_module = fun.getattr("__module__").unwrap().is_none(); + assert!(no_module); + + let module = PyModule::new(py, "cool_module").unwrap(); + module.add_function(fun).unwrap(); + let res = module + .getattr("fun") + .unwrap() + .call((), None) + .unwrap() + .extract::<&str>() + .unwrap(); + assert_eq!(res, "Some(true)"); +}