refactor `parse_fn_type`

This commit is contained in:
David Hewitt 2023-10-02 08:26:56 +01:00
parent 1158c08f42
commit b3ee70db40
1 changed files with 41 additions and 41 deletions

View File

@ -85,6 +85,18 @@ pub enum FnType {
}
impl FnType {
pub fn skip_first_rust_argument_in_python_signature(&self) -> bool {
match self {
FnType::Getter(_)
| FnType::Setter(_)
| FnType::Fn(_)
| FnType::FnClass
| FnType::FnNewClass
| FnType::FnModule => true,
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => false,
}
}
pub fn self_arg(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream {
match self {
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => {
@ -264,26 +276,23 @@ impl<'a> FnSpec<'a> {
let mut python_name = name.map(|name| name.value.0);
let (fn_type, skip_first_arg, fixed_convention) =
Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
let fn_type = Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
ensure_signatures_on_valid_method(&fn_type, signature.as_ref(), text_signature.as_ref())?;
let name = &sig.ident;
let ty = get_return_info(&sig.output);
let python_name = python_name.as_ref().unwrap_or(name).unraw();
let arguments: Vec<_> = if skip_first_arg {
sig.inputs
.iter_mut()
.skip(1)
.map(FnArg::parse)
.collect::<Result<_>>()?
} else {
sig.inputs
.iter_mut()
.map(FnArg::parse)
.collect::<Result<_>>()?
};
let arguments: Vec<_> = sig
.inputs
.iter_mut()
.skip(if fn_type.skip_first_rust_argument_in_python_signature() {
1
} else {
0
})
.map(FnArg::parse)
.collect::<Result<_>>()?;
let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
@ -291,8 +300,11 @@ impl<'a> FnSpec<'a> {
FunctionSignature::from_arguments(arguments)?
};
let convention =
fixed_convention.unwrap_or_else(|| CallingConvention::from_signature(&signature));
let convention = if matches!(fn_type, FnType::FnNew | FnType::FnNewClass) {
CallingConvention::TpNew
} else {
CallingConvention::from_signature(&signature)
};
Ok(FnSpec {
tp: fn_type,
@ -314,7 +326,7 @@ impl<'a> FnSpec<'a> {
sig: &syn::Signature,
meth_attrs: &mut Vec<syn::Attribute>,
python_name: &mut Option<syn::Ident>,
) -> Result<(FnType, bool, Option<CallingConvention>)> {
) -> Result<FnType> {
let mut method_attributes = parse_method_attributes(meth_attrs)?;
let name = &sig.ident;
@ -334,16 +346,12 @@ impl<'a> FnSpec<'a> {
.map(|stripped| syn::Ident::new(stripped, name.span()))
};
let (fn_type, skip_first_arg, fixed_convention) = match method_attributes.as_mut_slice() {
[] => (
FnType::Fn(parse_receiver(
"static method needs #[staticmethod] attribute",
)?),
true,
None,
),
[MethodTypeAttribute::StaticMethod(_)] => (FnType::FnStatic, false, None),
[MethodTypeAttribute::ClassAttribute(_)] => (FnType::ClassAttribute, false, None),
let fn_type = match method_attributes.as_mut_slice() {
[] => FnType::Fn(parse_receiver(
"static method needs #[staticmethod] attribute",
)?),
[MethodTypeAttribute::StaticMethod(_)] => FnType::FnStatic,
[MethodTypeAttribute::ClassAttribute(_)] => FnType::ClassAttribute,
[MethodTypeAttribute::New(_)]
| [MethodTypeAttribute::New(_), MethodTypeAttribute::ClassMethod(_)]
| [MethodTypeAttribute::ClassMethod(_), MethodTypeAttribute::New(_)] => {
@ -352,12 +360,12 @@ impl<'a> FnSpec<'a> {
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
if matches!(method_attributes.as_slice(), [MethodTypeAttribute::New(_)]) {
(FnType::FnNew, false, Some(CallingConvention::TpNew))
FnType::FnNew
} else {
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
FnType::FnNewClass
}
}
[MethodTypeAttribute::ClassMethod(_)] => (FnType::FnClass, true, None),
[MethodTypeAttribute::ClassMethod(_)] => FnType::FnClass,
[MethodTypeAttribute::Getter(_, name)] => {
if let Some(name) = name.take() {
ensure_spanned!(
@ -369,11 +377,7 @@ impl<'a> FnSpec<'a> {
*python_name = strip_fn_name("get_");
}
(
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?),
true,
None,
)
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?)
}
[MethodTypeAttribute::Setter(_, name)] => {
if let Some(name) = name.take() {
@ -386,11 +390,7 @@ impl<'a> FnSpec<'a> {
*python_name = strip_fn_name("set_");
}
(
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?),
true,
None,
)
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?)
}
[first, rest @ .., last] => {
// Join as many of the spans together as possible
@ -416,7 +416,7 @@ impl<'a> FnSpec<'a> {
bail_spanned!(span => msg)
}
};
Ok((fn_type, skip_first_arg, fixed_convention))
Ok(fn_type)
}
/// Return a C wrapper function for this signature.