From 721e7465854d4dbc893ff04c0c6da67ad9f934c9 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 12 Jul 2019 23:41:13 +0900 Subject: [PATCH] Allow py: Python as an argument of getter --- ci/travis/cover.sh | 2 +- pyo3-derive-backend/src/method.rs | 5 --- pyo3-derive-backend/src/pyclass.rs | 9 +++-- pyo3-derive-backend/src/pyimpl.rs | 10 ++--- pyo3-derive-backend/src/pymethod.rs | 54 +++++++++++++++++-------- tests/test_compile_error.rs | 3 +- tests/test_getter_setter.rs | 12 ++++-- tests/ui/too_many_args_to_getter.rs | 14 +++++++ tests/ui/too_many_args_to_getter.stderr | 5 +++ 9 files changed, 79 insertions(+), 35 deletions(-) mode change 100644 => 100755 tests/test_compile_error.rs create mode 100644 tests/ui/too_many_args_to_getter.rs create mode 100644 tests/ui/too_many_args_to_getter.stderr diff --git a/ci/travis/cover.sh b/ci/travis/cover.sh index 1b685005..2137e37b 100755 --- a/ci/travis/cover.sh +++ b/ci/travis/cover.sh @@ -23,7 +23,7 @@ echo $FILES | xargs -n1 -P1 sh -c ' echo "Collecting coverage data of $(basename $@)" kcov \ --exclude-path=./tests \ - --exclude-region="#[cfg(test)]:#[cfg(testkcovstopmarker)]" \ + --exclude-region="#[cfg(test)]:#[cfg(not(testkcovstopmarker))]" \ --exclude-pattern=/.cargo,/usr/lib \ --verify $dir "$@" 2>&1 >/dev/null ' _ diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 920689fb..9cc02414 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -200,11 +200,6 @@ impl<'a> FnSpec<'a> { } false } - - /// A FnSpec is valid as getter if it has no argument or has one argument of type `Python` - pub fn valid_as_getter(&self) -> bool { - false - } } pub fn is_ref(name: &syn::Ident, ty: &syn::Type) -> bool { diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index cbd4fbfc..0d9a334a 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -417,9 +417,12 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> let field_ty = &field.ty; match *desc { - FnType::Getter(ref getter) => { - impl_py_getter_def(&name, doc, getter, &impl_wrap_getter(&cls, &name)) - } + FnType::Getter(ref getter) => impl_py_getter_def( + &name, + doc, + getter, + &impl_wrap_getter(&cls, &name, false), + ), FnType::Setter(ref setter) => { let setter_name = syn::Ident::new(&format!("set_{}", name), Span::call_site()); diff --git a/pyo3-derive-backend/src/pyimpl.rs b/pyo3-derive-backend/src/pyimpl.rs index 5af616c9..64abab02 100644 --- a/pyo3-derive-backend/src/pyimpl.rs +++ b/pyo3-derive-backend/src/pyimpl.rs @@ -16,11 +16,11 @@ pub fn build_py_methods(ast: &mut syn::ItemImpl) -> syn::Result { "#[pymethods] can not be used with lifetime parameters or generics", )) } else { - Ok(impl_methods(&ast.self_ty, &mut ast.items)) + impl_methods(&ast.self_ty, &mut ast.items) } } -pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> TokenStream { +pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> syn::Result { // get method names in impl block let mut methods = Vec::new(); for iimpl in impls.iter_mut() { @@ -31,16 +31,16 @@ pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> TokenStre &name, &mut meth.sig, &mut meth.attrs, - )); + )?); } } - quote! { + Ok(quote! { pyo3::inventory::submit! { #![crate = pyo3] { type TyInventory = <#ty as pyo3::class::methods::PyMethodsInventoryDispatch>::InventoryType; ::new(&[#(#methods),*]) } } - } + }) } diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 994ab003..16d5b6cd 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -9,13 +9,13 @@ pub fn gen_py_method( name: &syn::Ident, sig: &mut syn::MethodSig, meth_attrs: &mut Vec, -) -> TokenStream { - check_generic(name, sig); +) -> syn::Result { + check_generic(name, sig)?; let doc = utils::get_doc(&meth_attrs, true); - let spec = FnSpec::parse(name, sig, meth_attrs).unwrap(); + let spec = FnSpec::parse(name, sig, meth_attrs)?; - match spec.tp { + Ok(match spec.tp { FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)), FnType::PySelf(ref self_ty) => impl_py_method_def( name, @@ -31,28 +31,43 @@ pub fn gen_py_method( impl_py_method_def_static(name, doc, &impl_wrap_static(cls, name, &spec)) } FnType::Getter(ref getter) => { - impl_py_getter_def(name, doc, getter, &impl_wrap_getter(cls, name)) + let mut takes_py = false; + for arg in &spec.args { + if !utils::if_type_is_python(arg.ty) { + return Err(syn::Error::new_spanned( + arg.ty, + "Getter function cannot have arguments other than pyo3::Python!", + )); + } + takes_py = true; + } + impl_py_getter_def(name, doc, getter, &impl_wrap_getter(cls, name, takes_py)) } FnType::Setter(ref setter) => { impl_py_setter_def(name, doc, setter, &impl_wrap_setter(cls, name, &spec)) } - } + }) } -fn check_generic(name: &syn::Ident, sig: &syn::MethodSig) { +fn check_generic(name: &syn::Ident, sig: &syn::MethodSig) -> syn::Result<()> { + let err_msg = |typ| { + format!( + "A Python method can't have a generic {} parameter: {}", + name, typ + ) + }; for param in &sig.decl.generics.params { match param { syn::GenericParam::Lifetime(_) => {} - syn::GenericParam::Type(_) => panic!( - "A Python method can't have a generic type parameter: {}", - name - ), - syn::GenericParam::Const(_) => panic!( - "A Python method can't have a const generic parameter: {}", - name - ), + syn::GenericParam::Type(_) => { + return Err(syn::Error::new_spanned(param, err_msg("type"))) + } + syn::GenericParam::Const(_) => { + return Err(syn::Error::new_spanned(param, err_msg("const"))) + } } } + Ok(()) } /// Generate function wrapper (PyCFunction, PyCFunctionWithKeywords) @@ -302,7 +317,12 @@ pub fn impl_wrap_static(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec<'_>) - } /// Generate functiona wrapper (PyCFunction, PyCFunctionWithKeywords) -pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident) -> TokenStream { +pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident, takes_py: bool) -> TokenStream { + let fncall = if takes_py { + quote! { _slf.#name(_py) } + } else { + quote! { _slf.#name() } + }; quote! { unsafe extern "C" fn __wrap( _slf: *mut pyo3::ffi::PyObject, _: *mut ::std::os::raw::c_void) -> *mut pyo3::ffi::PyObject @@ -313,7 +333,7 @@ pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident) -> TokenStrea let _py = pyo3::Python::assume_gil_acquired(); let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); - let result = pyo3::derive_utils::IntoPyResult::into_py_result(_slf.#name()); + let result = pyo3::derive_utils::IntoPyResult::into_py_result(#fncall); match result { Ok(val) => { diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs old mode 100644 new mode 100755 index 0437d5e3..14be13ca --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -1,6 +1,7 @@ #[test] -#[cfg(testkcovstopmarker)] +#[cfg(not(testkcovstopmarker))] fn test_compile_errors() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/reject_generics.rs"); + t.compile_fail("tests/ui/too_many_args_to_getter.rs"); } diff --git a/tests/test_getter_setter.rs b/tests/test_getter_setter.rs index a82d1ef4..cacf5652 100644 --- a/tests/test_getter_setter.rs +++ b/tests/test_getter_setter.rs @@ -1,6 +1,6 @@ use pyo3::prelude::*; use pyo3::py_run; -use pyo3::types::IntoPyDict; +use pyo3::types::{IntoPyDict, PyList}; use std::isize; mod common; @@ -32,10 +32,16 @@ impl ClassWithProperties { fn get_unwrapped(&self) -> i32 { self.num } + #[setter] fn set_unwrapped(&mut self, value: i32) { self.num = value; } + + #[getter] + fn get_data_list<'py>(&self, py: Python<'py>) -> &'py PyList { + PyList::new(py, &[self.num]) + } } #[test] @@ -48,12 +54,12 @@ fn class_with_properties() { py_run!(py, inst, "assert inst.get_num() == 10"); py_run!(py, inst, "assert inst.get_num() == inst.DATA"); py_run!(py, inst, "inst.DATA = 20"); - py_run!(py, inst, "assert inst.get_num() == 20"); - py_run!(py, inst, "assert inst.get_num() == inst.DATA"); + py_run!(py, inst, "assert inst.get_num() == 20 == inst.DATA"); py_run!(py, inst, "assert inst.get_num() == inst.unwrapped == 20"); py_run!(py, inst, "inst.unwrapped = 42"); py_run!(py, inst, "assert inst.get_num() == inst.unwrapped == 42"); + py_run!(py, inst, "assert inst.data_list == [42]"); let d = [("C", py.get_type::())].into_py_dict(py); py.run( diff --git a/tests/ui/too_many_args_to_getter.rs b/tests/ui/too_many_args_to_getter.rs new file mode 100644 index 00000000..f28b53f6 --- /dev/null +++ b/tests/ui/too_many_args_to_getter.rs @@ -0,0 +1,14 @@ +use pyo3::prelude::*; + +#[pyclass] +struct ClassWithGetter { + a: u32, +} + +#[pymethods] +impl ClassWithGetter { + #[getter] + fn get_num(&self, index: u32) {} +} + +fn main() {} diff --git a/tests/ui/too_many_args_to_getter.stderr b/tests/ui/too_many_args_to_getter.stderr new file mode 100644 index 00000000..b1d67c14 --- /dev/null +++ b/tests/ui/too_many_args_to_getter.stderr @@ -0,0 +1,5 @@ +error: Getter function cannot have arguments other than pyo3::Python! + --> $DIR/too_many_args_to_getter.rs:11:30 + | +11 | fn get_num(&self, index: u32) {} + | ^^^