'#[derive(FromPyObject)]` changes suggested by @davidwhewitt.
This commit is contained in:
parent
7168309464
commit
60fe4925f5
|
@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
- Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098)
|
||||
- Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/)
|
||||
- Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106).
|
||||
- Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065)
|
||||
|
||||
### Changed
|
||||
- Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py<T>` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024)
|
||||
|
|
|
@ -0,0 +1,496 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::{parse_quote, Attribute, DataEnum, DeriveInput, ExprCall, Fields, Ident, Result};
|
||||
|
||||
/// Describes derivation input of an enum.
|
||||
#[derive(Debug)]
|
||||
struct Enum<'a> {
|
||||
enum_ident: &'a Ident,
|
||||
variants: Vec<Container<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Enum<'a> {
|
||||
/// Construct a new enum representation.
|
||||
///
|
||||
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
|
||||
/// `Identifier` of the enum.
|
||||
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
|
||||
if data_enum.variants.is_empty() {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&data_enum.variants,
|
||||
"Cannot derive FromPyObject for empty enum.",
|
||||
));
|
||||
}
|
||||
let vars = data_enum
|
||||
.variants
|
||||
.iter()
|
||||
.map(|variant| {
|
||||
let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?;
|
||||
let var_ident = &variant.ident;
|
||||
Container::new(
|
||||
&variant.fields,
|
||||
parse_quote!(#ident::#var_ident),
|
||||
attrs,
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Enum {
|
||||
enum_ident: ident,
|
||||
variants: vars,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build derivation body for enums.
|
||||
fn build(&self) -> TokenStream {
|
||||
let mut var_extracts = Vec::new();
|
||||
let mut error_names = String::new();
|
||||
for (i, var) in self.variants.iter().enumerate() {
|
||||
let struct_derive = var.build();
|
||||
let ext = quote!(
|
||||
let maybe_ret = || -> ::pyo3::PyResult<Self> {
|
||||
#struct_derive
|
||||
}();
|
||||
if maybe_ret.is_ok() {
|
||||
return maybe_ret
|
||||
}
|
||||
);
|
||||
|
||||
var_extracts.push(ext);
|
||||
error_names.push_str(&var.err_name);
|
||||
if i < self.variants.len() - 1 {
|
||||
error_names.push_str(", ");
|
||||
}
|
||||
}
|
||||
quote!(
|
||||
#(#var_extracts)*
|
||||
let type_name = obj.get_type().name();
|
||||
let from = obj
|
||||
.repr()
|
||||
.map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
|
||||
.unwrap_or_else(|_| type_name.to_string());
|
||||
let err_msg = format!("Can't convert {} to {}", from, #error_names);
|
||||
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Container Style
|
||||
///
|
||||
/// Covers Structs, Tuplestructs and corresponding Newtypes.
|
||||
#[derive(Debug)]
|
||||
enum ContainerType<'a> {
|
||||
/// Struct Container, e.g. `struct Foo { a: String }`
|
||||
///
|
||||
/// Variant contains the list of field identifiers and the corresponding extraction call.
|
||||
Struct(Vec<(&'a Ident, FieldAttribute)>),
|
||||
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
|
||||
///
|
||||
/// The field specified by the identifier is extracted directly from the object.
|
||||
StructNewtype(&'a Ident),
|
||||
/// Tuple struct, e.g. `struct Foo(String)`.
|
||||
///
|
||||
/// Fields are extracted from a tuple.
|
||||
Tuple(usize),
|
||||
/// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
|
||||
///
|
||||
/// The wrapped field is directly extracted from the object.
|
||||
TupleNewtype,
|
||||
}
|
||||
|
||||
/// Data container
|
||||
///
|
||||
/// Either describes a struct or an enum variant.
|
||||
#[derive(Debug)]
|
||||
struct Container<'a> {
|
||||
path: syn::Path,
|
||||
ty: ContainerType<'a>,
|
||||
err_name: String,
|
||||
is_enum_variant: bool,
|
||||
}
|
||||
|
||||
impl<'a> Container<'a> {
|
||||
/// Construct a container based on fields, identifier and attributes.
|
||||
///
|
||||
/// Fails if the variant has no fields or incompatible attributes.
|
||||
fn new(
|
||||
fields: &'a Fields,
|
||||
path: syn::Path,
|
||||
attrs: Vec<ContainerAttribute>,
|
||||
is_enum_variant: bool,
|
||||
) -> Result<Self> {
|
||||
let transparent = attrs.iter().any(ContainerAttribute::transparent);
|
||||
if transparent {
|
||||
Self::check_transparent_len(fields)?;
|
||||
}
|
||||
let style = match (fields, transparent) {
|
||||
(Fields::Unnamed(_), true) => ContainerType::TupleNewtype,
|
||||
(Fields::Unnamed(unnamed), false) => ContainerType::Tuple(unnamed.unnamed.len()),
|
||||
(Fields::Named(named), true) => {
|
||||
let field = named
|
||||
.named
|
||||
.iter()
|
||||
.next()
|
||||
.expect("Check for len 1 is done above");
|
||||
let ident = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("Named fields should have identifiers");
|
||||
ContainerType::StructNewtype(ident)
|
||||
}
|
||||
(Fields::Named(named), false) => {
|
||||
let mut fields = Vec::new();
|
||||
for field in named.named.iter() {
|
||||
let ident = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("Named fields should have identifiers");
|
||||
let attr = FieldAttribute::parse_attrs(&field.attrs)?
|
||||
.unwrap_or_else(|| FieldAttribute::Ident(parse_quote!(getattr)));
|
||||
fields.push((ident, attr))
|
||||
}
|
||||
ContainerType::Struct(fields)
|
||||
}
|
||||
(Fields::Unit, _) => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&fields,
|
||||
"Cannot derive FromPyObject for Unit structs and variants",
|
||||
))
|
||||
}
|
||||
};
|
||||
let err_name = attrs
|
||||
.iter()
|
||||
.find_map(|a| a.annotation())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| path.segments.last().unwrap().ident.to_string());
|
||||
|
||||
let v = Container {
|
||||
path,
|
||||
ty: style,
|
||||
err_name,
|
||||
is_enum_variant,
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> {
|
||||
for attr in attrs {
|
||||
match attr {
|
||||
ContainerAttribute::Transparent => continue,
|
||||
ContainerAttribute::ErrorAnnotation(_) => {
|
||||
return Err(syn::Error::new(
|
||||
Span::call_site(),
|
||||
"Annotating error messages for structs is \
|
||||
not supported. Remove the annotation attribute.",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build derivation body for a struct.
|
||||
fn build(&self) -> TokenStream {
|
||||
match &self.ty {
|
||||
ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
|
||||
ContainerType::TupleNewtype => self.build_newtype_struct(None),
|
||||
ContainerType::Tuple(len) => self.build_tuple_struct(*len),
|
||||
ContainerType::Struct(tups) => self.build_struct(tups),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
|
||||
let self_ty = &self.path;
|
||||
if let Some(ident) = field_ident {
|
||||
quote!(
|
||||
Ok(#self_ty{#ident: obj.extract()?})
|
||||
)
|
||||
} else {
|
||||
quote!(Ok(#self_ty(obj.extract()?)))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tuple_struct(&self, len: usize) -> TokenStream {
|
||||
let self_ty = &self.path;
|
||||
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
|
||||
for i in 0..len {
|
||||
fields.push(quote!(slice[#i].extract()?));
|
||||
}
|
||||
let msg = if self.is_enum_variant {
|
||||
quote!(format!(
|
||||
"Expected tuple of length {}, but got length {}.",
|
||||
#len,
|
||||
s.len()
|
||||
))
|
||||
} else {
|
||||
quote!("")
|
||||
};
|
||||
quote!(
|
||||
let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
|
||||
if s.len() != #len {
|
||||
return Err(::pyo3::exceptions::PyValueError::py_err(#msg))
|
||||
}
|
||||
let slice = s.as_slice();
|
||||
Ok(#self_ty(#fields))
|
||||
)
|
||||
}
|
||||
|
||||
fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream {
|
||||
let self_ty = &self.path;
|
||||
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
|
||||
for (ident, attr) in tups {
|
||||
let ext_fn = match attr {
|
||||
FieldAttribute::IdentWithArg(expr) => quote!(#expr),
|
||||
FieldAttribute::Ident(meth) => {
|
||||
let arg = ident.to_string();
|
||||
quote!(#meth(#arg))
|
||||
}
|
||||
};
|
||||
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
|
||||
}
|
||||
quote!(Ok(#self_ty{#fields}))
|
||||
}
|
||||
|
||||
fn check_transparent_len(fields: &Fields) -> Result<()> {
|
||||
if fields.len() != 1 {
|
||||
return Err(syn::Error::new_spanned(
|
||||
fields,
|
||||
"Transparent structs and variants can only have 1 field",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Attributes for deriving FromPyObject scoped on containers.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
enum ContainerAttribute {
|
||||
/// Treat the Container as a Wrapper, directly extract its fields from the input object.
|
||||
Transparent,
|
||||
/// Change the name of an enum variant in the generated error message.
|
||||
ErrorAnnotation(String),
|
||||
}
|
||||
|
||||
impl ContainerAttribute {
|
||||
/// Return whether this attribute is `Transparent`
|
||||
fn transparent(&self) -> bool {
|
||||
match self {
|
||||
ContainerAttribute::Transparent => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience method to access `ErrorAnnotation`.
|
||||
fn annotation(&self) -> Option<&String> {
|
||||
match self {
|
||||
ContainerAttribute::ErrorAnnotation(s) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse valid container arguments
|
||||
///
|
||||
/// Fails if any are invalid.
|
||||
fn parse_attrs(value: &[Attribute]) -> Result<Vec<Self>> {
|
||||
let mut attrs = Vec::new();
|
||||
let list = get_pyo3_meta_list(value)?;
|
||||
for meta in list.nested {
|
||||
if let syn::NestedMeta::Meta(metaitem) = &meta {
|
||||
match metaitem {
|
||||
syn::Meta::Path(p) if p.is_ident("transparent") => {
|
||||
attrs.push(ContainerAttribute::Transparent)
|
||||
}
|
||||
syn::Meta::NameValue(nv) if nv.path.is_ident("annotation") => {
|
||||
if let syn::Lit::Str(s) = &nv.lit {
|
||||
attrs.push(ContainerAttribute::ErrorAnnotation(s.value()))
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&nv.lit,
|
||||
"Expected string literal.",
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(
|
||||
meta,
|
||||
"Unknown container attribute, expected `transparent` or \
|
||||
`annotation(\"err_name\")`",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(attrs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Attributes for deriving FromPyObject scoped on fields.
|
||||
#[derive(Clone, Debug)]
|
||||
enum FieldAttribute {
|
||||
/// How a specific field should be extracted.
|
||||
Ident(Ident),
|
||||
IdentWithArg(ExprCall),
|
||||
}
|
||||
|
||||
impl FieldAttribute {
|
||||
/// Extract the field attribute.
|
||||
///
|
||||
/// Currently fails if more than 1 attribute is passed in `pyo3`
|
||||
fn parse_attrs(attrs: &[Attribute]) -> Result<Option<Self>> {
|
||||
let list = get_pyo3_meta_list(attrs)?;
|
||||
if list.nested.len() > 1 {
|
||||
return Err(syn::Error::new_spanned(
|
||||
list,
|
||||
"Only one of `item`, `attribute` can be provided, possibly as \
|
||||
a key-value pair: `attribute = \"name\"`.",
|
||||
));
|
||||
}
|
||||
let meta = if let Some(attr) = list.nested.first() {
|
||||
attr
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
if let syn::NestedMeta::Meta(metaitem) = meta {
|
||||
let path = metaitem.path();
|
||||
let ident = Self::check_valid_ident(path)?;
|
||||
match metaitem {
|
||||
syn::Meta::NameValue(nv) => Self::get_ident_with_arg(ident, &nv.lit).map(Some),
|
||||
syn::Meta::Path(_) => Ok(Some(FieldAttribute::Ident(parse_quote!(#ident)))),
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
metaitem,
|
||||
"`item` or `attribute` need to be passed alone or as key-value \
|
||||
pairs, e.g. `attribute = \"name\"`.",
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(meta, "Unexpected literal."))
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the attribute path and return it if it is valid.
|
||||
fn check_valid_ident(path: &syn::Path) -> Result<Ident> {
|
||||
if path.is_ident("item") {
|
||||
Ok(parse_quote!(get_item))
|
||||
} else if path.is_ident("attribute") {
|
||||
Ok(parse_quote!(getattr))
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
path,
|
||||
"Expected `item` or `attribute`",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to build `IdentWithArg` based on identifier and literal.
|
||||
fn get_ident_with_arg(ident: Ident, lit: &syn::Lit) -> Result<Self> {
|
||||
if ident == "getattr" {
|
||||
if let syn::Lit::Str(s) = lit {
|
||||
return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#s))));
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(lit, "Expected string literal."));
|
||||
}
|
||||
}
|
||||
if ident == "get_item" {
|
||||
return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#lit))));
|
||||
}
|
||||
|
||||
// path is already checked in the `parse_attrs` loop, returning the error here anyways.
|
||||
Err(syn::Error::new_spanned(
|
||||
ident,
|
||||
"Expected `item` or `attribute`.",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract pyo3 metalist, flattens multiple lists into a single one.
|
||||
fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<syn::MetaList> {
|
||||
let mut list: Punctuated<syn::NestedMeta, syn::Token![,]> = Punctuated::new();
|
||||
for value in attrs {
|
||||
match value.parse_meta()? {
|
||||
syn::Meta::List(ml) if value.path.is_ident("pyo3") => {
|
||||
for meta in ml.nested {
|
||||
list.push(meta);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
value,
|
||||
"Expected `pyo3()` attribute.",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(syn::MetaList {
|
||||
path: parse_quote!(pyo3),
|
||||
paren_token: syn::token::Paren::default(),
|
||||
nested: list,
|
||||
})
|
||||
}
|
||||
|
||||
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
|
||||
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
|
||||
if lifetimes.len() > 1 {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&generics,
|
||||
"FromPyObject can only be derived with at most one lifetime parameter.",
|
||||
));
|
||||
}
|
||||
Ok(lifetimes.into_iter().next())
|
||||
}
|
||||
|
||||
/// Derive FromPyObject for enums and structs.
|
||||
///
|
||||
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
|
||||
/// * At least one field, in case of `#[transparent]`, exactly one field
|
||||
/// * At least one variant for enums.
|
||||
/// * Fields of input structs and enums must implement `FromPyObject`
|
||||
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
|
||||
/// adds `T: FromPyObject` on the derived implementation.
|
||||
pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
|
||||
let mut trait_generics = tokens.generics.clone();
|
||||
let generics = &tokens.generics;
|
||||
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
|
||||
lt.clone()
|
||||
} else {
|
||||
trait_generics.params.push(parse_quote!('source));
|
||||
parse_quote!('source)
|
||||
};
|
||||
let mut where_clause: syn::WhereClause = parse_quote!(where);
|
||||
for param in generics.type_params() {
|
||||
let gen_ident = ¶m.ident;
|
||||
where_clause
|
||||
.predicates
|
||||
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
|
||||
}
|
||||
let derives = match &tokens.data {
|
||||
syn::Data::Enum(en) => {
|
||||
let en = Enum::new(en, &tokens.ident)?;
|
||||
en.build()
|
||||
}
|
||||
syn::Data::Struct(st) => {
|
||||
let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?;
|
||||
Container::verify_struct_container_attrs(&attrs)?;
|
||||
let ident = &tokens.ident;
|
||||
let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?;
|
||||
st.build()
|
||||
}
|
||||
_ => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
tokens,
|
||||
"FromPyObject can only be derived for structs and enums.",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let ident = &tokens.ident;
|
||||
Ok(quote!(
|
||||
#[automatically_derived]
|
||||
impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
|
||||
fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
|
||||
#derives
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
|
@ -1,428 +0,0 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Paren;
|
||||
use syn::{
|
||||
parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result,
|
||||
Variant,
|
||||
};
|
||||
|
||||
/// Describes derivation input of an enum.
|
||||
#[derive(Debug)]
|
||||
struct Enum<'a> {
|
||||
enum_ident: &'a Ident,
|
||||
vars: Vec<Container<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Enum<'a> {
|
||||
/// Construct a new enum representation.
|
||||
///
|
||||
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
|
||||
/// `Identifier` of the enum.
|
||||
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
|
||||
if data_enum.variants.is_empty() {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&data_enum.variants,
|
||||
"Cannot derive FromPyObject for empty enum.",
|
||||
));
|
||||
}
|
||||
let vars = data_enum
|
||||
.variants
|
||||
.iter()
|
||||
.map(Container::from_variant)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Enum {
|
||||
enum_ident: ident,
|
||||
vars,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build derivation body for enums.
|
||||
fn derive_enum(&self) -> TokenStream {
|
||||
let mut var_extracts = Vec::new();
|
||||
let mut error_names = String::new();
|
||||
for (i, var) in self.vars.iter().enumerate() {
|
||||
let ext = match &var.style {
|
||||
Style::Struct(tups) => self.build_struct_variant(tups, var.ident),
|
||||
Style::StructNewtype(ident) => {
|
||||
self.build_transparent_variant(var.ident, Some(ident))
|
||||
}
|
||||
Style::Tuple(len) => self.build_tuple_variant(var.ident, *len),
|
||||
Style::TupleNewtype => self.build_transparent_variant(var.ident, None),
|
||||
};
|
||||
var_extracts.push(ext);
|
||||
error_names.push_str(&var.err_name);
|
||||
if i < self.vars.len() - 1 {
|
||||
error_names.push_str(", ");
|
||||
}
|
||||
}
|
||||
quote!(
|
||||
#(#var_extracts)*
|
||||
let type_name = obj.get_type().name();
|
||||
let from = obj
|
||||
.repr()
|
||||
.map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
|
||||
.unwrap_or_else(|_| type_name.to_string());
|
||||
let err_msg = format!("Can't convert {} to {}", from, #error_names);
|
||||
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
|
||||
)
|
||||
}
|
||||
|
||||
/// Build match for tuple struct variant.
|
||||
fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream {
|
||||
let enum_ident = self.enum_ident;
|
||||
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
|
||||
let mut fields: Punctuated<Ident, syn::Token![,]> = Punctuated::new();
|
||||
let mut field_pats = PatTuple {
|
||||
attrs: vec![],
|
||||
paren_token: Paren::default(),
|
||||
elems: Default::default(),
|
||||
};
|
||||
for i in 0..len {
|
||||
ext.push(parse_quote!(slice[#i].extract()));
|
||||
let ident = Ident::new(&format!("_field{}", i), Span::call_site());
|
||||
field_pats.elems.push(parse_quote!(Ok(#ident)));
|
||||
fields.push(ident);
|
||||
}
|
||||
|
||||
quote!(
|
||||
match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) {
|
||||
Ok(s) => {
|
||||
if s.len() == #len {
|
||||
let slice = s.as_slice();
|
||||
if let (#field_pats) = (#ext) {
|
||||
return Ok(#enum_ident::#var_ident(#fields))
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(_) => {}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Build match for transparent enum variants.
|
||||
fn build_transparent_variant(
|
||||
&self,
|
||||
var_ident: &Ident,
|
||||
field_ident: Option<&Ident>,
|
||||
) -> TokenStream {
|
||||
let enum_ident = self.enum_ident;
|
||||
if let Some(ident) = field_ident {
|
||||
quote!(
|
||||
if let Ok(#ident) = obj.extract() {
|
||||
return Ok(#enum_ident::#var_ident{#ident})
|
||||
}
|
||||
)
|
||||
} else {
|
||||
quote!(
|
||||
if let Ok(inner) = obj.extract() {
|
||||
return Ok(#enum_ident::#var_ident(inner))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build match for struct variant with named fields.
|
||||
fn build_struct_variant(
|
||||
&self,
|
||||
tups: &[(&'a Ident, ExprCall)],
|
||||
var_ident: &Ident,
|
||||
) -> TokenStream {
|
||||
let enum_ident = self.enum_ident;
|
||||
let mut field_pats = PatTuple {
|
||||
attrs: vec![],
|
||||
paren_token: Paren::default(),
|
||||
elems: Default::default(),
|
||||
};
|
||||
let mut fields: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
|
||||
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
|
||||
for (ident, ext_fn) in tups {
|
||||
field_pats.elems.push(parse_quote!(Ok(#ident)));
|
||||
fields.push(parse_quote!(#ident));
|
||||
ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract())));
|
||||
}
|
||||
quote!(if let #field_pats = #ext {
|
||||
return Ok(#enum_ident::#var_ident{#fields});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Container Style
|
||||
///
|
||||
/// Covers Structs, Tuplestructs and corresponding Newtypes.
|
||||
#[derive(Clone, Debug)]
|
||||
enum Style<'a> {
|
||||
/// Struct Container, e.g. `struct Foo { a: String }`
|
||||
///
|
||||
/// Variant contains the list of field identifiers and the corresponding extraction call.
|
||||
Struct(Vec<(&'a Ident, ExprCall)>),
|
||||
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
|
||||
///
|
||||
/// The field specified by the identifier is extracted directly from the object.
|
||||
StructNewtype(&'a Ident),
|
||||
/// Tuple struct, e.g. `struct Foo(String)`.
|
||||
///
|
||||
/// Fields are extracted from a tuple.
|
||||
Tuple(usize),
|
||||
/// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
|
||||
///
|
||||
/// The wrapped field is directly extracted from the object.
|
||||
TupleNewtype,
|
||||
}
|
||||
|
||||
/// Data container
|
||||
///
|
||||
/// Either describes a struct or an enum variant.
|
||||
#[derive(Debug)]
|
||||
struct Container<'a> {
|
||||
ident: &'a Ident,
|
||||
style: Style<'a>,
|
||||
err_name: String,
|
||||
}
|
||||
|
||||
impl<'a> Container<'a> {
|
||||
/// Construct a container from an enum Variant.
|
||||
///
|
||||
/// Fails if the variant has no fields or incompatible attributes.
|
||||
fn from_variant(var: &'a Variant) -> Result<Self> {
|
||||
Self::new(&var.fields, &var.ident, &var.attrs)
|
||||
}
|
||||
|
||||
/// Construct a container based on fields, identifier and attributes.
|
||||
///
|
||||
/// Fails if the variant has no fields or incompatible attributes.
|
||||
fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result<Self> {
|
||||
let transparent = attrs.iter().any(|a| a.path.is_ident("transparent"));
|
||||
if transparent {
|
||||
Self::check_transparent_len(fields)?;
|
||||
}
|
||||
let style = match fields {
|
||||
Fields::Unnamed(unnamed) => {
|
||||
if transparent {
|
||||
Style::TupleNewtype
|
||||
} else {
|
||||
Style::Tuple(unnamed.unnamed.len())
|
||||
}
|
||||
}
|
||||
Fields::Named(named) => {
|
||||
if transparent {
|
||||
let field = named
|
||||
.named
|
||||
.iter()
|
||||
.next()
|
||||
.expect("Check for len 1 is done above");
|
||||
let ident = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("Named fields should have identifiers");
|
||||
Style::StructNewtype(ident)
|
||||
} else {
|
||||
let mut fields = Vec::new();
|
||||
for field in named.named.iter() {
|
||||
let ident = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("Named fields should have identifiers");
|
||||
fields.push((ident, ext_fn(&field.attrs, ident)?))
|
||||
}
|
||||
Style::Struct(fields)
|
||||
}
|
||||
}
|
||||
Fields::Unit => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&fields,
|
||||
"Cannot derive FromPyObject for Unit structs and variants",
|
||||
))
|
||||
}
|
||||
};
|
||||
let err_name = maybe_renamed_err(&attrs)?
|
||||
.map(|s| s.value())
|
||||
.unwrap_or_else(|| ident.to_string());
|
||||
|
||||
let v = Container {
|
||||
ident: &ident,
|
||||
style,
|
||||
err_name,
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Build derivation body for a struct.
|
||||
fn derive_struct(&self) -> TokenStream {
|
||||
match &self.style {
|
||||
Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
|
||||
Style::TupleNewtype => self.build_newtype_struct(None),
|
||||
Style::Tuple(len) => self.build_tuple_struct(*len),
|
||||
Style::Struct(tups) => self.build_struct(tups),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
|
||||
if let Some(ident) = field_ident {
|
||||
quote!(
|
||||
Ok(Self{#ident: obj.extract()?})
|
||||
)
|
||||
} else {
|
||||
quote!(Ok(Self(obj.extract()?)))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tuple_struct(&self, len: usize) -> TokenStream {
|
||||
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
|
||||
for i in 0..len {
|
||||
fields.push(quote!(slice[#i].extract()?));
|
||||
}
|
||||
quote!(
|
||||
let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
|
||||
let seq_len = s.len();
|
||||
if seq_len != #len {
|
||||
let msg = format!(
|
||||
"Expected tuple of length {}, but got length {}.",
|
||||
#len,
|
||||
seq_len
|
||||
);
|
||||
return Err(::pyo3::exceptions::PyValueError::py_err(msg))
|
||||
}
|
||||
let slice = s.as_slice();
|
||||
Ok(Self(#fields))
|
||||
)
|
||||
}
|
||||
|
||||
fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream {
|
||||
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
|
||||
for (ident, ext_fn) in tups {
|
||||
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
|
||||
}
|
||||
quote!(Ok(Self{#fields}))
|
||||
}
|
||||
|
||||
fn check_transparent_len(fields: &Fields) -> Result<()> {
|
||||
if fields.len() != 1 {
|
||||
return Err(syn::Error::new_spanned(
|
||||
fields,
|
||||
"Transparent structs and variants can only have 1 field",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the extraction function that's called on the input object.
|
||||
///
|
||||
/// Valid arguments are `get_item`, `get_attr` which are called with the
|
||||
/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")`
|
||||
fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result<syn::ExprCall> {
|
||||
let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) {
|
||||
attr
|
||||
} else {
|
||||
return Ok(parse_quote!(getattr(stringify!(#field_ident))));
|
||||
};
|
||||
if let Ok(ident) = attr.parse_args::<Ident>() {
|
||||
if ident != "getattr" && ident != "get_item" {
|
||||
Err(syn::Error::new_spanned(
|
||||
ident,
|
||||
"Only get_item and getattr are valid for extraction.",
|
||||
))
|
||||
} else {
|
||||
let arg = field_ident.to_string();
|
||||
Ok(parse_quote!(#ident(#arg)))
|
||||
}
|
||||
} else if let Ok(call) = attr.parse_args() {
|
||||
Ok(call)
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
attr,
|
||||
"Only get_item and getattr are valid for extraction,\
|
||||
both can be passed with or without an argument, e.g. \
|
||||
#[extract(getattr(\"attr\")] and #[extract(getattr)]",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the name of the variant for the error message if no variants match.
|
||||
fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result<Option<syn::LitStr>> {
|
||||
for attr in attrs {
|
||||
if !attr.path.is_ident("rename_err") {
|
||||
continue;
|
||||
}
|
||||
let attr = attr.parse_meta()?;
|
||||
if let syn::Meta::NameValue(nv) = &attr {
|
||||
match &nv.lit {
|
||||
syn::Lit::Str(s) => {
|
||||
return Ok(Some(s.clone()));
|
||||
}
|
||||
_ => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
attr,
|
||||
"rename_err attribute must be string literal: #[rename_err=\"Name\"]",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
|
||||
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
|
||||
if lifetimes.len() > 1 {
|
||||
return Err(syn::Error::new_spanned(
|
||||
&generics,
|
||||
"Only a single lifetime parameter can be specified.",
|
||||
));
|
||||
}
|
||||
Ok(lifetimes.into_iter().next())
|
||||
}
|
||||
|
||||
/// Derive FromPyObject for enums and structs.
|
||||
///
|
||||
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
|
||||
/// * At least one field, in case of `#[transparent]`, exactly one field
|
||||
/// * At least one variant for enums.
|
||||
/// * Fields of input structs and enums must implement `FromPyObject`
|
||||
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
|
||||
/// adds `T: FromPyObject` on the derived implementation.
|
||||
pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result<TokenStream> {
|
||||
let mut trait_generics = tokens.generics.clone();
|
||||
let generics = &tokens.generics;
|
||||
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
|
||||
lt.clone()
|
||||
} else {
|
||||
trait_generics.params.push(parse_quote!('source));
|
||||
parse_quote!('source)
|
||||
};
|
||||
let mut where_clause: syn::WhereClause = parse_quote!(where);
|
||||
for param in generics.type_params() {
|
||||
let gen_ident = ¶m.ident;
|
||||
where_clause
|
||||
.predicates
|
||||
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
|
||||
}
|
||||
let derives = match &tokens.data {
|
||||
syn::Data::Enum(en) => {
|
||||
let en = Enum::new(en, &tokens.ident)?;
|
||||
en.derive_enum()
|
||||
}
|
||||
syn::Data::Struct(st) => {
|
||||
let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?;
|
||||
st.derive_struct()
|
||||
}
|
||||
_ => {
|
||||
return Err(syn::Error::new_spanned(
|
||||
tokens,
|
||||
"FromPyObject can only be derived for structs and enums.",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let ident = &tokens.ident;
|
||||
Ok(quote!(
|
||||
#[automatically_derived]
|
||||
impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
|
||||
fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
|
||||
#derives
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
#![recursion_limit = "1024"]
|
||||
|
||||
mod defs;
|
||||
mod frompy;
|
||||
mod from_pyobject;
|
||||
mod konst;
|
||||
mod method;
|
||||
mod module;
|
||||
|
@ -16,7 +16,7 @@ mod pymethod;
|
|||
mod pyproto;
|
||||
mod utils;
|
||||
|
||||
pub use frompy::build_derive_from_pyobject;
|
||||
pub use from_pyobject::build_derive_from_pyobject;
|
||||
pub use module::{add_fn_to_module, process_functions_in_module, py_init};
|
||||
pub use pyclass::{build_py_class, PyClassArgs};
|
||||
pub use pyfunction::{build_py_function, PyFunctionAttr};
|
||||
|
|
|
@ -92,10 +92,10 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
|
|||
.into()
|
||||
}
|
||||
|
||||
#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))]
|
||||
#[proc_macro_derive(FromPyObject, attributes(pyo3, extract))]
|
||||
pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
|
||||
let mut ast = parse_macro_input!(item as syn::DeriveInput);
|
||||
let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error());
|
||||
let ast = parse_macro_input!(item as syn::DeriveInput);
|
||||
let expanded = build_derive_from_pyobject(&ast).unwrap_or_else(|e| e.to_compile_error());
|
||||
quote!(
|
||||
#expanded
|
||||
)
|
||||
|
|
|
@ -8,11 +8,11 @@ mod common;
|
|||
|
||||
#[derive(Debug, FromPyObject)]
|
||||
pub struct A<'a> {
|
||||
#[extract(getattr)]
|
||||
#[pyo3(attribute)]
|
||||
s: String,
|
||||
#[extract(get_item)]
|
||||
#[pyo3(item)]
|
||||
t: &'a PyString,
|
||||
#[extract(getattr("foo"))]
|
||||
#[pyo3(attribute = "foo")]
|
||||
p: &'a PyAny,
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ fn test_named_fields_struct() {
|
|||
}
|
||||
|
||||
#[derive(Debug, FromPyObject)]
|
||||
#[transparent]
|
||||
#[pyo3(transparent)]
|
||||
pub struct B {
|
||||
test: String,
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ fn test_transparent_named_field_struct() {
|
|||
}
|
||||
|
||||
#[derive(Debug, FromPyObject)]
|
||||
#[transparent]
|
||||
#[pyo3(transparent)]
|
||||
pub struct D<T> {
|
||||
test: T,
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ fn test_generic_named_fields_struct() {
|
|||
|
||||
#[derive(Debug, FromPyObject)]
|
||||
pub struct C {
|
||||
#[extract(getattr("test"))]
|
||||
#[pyo3(attribute = "test")]
|
||||
test: String,
|
||||
}
|
||||
|
||||
|
@ -155,17 +155,18 @@ fn test_tuple_struct() {
|
|||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
#[pyo3(transparent)]
|
||||
pub struct TransparentTuple(String);
|
||||
|
||||
#[test]
|
||||
fn test_transparent_tuple_struct() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let tup = PyTuple::new(py, &[1.into_py(py)]);
|
||||
let tup = TransparentTuple::extract(tup.as_ref());
|
||||
let tup: PyObject = 1.into_py(py);
|
||||
let tup = TransparentTuple::extract(tup.as_ref(py));
|
||||
assert!(tup.is_err());
|
||||
let tup = PyTuple::new(py, &["test".into_py(py)]);
|
||||
let tup = TransparentTuple::extract(tup.as_ref())
|
||||
let test = "test".into_py(py);
|
||||
let tup = TransparentTuple::extract(test.as_ref(py))
|
||||
.expect("Failed to extract TransparentTuple from PyTuple");
|
||||
assert_eq!(tup.0, "test");
|
||||
}
|
||||
|
@ -176,25 +177,25 @@ pub enum Foo<'a> {
|
|||
StructVar {
|
||||
test: &'a PyString,
|
||||
},
|
||||
#[transparent]
|
||||
#[pyo3(transparent)]
|
||||
TransparentTuple(usize),
|
||||
#[transparent]
|
||||
#[pyo3(transparent)]
|
||||
TransparentStructVar {
|
||||
a: Option<String>,
|
||||
},
|
||||
StructVarGetAttrArg {
|
||||
#[extract(getattr("bla"))]
|
||||
#[pyo3(attribute = "bla")]
|
||||
a: bool,
|
||||
},
|
||||
StructWithGetItem {
|
||||
#[extract(get_item)]
|
||||
#[pyo3(item)]
|
||||
a: String,
|
||||
},
|
||||
StructWithGetItemArg {
|
||||
#[extract(get_item("foo"))]
|
||||
#[pyo3(item = "foo")]
|
||||
a: String,
|
||||
},
|
||||
#[transparent]
|
||||
#[pyo3(transparent)]
|
||||
CatchAll(&'a PyAny),
|
||||
}
|
||||
|
||||
|
@ -279,11 +280,11 @@ fn test_enum() {
|
|||
|
||||
#[derive(FromPyObject)]
|
||||
pub enum Bar {
|
||||
#[rename_err = "str"]
|
||||
#[pyo3(annotation = "str")]
|
||||
A(String),
|
||||
#[rename_err = "uint"]
|
||||
#[pyo3(annotation = "uint")]
|
||||
B(usize),
|
||||
#[rename_err = "int"]
|
||||
#[pyo3(annotation = "int", transparent)]
|
||||
C(isize),
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue