Make process_functions_... only accept ItemFn

This commit is contained in:
Martin Larralde 2018-05-13 21:33:59 +02:00
parent b43b481980
commit bf550c3c1f

View file

@ -9,10 +9,9 @@ use utils;
/// Generates the function that is called by the python interpreter to initialize the native /// Generates the function that is called by the python interpreter to initialize the native
/// module /// module
pub fn py3_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens { pub fn py3_init(fnname: &syn::Ident, name: &syn::Ident, doc: syn::Lit) -> Tokens {
let m_name: syn::Ident = syn::parse_str(name.trim().as_ref()).unwrap(); let cb_name: syn::Ident = syn::parse_str(&format!("PyInit_{}", name)).unwrap();
let cb_name: syn::Ident = syn::parse_str(&format!("PyInit_{}", name.trim())).unwrap();
quote! { quote! {
#[no_mangle] #[no_mangle]
@ -27,7 +26,7 @@ pub fn py3_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens {
static mut MODULE_DEF: pyo3::ffi::PyModuleDef = pyo3::ffi::PyModuleDef_INIT; static mut MODULE_DEF: pyo3::ffi::PyModuleDef = pyo3::ffi::PyModuleDef_INIT;
// We can't convert &'static str to *const c_char within a static initializer, // We can't convert &'static str to *const c_char within a static initializer,
// so we'll do it here in the module initialization: // so we'll do it here in the module initialization:
MODULE_DEF.m_name = concat!(stringify!(#m_name), "\0").as_ptr() as *const _; MODULE_DEF.name = concat!(stringify!(#name), "\0").as_ptr() as *const _;
#[cfg(py_sys_config = "WITH_THREAD")] #[cfg(py_sys_config = "WITH_THREAD")]
pyo3::ffi::PyEval_InitThreads(); pyo3::ffi::PyEval_InitThreads();
@ -58,9 +57,9 @@ pub fn py3_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens {
} }
} }
pub fn py2_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens { pub fn py2_init(fnname: &syn::Ident, name: &syn::Ident, doc: syn::Lit) -> Tokens {
let m_name: syn::Ident = syn::parse_str(name.trim().as_ref()).unwrap();
let cb_name: syn::Ident = syn::parse_str(&format!("PyInit_{}", name.trim())).unwrap(); let cb_name: syn::Ident = syn::parse_str(&format!("PyInit_{}", name)).unwrap();
quote! { quote! {
#[no_mangle] #[no_mangle]
@ -72,7 +71,7 @@ pub fn py2_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens {
pyo3::prepare_pyo3_library(); pyo3::prepare_pyo3_library();
pyo3::ffi::PyEval_InitThreads(); pyo3::ffi::PyEval_InitThreads();
let _name = concat!(stringify!(#m_name), "\0").as_ptr() as *const _; let _name = concat!(stringify!(#name), "\0").as_ptr() as *const _;
let _pool = pyo3::GILPool::new(); let _pool = pyo3::GILPool::new();
let _py = pyo3::Python::assume_gil_acquired(); let _py = pyo3::Python::assume_gil_acquired();
let _module = pyo3::ffi::Py_InitModule(_name, std::ptr::null_mut()); let _module = pyo3::ffi::Py_InitModule(_name, std::ptr::null_mut());
@ -97,8 +96,7 @@ pub fn py2_init(fnname: &syn::Ident, name: &String, doc: syn::Lit) -> Tokens {
} }
/// Finds and takes care of the #[pyfn(...)] in #[modinit(...)] /// Finds and takes care of the #[pyfn(...)] in #[modinit(...)]
pub fn process_functions_in_module(ast: &mut syn::Item) { pub fn process_functions_in_module(func: &mut syn::ItemFn) {
if let syn::Item::Fn(ref mut func) = ast {
let mut stmts: Vec<syn::Stmt> = Vec::new(); let mut stmts: Vec<syn::Stmt> = Vec::new();
for stmt in func.block.stmts.iter_mut() { for stmt in func.block.stmts.iter_mut() {
@ -121,9 +119,6 @@ pub fn process_functions_in_module(ast: &mut syn::Item) {
} }
func.block.stmts = stmts; func.block.stmts = stmts;
} else {
panic!("#[modinit] can only be used with fn block");
}
} }
/// Transforms a rust fn arg parsed with syn into a method::FnArg /// Transforms a rust fn arg parsed with syn into a method::FnArg