Make `PyIterProtocol` methods accept both `PyRef` and `PyRefMut`

This commit is contained in:
Martin Larralde 2020-04-18 03:39:50 +02:00
parent 42e84ea6ff
commit 187d889565
5 changed files with 73 additions and 16 deletions

View File

@ -260,13 +260,15 @@ pub const ITER: Proto = Proto {
name: "Iter",
py_methods: &[],
methods: &[
MethodProto::Unary {
MethodProto::UnaryS {
name: "__iter__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterIterProtocol",
},
MethodProto::Unary {
MethodProto::UnaryS {
name: "__next__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterNextProtocol",
},

View File

@ -18,6 +18,12 @@ pub enum MethodProto {
pyres: bool,
proto: &'static str,
},
UnaryS {
name: &'static str,
arg: &'static str,
pyres: bool,
proto: &'static str,
},
Binary {
name: &'static str,
arg: &'static str,
@ -60,6 +66,7 @@ impl MethodProto {
match *self {
MethodProto::Free { ref name, .. } => name,
MethodProto::Unary { ref name, .. } => name,
MethodProto::UnaryS { ref name, .. } => name,
MethodProto::Binary { ref name, .. } => name,
MethodProto::BinaryS { ref name, .. } => name,
MethodProto::Ternary { ref name, .. } => name,
@ -114,6 +121,54 @@ pub(crate) fn impl_method_proto(
}
}
}
MethodProto::UnaryS { pyres, proto, arg, .. } => {
let p: syn::Path = syn::parse_str(proto).unwrap();
let (ty, succ) = get_res_success(ty);
let slf_name = syn::Ident::new(arg, Span::call_site());
let mut slf_ty = get_arg_ty(sig, 0);
// update the type if not lifetime was given:
// PyRef<Self> --> PyRef<'p, Self>
if let syn::Type::Path(path) = &mut slf_ty {
if let syn::PathArguments::AngleBracketed(args) = &mut path.path.segments[0].arguments {
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
} else {
let lt = syn::parse_quote! {'p};
args.args.insert(0, lt);
}
}
}
let tmp: syn::ItemFn = syn::parse_quote! {
fn test(&self) -> <#cls as #p<'p>>::Result {}
};
sig.output = tmp.sig.output;
modify_self_ty(sig);
if let syn::FnArg::Typed(ref mut arg) = sig.inputs[0] {
arg.ty = Box::new(syn::parse_quote! {
<#cls as #p<'p>>::#slf_name
});
}
if pyres {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Success = #succ;
type Result = #ty;
}
}
} else {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Result = #ty;
}
}
}
}
MethodProto::Binary {
name,
arg,

View File

@ -4,7 +4,8 @@
use crate::callback::IntoPyCallbackOutput;
use crate::err::PyResult;
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python};
use crate::pycell::TryFromPyCell;
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python};
/// Python Iterator Interface.
///
@ -12,14 +13,14 @@ use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python};
/// for more.
#[allow(unused_variables)]
pub trait PyIterProtocol<'p>: PyClass {
fn __iter__(slf: PyRefMut<Self>) -> Self::Result
fn __iter__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterIterProtocol<'p>,
{
unimplemented!()
}
fn __next__(slf: PyRefMut<Self>) -> Self::Result
fn __next__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterNextProtocol<'p>,
{
@ -28,11 +29,13 @@ pub trait PyIterProtocol<'p>: PyClass {
}
pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Self::Success>>;
}
pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Option<Self::Success>>>;
}
@ -76,7 +79,7 @@ where
{
#[inline]
fn tp_iter() -> Option<ffi::getiterfunc> {
py_unary_refmut_func!(PyIterIterProtocol, T::__iter__)
py_unarys_func!(PyIterIterProtocol, T::__iter__)
}
}
@ -99,7 +102,7 @@ where
{
#[inline]
fn tp_iternext() -> Option<ffi::iternextfunc> {
py_unary_refmut_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
}
}

View File

@ -28,7 +28,7 @@ macro_rules! py_unary_func {
#[macro_export]
#[doc(hidden)]
macro_rules! py_unary_refmut_func {
macro_rules! py_unarys_func {
($trait:ident, $class:ident :: $f:ident $(, $conv:expr)?) => {{
unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject
where
@ -38,7 +38,9 @@ macro_rules! py_unary_refmut_func {
let py = pool.python();
$crate::run_callback(py, || {
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let res = $class::$f(slf.borrow_mut()).into();
let borrow = <T::Receiver>::try_from_pycell(slf)
.map_err(|e| e.into())?;
let res = $class::$f(borrow).into();
$crate::callback::convert(py, res $(.map($conv))?)
})
}

View File

@ -5,7 +5,6 @@ use crate::pyclass_slots::{PyClassDict, PyClassWeakRef};
use crate::type_object::{PyBorrowFlagLayout, PyDowncastImpl, PyLayout, PySizedLayout, PyTypeInfo};
use crate::{ffi, FromPy, PyAny, PyClass, PyErr, PyNativeType, PyObject, PyResult, Python};
use std::cell::{Cell, UnsafeCell};
use std::error::Error;
use std::fmt;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
@ -653,7 +652,7 @@ impl<T: PyClass + fmt::Debug> fmt::Debug for PyRefMut<'_, T> {
/// This serves to unify the use of `PyRef` and `PyRefMut` in automatically
/// derived code, since both types can be obtained from a `PyCell`.
pub trait TryFromPyCell<'a, T: PyClass>: Sized {
type Error: Error + Into<PyErr>;
type Error: Into<PyErr>;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error>;
}
@ -661,7 +660,7 @@ impl <'a, T, R> TryFromPyCell<'a, T> for R
where
T: 'a + PyClass,
R: std::convert::TryFrom<&'a PyCell<T>>,
R::Error: Error + Into<PyErr>,
R::Error: Into<PyErr>,
{
type Error = R::Error;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error> {
@ -702,8 +701,6 @@ impl fmt::Display for PyBorrowError {
}
}
impl Error for PyBorrowError {}
/// An error returned by [`PyCell::try_borrow_mut`](struct.PyCell.html#method.try_borrow_mut).
///
/// In Python, you can catch this error by `except RuntimeError`.
@ -723,7 +720,5 @@ impl fmt::Display for PyBorrowMutError {
}
}
impl Error for PyBorrowMutError {}
pyo3_exception!(PyBorrowError, crate::exceptions::RuntimeError);
pyo3_exception!(PyBorrowMutError, crate::exceptions::RuntimeError);