From 45bb09b3e8b64b7a1d733f36b9a10a6fa84b255f Mon Sep 17 00:00:00 2001 From: konstin Date: Fri, 6 Apr 2018 17:19:32 +0200 Subject: [PATCH] Relax return type requirements Allows returning essentially arbitrary types by wrapping them into a PyResult. This is done with a conversion trait that specializes for PyResult. --- pyo3-derive-backend/src/args.rs | 2 +- pyo3-derive-backend/src/method.rs | 21 ++++---- pyo3-derive-backend/src/module.rs | 27 +++++----- pyo3-derive-backend/src/py_class.rs | 2 +- pyo3-derive-backend/src/py_method.rs | 77 +++++++++++++++------------- src/conversion.rs | 28 ++++++++++ src/err.rs | 2 +- src/lib.rs | 3 +- tests/test_class.rs | 8 ++- 9 files changed, 105 insertions(+), 65 deletions(-) diff --git a/pyo3-derive-backend/src/args.rs b/pyo3-derive-backend/src/args.rs index 6eb09b9f..dd191dc1 100644 --- a/pyo3-derive-backend/src/args.rs +++ b/pyo3-derive-backend/src/args.rs @@ -1,7 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use syn; -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Argument { VarArgsSeparator, VarArgs(String), diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 70648593..7343fefb 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -1,13 +1,12 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use syn; -use quote::{Tokens, Ident}; - use args::{Argument, parse_arguments}; +use quote::{Ident, Tokens}; +use syn; use utils::for_err_msg; -#[derive(Clone, Debug)] +#[derive(Clone, PartialEq, Debug)] pub struct FnArg<'a> { pub name: &'a syn::Ident, pub mode: &'a syn::BindingMode, @@ -29,6 +28,7 @@ pub enum FnType { FnStatic, } +#[derive(Clone, PartialEq, Debug)] pub struct FnSpec<'a> { pub tp: FnType, pub attrs: Vec, @@ -36,8 +36,14 @@ pub struct FnSpec<'a> { pub output: syn::Ty, } -impl<'a> FnSpec<'a> { +pub fn get_return_info(output: &syn::FunctionRetTy) -> syn::Ty { + match output { + syn::FunctionRetTy::Default => syn::Ty::Tup(vec![]), + syn::FunctionRetTy::Ty(ref ty) => ty.clone() + } +} +impl<'a> FnSpec<'a> { /// Parser function signature and function attributes pub fn parse(name: &'a syn::Ident, sig: &'a syn::MethodSig, @@ -96,10 +102,7 @@ impl<'a> FnSpec<'a> { } } - let ty = match sig.decl.output { - syn::FunctionRetTy::Default => syn::Ty::Infer, - syn::FunctionRetTy::Ty(ref ty) => ty.clone() - }; + let ty = get_return_info(&sig.decl.output); FnSpec { tp: fn_type, diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index c3da66fe..770a837a 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -227,22 +227,21 @@ fn wrap_fn(item: &mut syn::Item) -> Option> { }; let opt = method::check_arg_ty_and_optional(&name, ty); - arguments.push(method::FnArg {name: ident, - mode: mode, - ty: ty, - optional: opt, - py: py, - reference: method::is_ref(&name, ty)}); + arguments.push(method::FnArg { + name: ident, + mode: mode, + ty: ty, + optional: opt, + py: py, + reference: method::is_ref(&name, ty), + }); } &syn::FnArg::Ignored(_) => panic!("ignored argument: {:?}", name), } } - let ty = match decl.output { - syn::FunctionRetTy::Default => syn::Ty::Infer, - syn::FunctionRetTy::Ty(ref ty) => ty.clone() - }; + let ty = method::get_return_info(&decl.output); let spec = method::FnSpec { tp: method::FnType::Fn, @@ -309,11 +308,11 @@ pub fn impl_wrap(name: &syn::Ident, spec: &method::FnSpec) -> Tokens { |item| if item.1.py {syn::Ident::from("_py")} else { syn::Ident::from(format!("arg{}", item.0))}).collect(); let cb = quote! {{ - #name(#(#names),*) + #name(#(#names),*).return_type_into_py_result() }}; let body = py_method::impl_arg_params(spec, cb); - let output = &spec.output; + let body_to_result = py_method::body_to_result(&body, spec); quote! { #[allow(unused_variables, unused_imports)] @@ -329,9 +328,7 @@ pub fn impl_wrap(name: &syn::Ident, spec: &method::FnSpec) -> Tokens { let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: #output = { - #body - }; + #body_to_result _pyo3::callback::cb_convert( _pyo3::callback::PyObjectCallbackConverter, _py, _result) } diff --git a/pyo3-derive-backend/src/py_class.rs b/pyo3-derive-backend/src/py_class.rs index 1d25c025..5756ea80 100644 --- a/pyo3-derive-backend/src/py_class.rs +++ b/pyo3-derive-backend/src/py_class.rs @@ -337,7 +337,7 @@ fn impl_descriptors(cls: &syn::Ty, descriptors: Vec<(syn::Field, Vec)>) py: true, reference: false }], - output: syn::parse::ty("PyResult<()>").expect("error parse PyResult<()>") + output: syn::parse::ty("PyResult<()>").expect("error parse PyResult<()>"), }; impl_py_setter_def(&name, doc, setter, &impl_wrap_setter(&Box::new(cls.clone()), &setter_name, &spec)) }, diff --git a/pyo3-derive-backend/src/py_method.rs b/pyo3-derive-backend/src/py_method.rs index c44e24f7..1bcd7e58 100644 --- a/pyo3-derive-backend/src/py_method.rs +++ b/pyo3-derive-backend/src/py_method.rs @@ -43,12 +43,23 @@ fn check_generic(name: &syn::Ident, sig: &syn::MethodSig) { } +pub fn body_to_result(body: &Tokens, spec: &FnSpec) -> Tokens { + let output = &spec.output; + quote! { + use pyo3::ReturnTypeIntoPyResult; + let _result: PyResult<<#output as ReturnTypeIntoPyResult>::Inner> = { + #body + }; + } +} + /// Generate function wrapper (PyCFunction, PyCFunctionWithKeywords) pub fn impl_wrap(cls: &Box, name: &syn::Ident, spec: &FnSpec, noargs: bool) -> Tokens { - let cb = impl_call(cls, name, &spec); - let output = &spec.output; + let body = impl_call(cls, name, &spec); if spec.args.is_empty() && noargs { + let body_to_result = body_to_result(&body, spec); + quote! { unsafe extern "C" fn __wrap( _slf: *mut _pyo3::ffi::PyObject) -> *mut _pyo3::ffi::PyObject @@ -59,15 +70,14 @@ pub fn impl_wrap(cls: &Box, name: &syn::Ident, spec: &FnSpec, noargs: b let _py = _pyo3::Python::assume_gil_acquired(); let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); - let _result: #output = { - #cb - }; + #body_to_result _pyo3::callback::cb_convert( _pyo3::callback::PyObjectCallbackConverter, _py, _result) } } } else { - let body = impl_arg_params(&spec, cb); + let body = impl_arg_params(&spec, body); + let body_to_result = body_to_result(&body, spec); quote! { unsafe extern "C" fn __wrap( @@ -84,9 +94,7 @@ pub fn impl_wrap(cls: &Box, name: &syn::Ident, spec: &FnSpec, noargs: b let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: #output = { - #body - }; + #body_to_result _pyo3::callback::cb_convert( _pyo3::callback::PyObjectCallbackConverter, _py, _result) } @@ -128,11 +136,11 @@ pub fn impl_wrap_new(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> To |item| if item.1.py {syn::Ident::from("_py")} else { syn::Ident::from(format!("arg{}", item.0))}).collect(); let cb = quote! {{ - #cls::#name(&_obj, #(#names),*) + #cls::#name(&_obj, #(#names),*).return_type_into_py_result() }}; let body = impl_arg_params(spec, cb); - let output = &spec.output; + let body_to_result = body_to_result(&body, spec); quote! { #[allow(unused_mut)] @@ -152,9 +160,7 @@ pub fn impl_wrap_new(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> To let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: #output = { - #body - }; + #body_to_result match _result { Ok(_) => _obj.into_ptr(), @@ -176,7 +182,13 @@ pub fn impl_wrap_new(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> To /// Generate function wrapper for ffi::initproc fn impl_wrap_init(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> Tokens { let cb = impl_call(cls, name, &spec); + let output = &spec.output; + if quote! {#output} != quote! {PyResult<()>} || quote! {#output} != quote! {()}{ + panic!("Constructor must return PyResult<()> or a ()"); + } + let body = impl_arg_params(&spec, cb); + let body_to_result = body_to_result(&body, spec); quote! { #[allow(unused_mut)] @@ -192,9 +204,7 @@ fn impl_wrap_init(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> Token let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: PyResult<()> = { - #body - }; + #body_to_result match _result { Ok(_) => 0, Err(e) => { @@ -212,10 +222,11 @@ pub fn impl_wrap_class(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> |item| if item.1.py {syn::Ident::from("_py")} else { syn::Ident::from(format!("arg{}", item.0))}).collect(); let cb = quote! {{ - #cls::#name(&_cls, #(#names),*) + #cls::#name(&_cls, #(#names),*).return_type_into_py_result() }}; + let body = impl_arg_params(spec, cb); - let output = &spec.output; + let body_to_result = body_to_result(&body, spec); quote! { #[allow(unused_mut)] @@ -231,9 +242,7 @@ pub fn impl_wrap_class(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: #output = { - #body - }; + #body_to_result _pyo3::callback::cb_convert( _pyo3::callback::PyObjectCallbackConverter, _py, _result) } @@ -246,11 +255,11 @@ pub fn impl_wrap_static(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> |item| if item.1.py {syn::Ident::from("_py")} else { syn::Ident::from(format!("arg{}", item.0))}).collect(); let cb = quote! {{ - #cls::#name(#(#names),*) + #cls::#name(#(#names),*).return_type_into_py_result() }}; let body = impl_arg_params(spec, cb); - let output = &spec.output; + let body_to_result = body_to_result(&body, spec); quote! { #[allow(unused_mut)] @@ -265,9 +274,7 @@ pub fn impl_wrap_static(cls: &Box, name: &syn::Ident, spec: &FnSpec) -> let _args = _py.from_borrowed_ptr::<_pyo3::PyTuple>(_args); let _kwargs = _pyo3::argparse::get_kwargs(_py, _kwargs); - let _result: #output = { - #body - }; + #body_to_result _pyo3::callback::cb_convert( _pyo3::callback::PyObjectCallbackConverter, _py, _result) } @@ -344,7 +351,7 @@ fn impl_call(_cls: &Box, fname: &syn::Ident, spec: &FnSpec) -> Tokens { } ).collect(); quote! {{ - _slf.#fname(#(#names),*) + _slf.#fname(#(#names),*).return_type_into_py_result() }} } @@ -393,6 +400,7 @@ pub fn impl_arg_params(spec: &FnSpec, body: Tokens) -> Tokens { let mut rargs = spec.args.clone(); rargs.reverse(); let mut body = body; + for (idx, arg) in rargs.iter().enumerate() { body = impl_arg_param(&arg, &spec, &body, len-idx-1); } @@ -424,7 +432,7 @@ pub fn impl_arg_params(spec: &FnSpec, body: Tokens) -> Tokens { fn impl_arg_param(arg: &FnArg, spec: &FnSpec, body: &Tokens, idx: usize) -> Tokens { if arg.py { - return body.clone() + return body.clone(); } let ty = arg.ty; let name = arg.name; @@ -444,19 +452,17 @@ fn impl_arg_param(arg: &FnArg, spec: &FnSpec, body: &Tokens, idx: usize) -> Toke Err(e) => Err(e) } } - } - else if spec.is_kwargs(&name) { + } else if spec.is_kwargs(&name) { quote! {{ let #arg_name = _kwargs; #body }} - } - else { + } else { if let Some(_) = arg.optional { // default value let mut default = Tokens::new(); if let Some(d) = spec.default_value(name) { - let dt = quote!{ Some(#d) }; + let dt = quote! { Some(#d) }; dt.to_tokens(&mut default); } else { syn::Ident::from("None").to_tokens(&mut default); @@ -500,8 +506,7 @@ fn impl_arg_param(arg: &FnArg, spec: &FnSpec, body: &Tokens, idx: usize) -> Toke Err(e) => Err(e) } } - } - else { + } else { quote! { match _iter.next().unwrap().as_ref().unwrap().extract() { Ok(#arg_name) => { diff --git a/src/conversion.rs b/src/conversion.rs index 518b95fc..37165e96 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -301,3 +301,31 @@ impl PyTryFrom for T where T: PyTypeInfo { } } } + + +/// This trait wraps a T: IntoPyObject into PyResult while PyResult remains PyResult. +/// +/// This is necessaty because proc macros run before typechecking and can't decide +/// whether a return type is a (possibly aliased) PyResult or not. It is also quite handy because +/// the codegen is currently built on the assumption that all functions return a PyResult. +pub trait ReturnTypeIntoPyResult { + type Inner: ToPyObject; + + fn return_type_into_py_result(self) -> PyResult; +} + +impl ReturnTypeIntoPyResult for T { + type Inner = T; + + default fn return_type_into_py_result(self) -> PyResult { + Ok(self) + } +} + +impl ReturnTypeIntoPyResult for PyResult { + type Inner = T; + + fn return_type_into_py_result(self) -> PyResult { + self + } +} \ No newline at end of file diff --git a/src/err.rs b/src/err.rs index 59134ba2..4b77e243 100644 --- a/src/err.rs +++ b/src/err.rs @@ -140,7 +140,7 @@ pub struct PyDowncastError; /// Helper conversion trait that allows to use custom arguments for exception constructor. pub trait PyErrArguments { /// Arguments for exception - fn arguments(&self, Python) -> PyObject; + fn arguments(&self, _: Python) -> PyObject; } impl PyErr { diff --git a/src/lib.rs b/src/lib.rs index 8f1b65f6..9f4d8fcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,7 +165,8 @@ pub use python::{Python, ToPyPointer, IntoPyPointer, IntoPyDictPointer}; pub use pythonrun::{GILGuard, GILPool, prepare_freethreaded_python, prepare_pyo3_library}; pub use instance::{PyToken, PyObjectWithToken, AsPyRef, Py, PyNativeType}; pub use conversion::{FromPyObject, PyTryFrom, PyTryInto, - ToPyObject, ToBorrowedObject, IntoPyObject, IntoPyTuple}; + ToPyObject, ToBorrowedObject, IntoPyObject, IntoPyTuple, + ReturnTypeIntoPyResult}; pub mod class; pub use class::*; diff --git a/tests/test_class.rs b/tests/test_class.rs index 04c659f4..09c10f5e 100644 --- a/tests/test_class.rs +++ b/tests/test_class.rs @@ -396,6 +396,11 @@ impl StaticMethod { fn method(py: Python) -> PyResult<&'static str> { Ok("StaticMethod.method()!") } + + #[staticmethod] + fn no_parameters() -> PyResult<&'static str> { + Ok("StaticMethod.no_parameters()!") + } } #[test] @@ -408,6 +413,7 @@ fn static_method() { d.set_item("C", py.get_type::()).unwrap(); py.run("assert C.method() == 'StaticMethod.method()!'", None, Some(d)).unwrap(); py.run("assert C().method() == 'StaticMethod.method()!'", None, Some(d)).unwrap(); + py.run("assert C.no_parameters() == 'StaticMethod.no_parameters()!'", None, Some(d)).unwrap(); } #[py::class] @@ -1444,4 +1450,4 @@ fn mut_ref_arg() { py.run("inst1.set_other(inst2)", None, Some(d)).unwrap(); assert_eq!(inst2.as_ref(py).n, 100); -} +} \ No newline at end of file