update `#[derive(FromPyObject)]` to use `extract_bound` (#3828)

* update `#[derive(FromPyObject)]` to use `extract_bound`

* type inference for `from_py_with` using function pointers
This commit is contained in:
Icxolu 2024-02-13 01:09:41 +01:00 committed by GitHub
parent 94b7d7e434
commit fbfeb2ff03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 22 deletions

View File

@ -271,7 +271,7 @@ impl<'a> Container<'a> {
value: expr_path, .. value: expr_path, ..
}) => quote! { }) => quote! {
Ok(#self_ty { Ok(#self_ty {
#ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj, #struct_name, #field_name)? #ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
}) })
}, },
} }
@ -283,7 +283,7 @@ impl<'a> Container<'a> {
Some(FromPyWithAttribute { Some(FromPyWithAttribute {
value: expr_path, .. value: expr_path, ..
}) => quote! ( }) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, obj, #struct_name, 0).map(#self_ty) _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
), ),
} }
} }
@ -298,12 +298,12 @@ impl<'a> Container<'a> {
let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| { let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
match &field.from_py_with { match &field.from_py_with {
None => quote!( None => quote!(
_pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)? _pyo3::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
), ),
Some(FromPyWithAttribute { Some(FromPyWithAttribute {
value: expr_path, .. value: expr_path, ..
}) => quote! ( }) => quote! (
_pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)? _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
), ),
} }
}); });
@ -339,12 +339,12 @@ impl<'a> Container<'a> {
}; };
let extractor = match &field.from_py_with { let extractor = match &field.from_py_with {
None => { None => {
quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?) quote!(_pyo3::impl_::frompyobject::extract_struct_field(&obj.#getter?, #struct_name, #field_name)?)
} }
Some(FromPyWithAttribute { Some(FromPyWithAttribute {
value: expr_path, .. value: expr_path, ..
}) => { }) => {
quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?) quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &obj.#getter?, #struct_name, #field_name)?)
} }
}; };
@ -606,10 +606,11 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
Ok(quote!( Ok(quote!(
const _: () = { const _: () = {
use #krate as _pyo3; use #krate as _pyo3;
use _pyo3::prelude::PyAnyMethods;
#[automatically_derived] #[automatically_derived]
impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause { impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause {
fn extract(obj: &#lt_param _pyo3::PyAny) -> _pyo3::PyResult<Self> { fn extract_bound(obj: &_pyo3::Bound<#lt_param, _pyo3::PyAny>) -> _pyo3::PyResult<Self> {
#derives #derives
} }
} }

View File

@ -1,5 +1,33 @@
use crate::types::any::PyAnyMethods;
use crate::Bound;
use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python}; use crate::{exceptions::PyTypeError, FromPyObject, PyAny, PyErr, PyResult, Python};
pub enum Extractor<'a, 'py, T> {
Bound(fn(&'a Bound<'py, PyAny>) -> PyResult<T>),
GilRef(fn(&'a PyAny) -> PyResult<T>),
}
impl<'a, 'py, T> From<fn(&'a Bound<'py, PyAny>) -> PyResult<T>> for Extractor<'a, 'py, T> {
fn from(value: fn(&'a Bound<'py, PyAny>) -> PyResult<T>) -> Self {
Self::Bound(value)
}
}
impl<'a, T> From<fn(&'a PyAny) -> PyResult<T>> for Extractor<'a, '_, T> {
fn from(value: fn(&'a PyAny) -> PyResult<T>) -> Self {
Self::GilRef(value)
}
}
impl<'a, 'py, T> Extractor<'a, 'py, T> {
fn call(self, obj: &'a Bound<'py, PyAny>) -> PyResult<T> {
match self {
Extractor::Bound(f) => f(obj),
Extractor::GilRef(f) => f(obj.as_gil_ref()),
}
}
}
#[cold] #[cold]
pub fn failed_to_extract_enum( pub fn failed_to_extract_enum(
py: Python<'_>, py: Python<'_>,
@ -41,7 +69,7 @@ fn extract_traceback(py: Python<'_>, mut error: PyErr) -> String {
} }
pub fn extract_struct_field<'py, T>( pub fn extract_struct_field<'py, T>(
obj: &'py PyAny, obj: &Bound<'py, PyAny>,
struct_name: &str, struct_name: &str,
field_name: &str, field_name: &str,
) -> PyResult<T> ) -> PyResult<T>
@ -59,13 +87,13 @@ where
} }
} }
pub fn extract_struct_field_with<'py, T>( pub fn extract_struct_field_with<'a, 'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>, extractor: impl Into<Extractor<'a, 'py, T>>,
obj: &'py PyAny, obj: &'a Bound<'py, PyAny>,
struct_name: &str, struct_name: &str,
field_name: &str, field_name: &str,
) -> PyResult<T> { ) -> PyResult<T> {
match extractor(obj) { match extractor.into().call(obj) {
Ok(value) => Ok(value), Ok(value) => Ok(value),
Err(err) => Err(failed_to_extract_struct_field( Err(err) => Err(failed_to_extract_struct_field(
obj.py(), obj.py(),
@ -92,7 +120,7 @@ fn failed_to_extract_struct_field(
} }
pub fn extract_tuple_struct_field<'py, T>( pub fn extract_tuple_struct_field<'py, T>(
obj: &'py PyAny, obj: &Bound<'py, PyAny>,
struct_name: &str, struct_name: &str,
index: usize, index: usize,
) -> PyResult<T> ) -> PyResult<T>
@ -110,13 +138,13 @@ where
} }
} }
pub fn extract_tuple_struct_field_with<'py, T>( pub fn extract_tuple_struct_field_with<'a, 'py, T>(
extractor: impl FnOnce(&'py PyAny) -> PyResult<T>, extractor: impl Into<Extractor<'a, 'py, T>>,
obj: &'py PyAny, obj: &'a Bound<'py, PyAny>,
struct_name: &str, struct_name: &str,
index: usize, index: usize,
) -> PyResult<T> { ) -> PyResult<T> {
match extractor(obj) { match extractor.into().call(obj) {
Ok(value) => Ok(value), Ok(value) => Ok(value),
Err(err) => Err(failed_to_extract_tuple_struct_field( Err(err) => Err(failed_to_extract_tuple_struct_field(
obj.py(), obj.py(),

View File

@ -502,7 +502,7 @@ pub struct Zap {
#[pyo3(item)] #[pyo3(item)]
name: String, name: String,
#[pyo3(from_py_with = "PyAny::len", item("my_object"))] #[pyo3(from_py_with = "Bound::<'_, PyAny>::len", item("my_object"))]
some_object_length: usize, some_object_length: usize,
} }
@ -525,7 +525,10 @@ fn test_from_py_with() {
} }
#[derive(Debug, FromPyObject)] #[derive(Debug, FromPyObject)]
pub struct ZapTuple(String, #[pyo3(from_py_with = "PyAny::len")] usize); pub struct ZapTuple(
String,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize,
);
#[test] #[test]
fn test_from_py_with_tuple_struct() { fn test_from_py_with_tuple_struct() {
@ -560,8 +563,11 @@ fn test_from_py_with_tuple_struct_error() {
#[derive(Debug, FromPyObject, PartialEq, Eq)] #[derive(Debug, FromPyObject, PartialEq, Eq)]
pub enum ZapEnum { pub enum ZapEnum {
Zip(#[pyo3(from_py_with = "PyAny::len")] usize), Zip(#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize),
Zap(String, #[pyo3(from_py_with = "PyAny::len")] usize), Zap(
String,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] usize,
),
} }
#[test] #[test]
@ -581,7 +587,7 @@ fn test_from_py_with_enum() {
#[derive(Debug, FromPyObject, PartialEq, Eq)] #[derive(Debug, FromPyObject, PartialEq, Eq)]
#[pyo3(transparent)] #[pyo3(transparent)]
pub struct TransparentFromPyWith { pub struct TransparentFromPyWith {
#[pyo3(from_py_with = "PyAny::len")] #[pyo3(from_py_with = "Bound::<'_, PyAny>::len")]
len: usize, len: usize,
} }