Add constructor for PyCFunction.

This commit is contained in:
Sebastian Pütz 2020-09-08 14:23:24 +02:00
parent 2e8010b5df
commit be877d133f
4 changed files with 129 additions and 36 deletions

View file

@ -204,49 +204,28 @@ pub fn add_fn_to_module(
let python_name = &spec.python_name;
let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.pass_module);
let name = &func.sig.ident;
let wrapper_ident = format_ident!("__pyo3_raw_{}", name);
let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, pyfn_attrs.pass_module);
Ok(quote! {
#wrapper
fn #function_wrapper_ident<'a>(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> {
let arg = args.into();
let (py, maybe_module) = arg.into_py_and_maybe_module();
#wrapper
let _def = pyo3::class::PyMethodDef {
ml_name: stringify!(#python_name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
ml_doc: #doc,
};
let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = m.name()?;
let name = <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py);
(mod_ptr, <PyObject as pyo3::AsPyPointer>::as_ptr(&name))
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
let function = unsafe {
py.from_owned_ptr::<pyo3::types::PyCFunction>(
pyo3::ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(_def.as_method_def())),
mod_ptr,
name
)
)
};
Ok(function)
pyo3::types::PyCFunction::new_with_keywords(#wrapper_ident, stringify!(#python_name), #doc, maybe_module, py)
}
})
}
/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool) -> TokenStream {
fn function_c_wrapper(
name: &Ident,
wrapper_ident: &Ident,
spec: &method::FnSpec<'_>,
pass_module: bool,
) -> TokenStream {
let names: Vec<Ident> = get_arg_names(&spec);
let cb;
let slf_module;
@ -264,9 +243,8 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool
slf_module = None;
};
let body = pymethod::impl_arg_params(spec, None, cb);
quote! {
unsafe extern "C" fn __wrap(
unsafe extern "C" fn #wrapper_ident(
_slf: *mut pyo3::ffi::PyObject,
_args: *mut pyo3::ffi::PyObject,
_kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject

View file

@ -220,6 +220,16 @@ macro_rules! wrap_pyfunction {
}};
}
/// Returns the function that is called in the C-FFI.
///
/// Use this together with `#[pyfunction]` and [types::PyCFunction].
#[macro_export]
macro_rules! raw_pycfunction {
($function_name: ident) => {{
pyo3::paste::expr! { [<__pyo3_raw_ $function_name>] }
}};
}
/// Returns a function that takes a [Python] instance and returns a Python module.
///
/// Use this together with `#[pymodule]` and [types::PyModule::add_wrapped].

View file

@ -1,5 +1,6 @@
use crate::ffi;
use crate::exceptions::PyValueError;
use crate::prelude::*;
use crate::{class, ffi, AsPyPointer, PyMethodDef, PyMethodType};
/// Represents a builtin Python function object.
#[repr(transparent)]
@ -7,6 +8,85 @@ pub struct PyCFunction(PyAny);
pyobject_native_var_type!(PyCFunction, ffi::PyCFunction_Type, ffi::PyCFunction_Check);
impl PyCFunction {
/// Create a new built-in function with keywords.
pub fn new_with_keywords<'a>(
fun: ffi::PyCFunctionWithKeywords,
name: &str,
doc: &str,
module: Option<&'a PyModule>,
py: Python<'a>,
) -> PyResult<&'a PyCFunction> {
let fun = PyMethodType::PyCFunctionWithKeywords(fun);
Self::new_(fun, name, doc, module, py)
}
/// Create a new built-in function without keywords.
pub fn new<'a>(
fun: ffi::PyCFunction,
name: &str,
doc: &str,
module: Option<&'a PyModule>,
py: Python<'a>,
) -> PyResult<&'a PyCFunction> {
let fun = PyMethodType::PyCFunction(fun);
Self::new_(fun, name, doc, module, py)
}
fn new_<'a>(
fun: class::PyMethodType,
name: &str,
doc: &str,
module: Option<&'a PyModule>,
py: Python<'a>,
) -> PyResult<&'a PyCFunction> {
let name = name.to_string();
let name: &'static str = Box::leak(name.into_boxed_str());
// this is ugly but necessary since `PyMethodDef::ml_doc` is &str and not `CStr`
let doc = if doc.ends_with('\0') {
doc.to_string()
} else {
format!("{}\0", doc)
};
let doc: &'static str = Box::leak(doc.into_boxed_str());
let def = match &fun {
PyMethodType::PyCFunction(_) => PyMethodDef {
ml_name: name,
ml_meth: fun,
ml_flags: ffi::METH_VARARGS,
ml_doc: doc,
},
PyMethodType::PyCFunctionWithKeywords(_) => PyMethodDef {
ml_name: name,
ml_meth: fun,
ml_flags: ffi::METH_VARARGS | ffi::METH_KEYWORDS,
ml_doc: doc,
},
_ => {
return Err(PyValueError::py_err(
"Only PyCFunction and PyCFunctionWithKeywords are valid.",
))
}
};
let def = def.as_method_def();
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
unsafe {
py.from_owned_ptr_or_err::<PyCFunction>(ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(def)),
mod_ptr,
module_name,
))
}
}
}
/// Represents a Python function object.
#[repr(transparent)]
pub struct PyFunction(PyAny);

View file

@ -1,7 +1,7 @@
use pyo3::buffer::PyBuffer;
use pyo3::prelude::*;
use pyo3::types::{PyCFunction, PyFunction};
use pyo3::wrap_pyfunction;
use pyo3::{raw_pycfunction, wrap_pyfunction};
mod common;
@ -94,3 +94,28 @@ fn test_functions_with_function_args() {
"#
)
}
#[test]
fn test_raw_function() {
let gil = Python::acquire_gil();
let py = gil.python();
let raw_func = raw_pycfunction!(optional_bool);
let fun = PyCFunction::new_with_keywords(raw_func, "fun", "", None, py).unwrap();
let res = fun.call((), None).unwrap().extract::<&str>().unwrap();
assert_eq!(res, "Some(true)");
let res = fun.call((false,), None).unwrap().extract::<&str>().unwrap();
assert_eq!(res, "Some(false)");
let no_module = fun.getattr("__module__").unwrap().is_none();
assert!(no_module);
let module = PyModule::new(py, "cool_module").unwrap();
module.add_function(&|_| Ok(fun)).unwrap();
let res = module
.getattr("fun")
.unwrap()
.call((), None)
.unwrap()
.extract::<&str>()
.unwrap();
assert_eq!(res, "Some(true)");
}