pyfunction: reject generic functions

This commit is contained in:
David Hewitt 2021-03-09 23:37:01 +00:00
parent fe75b2da59
commit 9613228a0c
9 changed files with 155 additions and 101 deletions

View File

@ -6,7 +6,6 @@ use crate::utils;
use proc_macro2::TokenStream;
use quote::ToTokens;
use quote::{quote, quote_spanned};
use std::ops::Deref;
use syn::ext::IdentExt;
use syn::spanned::Spanned;
@ -21,6 +20,44 @@ pub struct FnArg<'a> {
pub attrs: PyFunctionArgAttrs,
}
impl<'a> FnArg<'a> {
/// Transforms a rust fn arg parsed with syn into a method::FnArg
pub fn parse(arg: &'a mut syn::FnArg) -> syn::Result<Self> {
match arg {
syn::FnArg::Receiver(recv) => {
bail_spanned!(recv.span() => "unexpected receiver")
} // checked in parse_fn_type
syn::FnArg::Typed(cap) => {
ensure_spanned!(
!matches!(&*cap.ty, syn::Type::ImplTrait(_)),
cap.ty.span() => IMPL_TRAIT_ERR
);
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
let (ident, by_ref, mutability) = match *cap.pat {
syn::Pat::Ident(syn::PatIdent {
ref ident,
ref by_ref,
ref mutability,
..
}) => (ident, by_ref, mutability),
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
};
Ok(FnArg {
name: ident,
by_ref,
mutability,
ty: &cap.ty,
optional: utils::option_type_argument(&cap.ty),
py: utils::is_python(&cap.ty),
attrs: arg_attrs,
})
}
}
}
}
#[derive(Clone, PartialEq, Debug, Copy, Eq)]
pub enum MethodTypeAttribute {
/// #[new]
@ -111,7 +148,13 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result<SelfType> {
syn::FnArg::Receiver(recv) => Ok(SelfType::Receiver {
mutable: recv.mutability.is_some(),
}),
syn::FnArg::Typed(syn::PatType { ty, .. }) => Ok(SelfType::TryFromPyCell(ty.span())),
syn::FnArg::Typed(syn::PatType { ty, .. }) => {
ensure_spanned!(
!matches!(&**ty, syn::Type::ImplTrait(_)),
ty.span() => IMPL_TRAIT_ERR
);
Ok(SelfType::TryFromPyCell(ty.span()))
}
}
}
@ -138,9 +181,16 @@ impl<'a> FnSpec<'a> {
let doc = utils::get_doc(&meth_attrs, text_signature, true)?;
let arguments = if skip_first_arg {
Self::parse_arguments(&mut sig.inputs.iter_mut().skip(1))?
sig.inputs
.iter_mut()
.skip(1)
.map(FnArg::parse)
.collect::<syn::Result<_>>()?
} else {
Self::parse_arguments(&mut sig.inputs.iter_mut())?
sig.inputs
.iter_mut()
.map(FnArg::parse)
.collect::<syn::Result<_>>()?
};
Ok(FnSpec {
@ -186,44 +236,6 @@ impl<'a> FnSpec<'a> {
Ok(text_signature)
}
fn parse_arguments(
// inputs: &'a mut [syn::FnArg],
inputs_iter: impl Iterator<Item = &'a mut syn::FnArg>,
) -> syn::Result<Vec<FnArg<'a>>> {
let mut arguments = vec![];
for input in inputs_iter {
match input {
syn::FnArg::Receiver(recv) => {
bail_spanned!(recv.span() => "unexpected receiver for method")
} // checked in parse_fn_type
syn::FnArg::Typed(cap) => {
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
let (ident, by_ref, mutability) = match *cap.pat {
syn::Pat::Ident(syn::PatIdent {
ref ident,
ref by_ref,
ref mutability,
..
}) => (ident, by_ref, mutability),
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
};
arguments.push(FnArg {
name: ident,
by_ref,
mutability,
ty: cap.ty.deref(),
optional: utils::option_type_argument(cap.ty.deref()),
py: utils::is_python(cap.ty.deref()),
attrs: arg_attrs,
});
}
}
}
Ok(arguments)
}
fn parse_fn_type(
sig: &syn::Signature,
fn_type_attr: Option<MethodTypeAttribute>,
@ -493,3 +505,5 @@ fn parse_method_name_attribute(
_ => name,
})
}
const IMPL_TRAIT_ERR: &str = "Python functions cannot have `impl Trait` arguments";

View File

@ -1,10 +1,9 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
//! Code generation for the function that initializes a python module and adds classes and function.
use crate::method;
use crate::pyfunction::{PyFunctionArgAttrs, PyFunctionAttr};
use crate::pymethod;
use crate::pymethod::get_arg_names;
use crate::method::{self, FnArg};
use crate::pyfunction::PyFunctionAttr;
use crate::pymethod::{check_generic, get_arg_names, impl_arg_params};
use crate::utils;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
@ -57,26 +56,6 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
Ok(())
}
/// Transforms a rust fn arg parsed with syn into a method::FnArg
fn wrap_fn_argument(cap: &mut syn::PatType) -> syn::Result<method::FnArg> {
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
let (mutability, by_ref, ident) = match &*cap.pat {
syn::Pat::Ident(patid) => (&patid.mutability, &patid.by_ref, &patid.ident),
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
};
Ok(method::FnArg {
name: ident,
mutability,
by_ref,
ty: &cap.ty,
optional: utils::option_type_argument(&cap.ty),
py: utils::is_python(&cap.ty),
attrs: arg_attrs,
})
}
/// Extracts the data from the #[pyfn(...)] attribute of a function
fn extract_pyfn_attrs(
attrs: &mut Vec<syn::Attribute>,
@ -143,36 +122,26 @@ pub fn add_fn_to_module(
python_name: Ident,
pyfn_attrs: PyFunctionAttr,
) -> syn::Result<TokenStream> {
let mut arguments = Vec::new();
check_generic(&func.sig)?;
for (i, input) in func.sig.inputs.iter_mut().enumerate() {
match input {
syn::FnArg::Receiver(_) => {
bail_spanned!(input.span() => "unexpected receiver for #[pyfn]");
}
syn::FnArg::Typed(cap) => {
if pyfn_attrs.pass_module && i == 0 {
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
if typath
.path
.segments
.last()
.map(|seg| seg.ident == "PyModule")
.unwrap_or(false)
{
continue;
}
}
}
bail_spanned!(
cap.span() => "expected &PyModule as first argument with `pass_module`"
);
} else {
arguments.push(wrap_fn_argument(cap)?);
}
}
}
let mut arguments = func
.sig
.inputs
.iter_mut()
.map(FnArg::parse)
.collect::<syn::Result<Vec<_>>>()?;
if pyfn_attrs.pass_module {
const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`";
ensure_spanned!(
!arguments.is_empty(),
func.span() => PASS_MODULE_ERR
);
let arg = arguments.remove(0);
ensure_spanned!(
type_is_pymodule(arg.ty),
arg.ty.span() => PASS_MODULE_ERR
);
}
let ty = method::get_return_info(&func.sig.output);
@ -217,6 +186,23 @@ pub fn add_fn_to_module(
})
}
fn type_is_pymodule(ty: &syn::Type) -> bool {
if let syn::Type::Reference(tyref) = ty {
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
if typath
.path
.segments
.last()
.map(|seg| seg.ident == "PyModule")
.unwrap_or(false)
{
return true;
}
}
}
false
}
/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
fn function_c_wrapper(
name: &Ident,
@ -240,7 +226,7 @@ fn function_c_wrapper(
};
slf_module = None;
};
let body = pymethod::impl_arg_params(spec, None, cb)?;
let body = impl_arg_params(spec, None, cb)?;
Ok(quote! {
unsafe extern "C" fn #wrapper_ident(
_slf: *mut pyo3::ffi::PyObject,

View File

@ -63,8 +63,8 @@ pub fn gen_py_method(
})
}
fn check_generic(sig: &syn::Signature) -> syn::Result<()> {
let err_msg = |typ| format!("a Python method can't have a generic {} parameter", typ);
pub(crate) fn check_generic(sig: &syn::Signature) -> syn::Result<()> {
let err_msg = |typ| format!("Python functions cannot have generic {} parameters", typ);
for param in &sig.generics.params {
match param {
syn::GenericParam::Lifetime(_) => {}

View File

@ -6,6 +6,7 @@ fn test_compile_errors() {
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
t.compile_fail("tests/ui/invalid_pyfunctions.rs");
t.compile_fail("tests/ui/invalid_pymethods.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_argument_attributes.rs");

View File

@ -1,5 +1,5 @@
error: expected &PyModule as first argument with `pass_module`
--> $DIR/invalid_need_module_arg_position.rs:6:13
--> $DIR/invalid_need_module_arg_position.rs:6:21
|
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
| ^^^^^^
| ^

View File

@ -0,0 +1,9 @@
use pyo3::prelude::*;
#[pyfunction]
fn generic_function<T>(value: T) {}
#[pyfunction]
fn impl_trait_function(impl_trait: impl AsRef<PyAny>) {}
fn main() {}

View File

@ -0,0 +1,11 @@
error: Python functions cannot have generic type parameters
--> $DIR/invalid_pyfunctions.rs:4:21
|
4 | fn generic_function<T>(value: T) {}
| ^
error: Python functions cannot have `impl Trait` arguments
--> $DIR/invalid_pyfunctions.rs:7:36
|
7 | fn impl_trait_function(impl_trait: impl AsRef<PyAny>) {}
| ^^^^

View File

@ -81,5 +81,20 @@ impl MyClass {
fn multiple_method_types() {}
}
#[pymethods]
impl MyClass {
fn generic_method<T>(value: T) {}
}
#[pymethods]
impl MyClass {
fn impl_trait_method_first_arg(impl_trait: impl AsRef<PyAny>) {}
}
#[pymethods]
impl MyClass {
fn impl_trait_method_second_arg(&self, impl_trait: impl AsRef<PyAny>) {}
}
fn main() {}

View File

@ -10,7 +10,7 @@ error: static method needs #[staticmethod] attribute
14 | fn staticmethod_without_attribute() {}
| ^^
error: unexpected receiver for method
error: unexpected receiver
--> $DIR/invalid_pymethods.rs:20:35
|
20 | fn staticmethod_with_receiver(&self) {}
@ -63,3 +63,21 @@ error: cannot specify a second method type
|
80 | #[staticmethod]
| ^^^^^^^^^^^^
error: Python functions cannot have generic type parameters
--> $DIR/invalid_pymethods.rs:86:23
|
86 | fn generic_method<T>(value: T) {}
| ^
error: Python functions cannot have `impl Trait` arguments
--> $DIR/invalid_pymethods.rs:92:48
|
92 | fn impl_trait_method_first_arg(impl_trait: impl AsRef<PyAny>) {}
| ^^^^
error: Python functions cannot have `impl Trait` arguments
--> $DIR/invalid_pymethods.rs:97:56
|
97 | fn impl_trait_method_second_arg(&self, impl_trait: impl AsRef<PyAny>) {}
| ^^^^