'#[derive(FromPyObject)]` changes suggested by @davidwhewitt.

This commit is contained in:
Sebastian Pütz 2020-08-26 22:13:14 +02:00
parent 7168309464
commit 60fe4925f5
6 changed files with 522 additions and 452 deletions

View File

@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098) - Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098)
- Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/) - Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/)
- Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106). - Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106).
- Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065)
### Changed ### Changed
- Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py<T>` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py<T>` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024)

View File

@ -0,0 +1,496 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::{parse_quote, Attribute, DataEnum, DeriveInput, ExprCall, Fields, Ident, Result};
/// Describes derivation input of an enum.
#[derive(Debug)]
struct Enum<'a> {
enum_ident: &'a Ident,
variants: Vec<Container<'a>>,
}
impl<'a> Enum<'a> {
/// Construct a new enum representation.
///
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
/// `Identifier` of the enum.
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
if data_enum.variants.is_empty() {
return Err(syn::Error::new_spanned(
&data_enum.variants,
"Cannot derive FromPyObject for empty enum.",
));
}
let vars = data_enum
.variants
.iter()
.map(|variant| {
let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?;
let var_ident = &variant.ident;
Container::new(
&variant.fields,
parse_quote!(#ident::#var_ident),
attrs,
true,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Enum {
enum_ident: ident,
variants: vars,
})
}
/// Build derivation body for enums.
fn build(&self) -> TokenStream {
let mut var_extracts = Vec::new();
let mut error_names = String::new();
for (i, var) in self.variants.iter().enumerate() {
let struct_derive = var.build();
let ext = quote!(
let maybe_ret = || -> ::pyo3::PyResult<Self> {
#struct_derive
}();
if maybe_ret.is_ok() {
return maybe_ret
}
);
var_extracts.push(ext);
error_names.push_str(&var.err_name);
if i < self.variants.len() - 1 {
error_names.push_str(", ");
}
}
quote!(
#(#var_extracts)*
let type_name = obj.get_type().name();
let from = obj
.repr()
.map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
.unwrap_or_else(|_| type_name.to_string());
let err_msg = format!("Can't convert {} to {}", from, #error_names);
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
)
}
}
/// Container Style
///
/// Covers Structs, Tuplestructs and corresponding Newtypes.
#[derive(Debug)]
enum ContainerType<'a> {
/// Struct Container, e.g. `struct Foo { a: String }`
///
/// Variant contains the list of field identifiers and the corresponding extraction call.
Struct(Vec<(&'a Ident, FieldAttribute)>),
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
///
/// The field specified by the identifier is extracted directly from the object.
StructNewtype(&'a Ident),
/// Tuple struct, e.g. `struct Foo(String)`.
///
/// Fields are extracted from a tuple.
Tuple(usize),
/// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
///
/// The wrapped field is directly extracted from the object.
TupleNewtype,
}
/// Data container
///
/// Either describes a struct or an enum variant.
#[derive(Debug)]
struct Container<'a> {
path: syn::Path,
ty: ContainerType<'a>,
err_name: String,
is_enum_variant: bool,
}
impl<'a> Container<'a> {
/// Construct a container based on fields, identifier and attributes.
///
/// Fails if the variant has no fields or incompatible attributes.
fn new(
fields: &'a Fields,
path: syn::Path,
attrs: Vec<ContainerAttribute>,
is_enum_variant: bool,
) -> Result<Self> {
let transparent = attrs.iter().any(ContainerAttribute::transparent);
if transparent {
Self::check_transparent_len(fields)?;
}
let style = match (fields, transparent) {
(Fields::Unnamed(_), true) => ContainerType::TupleNewtype,
(Fields::Unnamed(unnamed), false) => ContainerType::Tuple(unnamed.unnamed.len()),
(Fields::Named(named), true) => {
let field = named
.named
.iter()
.next()
.expect("Check for len 1 is done above");
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
ContainerType::StructNewtype(ident)
}
(Fields::Named(named), false) => {
let mut fields = Vec::new();
for field in named.named.iter() {
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
let attr = FieldAttribute::parse_attrs(&field.attrs)?
.unwrap_or_else(|| FieldAttribute::Ident(parse_quote!(getattr)));
fields.push((ident, attr))
}
ContainerType::Struct(fields)
}
(Fields::Unit, _) => {
return Err(syn::Error::new_spanned(
&fields,
"Cannot derive FromPyObject for Unit structs and variants",
))
}
};
let err_name = attrs
.iter()
.find_map(|a| a.annotation())
.cloned()
.unwrap_or_else(|| path.segments.last().unwrap().ident.to_string());
let v = Container {
path,
ty: style,
err_name,
is_enum_variant,
};
Ok(v)
}
fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> {
for attr in attrs {
match attr {
ContainerAttribute::Transparent => continue,
ContainerAttribute::ErrorAnnotation(_) => {
return Err(syn::Error::new(
Span::call_site(),
"Annotating error messages for structs is \
not supported. Remove the annotation attribute.",
))
}
}
}
Ok(())
}
/// Build derivation body for a struct.
fn build(&self) -> TokenStream {
match &self.ty {
ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
ContainerType::TupleNewtype => self.build_newtype_struct(None),
ContainerType::Tuple(len) => self.build_tuple_struct(*len),
ContainerType::Struct(tups) => self.build_struct(tups),
}
}
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
let self_ty = &self.path;
if let Some(ident) = field_ident {
quote!(
Ok(#self_ty{#ident: obj.extract()?})
)
} else {
quote!(Ok(#self_ty(obj.extract()?)))
}
}
fn build_tuple_struct(&self, len: usize) -> TokenStream {
let self_ty = &self.path;
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for i in 0..len {
fields.push(quote!(slice[#i].extract()?));
}
let msg = if self.is_enum_variant {
quote!(format!(
"Expected tuple of length {}, but got length {}.",
#len,
s.len()
))
} else {
quote!("")
};
quote!(
let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
if s.len() != #len {
return Err(::pyo3::exceptions::PyValueError::py_err(#msg))
}
let slice = s.as_slice();
Ok(#self_ty(#fields))
)
}
fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream {
let self_ty = &self.path;
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for (ident, attr) in tups {
let ext_fn = match attr {
FieldAttribute::IdentWithArg(expr) => quote!(#expr),
FieldAttribute::Ident(meth) => {
let arg = ident.to_string();
quote!(#meth(#arg))
}
};
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
}
quote!(Ok(#self_ty{#fields}))
}
fn check_transparent_len(fields: &Fields) -> Result<()> {
if fields.len() != 1 {
return Err(syn::Error::new_spanned(
fields,
"Transparent structs and variants can only have 1 field",
));
}
Ok(())
}
}
/// Attributes for deriving FromPyObject scoped on containers.
#[derive(Clone, Debug, PartialEq)]
enum ContainerAttribute {
/// Treat the Container as a Wrapper, directly extract its fields from the input object.
Transparent,
/// Change the name of an enum variant in the generated error message.
ErrorAnnotation(String),
}
impl ContainerAttribute {
/// Return whether this attribute is `Transparent`
fn transparent(&self) -> bool {
match self {
ContainerAttribute::Transparent => true,
_ => false,
}
}
/// Convenience method to access `ErrorAnnotation`.
fn annotation(&self) -> Option<&String> {
match self {
ContainerAttribute::ErrorAnnotation(s) => Some(s),
_ => None,
}
}
/// Parse valid container arguments
///
/// Fails if any are invalid.
fn parse_attrs(value: &[Attribute]) -> Result<Vec<Self>> {
let mut attrs = Vec::new();
let list = get_pyo3_meta_list(value)?;
for meta in list.nested {
if let syn::NestedMeta::Meta(metaitem) = &meta {
match metaitem {
syn::Meta::Path(p) if p.is_ident("transparent") => {
attrs.push(ContainerAttribute::Transparent)
}
syn::Meta::NameValue(nv) if nv.path.is_ident("annotation") => {
if let syn::Lit::Str(s) = &nv.lit {
attrs.push(ContainerAttribute::ErrorAnnotation(s.value()))
} else {
return Err(syn::Error::new_spanned(
&nv.lit,
"Expected string literal.",
));
}
}
_ => (),
}
} else {
return Err(syn::Error::new_spanned(
meta,
"Unknown container attribute, expected `transparent` or \
`annotation(\"err_name\")`",
));
}
}
Ok(attrs)
}
}
/// Attributes for deriving FromPyObject scoped on fields.
#[derive(Clone, Debug)]
enum FieldAttribute {
/// How a specific field should be extracted.
Ident(Ident),
IdentWithArg(ExprCall),
}
impl FieldAttribute {
/// Extract the field attribute.
///
/// Currently fails if more than 1 attribute is passed in `pyo3`
fn parse_attrs(attrs: &[Attribute]) -> Result<Option<Self>> {
let list = get_pyo3_meta_list(attrs)?;
if list.nested.len() > 1 {
return Err(syn::Error::new_spanned(
list,
"Only one of `item`, `attribute` can be provided, possibly as \
a key-value pair: `attribute = \"name\"`.",
));
}
let meta = if let Some(attr) = list.nested.first() {
attr
} else {
return Ok(None);
};
if let syn::NestedMeta::Meta(metaitem) = meta {
let path = metaitem.path();
let ident = Self::check_valid_ident(path)?;
match metaitem {
syn::Meta::NameValue(nv) => Self::get_ident_with_arg(ident, &nv.lit).map(Some),
syn::Meta::Path(_) => Ok(Some(FieldAttribute::Ident(parse_quote!(#ident)))),
_ => Err(syn::Error::new_spanned(
metaitem,
"`item` or `attribute` need to be passed alone or as key-value \
pairs, e.g. `attribute = \"name\"`.",
)),
}
} else {
Err(syn::Error::new_spanned(meta, "Unexpected literal."))
}
}
/// Verify the attribute path and return it if it is valid.
fn check_valid_ident(path: &syn::Path) -> Result<Ident> {
if path.is_ident("item") {
Ok(parse_quote!(get_item))
} else if path.is_ident("attribute") {
Ok(parse_quote!(getattr))
} else {
Err(syn::Error::new_spanned(
path,
"Expected `item` or `attribute`",
))
}
}
/// Try to build `IdentWithArg` based on identifier and literal.
fn get_ident_with_arg(ident: Ident, lit: &syn::Lit) -> Result<Self> {
if ident == "getattr" {
if let syn::Lit::Str(s) = lit {
return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#s))));
} else {
return Err(syn::Error::new_spanned(lit, "Expected string literal."));
}
}
if ident == "get_item" {
return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#lit))));
}
// path is already checked in the `parse_attrs` loop, returning the error here anyways.
Err(syn::Error::new_spanned(
ident,
"Expected `item` or `attribute`.",
))
}
}
/// Extract pyo3 metalist, flattens multiple lists into a single one.
fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<syn::MetaList> {
let mut list: Punctuated<syn::NestedMeta, syn::Token![,]> = Punctuated::new();
for value in attrs {
match value.parse_meta()? {
syn::Meta::List(ml) if value.path.is_ident("pyo3") => {
for meta in ml.nested {
list.push(meta);
}
}
_ => {
return Err(syn::Error::new_spanned(
value,
"Expected `pyo3()` attribute.",
))
}
}
}
Ok(syn::MetaList {
path: parse_quote!(pyo3),
paren_token: syn::token::Paren::default(),
nested: list,
})
}
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
if lifetimes.len() > 1 {
return Err(syn::Error::new_spanned(
&generics,
"FromPyObject can only be derived with at most one lifetime parameter.",
));
}
Ok(lifetimes.into_iter().next())
}
/// Derive FromPyObject for enums and structs.
///
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
/// * At least one field, in case of `#[transparent]`, exactly one field
/// * At least one variant for enums.
/// * Fields of input structs and enums must implement `FromPyObject`
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
/// adds `T: FromPyObject` on the derived implementation.
pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
let mut trait_generics = tokens.generics.clone();
let generics = &tokens.generics;
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
lt.clone()
} else {
trait_generics.params.push(parse_quote!('source));
parse_quote!('source)
};
let mut where_clause: syn::WhereClause = parse_quote!(where);
for param in generics.type_params() {
let gen_ident = &param.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
}
let derives = match &tokens.data {
syn::Data::Enum(en) => {
let en = Enum::new(en, &tokens.ident)?;
en.build()
}
syn::Data::Struct(st) => {
let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?;
Container::verify_struct_container_attrs(&attrs)?;
let ident = &tokens.ident;
let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?;
st.build()
}
_ => {
return Err(syn::Error::new_spanned(
tokens,
"FromPyObject can only be derived for structs and enums.",
))
}
};
let ident = &tokens.ident;
Ok(quote!(
#[automatically_derived]
impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
#derives
}
}
))
}

View File

@ -1,428 +0,0 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Paren;
use syn::{
parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result,
Variant,
};
/// Describes derivation input of an enum.
#[derive(Debug)]
struct Enum<'a> {
enum_ident: &'a Ident,
vars: Vec<Container<'a>>,
}
impl<'a> Enum<'a> {
/// Construct a new enum representation.
///
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
/// `Identifier` of the enum.
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
if data_enum.variants.is_empty() {
return Err(syn::Error::new_spanned(
&data_enum.variants,
"Cannot derive FromPyObject for empty enum.",
));
}
let vars = data_enum
.variants
.iter()
.map(Container::from_variant)
.collect::<Result<Vec<_>>>()?;
Ok(Enum {
enum_ident: ident,
vars,
})
}
/// Build derivation body for enums.
fn derive_enum(&self) -> TokenStream {
let mut var_extracts = Vec::new();
let mut error_names = String::new();
for (i, var) in self.vars.iter().enumerate() {
let ext = match &var.style {
Style::Struct(tups) => self.build_struct_variant(tups, var.ident),
Style::StructNewtype(ident) => {
self.build_transparent_variant(var.ident, Some(ident))
}
Style::Tuple(len) => self.build_tuple_variant(var.ident, *len),
Style::TupleNewtype => self.build_transparent_variant(var.ident, None),
};
var_extracts.push(ext);
error_names.push_str(&var.err_name);
if i < self.vars.len() - 1 {
error_names.push_str(", ");
}
}
quote!(
#(#var_extracts)*
let type_name = obj.get_type().name();
let from = obj
.repr()
.map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
.unwrap_or_else(|_| type_name.to_string());
let err_msg = format!("Can't convert {} to {}", from, #error_names);
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
)
}
/// Build match for tuple struct variant.
fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream {
let enum_ident = self.enum_ident;
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
let mut fields: Punctuated<Ident, syn::Token![,]> = Punctuated::new();
let mut field_pats = PatTuple {
attrs: vec![],
paren_token: Paren::default(),
elems: Default::default(),
};
for i in 0..len {
ext.push(parse_quote!(slice[#i].extract()));
let ident = Ident::new(&format!("_field{}", i), Span::call_site());
field_pats.elems.push(parse_quote!(Ok(#ident)));
fields.push(ident);
}
quote!(
match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) {
Ok(s) => {
if s.len() == #len {
let slice = s.as_slice();
if let (#field_pats) = (#ext) {
return Ok(#enum_ident::#var_ident(#fields))
}
}
},
Err(_) => {}
}
)
}
/// Build match for transparent enum variants.
fn build_transparent_variant(
&self,
var_ident: &Ident,
field_ident: Option<&Ident>,
) -> TokenStream {
let enum_ident = self.enum_ident;
if let Some(ident) = field_ident {
quote!(
if let Ok(#ident) = obj.extract() {
return Ok(#enum_ident::#var_ident{#ident})
}
)
} else {
quote!(
if let Ok(inner) = obj.extract() {
return Ok(#enum_ident::#var_ident(inner))
}
)
}
}
/// Build match for struct variant with named fields.
fn build_struct_variant(
&self,
tups: &[(&'a Ident, ExprCall)],
var_ident: &Ident,
) -> TokenStream {
let enum_ident = self.enum_ident;
let mut field_pats = PatTuple {
attrs: vec![],
paren_token: Paren::default(),
elems: Default::default(),
};
let mut fields: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
for (ident, ext_fn) in tups {
field_pats.elems.push(parse_quote!(Ok(#ident)));
fields.push(parse_quote!(#ident));
ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract())));
}
quote!(if let #field_pats = #ext {
return Ok(#enum_ident::#var_ident{#fields});
})
}
}
/// Container Style
///
/// Covers Structs, Tuplestructs and corresponding Newtypes.
#[derive(Clone, Debug)]
enum Style<'a> {
/// Struct Container, e.g. `struct Foo { a: String }`
///
/// Variant contains the list of field identifiers and the corresponding extraction call.
Struct(Vec<(&'a Ident, ExprCall)>),
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
///
/// The field specified by the identifier is extracted directly from the object.
StructNewtype(&'a Ident),
/// Tuple struct, e.g. `struct Foo(String)`.
///
/// Fields are extracted from a tuple.
Tuple(usize),
/// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
///
/// The wrapped field is directly extracted from the object.
TupleNewtype,
}
/// Data container
///
/// Either describes a struct or an enum variant.
#[derive(Debug)]
struct Container<'a> {
ident: &'a Ident,
style: Style<'a>,
err_name: String,
}
impl<'a> Container<'a> {
/// Construct a container from an enum Variant.
///
/// Fails if the variant has no fields or incompatible attributes.
fn from_variant(var: &'a Variant) -> Result<Self> {
Self::new(&var.fields, &var.ident, &var.attrs)
}
/// Construct a container based on fields, identifier and attributes.
///
/// Fails if the variant has no fields or incompatible attributes.
fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result<Self> {
let transparent = attrs.iter().any(|a| a.path.is_ident("transparent"));
if transparent {
Self::check_transparent_len(fields)?;
}
let style = match fields {
Fields::Unnamed(unnamed) => {
if transparent {
Style::TupleNewtype
} else {
Style::Tuple(unnamed.unnamed.len())
}
}
Fields::Named(named) => {
if transparent {
let field = named
.named
.iter()
.next()
.expect("Check for len 1 is done above");
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
Style::StructNewtype(ident)
} else {
let mut fields = Vec::new();
for field in named.named.iter() {
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
fields.push((ident, ext_fn(&field.attrs, ident)?))
}
Style::Struct(fields)
}
}
Fields::Unit => {
return Err(syn::Error::new_spanned(
&fields,
"Cannot derive FromPyObject for Unit structs and variants",
))
}
};
let err_name = maybe_renamed_err(&attrs)?
.map(|s| s.value())
.unwrap_or_else(|| ident.to_string());
let v = Container {
ident: &ident,
style,
err_name,
};
Ok(v)
}
/// Build derivation body for a struct.
fn derive_struct(&self) -> TokenStream {
match &self.style {
Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
Style::TupleNewtype => self.build_newtype_struct(None),
Style::Tuple(len) => self.build_tuple_struct(*len),
Style::Struct(tups) => self.build_struct(tups),
}
}
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
if let Some(ident) = field_ident {
quote!(
Ok(Self{#ident: obj.extract()?})
)
} else {
quote!(Ok(Self(obj.extract()?)))
}
}
fn build_tuple_struct(&self, len: usize) -> TokenStream {
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for i in 0..len {
fields.push(quote!(slice[#i].extract()?));
}
quote!(
let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
let seq_len = s.len();
if seq_len != #len {
let msg = format!(
"Expected tuple of length {}, but got length {}.",
#len,
seq_len
);
return Err(::pyo3::exceptions::PyValueError::py_err(msg))
}
let slice = s.as_slice();
Ok(Self(#fields))
)
}
fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream {
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for (ident, ext_fn) in tups {
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
}
quote!(Ok(Self{#fields}))
}
fn check_transparent_len(fields: &Fields) -> Result<()> {
if fields.len() != 1 {
return Err(syn::Error::new_spanned(
fields,
"Transparent structs and variants can only have 1 field",
));
}
Ok(())
}
}
/// Get the extraction function that's called on the input object.
///
/// Valid arguments are `get_item`, `get_attr` which are called with the
/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")`
fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result<syn::ExprCall> {
let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) {
attr
} else {
return Ok(parse_quote!(getattr(stringify!(#field_ident))));
};
if let Ok(ident) = attr.parse_args::<Ident>() {
if ident != "getattr" && ident != "get_item" {
Err(syn::Error::new_spanned(
ident,
"Only get_item and getattr are valid for extraction.",
))
} else {
let arg = field_ident.to_string();
Ok(parse_quote!(#ident(#arg)))
}
} else if let Ok(call) = attr.parse_args() {
Ok(call)
} else {
Err(syn::Error::new_spanned(
attr,
"Only get_item and getattr are valid for extraction,\
both can be passed with or without an argument, e.g. \
#[extract(getattr(\"attr\")] and #[extract(getattr)]",
))
}
}
/// Returns the name of the variant for the error message if no variants match.
fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result<Option<syn::LitStr>> {
for attr in attrs {
if !attr.path.is_ident("rename_err") {
continue;
}
let attr = attr.parse_meta()?;
if let syn::Meta::NameValue(nv) = &attr {
match &nv.lit {
syn::Lit::Str(s) => {
return Ok(Some(s.clone()));
}
_ => {
return Err(syn::Error::new_spanned(
attr,
"rename_err attribute must be string literal: #[rename_err=\"Name\"]",
))
}
}
}
}
Ok(None)
}
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
if lifetimes.len() > 1 {
return Err(syn::Error::new_spanned(
&generics,
"Only a single lifetime parameter can be specified.",
));
}
Ok(lifetimes.into_iter().next())
}
/// Derive FromPyObject for enums and structs.
///
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
/// * At least one field, in case of `#[transparent]`, exactly one field
/// * At least one variant for enums.
/// * Fields of input structs and enums must implement `FromPyObject`
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
/// adds `T: FromPyObject` on the derived implementation.
pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result<TokenStream> {
let mut trait_generics = tokens.generics.clone();
let generics = &tokens.generics;
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
lt.clone()
} else {
trait_generics.params.push(parse_quote!('source));
parse_quote!('source)
};
let mut where_clause: syn::WhereClause = parse_quote!(where);
for param in generics.type_params() {
let gen_ident = &param.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
}
let derives = match &tokens.data {
syn::Data::Enum(en) => {
let en = Enum::new(en, &tokens.ident)?;
en.derive_enum()
}
syn::Data::Struct(st) => {
let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?;
st.derive_struct()
}
_ => {
return Err(syn::Error::new_spanned(
tokens,
"FromPyObject can only be derived for structs and enums.",
))
}
};
let ident = &tokens.ident;
Ok(quote!(
#[automatically_derived]
impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
#derives
}
}
))
}

View File

@ -4,7 +4,7 @@
#![recursion_limit = "1024"] #![recursion_limit = "1024"]
mod defs; mod defs;
mod frompy; mod from_pyobject;
mod konst; mod konst;
mod method; mod method;
mod module; mod module;
@ -16,7 +16,7 @@ mod pymethod;
mod pyproto; mod pyproto;
mod utils; mod utils;
pub use frompy::build_derive_from_pyobject; pub use from_pyobject::build_derive_from_pyobject;
pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use module::{add_fn_to_module, process_functions_in_module, py_init};
pub use pyclass::{build_py_class, PyClassArgs}; pub use pyclass::{build_py_class, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionAttr}; pub use pyfunction::{build_py_function, PyFunctionAttr};

View File

@ -92,10 +92,10 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
.into() .into()
} }
#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))] #[proc_macro_derive(FromPyObject, attributes(pyo3, extract))]
pub fn derive_from_py_object(item: TokenStream) -> TokenStream { pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(item as syn::DeriveInput); let ast = parse_macro_input!(item as syn::DeriveInput);
let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error()); let expanded = build_derive_from_pyobject(&ast).unwrap_or_else(|e| e.to_compile_error());
quote!( quote!(
#expanded #expanded
) )

View File

@ -8,11 +8,11 @@ mod common;
#[derive(Debug, FromPyObject)] #[derive(Debug, FromPyObject)]
pub struct A<'a> { pub struct A<'a> {
#[extract(getattr)] #[pyo3(attribute)]
s: String, s: String,
#[extract(get_item)] #[pyo3(item)]
t: &'a PyString, t: &'a PyString,
#[extract(getattr("foo"))] #[pyo3(attribute = "foo")]
p: &'a PyAny, p: &'a PyAny,
} }
@ -51,7 +51,7 @@ fn test_named_fields_struct() {
} }
#[derive(Debug, FromPyObject)] #[derive(Debug, FromPyObject)]
#[transparent] #[pyo3(transparent)]
pub struct B { pub struct B {
test: String, test: String,
} }
@ -69,7 +69,7 @@ fn test_transparent_named_field_struct() {
} }
#[derive(Debug, FromPyObject)] #[derive(Debug, FromPyObject)]
#[transparent] #[pyo3(transparent)]
pub struct D<T> { pub struct D<T> {
test: T, test: T,
} }
@ -121,7 +121,7 @@ fn test_generic_named_fields_struct() {
#[derive(Debug, FromPyObject)] #[derive(Debug, FromPyObject)]
pub struct C { pub struct C {
#[extract(getattr("test"))] #[pyo3(attribute = "test")]
test: String, test: String,
} }
@ -155,17 +155,18 @@ fn test_tuple_struct() {
} }
#[derive(FromPyObject)] #[derive(FromPyObject)]
#[pyo3(transparent)]
pub struct TransparentTuple(String); pub struct TransparentTuple(String);
#[test] #[test]
fn test_transparent_tuple_struct() { fn test_transparent_tuple_struct() {
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
let tup = PyTuple::new(py, &[1.into_py(py)]); let tup: PyObject = 1.into_py(py);
let tup = TransparentTuple::extract(tup.as_ref()); let tup = TransparentTuple::extract(tup.as_ref(py));
assert!(tup.is_err()); assert!(tup.is_err());
let tup = PyTuple::new(py, &["test".into_py(py)]); let test = "test".into_py(py);
let tup = TransparentTuple::extract(tup.as_ref()) let tup = TransparentTuple::extract(test.as_ref(py))
.expect("Failed to extract TransparentTuple from PyTuple"); .expect("Failed to extract TransparentTuple from PyTuple");
assert_eq!(tup.0, "test"); assert_eq!(tup.0, "test");
} }
@ -176,25 +177,25 @@ pub enum Foo<'a> {
StructVar { StructVar {
test: &'a PyString, test: &'a PyString,
}, },
#[transparent] #[pyo3(transparent)]
TransparentTuple(usize), TransparentTuple(usize),
#[transparent] #[pyo3(transparent)]
TransparentStructVar { TransparentStructVar {
a: Option<String>, a: Option<String>,
}, },
StructVarGetAttrArg { StructVarGetAttrArg {
#[extract(getattr("bla"))] #[pyo3(attribute = "bla")]
a: bool, a: bool,
}, },
StructWithGetItem { StructWithGetItem {
#[extract(get_item)] #[pyo3(item)]
a: String, a: String,
}, },
StructWithGetItemArg { StructWithGetItemArg {
#[extract(get_item("foo"))] #[pyo3(item = "foo")]
a: String, a: String,
}, },
#[transparent] #[pyo3(transparent)]
CatchAll(&'a PyAny), CatchAll(&'a PyAny),
} }
@ -279,11 +280,11 @@ fn test_enum() {
#[derive(FromPyObject)] #[derive(FromPyObject)]
pub enum Bar { pub enum Bar {
#[rename_err = "str"] #[pyo3(annotation = "str")]
A(String), A(String),
#[rename_err = "uint"] #[pyo3(annotation = "uint")]
B(usize), B(usize),
#[rename_err = "int"] #[pyo3(annotation = "int", transparent)]
C(isize), C(isize),
} }