Derive FromPyObject

This commit is contained in:
Sebastian Pütz 2020-08-25 00:00:12 +02:00 committed by Sebastian Pütz
parent f816786de4
commit 7168309464
5 changed files with 751 additions and 3 deletions

View File

@ -0,0 +1,428 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Paren;
use syn::{
parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result,
Variant,
};
/// Describes derivation input of an enum.
#[derive(Debug)]
struct Enum<'a> {
enum_ident: &'a Ident,
vars: Vec<Container<'a>>,
}
impl<'a> Enum<'a> {
/// Construct a new enum representation.
///
/// `data_enum` is the `syn` representation of the input enum, `ident` is the
/// `Identifier` of the enum.
fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
if data_enum.variants.is_empty() {
return Err(syn::Error::new_spanned(
&data_enum.variants,
"Cannot derive FromPyObject for empty enum.",
));
}
let vars = data_enum
.variants
.iter()
.map(Container::from_variant)
.collect::<Result<Vec<_>>>()?;
Ok(Enum {
enum_ident: ident,
vars,
})
}
/// Build derivation body for enums.
fn derive_enum(&self) -> TokenStream {
let mut var_extracts = Vec::new();
let mut error_names = String::new();
for (i, var) in self.vars.iter().enumerate() {
let ext = match &var.style {
Style::Struct(tups) => self.build_struct_variant(tups, var.ident),
Style::StructNewtype(ident) => {
self.build_transparent_variant(var.ident, Some(ident))
}
Style::Tuple(len) => self.build_tuple_variant(var.ident, *len),
Style::TupleNewtype => self.build_transparent_variant(var.ident, None),
};
var_extracts.push(ext);
error_names.push_str(&var.err_name);
if i < self.vars.len() - 1 {
error_names.push_str(", ");
}
}
quote!(
#(#var_extracts)*
let type_name = obj.get_type().name();
let from = obj
.repr()
.map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
.unwrap_or_else(|_| type_name.to_string());
let err_msg = format!("Can't convert {} to {}", from, #error_names);
Err(::pyo3::exceptions::PyTypeError::py_err(err_msg))
)
}
/// Build match for tuple struct variant.
fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream {
let enum_ident = self.enum_ident;
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
let mut fields: Punctuated<Ident, syn::Token![,]> = Punctuated::new();
let mut field_pats = PatTuple {
attrs: vec![],
paren_token: Paren::default(),
elems: Default::default(),
};
for i in 0..len {
ext.push(parse_quote!(slice[#i].extract()));
let ident = Ident::new(&format!("_field{}", i), Span::call_site());
field_pats.elems.push(parse_quote!(Ok(#ident)));
fields.push(ident);
}
quote!(
match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) {
Ok(s) => {
if s.len() == #len {
let slice = s.as_slice();
if let (#field_pats) = (#ext) {
return Ok(#enum_ident::#var_ident(#fields))
}
}
},
Err(_) => {}
}
)
}
/// Build match for transparent enum variants.
fn build_transparent_variant(
&self,
var_ident: &Ident,
field_ident: Option<&Ident>,
) -> TokenStream {
let enum_ident = self.enum_ident;
if let Some(ident) = field_ident {
quote!(
if let Ok(#ident) = obj.extract() {
return Ok(#enum_ident::#var_ident{#ident})
}
)
} else {
quote!(
if let Ok(inner) = obj.extract() {
return Ok(#enum_ident::#var_ident(inner))
}
)
}
}
/// Build match for struct variant with named fields.
fn build_struct_variant(
&self,
tups: &[(&'a Ident, ExprCall)],
var_ident: &Ident,
) -> TokenStream {
let enum_ident = self.enum_ident;
let mut field_pats = PatTuple {
attrs: vec![],
paren_token: Paren::default(),
elems: Default::default(),
};
let mut fields: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
let mut ext: Punctuated<Expr, syn::Token![,]> = Punctuated::new();
for (ident, ext_fn) in tups {
field_pats.elems.push(parse_quote!(Ok(#ident)));
fields.push(parse_quote!(#ident));
ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract())));
}
quote!(if let #field_pats = #ext {
return Ok(#enum_ident::#var_ident{#fields});
})
}
}
/// Container Style
///
/// Covers Structs, Tuplestructs and corresponding Newtypes.
#[derive(Clone, Debug)]
enum Style<'a> {
/// Struct Container, e.g. `struct Foo { a: String }`
///
/// Variant contains the list of field identifiers and the corresponding extraction call.
Struct(Vec<(&'a Ident, ExprCall)>),
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
///
/// The field specified by the identifier is extracted directly from the object.
StructNewtype(&'a Ident),
/// Tuple struct, e.g. `struct Foo(String)`.
///
/// Fields are extracted from a tuple.
Tuple(usize),
/// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
///
/// The wrapped field is directly extracted from the object.
TupleNewtype,
}
/// Data container
///
/// Either describes a struct or an enum variant.
#[derive(Debug)]
struct Container<'a> {
ident: &'a Ident,
style: Style<'a>,
err_name: String,
}
impl<'a> Container<'a> {
/// Construct a container from an enum Variant.
///
/// Fails if the variant has no fields or incompatible attributes.
fn from_variant(var: &'a Variant) -> Result<Self> {
Self::new(&var.fields, &var.ident, &var.attrs)
}
/// Construct a container based on fields, identifier and attributes.
///
/// Fails if the variant has no fields or incompatible attributes.
fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result<Self> {
let transparent = attrs.iter().any(|a| a.path.is_ident("transparent"));
if transparent {
Self::check_transparent_len(fields)?;
}
let style = match fields {
Fields::Unnamed(unnamed) => {
if transparent {
Style::TupleNewtype
} else {
Style::Tuple(unnamed.unnamed.len())
}
}
Fields::Named(named) => {
if transparent {
let field = named
.named
.iter()
.next()
.expect("Check for len 1 is done above");
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
Style::StructNewtype(ident)
} else {
let mut fields = Vec::new();
for field in named.named.iter() {
let ident = field
.ident
.as_ref()
.expect("Named fields should have identifiers");
fields.push((ident, ext_fn(&field.attrs, ident)?))
}
Style::Struct(fields)
}
}
Fields::Unit => {
return Err(syn::Error::new_spanned(
&fields,
"Cannot derive FromPyObject for Unit structs and variants",
))
}
};
let err_name = maybe_renamed_err(&attrs)?
.map(|s| s.value())
.unwrap_or_else(|| ident.to_string());
let v = Container {
ident: &ident,
style,
err_name,
};
Ok(v)
}
/// Build derivation body for a struct.
fn derive_struct(&self) -> TokenStream {
match &self.style {
Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
Style::TupleNewtype => self.build_newtype_struct(None),
Style::Tuple(len) => self.build_tuple_struct(*len),
Style::Struct(tups) => self.build_struct(tups),
}
}
fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
if let Some(ident) = field_ident {
quote!(
Ok(Self{#ident: obj.extract()?})
)
} else {
quote!(Ok(Self(obj.extract()?)))
}
}
fn build_tuple_struct(&self, len: usize) -> TokenStream {
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for i in 0..len {
fields.push(quote!(slice[#i].extract()?));
}
quote!(
let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
let seq_len = s.len();
if seq_len != #len {
let msg = format!(
"Expected tuple of length {}, but got length {}.",
#len,
seq_len
);
return Err(::pyo3::exceptions::PyValueError::py_err(msg))
}
let slice = s.as_slice();
Ok(Self(#fields))
)
}
fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream {
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
for (ident, ext_fn) in tups {
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
}
quote!(Ok(Self{#fields}))
}
fn check_transparent_len(fields: &Fields) -> Result<()> {
if fields.len() != 1 {
return Err(syn::Error::new_spanned(
fields,
"Transparent structs and variants can only have 1 field",
));
}
Ok(())
}
}
/// Get the extraction function that's called on the input object.
///
/// Valid arguments are `get_item`, `get_attr` which are called with the
/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")`
fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result<syn::ExprCall> {
let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) {
attr
} else {
return Ok(parse_quote!(getattr(stringify!(#field_ident))));
};
if let Ok(ident) = attr.parse_args::<Ident>() {
if ident != "getattr" && ident != "get_item" {
Err(syn::Error::new_spanned(
ident,
"Only get_item and getattr are valid for extraction.",
))
} else {
let arg = field_ident.to_string();
Ok(parse_quote!(#ident(#arg)))
}
} else if let Ok(call) = attr.parse_args() {
Ok(call)
} else {
Err(syn::Error::new_spanned(
attr,
"Only get_item and getattr are valid for extraction,\
both can be passed with or without an argument, e.g. \
#[extract(getattr(\"attr\")] and #[extract(getattr)]",
))
}
}
/// Returns the name of the variant for the error message if no variants match.
fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result<Option<syn::LitStr>> {
for attr in attrs {
if !attr.path.is_ident("rename_err") {
continue;
}
let attr = attr.parse_meta()?;
if let syn::Meta::NameValue(nv) = &attr {
match &nv.lit {
syn::Lit::Str(s) => {
return Ok(Some(s.clone()));
}
_ => {
return Err(syn::Error::new_spanned(
attr,
"rename_err attribute must be string literal: #[rename_err=\"Name\"]",
))
}
}
}
}
Ok(None)
}
fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
let lifetimes = generics.lifetimes().collect::<Vec<_>>();
if lifetimes.len() > 1 {
return Err(syn::Error::new_spanned(
&generics,
"Only a single lifetime parameter can be specified.",
));
}
Ok(lifetimes.into_iter().next())
}
/// Derive FromPyObject for enums and structs.
///
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
/// * At least one field, in case of `#[transparent]`, exactly one field
/// * At least one variant for enums.
/// * Fields of input structs and enums must implement `FromPyObject`
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
/// adds `T: FromPyObject` on the derived implementation.
pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result<TokenStream> {
let mut trait_generics = tokens.generics.clone();
let generics = &tokens.generics;
let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
lt.clone()
} else {
trait_generics.params.push(parse_quote!('source));
parse_quote!('source)
};
let mut where_clause: syn::WhereClause = parse_quote!(where);
for param in generics.type_params() {
let gen_ident = &param.ident;
where_clause
.predicates
.push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
}
let derives = match &tokens.data {
syn::Data::Enum(en) => {
let en = Enum::new(en, &tokens.ident)?;
en.derive_enum()
}
syn::Data::Struct(st) => {
let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?;
st.derive_struct()
}
_ => {
return Err(syn::Error::new_spanned(
tokens,
"FromPyObject can only be derived for structs and enums.",
))
}
};
let ident = &tokens.ident;
Ok(quote!(
#[automatically_derived]
impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
#derives
}
}
))
}

View File

@ -4,6 +4,7 @@
#![recursion_limit = "1024"]
mod defs;
mod frompy;
mod konst;
mod method;
mod module;
@ -15,6 +16,7 @@ mod pymethod;
mod pyproto;
mod utils;
pub use frompy::build_derive_from_pyobject;
pub use module::{add_fn_to_module, process_functions_in_module, py_init};
pub use pyclass::{build_py_class, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionAttr};

View File

@ -5,8 +5,8 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use pyo3_derive_backend::{
build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc,
process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr,
build_derive_from_pyobject, build_py_class, build_py_function, build_py_methods,
build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr,
};
use quote::quote;
use syn::parse_macro_input;
@ -91,3 +91,13 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
)
.into()
}
#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))]
pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(item as syn::DeriveInput);
let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error());
quote!(
#expanded
)
.into()
}

View File

@ -20,4 +20,4 @@ pub use crate::{FromPyObject, IntoPy, IntoPyPointer, PyTryFrom, PyTryInto, ToPyO
// PyModule is only part of the prelude because we need it for the pymodule function
pub use crate::types::{PyAny, PyModule};
#[cfg(feature = "macros")]
pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto};
pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject};

308
tests/test_frompyobject.rs Normal file
View File

@ -0,0 +1,308 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString, PyTuple};
use pyo3::{PyErrValue, PyMappingProtocol};
#[macro_use]
mod common;
#[derive(Debug, FromPyObject)]
pub struct A<'a> {
#[extract(getattr)]
s: String,
#[extract(get_item)]
t: &'a PyString,
#[extract(getattr("foo"))]
p: &'a PyAny,
}
#[pyclass]
pub struct PyA {
#[pyo3(get)]
s: String,
#[pyo3(get)]
foo: Option<String>,
}
#[pyproto]
impl PyMappingProtocol for PyA {
fn __getitem__(&self, key: String) -> pyo3::PyResult<String> {
if key == "t" {
Ok("bar".into())
} else {
Err(PyValueError::py_err("Failed"))
}
}
}
#[test]
fn test_named_fields_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let pya = PyA {
s: "foo".into(),
foo: None,
};
let py_c = Py::new(py, pya).unwrap();
let a: A = FromPyObject::extract(py_c.as_ref(py)).expect("Failed to extract A from PyA");
assert_eq!(a.s, "foo");
assert_eq!(a.t.to_string_lossy(), "bar");
assert!(a.p.is_none());
}
#[derive(Debug, FromPyObject)]
#[transparent]
pub struct B {
test: String,
}
#[test]
fn test_transparent_named_field_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let test = "test".into_py(py);
let b: B = FromPyObject::extract(test.as_ref(py)).expect("Failed to extract B from String");
assert_eq!(b.test, "test");
let test: PyObject = 1.into_py(py);
let b = B::extract(test.as_ref(py));
assert!(b.is_err())
}
#[derive(Debug, FromPyObject)]
#[transparent]
pub struct D<T> {
test: T,
}
#[test]
fn test_generic_transparent_named_field_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let test = "test".into_py(py);
let d: D<String> =
D::extract(test.as_ref(py)).expect("Failed to extract D<String> from String");
assert_eq!(d.test, "test");
let test = 1usize.into_py(py);
let d: D<usize> = D::extract(test.as_ref(py)).expect("Failed to extract D<usize> from String");
assert_eq!(d.test, 1);
}
#[derive(Debug, FromPyObject)]
pub struct E<T, T2> {
test: T,
test2: T2,
}
#[pyclass]
pub struct PyE {
#[pyo3(get)]
test: String,
#[pyo3(get)]
test2: usize,
}
#[test]
fn test_generic_named_fields_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let pye = PyE {
test: "test".into(),
test2: 2,
}
.into_py(py);
let e: E<String, usize> =
E::extract(pye.as_ref(py)).expect("Failed to extract E<String, usize> from PyE");
assert_eq!(e.test, "test");
assert_eq!(e.test2, 2);
let e = E::<usize, usize>::extract(pye.as_ref(py));
assert!(e.is_err());
}
#[derive(Debug, FromPyObject)]
pub struct C {
#[extract(getattr("test"))]
test: String,
}
#[test]
fn test_named_field_with_ext_fn() {
let gil = Python::acquire_gil();
let py = gil.python();
let pyc = PyE {
test: "foo".into(),
test2: 0,
}
.into_py(py);
let c = C::extract(pyc.as_ref(py)).expect("Failed to extract C from PyE");
assert_eq!(c.test, "foo");
}
#[derive(FromPyObject)]
pub struct Tuple(String, usize);
#[test]
fn test_tuple_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let tup = PyTuple::new(py, &[1.into_py(py), "test".into_py(py)]);
let tup = Tuple::extract(tup.as_ref());
assert!(tup.is_err());
let tup = PyTuple::new(py, &["test".into_py(py), 1.into_py(py)]);
let tup = Tuple::extract(tup.as_ref()).expect("Failed to extract Tuple from PyTuple");
assert_eq!(tup.0, "test");
assert_eq!(tup.1, 1);
}
#[derive(FromPyObject)]
pub struct TransparentTuple(String);
#[test]
fn test_transparent_tuple_struct() {
let gil = Python::acquire_gil();
let py = gil.python();
let tup = PyTuple::new(py, &[1.into_py(py)]);
let tup = TransparentTuple::extract(tup.as_ref());
assert!(tup.is_err());
let tup = PyTuple::new(py, &["test".into_py(py)]);
let tup = TransparentTuple::extract(tup.as_ref())
.expect("Failed to extract TransparentTuple from PyTuple");
assert_eq!(tup.0, "test");
}
#[derive(Debug, FromPyObject)]
pub enum Foo<'a> {
TupleVar(usize, String),
StructVar {
test: &'a PyString,
},
#[transparent]
TransparentTuple(usize),
#[transparent]
TransparentStructVar {
a: Option<String>,
},
StructVarGetAttrArg {
#[extract(getattr("bla"))]
a: bool,
},
StructWithGetItem {
#[extract(get_item)]
a: String,
},
StructWithGetItemArg {
#[extract(get_item("foo"))]
a: String,
},
#[transparent]
CatchAll(&'a PyAny),
}
#[pyclass]
pub struct PyBool {
#[pyo3(get)]
bla: bool,
}
#[test]
fn test_enum() {
let gil = Python::acquire_gil();
let py = gil.python();
let tup = PyTuple::new(py, &[1.into_py(py), "test".into_py(py)]);
let f = Foo::extract(tup.as_ref()).expect("Failed to extract Foo from tuple");
match f {
Foo::TupleVar(test, test2) => {
assert_eq!(test, 1);
assert_eq!(test2, "test");
}
_ => panic!("Expected extracting Foo::TupleVar, got {:?}", f),
}
let pye = PyE {
test: "foo".into(),
test2: 0,
}
.into_py(py);
let f = Foo::extract(pye.as_ref(py)).expect("Failed to extract Foo from PyE");
match f {
Foo::StructVar { test } => assert_eq!(test.to_string_lossy(), "foo"),
_ => panic!("Expected extracting Foo::StructVar, got {:?}", f),
}
let int: PyObject = 1.into_py(py);
let f = Foo::extract(int.as_ref(py)).expect("Failed to extract Foo from int");
match f {
Foo::TransparentTuple(test) => assert_eq!(test, 1),
_ => panic!("Expected extracting Foo::TransparentTuple, got {:?}", f),
}
let none = py.None();
let f = Foo::extract(none.as_ref(py)).expect("Failed to extract Foo from int");
match f {
Foo::TransparentStructVar { a } => assert!(a.is_none()),
_ => panic!("Expected extracting Foo::TransparentStructVar, got {:?}", f),
}
let pybool = PyBool { bla: true }.into_py(py);
let f = Foo::extract(pybool.as_ref(py)).expect("Failed to extract Foo from PyBool");
match f {
Foo::StructVarGetAttrArg { a } => assert!(a),
_ => panic!("Expected extracting Foo::StructVarGetAttrArg, got {:?}", f),
}
let dict = PyDict::new(py);
dict.set_item("a", "test").expect("Failed to set item");
let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict");
match f {
Foo::StructWithGetItem { a } => assert_eq!(a, "test"),
_ => panic!("Expected extracting Foo::StructWithGetItem, got {:?}", f),
}
let dict = PyDict::new(py);
dict.set_item("foo", "test").expect("Failed to set item");
let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict");
match f {
Foo::StructWithGetItemArg { a } => assert_eq!(a, "test"),
_ => panic!("Expected extracting Foo::StructWithGetItemArg, got {:?}", f),
}
let dict = PyDict::new(py);
let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict");
match f {
Foo::CatchAll(any) => {
let d = <&PyDict>::extract(any).expect("Expected pydict");
assert!(d.is_empty());
}
_ => panic!("Expected extracting Foo::CatchAll, got {:?}", f),
}
}
#[derive(FromPyObject)]
pub enum Bar {
#[rename_err = "str"]
A(String),
#[rename_err = "uint"]
B(usize),
#[rename_err = "int"]
C(isize),
}
#[test]
fn test_err_rename() {
let gil = Python::acquire_gil();
let py = gil.python();
let dict = PyDict::new(py);
let f = Bar::extract(dict.as_ref());
assert!(f.is_err());
match f {
Ok(_) => {}
Err(e) => match e.pvalue {
PyErrValue::ToObject(to) => {
let o = to.to_object(py);
let s = String::extract(o.as_ref(py)).expect("Err val is not a string");
assert_eq!(s, "Can't convert {} (dict) to str, uint, int")
}
_ => panic!("Expected PyErrValue::ToObject"),
},
}
}