Possible to pass PyModule as first arg.

This commit makes it possible to access the module of a function
by passing the `need_module` argument to the pyfn and pyfunction
macros.
This commit is contained in:
Sebastian Pütz 2020-09-03 17:27:24 +02:00
parent 3214249010
commit 795c054511
6 changed files with 154 additions and 18 deletions

View file

@ -2,7 +2,6 @@
//! Code generation for the function that initializes a python module and adds classes and function.
use crate::method;
use crate::pyfunction;
use crate::pyfunction::PyFunctionAttr;
use crate::pymethod;
use crate::pymethod::get_arg_names;
@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>>
/// Extracts the data from the #[pyfn(...)] attribute of a function
fn extract_pyfn_attrs(
attrs: &mut Vec<syn::Attribute>,
) -> syn::Result<Option<(syn::Path, Ident, Vec<pyfunction::Argument>)>> {
) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
let mut new_attrs = Vec::new();
let mut fnname = None;
let mut modname = None;
let mut fn_attrs = Vec::new();
let mut fn_attrs = PyFunctionAttr::default();
for attr in attrs.iter() {
match attr.parse_meta() {
@ -115,9 +114,7 @@ fn extract_pyfn_attrs(
}
// Read additional arguments
if list.nested.len() >= 3 {
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])
.unwrap()
.arguments;
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
}
} else {
return Err(syn::Error::new_spanned(
@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident {
pub fn add_fn_to_module(
func: &mut syn::ItemFn,
python_name: Ident,
pyfn_attrs: Vec<pyfunction::Argument>,
pyfn_attrs: PyFunctionAttr,
) -> syn::Result<TokenStream> {
let mut arguments = Vec::new();
for input in func.sig.inputs.iter() {
for (i, input) in func.sig.inputs.iter().enumerate() {
match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new_spanned(
@ -161,7 +158,27 @@ pub fn add_fn_to_module(
))
}
syn::FnArg::Typed(ref cap) => {
arguments.push(wrap_fn_argument(cap)?);
if pyfn_attrs.need_module && i == 0 {
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
if typath
.path
.segments
.last()
.map(|seg| seg.ident == "PyModule")
.unwrap_or(false)
{
continue;
}
}
}
return Err(syn::Error::new_spanned(
cap,
"Expected &PyModule as first argument with `need_module`.",
));
} else {
arguments.push(wrap_fn_argument(cap)?);
}
}
}
}
@ -177,7 +194,7 @@ pub fn add_fn_to_module(
tp: method::FnType::FnStatic,
name: &function_wrapper_ident,
python_name,
attrs: pyfn_attrs,
attrs: pyfn_attrs.arguments,
args: arguments,
output: ty,
doc,
@ -187,7 +204,7 @@ pub fn add_fn_to_module(
let python_name = &spec.python_name;
let wrapper = function_c_wrapper(&func.sig.ident, &spec);
let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module);
Ok(quote! {
fn #function_wrapper_ident<'a>(
@ -230,12 +247,23 @@ pub fn add_fn_to_module(
}
/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream {
let names: Vec<Ident> = get_arg_names(&spec);
let cb = quote! {
#name(#(#names),*)
let cb;
let slf_module;
if need_module {
cb = quote! {
#name(_slf, #(#names),*)
};
slf_module = Some(quote! {
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
});
} else {
cb = quote! {
#name(#(#names),*)
};
slf_module = None;
};
let body = pymethod::impl_arg_params(spec, None, cb);
quote! {
@ -246,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
{
const _LOCATION: &'static str = concat!(stringify!(#name), "()");
pyo3::callback_body!(_py, {
#slf_module
let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);

View file

@ -24,6 +24,7 @@ pub struct PyFunctionAttr {
has_kw: bool,
has_varargs: bool,
has_kwargs: bool,
pub need_module: bool,
}
impl syn::parse::Parse for PyFunctionAttr {
@ -45,6 +46,9 @@ impl PyFunctionAttr {
pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
match item {
NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => {
self.need_module = true;
}
NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
self.add_name_value(item, nv)?;
@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Opti
pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
let python_name =
parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
add_fn_to_module(ast, python_name, args.arguments)
add_fn_to_module(ast, python_name, args)
}
#[cfg(test)]

View file

@ -4,6 +4,7 @@ fn test_compile_errors() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/invalid_frompy_derive.rs");
t.compile_fail("tests/ui/invalid_macro_args.rs");
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");

View file

@ -1,6 +1,6 @@
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyTuple};
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
mod common;
@ -49,6 +49,11 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
Ok(42)
}
#[pyfn(m, "with_module", need_module)]
fn with_module(module: &PyModule) -> PyResult<&str> {
module.name()
}
#[pyfn(m, "double_value")]
fn double_value(v: &ValueClass) -> usize {
v.value * 2
@ -97,6 +102,7 @@ fn test_module_with_functions() {
run("assert module_with_functions.also_double(3) == 6");
run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'");
run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2");
run("assert module_with_functions.with_module() == 'module_with_functions'");
}
#[pymodule(other_name)]
@ -230,7 +236,7 @@ fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::{wrap_pyfunction, wrap_pymodule};
module.add_function(wrap_pyfunction!(superfunction))?;
module.add_module(wrap_pymodule!(submodule))?;
module.add_submodule(wrap_pymodule!(submodule))?;
Ok(())
}
@ -305,3 +311,82 @@ fn test_module_with_constant() {
py_assert!(py, m, "isinstance(m.ANON, m.AnonClass)");
});
}
#[pyfunction(need_module)]
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
module.name()
}
#[pyfunction(need_module)]
fn pyfunction_with_module_and_py<'a>(
module: &'a PyModule,
_python: Python<'a>,
) -> PyResult<&'a str> {
module.name()
}
#[pyfunction(need_module)]
fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> {
module.name().map(|s| (s, string))
}
#[pyfunction(need_module, string = "\"foo\"")]
fn pyfunction_with_module_and_default_arg<'a>(
module: &'a PyModule,
string: &str,
) -> PyResult<(&'a str, String)> {
module.name().map(|s| (s, string.into()))
}
#[pyfunction(need_module, args = "*", kwargs = "**")]
fn pyfunction_with_module_and_args_kwargs<'a>(
module: &'a PyModule,
args: &PyTuple,
kwargs: Option<&PyDict>,
) -> PyResult<(&'a str, usize, Option<usize>)> {
module
.name()
.map(|s| (s, args.len(), kwargs.map(|d| d.len())))
}
#[pymodule]
fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module))?;
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py))?;
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg))?;
m.add_function(pyo3::wrap_pyfunction!(
pyfunction_with_module_and_default_arg
))?;
m.add_function(pyo3::wrap_pyfunction!(
pyfunction_with_module_and_args_kwargs
))
}
#[test]
fn test_module_functions_with_module() {
let gil = Python::acquire_gil();
let py = gil.python();
let m = pyo3::wrap_pymodule!(module_with_functions_with_module)(py);
py_assert!(
py,
m,
"m.pyfunction_with_module() == 'module_with_functions_with_module'"
);
py_assert!(
py,
m,
"m.pyfunction_with_module_and_py() == 'module_with_functions_with_module'"
);
py_assert!(
py,
m,
"m.pyfunction_with_module_and_default_arg() \
== ('module_with_functions_with_module', 'foo')"
);
py_assert!(
py,
m,
"m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \
== ('module_with_functions_with_module', 1, 2)"
);
}

View file

@ -0,0 +1,12 @@
use pyo3::prelude::*;
#[pymodule]
fn module(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyfn(m, "with_module", need_module)]
fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
module.name()
}
Ok(())
}
fn main(){}

View file

@ -0,0 +1,5 @@
error: Expected &PyModule as first argument with `need_module`.
--> $DIR/invalid_need_module_arg_position.rs:6:13
|
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
| ^^^^^^^^^^^^