pyfunction: reject generic functions
This commit is contained in:
parent
fe75b2da59
commit
9613228a0c
|
@ -6,7 +6,6 @@ use crate::utils;
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use quote::{quote, quote_spanned};
|
||||
use std::ops::Deref;
|
||||
use syn::ext::IdentExt;
|
||||
use syn::spanned::Spanned;
|
||||
|
||||
|
@ -21,6 +20,44 @@ pub struct FnArg<'a> {
|
|||
pub attrs: PyFunctionArgAttrs,
|
||||
}
|
||||
|
||||
impl<'a> FnArg<'a> {
|
||||
/// Transforms a rust fn arg parsed with syn into a method::FnArg
|
||||
pub fn parse(arg: &'a mut syn::FnArg) -> syn::Result<Self> {
|
||||
match arg {
|
||||
syn::FnArg::Receiver(recv) => {
|
||||
bail_spanned!(recv.span() => "unexpected receiver")
|
||||
} // checked in parse_fn_type
|
||||
syn::FnArg::Typed(cap) => {
|
||||
ensure_spanned!(
|
||||
!matches!(&*cap.ty, syn::Type::ImplTrait(_)),
|
||||
cap.ty.span() => IMPL_TRAIT_ERR
|
||||
);
|
||||
|
||||
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
|
||||
let (ident, by_ref, mutability) = match *cap.pat {
|
||||
syn::Pat::Ident(syn::PatIdent {
|
||||
ref ident,
|
||||
ref by_ref,
|
||||
ref mutability,
|
||||
..
|
||||
}) => (ident, by_ref, mutability),
|
||||
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
|
||||
};
|
||||
|
||||
Ok(FnArg {
|
||||
name: ident,
|
||||
by_ref,
|
||||
mutability,
|
||||
ty: &cap.ty,
|
||||
optional: utils::option_type_argument(&cap.ty),
|
||||
py: utils::is_python(&cap.ty),
|
||||
attrs: arg_attrs,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Copy, Eq)]
|
||||
pub enum MethodTypeAttribute {
|
||||
/// #[new]
|
||||
|
@ -111,7 +148,13 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> syn::Result<SelfType> {
|
|||
syn::FnArg::Receiver(recv) => Ok(SelfType::Receiver {
|
||||
mutable: recv.mutability.is_some(),
|
||||
}),
|
||||
syn::FnArg::Typed(syn::PatType { ty, .. }) => Ok(SelfType::TryFromPyCell(ty.span())),
|
||||
syn::FnArg::Typed(syn::PatType { ty, .. }) => {
|
||||
ensure_spanned!(
|
||||
!matches!(&**ty, syn::Type::ImplTrait(_)),
|
||||
ty.span() => IMPL_TRAIT_ERR
|
||||
);
|
||||
Ok(SelfType::TryFromPyCell(ty.span()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,9 +181,16 @@ impl<'a> FnSpec<'a> {
|
|||
let doc = utils::get_doc(&meth_attrs, text_signature, true)?;
|
||||
|
||||
let arguments = if skip_first_arg {
|
||||
Self::parse_arguments(&mut sig.inputs.iter_mut().skip(1))?
|
||||
sig.inputs
|
||||
.iter_mut()
|
||||
.skip(1)
|
||||
.map(FnArg::parse)
|
||||
.collect::<syn::Result<_>>()?
|
||||
} else {
|
||||
Self::parse_arguments(&mut sig.inputs.iter_mut())?
|
||||
sig.inputs
|
||||
.iter_mut()
|
||||
.map(FnArg::parse)
|
||||
.collect::<syn::Result<_>>()?
|
||||
};
|
||||
|
||||
Ok(FnSpec {
|
||||
|
@ -186,44 +236,6 @@ impl<'a> FnSpec<'a> {
|
|||
Ok(text_signature)
|
||||
}
|
||||
|
||||
fn parse_arguments(
|
||||
// inputs: &'a mut [syn::FnArg],
|
||||
inputs_iter: impl Iterator<Item = &'a mut syn::FnArg>,
|
||||
) -> syn::Result<Vec<FnArg<'a>>> {
|
||||
let mut arguments = vec![];
|
||||
for input in inputs_iter {
|
||||
match input {
|
||||
syn::FnArg::Receiver(recv) => {
|
||||
bail_spanned!(recv.span() => "unexpected receiver for method")
|
||||
} // checked in parse_fn_type
|
||||
syn::FnArg::Typed(cap) => {
|
||||
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
|
||||
let (ident, by_ref, mutability) = match *cap.pat {
|
||||
syn::Pat::Ident(syn::PatIdent {
|
||||
ref ident,
|
||||
ref by_ref,
|
||||
ref mutability,
|
||||
..
|
||||
}) => (ident, by_ref, mutability),
|
||||
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
|
||||
};
|
||||
|
||||
arguments.push(FnArg {
|
||||
name: ident,
|
||||
by_ref,
|
||||
mutability,
|
||||
ty: cap.ty.deref(),
|
||||
optional: utils::option_type_argument(cap.ty.deref()),
|
||||
py: utils::is_python(cap.ty.deref()),
|
||||
attrs: arg_attrs,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(arguments)
|
||||
}
|
||||
|
||||
fn parse_fn_type(
|
||||
sig: &syn::Signature,
|
||||
fn_type_attr: Option<MethodTypeAttribute>,
|
||||
|
@ -493,3 +505,5 @@ fn parse_method_name_attribute(
|
|||
_ => name,
|
||||
})
|
||||
}
|
||||
|
||||
const IMPL_TRAIT_ERR: &str = "Python functions cannot have `impl Trait` arguments";
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
// Copyright (c) 2017-present PyO3 Project and Contributors
|
||||
//! Code generation for the function that initializes a python module and adds classes and function.
|
||||
|
||||
use crate::method;
|
||||
use crate::pyfunction::{PyFunctionArgAttrs, PyFunctionAttr};
|
||||
use crate::pymethod;
|
||||
use crate::pymethod::get_arg_names;
|
||||
use crate::method::{self, FnArg};
|
||||
use crate::pyfunction::PyFunctionAttr;
|
||||
use crate::pymethod::{check_generic, get_arg_names, impl_arg_params};
|
||||
use crate::utils;
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::{format_ident, quote};
|
||||
|
@ -57,26 +56,6 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Transforms a rust fn arg parsed with syn into a method::FnArg
|
||||
fn wrap_fn_argument(cap: &mut syn::PatType) -> syn::Result<method::FnArg> {
|
||||
let arg_attrs = PyFunctionArgAttrs::from_attrs(&mut cap.attrs)?;
|
||||
|
||||
let (mutability, by_ref, ident) = match &*cap.pat {
|
||||
syn::Pat::Ident(patid) => (&patid.mutability, &patid.by_ref, &patid.ident),
|
||||
_ => bail_spanned!(cap.pat.span() => "unsupported argument"),
|
||||
};
|
||||
|
||||
Ok(method::FnArg {
|
||||
name: ident,
|
||||
mutability,
|
||||
by_ref,
|
||||
ty: &cap.ty,
|
||||
optional: utils::option_type_argument(&cap.ty),
|
||||
py: utils::is_python(&cap.ty),
|
||||
attrs: arg_attrs,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extracts the data from the #[pyfn(...)] attribute of a function
|
||||
fn extract_pyfn_attrs(
|
||||
attrs: &mut Vec<syn::Attribute>,
|
||||
|
@ -143,36 +122,26 @@ pub fn add_fn_to_module(
|
|||
python_name: Ident,
|
||||
pyfn_attrs: PyFunctionAttr,
|
||||
) -> syn::Result<TokenStream> {
|
||||
let mut arguments = Vec::new();
|
||||
check_generic(&func.sig)?;
|
||||
|
||||
for (i, input) in func.sig.inputs.iter_mut().enumerate() {
|
||||
match input {
|
||||
syn::FnArg::Receiver(_) => {
|
||||
bail_spanned!(input.span() => "unexpected receiver for #[pyfn]");
|
||||
}
|
||||
syn::FnArg::Typed(cap) => {
|
||||
if pyfn_attrs.pass_module && i == 0 {
|
||||
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
|
||||
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
|
||||
if typath
|
||||
.path
|
||||
.segments
|
||||
.last()
|
||||
.map(|seg| seg.ident == "PyModule")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
bail_spanned!(
|
||||
cap.span() => "expected &PyModule as first argument with `pass_module`"
|
||||
);
|
||||
} else {
|
||||
arguments.push(wrap_fn_argument(cap)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut arguments = func
|
||||
.sig
|
||||
.inputs
|
||||
.iter_mut()
|
||||
.map(FnArg::parse)
|
||||
.collect::<syn::Result<Vec<_>>>()?;
|
||||
|
||||
if pyfn_attrs.pass_module {
|
||||
const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`";
|
||||
ensure_spanned!(
|
||||
!arguments.is_empty(),
|
||||
func.span() => PASS_MODULE_ERR
|
||||
);
|
||||
let arg = arguments.remove(0);
|
||||
ensure_spanned!(
|
||||
type_is_pymodule(arg.ty),
|
||||
arg.ty.span() => PASS_MODULE_ERR
|
||||
);
|
||||
}
|
||||
|
||||
let ty = method::get_return_info(&func.sig.output);
|
||||
|
@ -217,6 +186,23 @@ pub fn add_fn_to_module(
|
|||
})
|
||||
}
|
||||
|
||||
fn type_is_pymodule(ty: &syn::Type) -> bool {
|
||||
if let syn::Type::Reference(tyref) = ty {
|
||||
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
|
||||
if typath
|
||||
.path
|
||||
.segments
|
||||
.last()
|
||||
.map(|seg| seg.ident == "PyModule")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
|
||||
fn function_c_wrapper(
|
||||
name: &Ident,
|
||||
|
@ -240,7 +226,7 @@ fn function_c_wrapper(
|
|||
};
|
||||
slf_module = None;
|
||||
};
|
||||
let body = pymethod::impl_arg_params(spec, None, cb)?;
|
||||
let body = impl_arg_params(spec, None, cb)?;
|
||||
Ok(quote! {
|
||||
unsafe extern "C" fn #wrapper_ident(
|
||||
_slf: *mut pyo3::ffi::PyObject,
|
||||
|
|
|
@ -63,8 +63,8 @@ pub fn gen_py_method(
|
|||
})
|
||||
}
|
||||
|
||||
fn check_generic(sig: &syn::Signature) -> syn::Result<()> {
|
||||
let err_msg = |typ| format!("a Python method can't have a generic {} parameter", typ);
|
||||
pub(crate) fn check_generic(sig: &syn::Signature) -> syn::Result<()> {
|
||||
let err_msg = |typ| format!("Python functions cannot have generic {} parameters", typ);
|
||||
for param in &sig.generics.params {
|
||||
match param {
|
||||
syn::GenericParam::Lifetime(_) => {}
|
||||
|
|
|
@ -6,6 +6,7 @@ fn test_compile_errors() {
|
|||
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
|
||||
t.compile_fail("tests/ui/invalid_property_args.rs");
|
||||
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
|
||||
t.compile_fail("tests/ui/invalid_pyfunctions.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethods.rs");
|
||||
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
|
||||
t.compile_fail("tests/ui/invalid_argument_attributes.rs");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
error: expected &PyModule as first argument with `pass_module`
|
||||
--> $DIR/invalid_need_module_arg_position.rs:6:13
|
||||
--> $DIR/invalid_need_module_arg_position.rs:6:21
|
||||
|
|
||||
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
|
||||
| ^^^^^^
|
||||
| ^
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
use pyo3::prelude::*;
|
||||
|
||||
#[pyfunction]
|
||||
fn generic_function<T>(value: T) {}
|
||||
|
||||
#[pyfunction]
|
||||
fn impl_trait_function(impl_trait: impl AsRef<PyAny>) {}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,11 @@
|
|||
error: Python functions cannot have generic type parameters
|
||||
--> $DIR/invalid_pyfunctions.rs:4:21
|
||||
|
|
||||
4 | fn generic_function<T>(value: T) {}
|
||||
| ^
|
||||
|
||||
error: Python functions cannot have `impl Trait` arguments
|
||||
--> $DIR/invalid_pyfunctions.rs:7:36
|
||||
|
|
||||
7 | fn impl_trait_function(impl_trait: impl AsRef<PyAny>) {}
|
||||
| ^^^^
|
|
@ -81,5 +81,20 @@ impl MyClass {
|
|||
fn multiple_method_types() {}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MyClass {
|
||||
fn generic_method<T>(value: T) {}
|
||||
}
|
||||
|
||||
|
||||
#[pymethods]
|
||||
impl MyClass {
|
||||
fn impl_trait_method_first_arg(impl_trait: impl AsRef<PyAny>) {}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MyClass {
|
||||
fn impl_trait_method_second_arg(&self, impl_trait: impl AsRef<PyAny>) {}
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
@ -10,7 +10,7 @@ error: static method needs #[staticmethod] attribute
|
|||
14 | fn staticmethod_without_attribute() {}
|
||||
| ^^
|
||||
|
||||
error: unexpected receiver for method
|
||||
error: unexpected receiver
|
||||
--> $DIR/invalid_pymethods.rs:20:35
|
||||
|
|
||||
20 | fn staticmethod_with_receiver(&self) {}
|
||||
|
@ -63,3 +63,21 @@ error: cannot specify a second method type
|
|||
|
|
||||
80 | #[staticmethod]
|
||||
| ^^^^^^^^^^^^
|
||||
|
||||
error: Python functions cannot have generic type parameters
|
||||
--> $DIR/invalid_pymethods.rs:86:23
|
||||
|
|
||||
86 | fn generic_method<T>(value: T) {}
|
||||
| ^
|
||||
|
||||
error: Python functions cannot have `impl Trait` arguments
|
||||
--> $DIR/invalid_pymethods.rs:92:48
|
||||
|
|
||||
92 | fn impl_trait_method_first_arg(impl_trait: impl AsRef<PyAny>) {}
|
||||
| ^^^^
|
||||
|
||||
error: Python functions cannot have `impl Trait` arguments
|
||||
--> $DIR/invalid_pymethods.rs:97:56
|
||||
|
|
||||
97 | fn impl_trait_method_second_arg(&self, impl_trait: impl AsRef<PyAny>) {}
|
||||
| ^^^^
|
||||
|
|
Loading…
Reference in New Issue