From 98d810e662856a3da90e907a31f073f755b9e11b Mon Sep 17 00:00:00 2001 From: Yuji Kanagawa Date: Tue, 18 Feb 2020 12:31:45 +0900 Subject: [PATCH] Apply suggestions from davidhewitt's review Co-Authored-By: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- guide/src/class.md | 6 ++--- pyo3-derive-backend/src/module.rs | 33 +++++++++++---------------- pyo3-derive-backend/src/pyfunction.rs | 2 +- pyo3cls/src/lib.rs | 4 +++- src/derive_utils.rs | 2 +- src/objectprotocol.rs | 8 +++---- 6 files changed, 25 insertions(+), 30 deletions(-) diff --git a/guide/src/class.md b/guide/src/class.md index 3d0d40af..02a6abc3 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -116,7 +116,7 @@ from Rust code (e.g., for testing it). `PyCell` is always allocated in the Python heap, so we don't have the ownership of it. We can get `&PyCell`, not `PyCell`. -Thus, to mutate data behind `&PyCell` safely, we employs +Thus, to mutate data behind `&PyCell` safely, we employ the [Interior Mutability Pattern](https://doc.rust-lang.org/book/ch15-05-interior-mutability.html) like [std::cell::RefCell](https://doc.rust-lang.org/std/cell/struct.RefCell.html). @@ -145,13 +145,13 @@ let obj = PyCell::new(py, MyClass { num: 3, debug: true }).unwrap(); { let obj_ref = obj.borrow(); // Get PyRef assert_eq!(obj_ref.num, 3); - // You cannot get PyRefMut unless all PyRef drop + // You cannot get PyRefMut unless all PyRefs are dropped assert!(obj.try_borrow_mut().is_err()); } { let mut obj_mut = obj.borrow_mut(); // Get PyRefMut obj_mut.num = 5; - // You cannot get PyRef unless all PyRefMut drop + // You cannot get any other refs until the PyRefMut is dropped assert!(obj.try_borrow().is_err()); } // You can convert `&PyCell` to Python object diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 05046fdd..b282c3aa 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -28,7 +28,7 @@ pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::LitStr) -> TokenStream { } /// Finds and takes care of the #[pyfn(...)] in `#[pymodule]` -pub fn process_functions_in_module(func: &mut syn::ItemFn) { +pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let mut stmts: Vec = Vec::new(); for stmt in func.block.stmts.iter_mut() { @@ -36,7 +36,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) { if let Some((module_name, python_name, pyfn_attrs)) = extract_pyfn_attrs(&mut func.attrs) { - let function_to_python = add_fn_to_module(func, python_name, pyfn_attrs); + let function_to_python = add_fn_to_module(func, python_name, pyfn_attrs)?; let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { @@ -51,18 +51,19 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) { } func.block.stmts = stmts; + Ok(()) } /// Transforms a rust fn arg parsed with syn into a method::FnArg -fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> method::FnArg<'a> { +fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> syn::Result> { let (mutability, by_ref, ident) = match *cap.pat { syn::Pat::Ident(ref patid) => (&patid.mutability, &patid.by_ref, &patid.ident), - _ => panic!("unsupported argument: {:?}", cap.pat), + _ => return Err(syn::Error::new_spanned(&cap.pat, "Unsupported argument")), }; let py = crate::utils::if_type_is_python(&cap.ty); let opt = method::check_arg_ty_and_optional(&name, &cap.ty); - method::FnArg { + Ok(method::FnArg { name: ident, mutability, by_ref, @@ -70,7 +71,7 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> method::FnArg optional: opt, py, reference: method::is_ref(&name, &cap.ty), - } + }) } /// Extracts the data from the #[pyfn(...)] attribute of a function @@ -131,7 +132,7 @@ pub fn add_fn_to_module( func: &mut syn::ItemFn, python_name: Ident, pyfn_attrs: Vec, -) -> TokenStream { +) -> syn::Result { let mut arguments = Vec::new(); let mut self_ = None; @@ -141,21 +142,15 @@ pub fn add_fn_to_module( self_ = Some(recv.mutability.is_some()); } syn::FnArg::Typed(ref cap) => { - arguments.push(wrap_fn_argument(cap, &func.sig.ident)); + arguments.push(wrap_fn_argument(cap, &func.sig.ident)?); } } } let ty = method::get_return_info(&func.sig.output); - let text_signature = match utils::parse_text_signature_attrs(&mut func.attrs, &python_name) { - Ok(text_signature) => text_signature, - Err(err) => return err.to_compile_error(), - }; - let doc = match utils::get_doc(&func.attrs, text_signature, true) { - Ok(doc) => doc, - Err(err) => return err.to_compile_error(), - }; + let text_signature = utils::parse_text_signature_attrs(&mut func.attrs, &python_name)?; + let doc = utils::get_doc(&func.attrs, text_signature, true)?; let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); @@ -176,7 +171,7 @@ pub fn add_fn_to_module( let wrapper = function_c_wrapper(&func.sig.ident, &spec); - let tokens = quote! { + Ok(quote! { fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject { #wrapper @@ -199,9 +194,7 @@ pub fn add_fn_to_module( function } - }; - - tokens + }) } /// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords) diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index 588deb6a..c7bcacd3 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -234,7 +234,7 @@ pub fn parse_name_attribute(attrs: &mut Vec) -> syn::Result syn::Result { let python_name = parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw()); - Ok(add_fn_to_module(ast, python_name, args.arguments)) + add_fn_to_module(ast, python_name, args.arguments) } #[cfg(test)] diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 0ba93c99..795423cb 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -23,7 +23,9 @@ pub fn pymodule(attr: TokenStream, input: TokenStream) -> TokenStream { parse_macro_input!(attr as syn::Ident) }; - process_functions_in_module(&mut ast); + if let Err(err) = process_functions_in_module(&mut ast) { + return err.to_compile_error().into(); + } let doc = match get_doc(&ast.attrs, None, false) { Ok(doc) => doc, diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 1901f76a..da8cf01c 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -219,7 +219,7 @@ impl GetPropertyValue for PyObject { } } -/// Utitlities for basetype +/// Utilities for basetype pub trait PyBaseTypeUtils { type Dict; type WeakRef; diff --git a/src/objectprotocol.rs b/src/objectprotocol.rs index 55122ad8..d777ee63 100644 --- a/src/objectprotocol.rs +++ b/src/objectprotocol.rs @@ -171,10 +171,10 @@ pub trait ObjectProtocol { fn get_type_ptr(&self) -> *mut ffi::PyTypeObject; /// Gets the Python base object for this object. - fn get_base<'py>(&'py self) -> &'py ::BaseType + fn get_base(&self) -> &::BaseType where Self: PyTypeInfo, - ::BaseType: FromPyPointer<'py>; + ::BaseType: for<'py> FromPyPointer<'py>; /// Casts the PyObject to a concrete Python object type. fn cast_as<'a, D>(&'a self) -> Result<&'a D, PyDowncastError> @@ -445,10 +445,10 @@ where unsafe { (*self.as_ptr()).ob_type } } - fn get_base<'py>(&'py self) -> &'py ::BaseType + fn get_base(&self) -> &::BaseType where Self: PyTypeInfo, - ::BaseType: FromPyPointer<'py>, + ::BaseType: for<'py> FromPyPointer<'py>, { unsafe { self.py().from_borrowed_ptr(self.as_ptr()) } }