Added customizable conversion error message for type errors encounted in FromPyObject::extract
This commit is contained in:
parent
d71d4329ae
commit
8981b84b79
|
@ -12,6 +12,7 @@ pub mod kw {
|
|||
syn::custom_keyword!(annotation);
|
||||
syn::custom_keyword!(attribute);
|
||||
syn::custom_keyword!(from_py_with);
|
||||
syn::custom_keyword!(conversion_error);
|
||||
syn::custom_keyword!(item);
|
||||
syn::custom_keyword!(pass_module);
|
||||
syn::custom_keyword!(name);
|
||||
|
|
|
@ -241,21 +241,24 @@ impl<'a> Container<'a> {
|
|||
FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
|
||||
FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))),
|
||||
};
|
||||
|
||||
let conversion_error_msg = attrs.conversion_error
|
||||
.as_ref()
|
||||
.map_or(format!("failed to extract field {}.{}",
|
||||
quote!(#self_ty),
|
||||
ident),
|
||||
|msg| msg.value());
|
||||
let get_field = quote!(obj.#getter?);
|
||||
let extractor = match &attrs.from_py_with {
|
||||
None => quote!(#get_field.extract().map_err(|inner| {
|
||||
let err_msg = format!("failed to extract field {}.{}\n\nCaused by:\n {}\n",
|
||||
stringify!(#self_ty),
|
||||
stringify!(#ident),
|
||||
let err_msg = format!("{}\n\nCaused by:\n {}\n",
|
||||
#conversion_error_msg,
|
||||
inner);
|
||||
pyo3::exceptions::PyTypeError::new_err(err_msg)
|
||||
})?),
|
||||
Some(FromPyWithAttribute(expr_path)) => quote! (#expr_path(#get_field).
|
||||
map_err(|inner| {
|
||||
let err_msg = format!("failed to extract field {}.{}\n\nCaused by:\n {}\n",
|
||||
stringify!(#self_ty),
|
||||
stringify!(#ident),
|
||||
let err_msg = format!("{}\n\nCaused by:\n {}\n",
|
||||
#conversion_error_msg,
|
||||
inner);
|
||||
pyo3::exceptions::PyTypeError::new_err(err_msg)
|
||||
})?),
|
||||
|
@ -336,6 +339,7 @@ impl ContainerOptions {
|
|||
struct FieldPyO3Attributes {
|
||||
getter: FieldGetter,
|
||||
from_py_with: Option<FromPyWithAttribute>,
|
||||
conversion_error: Option<LitStr>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -347,6 +351,7 @@ enum FieldGetter {
|
|||
enum FieldPyO3Attribute {
|
||||
Getter(FieldGetter),
|
||||
FromPyWith(FromPyWithAttribute),
|
||||
ConversionError(LitStr),
|
||||
}
|
||||
|
||||
impl Parse for FieldPyO3Attribute {
|
||||
|
@ -390,6 +395,10 @@ impl Parse for FieldPyO3Attribute {
|
|||
}
|
||||
} else if lookahead.peek(attributes::kw::from_py_with) {
|
||||
input.parse().map(FieldPyO3Attribute::FromPyWith)
|
||||
} else if lookahead.peek(attributes::kw::conversion_error) {
|
||||
let _: attributes::kw::conversion_error = input.parse()?;
|
||||
let _: Token![=] = input.parse()?;
|
||||
input.parse().map(FieldPyO3Attribute::ConversionError)
|
||||
} else {
|
||||
Err(lookahead.error())
|
||||
}
|
||||
|
@ -402,6 +411,7 @@ impl FieldPyO3Attributes {
|
|||
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
|
||||
let mut getter = None;
|
||||
let mut from_py_with = None;
|
||||
let mut conversion_error = None;
|
||||
|
||||
for attr in attrs {
|
||||
if let Some(pyo3_attrs) = get_pyo3_attributes(attr)? {
|
||||
|
@ -421,6 +431,13 @@ impl FieldPyO3Attributes {
|
|||
);
|
||||
from_py_with = Some(from_py_with_attr);
|
||||
}
|
||||
FieldPyO3Attribute::ConversionError(conversion_error_msg) => {
|
||||
ensure_spanned!(
|
||||
conversion_error.is_none(),
|
||||
attr.span() => "`conversion_error` may only be provided once"
|
||||
);
|
||||
conversion_error = Some(conversion_error_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -429,6 +446,7 @@ impl FieldPyO3Attributes {
|
|||
Ok(FieldPyO3Attributes {
|
||||
getter: getter.unwrap_or(FieldGetter::GetAttr(None)),
|
||||
from_py_with,
|
||||
conversion_error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue