diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e2dd1a..3b3e0592 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098) - Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/) - Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106). +- Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs new file mode 100644 index 00000000..1f222c4c --- /dev/null +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -0,0 +1,496 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{parse_quote, Attribute, DataEnum, DeriveInput, ExprCall, Fields, Ident, Result}; + +/// Describes derivation input of an enum. +#[derive(Debug)] +struct Enum<'a> { + enum_ident: &'a Ident, + variants: Vec>, +} + +impl<'a> Enum<'a> { + /// Construct a new enum representation. + /// + /// `data_enum` is the `syn` representation of the input enum, `ident` is the + /// `Identifier` of the enum. + fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { + if data_enum.variants.is_empty() { + return Err(syn::Error::new_spanned( + &data_enum.variants, + "Cannot derive FromPyObject for empty enum.", + )); + } + let vars = data_enum + .variants + .iter() + .map(|variant| { + let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?; + let var_ident = &variant.ident; + Container::new( + &variant.fields, + parse_quote!(#ident::#var_ident), + attrs, + true, + ) + }) + .collect::>>()?; + + Ok(Enum { + enum_ident: ident, + variants: vars, + }) + } + + /// Build derivation body for enums. + fn build(&self) -> TokenStream { + let mut var_extracts = Vec::new(); + let mut error_names = String::new(); + for (i, var) in self.variants.iter().enumerate() { + let struct_derive = var.build(); + let ext = quote!( + let maybe_ret = || -> ::pyo3::PyResult { + #struct_derive + }(); + if maybe_ret.is_ok() { + return maybe_ret + } + ); + + var_extracts.push(ext); + error_names.push_str(&var.err_name); + if i < self.variants.len() - 1 { + error_names.push_str(", "); + } + } + quote!( + #(#var_extracts)* + let type_name = obj.get_type().name(); + let from = obj + .repr() + .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) + .unwrap_or_else(|_| type_name.to_string()); + let err_msg = format!("Can't convert {} to {}", from, #error_names); + Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) + ) + } +} + +/// Container Style +/// +/// Covers Structs, Tuplestructs and corresponding Newtypes. +#[derive(Debug)] +enum ContainerType<'a> { + /// Struct Container, e.g. `struct Foo { a: String }` + /// + /// Variant contains the list of field identifiers and the corresponding extraction call. + Struct(Vec<(&'a Ident, FieldAttribute)>), + /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` + /// + /// The field specified by the identifier is extracted directly from the object. + StructNewtype(&'a Ident), + /// Tuple struct, e.g. `struct Foo(String)`. + /// + /// Fields are extracted from a tuple. + Tuple(usize), + /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` + /// + /// The wrapped field is directly extracted from the object. + TupleNewtype, +} + +/// Data container +/// +/// Either describes a struct or an enum variant. +#[derive(Debug)] +struct Container<'a> { + path: syn::Path, + ty: ContainerType<'a>, + err_name: String, + is_enum_variant: bool, +} + +impl<'a> Container<'a> { + /// Construct a container based on fields, identifier and attributes. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn new( + fields: &'a Fields, + path: syn::Path, + attrs: Vec, + is_enum_variant: bool, + ) -> Result { + let transparent = attrs.iter().any(ContainerAttribute::transparent); + if transparent { + Self::check_transparent_len(fields)?; + } + let style = match (fields, transparent) { + (Fields::Unnamed(_), true) => ContainerType::TupleNewtype, + (Fields::Unnamed(unnamed), false) => ContainerType::Tuple(unnamed.unnamed.len()), + (Fields::Named(named), true) => { + let field = named + .named + .iter() + .next() + .expect("Check for len 1 is done above"); + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + ContainerType::StructNewtype(ident) + } + (Fields::Named(named), false) => { + let mut fields = Vec::new(); + for field in named.named.iter() { + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + let attr = FieldAttribute::parse_attrs(&field.attrs)? + .unwrap_or_else(|| FieldAttribute::Ident(parse_quote!(getattr))); + fields.push((ident, attr)) + } + ContainerType::Struct(fields) + } + (Fields::Unit, _) => { + return Err(syn::Error::new_spanned( + &fields, + "Cannot derive FromPyObject for Unit structs and variants", + )) + } + }; + let err_name = attrs + .iter() + .find_map(|a| a.annotation()) + .cloned() + .unwrap_or_else(|| path.segments.last().unwrap().ident.to_string()); + + let v = Container { + path, + ty: style, + err_name, + is_enum_variant, + }; + Ok(v) + } + + fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> { + for attr in attrs { + match attr { + ContainerAttribute::Transparent => continue, + ContainerAttribute::ErrorAnnotation(_) => { + return Err(syn::Error::new( + Span::call_site(), + "Annotating error messages for structs is \ + not supported. Remove the annotation attribute.", + )) + } + } + } + Ok(()) + } + + /// Build derivation body for a struct. + fn build(&self) -> TokenStream { + match &self.ty { + ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), + ContainerType::TupleNewtype => self.build_newtype_struct(None), + ContainerType::Tuple(len) => self.build_tuple_struct(*len), + ContainerType::Struct(tups) => self.build_struct(tups), + } + } + + fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { + let self_ty = &self.path; + if let Some(ident) = field_ident { + quote!( + Ok(#self_ty{#ident: obj.extract()?}) + ) + } else { + quote!(Ok(#self_ty(obj.extract()?))) + } + } + + fn build_tuple_struct(&self, len: usize) -> TokenStream { + let self_ty = &self.path; + let mut fields: Punctuated = Punctuated::new(); + for i in 0..len { + fields.push(quote!(slice[#i].extract()?)); + } + let msg = if self.is_enum_variant { + quote!(format!( + "Expected tuple of length {}, but got length {}.", + #len, + s.len() + )) + } else { + quote!("") + }; + quote!( + let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; + if s.len() != #len { + return Err(::pyo3::exceptions::PyValueError::py_err(#msg)) + } + let slice = s.as_slice(); + Ok(#self_ty(#fields)) + ) + } + + fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream { + let self_ty = &self.path; + let mut fields: Punctuated = Punctuated::new(); + for (ident, attr) in tups { + let ext_fn = match attr { + FieldAttribute::IdentWithArg(expr) => quote!(#expr), + FieldAttribute::Ident(meth) => { + let arg = ident.to_string(); + quote!(#meth(#arg)) + } + }; + fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); + } + quote!(Ok(#self_ty{#fields})) + } + + fn check_transparent_len(fields: &Fields) -> Result<()> { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + fields, + "Transparent structs and variants can only have 1 field", + )); + } + Ok(()) + } +} + +/// Attributes for deriving FromPyObject scoped on containers. +#[derive(Clone, Debug, PartialEq)] +enum ContainerAttribute { + /// Treat the Container as a Wrapper, directly extract its fields from the input object. + Transparent, + /// Change the name of an enum variant in the generated error message. + ErrorAnnotation(String), +} + +impl ContainerAttribute { + /// Return whether this attribute is `Transparent` + fn transparent(&self) -> bool { + match self { + ContainerAttribute::Transparent => true, + _ => false, + } + } + + /// Convenience method to access `ErrorAnnotation`. + fn annotation(&self) -> Option<&String> { + match self { + ContainerAttribute::ErrorAnnotation(s) => Some(s), + _ => None, + } + } + + /// Parse valid container arguments + /// + /// Fails if any are invalid. + fn parse_attrs(value: &[Attribute]) -> Result> { + let mut attrs = Vec::new(); + let list = get_pyo3_meta_list(value)?; + for meta in list.nested { + if let syn::NestedMeta::Meta(metaitem) = &meta { + match metaitem { + syn::Meta::Path(p) if p.is_ident("transparent") => { + attrs.push(ContainerAttribute::Transparent) + } + syn::Meta::NameValue(nv) if nv.path.is_ident("annotation") => { + if let syn::Lit::Str(s) = &nv.lit { + attrs.push(ContainerAttribute::ErrorAnnotation(s.value())) + } else { + return Err(syn::Error::new_spanned( + &nv.lit, + "Expected string literal.", + )); + } + } + _ => (), + } + } else { + return Err(syn::Error::new_spanned( + meta, + "Unknown container attribute, expected `transparent` or \ + `annotation(\"err_name\")`", + )); + } + } + Ok(attrs) + } +} + +/// Attributes for deriving FromPyObject scoped on fields. +#[derive(Clone, Debug)] +enum FieldAttribute { + /// How a specific field should be extracted. + Ident(Ident), + IdentWithArg(ExprCall), +} + +impl FieldAttribute { + /// Extract the field attribute. + /// + /// Currently fails if more than 1 attribute is passed in `pyo3` + fn parse_attrs(attrs: &[Attribute]) -> Result> { + let list = get_pyo3_meta_list(attrs)?; + if list.nested.len() > 1 { + return Err(syn::Error::new_spanned( + list, + "Only one of `item`, `attribute` can be provided, possibly as \ + a key-value pair: `attribute = \"name\"`.", + )); + } + let meta = if let Some(attr) = list.nested.first() { + attr + } else { + return Ok(None); + }; + if let syn::NestedMeta::Meta(metaitem) = meta { + let path = metaitem.path(); + let ident = Self::check_valid_ident(path)?; + match metaitem { + syn::Meta::NameValue(nv) => Self::get_ident_with_arg(ident, &nv.lit).map(Some), + syn::Meta::Path(_) => Ok(Some(FieldAttribute::Ident(parse_quote!(#ident)))), + _ => Err(syn::Error::new_spanned( + metaitem, + "`item` or `attribute` need to be passed alone or as key-value \ + pairs, e.g. `attribute = \"name\"`.", + )), + } + } else { + Err(syn::Error::new_spanned(meta, "Unexpected literal.")) + } + } + + /// Verify the attribute path and return it if it is valid. + fn check_valid_ident(path: &syn::Path) -> Result { + if path.is_ident("item") { + Ok(parse_quote!(get_item)) + } else if path.is_ident("attribute") { + Ok(parse_quote!(getattr)) + } else { + Err(syn::Error::new_spanned( + path, + "Expected `item` or `attribute`", + )) + } + } + + /// Try to build `IdentWithArg` based on identifier and literal. + fn get_ident_with_arg(ident: Ident, lit: &syn::Lit) -> Result { + if ident == "getattr" { + if let syn::Lit::Str(s) = lit { + return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#s)))); + } else { + return Err(syn::Error::new_spanned(lit, "Expected string literal.")); + } + } + if ident == "get_item" { + return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#lit)))); + } + + // path is already checked in the `parse_attrs` loop, returning the error here anyways. + Err(syn::Error::new_spanned( + ident, + "Expected `item` or `attribute`.", + )) + } +} + +/// Extract pyo3 metalist, flattens multiple lists into a single one. +fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { + let mut list: Punctuated = Punctuated::new(); + for value in attrs { + match value.parse_meta()? { + syn::Meta::List(ml) if value.path.is_ident("pyo3") => { + for meta in ml.nested { + list.push(meta); + } + } + _ => { + return Err(syn::Error::new_spanned( + value, + "Expected `pyo3()` attribute.", + )) + } + } + } + Ok(syn::MetaList { + path: parse_quote!(pyo3), + paren_token: syn::token::Paren::default(), + nested: list, + }) +} + +fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { + let lifetimes = generics.lifetimes().collect::>(); + if lifetimes.len() > 1 { + return Err(syn::Error::new_spanned( + &generics, + "FromPyObject can only be derived with at most one lifetime parameter.", + )); + } + Ok(lifetimes.into_iter().next()) +} + +/// Derive FromPyObject for enums and structs. +/// +/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier +/// * At least one field, in case of `#[transparent]`, exactly one field +/// * At least one variant for enums. +/// * Fields of input structs and enums must implement `FromPyObject` +/// * Derivation for structs with generic fields like `struct Foo(T)` +/// adds `T: FromPyObject` on the derived implementation. +pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { + let mut trait_generics = tokens.generics.clone(); + let generics = &tokens.generics; + let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { + lt.clone() + } else { + trait_generics.params.push(parse_quote!('source)); + parse_quote!('source) + }; + let mut where_clause: syn::WhereClause = parse_quote!(where); + for param in generics.type_params() { + let gen_ident = ¶m.ident; + where_clause + .predicates + .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) + } + let derives = match &tokens.data { + syn::Data::Enum(en) => { + let en = Enum::new(en, &tokens.ident)?; + en.build() + } + syn::Data::Struct(st) => { + let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?; + Container::verify_struct_container_attrs(&attrs)?; + let ident = &tokens.ident; + let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?; + st.build() + } + _ => { + return Err(syn::Error::new_spanned( + tokens, + "FromPyObject can only be derived for structs and enums.", + )) + } + }; + + let ident = &tokens.ident; + Ok(quote!( + #[automatically_derived] + impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { + fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { + #derives + } + } + )) +} diff --git a/pyo3-derive-backend/src/frompy.rs b/pyo3-derive-backend/src/frompy.rs deleted file mode 100644 index e60a0e45..00000000 --- a/pyo3-derive-backend/src/frompy.rs +++ /dev/null @@ -1,428 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::quote; -use syn::punctuated::Punctuated; -use syn::token::Paren; -use syn::{ - parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result, - Variant, -}; - -/// Describes derivation input of an enum. -#[derive(Debug)] -struct Enum<'a> { - enum_ident: &'a Ident, - vars: Vec>, -} - -impl<'a> Enum<'a> { - /// Construct a new enum representation. - /// - /// `data_enum` is the `syn` representation of the input enum, `ident` is the - /// `Identifier` of the enum. - fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { - if data_enum.variants.is_empty() { - return Err(syn::Error::new_spanned( - &data_enum.variants, - "Cannot derive FromPyObject for empty enum.", - )); - } - let vars = data_enum - .variants - .iter() - .map(Container::from_variant) - .collect::>>()?; - - Ok(Enum { - enum_ident: ident, - vars, - }) - } - - /// Build derivation body for enums. - fn derive_enum(&self) -> TokenStream { - let mut var_extracts = Vec::new(); - let mut error_names = String::new(); - for (i, var) in self.vars.iter().enumerate() { - let ext = match &var.style { - Style::Struct(tups) => self.build_struct_variant(tups, var.ident), - Style::StructNewtype(ident) => { - self.build_transparent_variant(var.ident, Some(ident)) - } - Style::Tuple(len) => self.build_tuple_variant(var.ident, *len), - Style::TupleNewtype => self.build_transparent_variant(var.ident, None), - }; - var_extracts.push(ext); - error_names.push_str(&var.err_name); - if i < self.vars.len() - 1 { - error_names.push_str(", "); - } - } - quote!( - #(#var_extracts)* - let type_name = obj.get_type().name(); - let from = obj - .repr() - .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) - .unwrap_or_else(|_| type_name.to_string()); - let err_msg = format!("Can't convert {} to {}", from, #error_names); - Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) - ) - } - - /// Build match for tuple struct variant. - fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream { - let enum_ident = self.enum_ident; - let mut ext: Punctuated = Punctuated::new(); - let mut fields: Punctuated = Punctuated::new(); - let mut field_pats = PatTuple { - attrs: vec![], - paren_token: Paren::default(), - elems: Default::default(), - }; - for i in 0..len { - ext.push(parse_quote!(slice[#i].extract())); - let ident = Ident::new(&format!("_field{}", i), Span::call_site()); - field_pats.elems.push(parse_quote!(Ok(#ident))); - fields.push(ident); - } - - quote!( - match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) { - Ok(s) => { - if s.len() == #len { - let slice = s.as_slice(); - if let (#field_pats) = (#ext) { - return Ok(#enum_ident::#var_ident(#fields)) - } - } - }, - Err(_) => {} - } - ) - } - - /// Build match for transparent enum variants. - fn build_transparent_variant( - &self, - var_ident: &Ident, - field_ident: Option<&Ident>, - ) -> TokenStream { - let enum_ident = self.enum_ident; - if let Some(ident) = field_ident { - quote!( - if let Ok(#ident) = obj.extract() { - return Ok(#enum_ident::#var_ident{#ident}) - } - ) - } else { - quote!( - if let Ok(inner) = obj.extract() { - return Ok(#enum_ident::#var_ident(inner)) - } - ) - } - } - - /// Build match for struct variant with named fields. - fn build_struct_variant( - &self, - tups: &[(&'a Ident, ExprCall)], - var_ident: &Ident, - ) -> TokenStream { - let enum_ident = self.enum_ident; - let mut field_pats = PatTuple { - attrs: vec![], - paren_token: Paren::default(), - elems: Default::default(), - }; - let mut fields: Punctuated = Punctuated::new(); - let mut ext: Punctuated = Punctuated::new(); - for (ident, ext_fn) in tups { - field_pats.elems.push(parse_quote!(Ok(#ident))); - fields.push(parse_quote!(#ident)); - ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract()))); - } - quote!(if let #field_pats = #ext { - return Ok(#enum_ident::#var_ident{#fields}); - }) - } -} - -/// Container Style -/// -/// Covers Structs, Tuplestructs and corresponding Newtypes. -#[derive(Clone, Debug)] -enum Style<'a> { - /// Struct Container, e.g. `struct Foo { a: String }` - /// - /// Variant contains the list of field identifiers and the corresponding extraction call. - Struct(Vec<(&'a Ident, ExprCall)>), - /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` - /// - /// The field specified by the identifier is extracted directly from the object. - StructNewtype(&'a Ident), - /// Tuple struct, e.g. `struct Foo(String)`. - /// - /// Fields are extracted from a tuple. - Tuple(usize), - /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` - /// - /// The wrapped field is directly extracted from the object. - TupleNewtype, -} - -/// Data container -/// -/// Either describes a struct or an enum variant. -#[derive(Debug)] -struct Container<'a> { - ident: &'a Ident, - style: Style<'a>, - err_name: String, -} - -impl<'a> Container<'a> { - /// Construct a container from an enum Variant. - /// - /// Fails if the variant has no fields or incompatible attributes. - fn from_variant(var: &'a Variant) -> Result { - Self::new(&var.fields, &var.ident, &var.attrs) - } - - /// Construct a container based on fields, identifier and attributes. - /// - /// Fails if the variant has no fields or incompatible attributes. - fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result { - let transparent = attrs.iter().any(|a| a.path.is_ident("transparent")); - if transparent { - Self::check_transparent_len(fields)?; - } - let style = match fields { - Fields::Unnamed(unnamed) => { - if transparent { - Style::TupleNewtype - } else { - Style::Tuple(unnamed.unnamed.len()) - } - } - Fields::Named(named) => { - if transparent { - let field = named - .named - .iter() - .next() - .expect("Check for len 1 is done above"); - let ident = field - .ident - .as_ref() - .expect("Named fields should have identifiers"); - Style::StructNewtype(ident) - } else { - let mut fields = Vec::new(); - for field in named.named.iter() { - let ident = field - .ident - .as_ref() - .expect("Named fields should have identifiers"); - fields.push((ident, ext_fn(&field.attrs, ident)?)) - } - Style::Struct(fields) - } - } - Fields::Unit => { - return Err(syn::Error::new_spanned( - &fields, - "Cannot derive FromPyObject for Unit structs and variants", - )) - } - }; - let err_name = maybe_renamed_err(&attrs)? - .map(|s| s.value()) - .unwrap_or_else(|| ident.to_string()); - - let v = Container { - ident: &ident, - style, - err_name, - }; - Ok(v) - } - - /// Build derivation body for a struct. - fn derive_struct(&self) -> TokenStream { - match &self.style { - Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), - Style::TupleNewtype => self.build_newtype_struct(None), - Style::Tuple(len) => self.build_tuple_struct(*len), - Style::Struct(tups) => self.build_struct(tups), - } - } - - fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { - if let Some(ident) = field_ident { - quote!( - Ok(Self{#ident: obj.extract()?}) - ) - } else { - quote!(Ok(Self(obj.extract()?))) - } - } - - fn build_tuple_struct(&self, len: usize) -> TokenStream { - let mut fields: Punctuated = Punctuated::new(); - for i in 0..len { - fields.push(quote!(slice[#i].extract()?)); - } - quote!( - let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; - let seq_len = s.len(); - if seq_len != #len { - let msg = format!( - "Expected tuple of length {}, but got length {}.", - #len, - seq_len - ); - return Err(::pyo3::exceptions::PyValueError::py_err(msg)) - } - let slice = s.as_slice(); - Ok(Self(#fields)) - ) - } - - fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream { - let mut fields: Punctuated = Punctuated::new(); - for (ident, ext_fn) in tups { - fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); - } - quote!(Ok(Self{#fields})) - } - - fn check_transparent_len(fields: &Fields) -> Result<()> { - if fields.len() != 1 { - return Err(syn::Error::new_spanned( - fields, - "Transparent structs and variants can only have 1 field", - )); - } - Ok(()) - } -} - -/// Get the extraction function that's called on the input object. -/// -/// Valid arguments are `get_item`, `get_attr` which are called with the -/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")` -fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result { - let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) { - attr - } else { - return Ok(parse_quote!(getattr(stringify!(#field_ident)))); - }; - if let Ok(ident) = attr.parse_args::() { - if ident != "getattr" && ident != "get_item" { - Err(syn::Error::new_spanned( - ident, - "Only get_item and getattr are valid for extraction.", - )) - } else { - let arg = field_ident.to_string(); - Ok(parse_quote!(#ident(#arg))) - } - } else if let Ok(call) = attr.parse_args() { - Ok(call) - } else { - Err(syn::Error::new_spanned( - attr, - "Only get_item and getattr are valid for extraction,\ - both can be passed with or without an argument, e.g. \ - #[extract(getattr(\"attr\")] and #[extract(getattr)]", - )) - } -} - -/// Returns the name of the variant for the error message if no variants match. -fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result> { - for attr in attrs { - if !attr.path.is_ident("rename_err") { - continue; - } - let attr = attr.parse_meta()?; - if let syn::Meta::NameValue(nv) = &attr { - match &nv.lit { - syn::Lit::Str(s) => { - return Ok(Some(s.clone())); - } - _ => { - return Err(syn::Error::new_spanned( - attr, - "rename_err attribute must be string literal: #[rename_err=\"Name\"]", - )) - } - } - } - } - Ok(None) -} - -fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { - let lifetimes = generics.lifetimes().collect::>(); - if lifetimes.len() > 1 { - return Err(syn::Error::new_spanned( - &generics, - "Only a single lifetime parameter can be specified.", - )); - } - Ok(lifetimes.into_iter().next()) -} - -/// Derive FromPyObject for enums and structs. -/// -/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier -/// * At least one field, in case of `#[transparent]`, exactly one field -/// * At least one variant for enums. -/// * Fields of input structs and enums must implement `FromPyObject` -/// * Derivation for structs with generic fields like `struct Foo(T)` -/// adds `T: FromPyObject` on the derived implementation. -pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result { - let mut trait_generics = tokens.generics.clone(); - let generics = &tokens.generics; - let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { - lt.clone() - } else { - trait_generics.params.push(parse_quote!('source)); - parse_quote!('source) - }; - let mut where_clause: syn::WhereClause = parse_quote!(where); - for param in generics.type_params() { - let gen_ident = ¶m.ident; - where_clause - .predicates - .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) - } - let derives = match &tokens.data { - syn::Data::Enum(en) => { - let en = Enum::new(en, &tokens.ident)?; - en.derive_enum() - } - syn::Data::Struct(st) => { - let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?; - st.derive_struct() - } - _ => { - return Err(syn::Error::new_spanned( - tokens, - "FromPyObject can only be derived for structs and enums.", - )) - } - }; - - let ident = &tokens.ident; - Ok(quote!( - #[automatically_derived] - impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { - fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { - #derives - } - } - )) -} diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 78db3736..2a943850 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -4,7 +4,7 @@ #![recursion_limit = "1024"] mod defs; -mod frompy; +mod from_pyobject; mod konst; mod method; mod module; @@ -16,7 +16,7 @@ mod pymethod; mod pyproto; mod utils; -pub use frompy::build_derive_from_pyobject; +pub use from_pyobject::build_derive_from_pyobject; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionAttr}; diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 2377ec03..bade299b 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -92,10 +92,10 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } -#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))] +#[proc_macro_derive(FromPyObject, attributes(pyo3, extract))] pub fn derive_from_py_object(item: TokenStream) -> TokenStream { - let mut ast = parse_macro_input!(item as syn::DeriveInput); - let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error()); + let ast = parse_macro_input!(item as syn::DeriveInput); + let expanded = build_derive_from_pyobject(&ast).unwrap_or_else(|e| e.to_compile_error()); quote!( #expanded ) diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index b1bf0cdd..23b42a4c 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -8,11 +8,11 @@ mod common; #[derive(Debug, FromPyObject)] pub struct A<'a> { - #[extract(getattr)] + #[pyo3(attribute)] s: String, - #[extract(get_item)] + #[pyo3(item)] t: &'a PyString, - #[extract(getattr("foo"))] + #[pyo3(attribute = "foo")] p: &'a PyAny, } @@ -51,7 +51,7 @@ fn test_named_fields_struct() { } #[derive(Debug, FromPyObject)] -#[transparent] +#[pyo3(transparent)] pub struct B { test: String, } @@ -69,7 +69,7 @@ fn test_transparent_named_field_struct() { } #[derive(Debug, FromPyObject)] -#[transparent] +#[pyo3(transparent)] pub struct D { test: T, } @@ -121,7 +121,7 @@ fn test_generic_named_fields_struct() { #[derive(Debug, FromPyObject)] pub struct C { - #[extract(getattr("test"))] + #[pyo3(attribute = "test")] test: String, } @@ -155,17 +155,18 @@ fn test_tuple_struct() { } #[derive(FromPyObject)] +#[pyo3(transparent)] pub struct TransparentTuple(String); #[test] fn test_transparent_tuple_struct() { let gil = Python::acquire_gil(); let py = gil.python(); - let tup = PyTuple::new(py, &[1.into_py(py)]); - let tup = TransparentTuple::extract(tup.as_ref()); + let tup: PyObject = 1.into_py(py); + let tup = TransparentTuple::extract(tup.as_ref(py)); assert!(tup.is_err()); - let tup = PyTuple::new(py, &["test".into_py(py)]); - let tup = TransparentTuple::extract(tup.as_ref()) + let test = "test".into_py(py); + let tup = TransparentTuple::extract(test.as_ref(py)) .expect("Failed to extract TransparentTuple from PyTuple"); assert_eq!(tup.0, "test"); } @@ -176,25 +177,25 @@ pub enum Foo<'a> { StructVar { test: &'a PyString, }, - #[transparent] + #[pyo3(transparent)] TransparentTuple(usize), - #[transparent] + #[pyo3(transparent)] TransparentStructVar { a: Option, }, StructVarGetAttrArg { - #[extract(getattr("bla"))] + #[pyo3(attribute = "bla")] a: bool, }, StructWithGetItem { - #[extract(get_item)] + #[pyo3(item)] a: String, }, StructWithGetItemArg { - #[extract(get_item("foo"))] + #[pyo3(item = "foo")] a: String, }, - #[transparent] + #[pyo3(transparent)] CatchAll(&'a PyAny), } @@ -279,11 +280,11 @@ fn test_enum() { #[derive(FromPyObject)] pub enum Bar { - #[rename_err = "str"] + #[pyo3(annotation = "str")] A(String), - #[rename_err = "uint"] + #[pyo3(annotation = "uint")] B(usize), - #[rename_err = "int"] + #[pyo3(annotation = "int", transparent)] C(isize), }