Possible to pass PyModule as first arg.
This commit makes it possible to access the module of a function by passing the `need_module` argument to the pyfn and pyfunction macros.
This commit is contained in:
parent
3214249010
commit
795c054511
|
@ -2,7 +2,6 @@
|
|||
//! Code generation for the function that initializes a python module and adds classes and function.
|
||||
|
||||
use crate::method;
|
||||
use crate::pyfunction;
|
||||
use crate::pyfunction::PyFunctionAttr;
|
||||
use crate::pymethod;
|
||||
use crate::pymethod::get_arg_names;
|
||||
|
@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>>
|
|||
/// Extracts the data from the #[pyfn(...)] attribute of a function
|
||||
fn extract_pyfn_attrs(
|
||||
attrs: &mut Vec<syn::Attribute>,
|
||||
) -> syn::Result<Option<(syn::Path, Ident, Vec<pyfunction::Argument>)>> {
|
||||
) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
|
||||
let mut new_attrs = Vec::new();
|
||||
let mut fnname = None;
|
||||
let mut modname = None;
|
||||
let mut fn_attrs = Vec::new();
|
||||
let mut fn_attrs = PyFunctionAttr::default();
|
||||
|
||||
for attr in attrs.iter() {
|
||||
match attr.parse_meta() {
|
||||
|
@ -115,9 +114,7 @@ fn extract_pyfn_attrs(
|
|||
}
|
||||
// Read additional arguments
|
||||
if list.nested.len() >= 3 {
|
||||
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])
|
||||
.unwrap()
|
||||
.arguments;
|
||||
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
|
||||
}
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(
|
||||
|
@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident {
|
|||
pub fn add_fn_to_module(
|
||||
func: &mut syn::ItemFn,
|
||||
python_name: Ident,
|
||||
pyfn_attrs: Vec<pyfunction::Argument>,
|
||||
pyfn_attrs: PyFunctionAttr,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let mut arguments = Vec::new();
|
||||
|
||||
for input in func.sig.inputs.iter() {
|
||||
for (i, input) in func.sig.inputs.iter().enumerate() {
|
||||
match input {
|
||||
syn::FnArg::Receiver(_) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
|
@ -161,7 +158,27 @@ pub fn add_fn_to_module(
|
|||
))
|
||||
}
|
||||
syn::FnArg::Typed(ref cap) => {
|
||||
arguments.push(wrap_fn_argument(cap)?);
|
||||
if pyfn_attrs.need_module && i == 0 {
|
||||
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
|
||||
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
|
||||
if typath
|
||||
.path
|
||||
.segments
|
||||
.last()
|
||||
.map(|seg| seg.ident == "PyModule")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(syn::Error::new_spanned(
|
||||
cap,
|
||||
"Expected &PyModule as first argument with `need_module`.",
|
||||
));
|
||||
} else {
|
||||
arguments.push(wrap_fn_argument(cap)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -177,7 +194,7 @@ pub fn add_fn_to_module(
|
|||
tp: method::FnType::FnStatic,
|
||||
name: &function_wrapper_ident,
|
||||
python_name,
|
||||
attrs: pyfn_attrs,
|
||||
attrs: pyfn_attrs.arguments,
|
||||
args: arguments,
|
||||
output: ty,
|
||||
doc,
|
||||
|
@ -187,7 +204,7 @@ pub fn add_fn_to_module(
|
|||
|
||||
let python_name = &spec.python_name;
|
||||
|
||||
let wrapper = function_c_wrapper(&func.sig.ident, &spec);
|
||||
let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module);
|
||||
|
||||
Ok(quote! {
|
||||
fn #function_wrapper_ident<'a>(
|
||||
|
@ -230,12 +247,23 @@ pub fn add_fn_to_module(
|
|||
}
|
||||
|
||||
/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
|
||||
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
|
||||
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream {
|
||||
let names: Vec<Ident> = get_arg_names(&spec);
|
||||
let cb = quote! {
|
||||
#name(#(#names),*)
|
||||
let cb;
|
||||
let slf_module;
|
||||
if need_module {
|
||||
cb = quote! {
|
||||
#name(_slf, #(#names),*)
|
||||
};
|
||||
slf_module = Some(quote! {
|
||||
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
|
||||
});
|
||||
} else {
|
||||
cb = quote! {
|
||||
#name(#(#names),*)
|
||||
};
|
||||
slf_module = None;
|
||||
};
|
||||
|
||||
let body = pymethod::impl_arg_params(spec, None, cb);
|
||||
|
||||
quote! {
|
||||
|
@ -246,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
|
|||
{
|
||||
const _LOCATION: &'static str = concat!(stringify!(#name), "()");
|
||||
pyo3::callback_body!(_py, {
|
||||
#slf_module
|
||||
let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
|
||||
let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ pub struct PyFunctionAttr {
|
|||
has_kw: bool,
|
||||
has_varargs: bool,
|
||||
has_kwargs: bool,
|
||||
pub need_module: bool,
|
||||
}
|
||||
|
||||
impl syn::parse::Parse for PyFunctionAttr {
|
||||
|
@ -45,6 +46,9 @@ impl PyFunctionAttr {
|
|||
|
||||
pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
|
||||
match item {
|
||||
NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => {
|
||||
self.need_module = true;
|
||||
}
|
||||
NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
|
||||
NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
|
||||
self.add_name_value(item, nv)?;
|
||||
|
@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Opti
|
|||
pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
|
||||
let python_name =
|
||||
parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
|
||||
add_fn_to_module(ast, python_name, args.arguments)
|
||||
add_fn_to_module(ast, python_name, args)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -4,6 +4,7 @@ fn test_compile_errors() {
|
|||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/ui/invalid_frompy_derive.rs");
|
||||
t.compile_fail("tests/ui/invalid_macro_args.rs");
|
||||
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
|
||||
t.compile_fail("tests/ui/invalid_property_args.rs");
|
||||
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
use pyo3::types::{IntoPyDict, PyTuple};
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
|
||||
mod common;
|
||||
|
||||
|
@ -49,6 +49,11 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||
Ok(42)
|
||||
}
|
||||
|
||||
#[pyfn(m, "with_module", need_module)]
|
||||
fn with_module(module: &PyModule) -> PyResult<&str> {
|
||||
module.name()
|
||||
}
|
||||
|
||||
#[pyfn(m, "double_value")]
|
||||
fn double_value(v: &ValueClass) -> usize {
|
||||
v.value * 2
|
||||
|
@ -97,6 +102,7 @@ fn test_module_with_functions() {
|
|||
run("assert module_with_functions.also_double(3) == 6");
|
||||
run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'");
|
||||
run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2");
|
||||
run("assert module_with_functions.with_module() == 'module_with_functions'");
|
||||
}
|
||||
|
||||
#[pymodule(other_name)]
|
||||
|
@ -230,7 +236,7 @@ fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
|
|||
use pyo3::{wrap_pyfunction, wrap_pymodule};
|
||||
|
||||
module.add_function(wrap_pyfunction!(superfunction))?;
|
||||
module.add_module(wrap_pymodule!(submodule))?;
|
||||
module.add_submodule(wrap_pymodule!(submodule))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -305,3 +311,82 @@ fn test_module_with_constant() {
|
|||
py_assert!(py, m, "isinstance(m.ANON, m.AnonClass)");
|
||||
});
|
||||
}
|
||||
|
||||
#[pyfunction(need_module)]
|
||||
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
|
||||
module.name()
|
||||
}
|
||||
|
||||
#[pyfunction(need_module)]
|
||||
fn pyfunction_with_module_and_py<'a>(
|
||||
module: &'a PyModule,
|
||||
_python: Python<'a>,
|
||||
) -> PyResult<&'a str> {
|
||||
module.name()
|
||||
}
|
||||
|
||||
#[pyfunction(need_module)]
|
||||
fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> {
|
||||
module.name().map(|s| (s, string))
|
||||
}
|
||||
|
||||
#[pyfunction(need_module, string = "\"foo\"")]
|
||||
fn pyfunction_with_module_and_default_arg<'a>(
|
||||
module: &'a PyModule,
|
||||
string: &str,
|
||||
) -> PyResult<(&'a str, String)> {
|
||||
module.name().map(|s| (s, string.into()))
|
||||
}
|
||||
|
||||
#[pyfunction(need_module, args = "*", kwargs = "**")]
|
||||
fn pyfunction_with_module_and_args_kwargs<'a>(
|
||||
module: &'a PyModule,
|
||||
args: &PyTuple,
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<(&'a str, usize, Option<usize>)> {
|
||||
module
|
||||
.name()
|
||||
.map(|s| (s, args.len(), kwargs.map(|d| d.len())))
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module))?;
|
||||
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py))?;
|
||||
m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg))?;
|
||||
m.add_function(pyo3::wrap_pyfunction!(
|
||||
pyfunction_with_module_and_default_arg
|
||||
))?;
|
||||
m.add_function(pyo3::wrap_pyfunction!(
|
||||
pyfunction_with_module_and_args_kwargs
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_functions_with_module() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let m = pyo3::wrap_pymodule!(module_with_functions_with_module)(py);
|
||||
py_assert!(
|
||||
py,
|
||||
m,
|
||||
"m.pyfunction_with_module() == 'module_with_functions_with_module'"
|
||||
);
|
||||
py_assert!(
|
||||
py,
|
||||
m,
|
||||
"m.pyfunction_with_module_and_py() == 'module_with_functions_with_module'"
|
||||
);
|
||||
py_assert!(
|
||||
py,
|
||||
m,
|
||||
"m.pyfunction_with_module_and_default_arg() \
|
||||
== ('module_with_functions_with_module', 'foo')"
|
||||
);
|
||||
py_assert!(
|
||||
py,
|
||||
m,
|
||||
"m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \
|
||||
== ('module_with_functions_with_module', 1, 2)"
|
||||
);
|
||||
}
|
||||
|
|
12
tests/ui/invalid_need_module_arg_position.rs
Normal file
12
tests/ui/invalid_need_module_arg_position.rs
Normal file
|
@ -0,0 +1,12 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
#[pymodule]
|
||||
fn module(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
#[pyfn(m, "with_module", need_module)]
|
||||
fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
|
||||
module.name()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main(){}
|
5
tests/ui/invalid_need_module_arg_position.stderr
Normal file
5
tests/ui/invalid_need_module_arg_position.stderr
Normal file
|
@ -0,0 +1,5 @@
|
|||
error: Expected &PyModule as first argument with `need_module`.
|
||||
--> $DIR/invalid_need_module_arg_position.rs:6:13
|
||||
|
|
||||
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
|
||||
| ^^^^^^^^^^^^
|
Loading…
Reference in a new issue