Added customizable conversion error message for type errors encounted in FromPyObject::extract

This commit is contained in:
R2D2 2021-05-28 00:24:47 +02:00
parent d71d4329ae
commit 8981b84b79
2 changed files with 26 additions and 7 deletions

View file

@ -12,6 +12,7 @@ pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(from_py_with);
syn::custom_keyword!(conversion_error);
syn::custom_keyword!(item);
syn::custom_keyword!(pass_module);
syn::custom_keyword!(name);

View file

@ -241,21 +241,24 @@ impl<'a> Container<'a> {
FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))),
};
let conversion_error_msg = attrs.conversion_error
.as_ref()
.map_or(format!("failed to extract field {}.{}",
quote!(#self_ty),
ident),
|msg| msg.value());
let get_field = quote!(obj.#getter?);
let extractor = match &attrs.from_py_with {
None => quote!(#get_field.extract().map_err(|inner| {
let err_msg = format!("failed to extract field {}.{}\n\nCaused by:\n {}\n",
stringify!(#self_ty),
stringify!(#ident),
let err_msg = format!("{}\n\nCaused by:\n {}\n",
#conversion_error_msg,
inner);
pyo3::exceptions::PyTypeError::new_err(err_msg)
})?),
Some(FromPyWithAttribute(expr_path)) => quote! (#expr_path(#get_field).
map_err(|inner| {
let err_msg = format!("failed to extract field {}.{}\n\nCaused by:\n {}\n",
stringify!(#self_ty),
stringify!(#ident),
let err_msg = format!("{}\n\nCaused by:\n {}\n",
#conversion_error_msg,
inner);
pyo3::exceptions::PyTypeError::new_err(err_msg)
})?),
@ -336,6 +339,7 @@ impl ContainerOptions {
struct FieldPyO3Attributes {
getter: FieldGetter,
from_py_with: Option<FromPyWithAttribute>,
conversion_error: Option<LitStr>,
}
#[derive(Clone, Debug)]
@ -347,6 +351,7 @@ enum FieldGetter {
enum FieldPyO3Attribute {
Getter(FieldGetter),
FromPyWith(FromPyWithAttribute),
ConversionError(LitStr),
}
impl Parse for FieldPyO3Attribute {
@ -390,6 +395,10 @@ impl Parse for FieldPyO3Attribute {
}
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(FieldPyO3Attribute::FromPyWith)
} else if lookahead.peek(attributes::kw::conversion_error) {
let _: attributes::kw::conversion_error = input.parse()?;
let _: Token![=] = input.parse()?;
input.parse().map(FieldPyO3Attribute::ConversionError)
} else {
Err(lookahead.error())
}
@ -402,6 +411,7 @@ impl FieldPyO3Attributes {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut getter = None;
let mut from_py_with = None;
let mut conversion_error = None;
for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_attributes(attr)? {
@ -421,6 +431,13 @@ impl FieldPyO3Attributes {
);
from_py_with = Some(from_py_with_attr);
}
FieldPyO3Attribute::ConversionError(conversion_error_msg) => {
ensure_spanned!(
conversion_error.is_none(),
attr.span() => "`conversion_error` may only be provided once"
);
conversion_error = Some(conversion_error_msg)
}
}
}
}
@ -429,6 +446,7 @@ impl FieldPyO3Attributes {
Ok(FieldPyO3Attributes {
getter: getter.unwrap_or(FieldGetter::GetAttr(None)),
from_py_with,
conversion_error,
})
}
}