Allow slf: Py<Self>/PyRef<Self>/PyRefMut<Self> in pymethods

This commit is contained in:
kngwyu 2019-03-28 22:55:41 +09:00
parent f8bf258602
commit 515c7beac0
5 changed files with 199 additions and 8 deletions

View file

@ -27,6 +27,35 @@ pub enum FnType {
FnCall,
FnClass,
FnStatic,
PySelf(PySelfType),
}
/// For fn(slf: &PyRef<Self>) support
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum PySelfType {
Py,
PyRef,
PyRefMut,
}
impl PySelfType {
fn from_args<'a>(args: &[FnArg<'a>]) -> Option<Self> {
let arg = args.iter().next()?;
let path = match arg.ty {
syn::Type::Path(p) => p,
_ => return None,
};
let last_seg = match path.path.segments.last()? {
syn::punctuated::Pair::Punctuated(t, _) => t,
syn::punctuated::Pair::End(t) => t,
};
match &*last_seg.ident.to_string() {
"Py" => Some(PySelfType::Py),
"PyRef" => Some(PySelfType::PyRef),
"PyRefMut" => Some(PySelfType::PyRefMut),
_ => None,
}
}
}
#[derive(Clone, PartialEq, Debug)]
@ -51,11 +80,10 @@ impl<'a> FnSpec<'a> {
sig: &'a syn::MethodSig,
meth_attrs: &'a mut Vec<syn::Attribute>,
) -> syn::Result<FnSpec<'a>> {
let (fn_type, fn_attrs) = parse_attributes(meth_attrs)?;
let (mut fn_type, fn_attrs) = parse_attributes(meth_attrs)?;
let mut has_self = false;
let mut arguments = Vec::new();
for input in sig.decl.inputs.iter() {
match input {
syn::FnArg::SelfRef(_) => {
@ -119,6 +147,17 @@ impl<'a> FnSpec<'a> {
let ty = get_return_info(&sig.decl.output);
if fn_type == FnType::Fn && !has_self {
if let Some(pyslf) = PySelfType::from_args(&arguments) {
fn_type = FnType::PySelf(pyslf);
arguments.remove(0);
} else {
panic!(
"Static method needs an attribute #[staticmethod] or PyRef/PyRefMut as the 1st arg"
);
}
}
Ok(FnSpec {
tp: fn_type,
attrs: fn_attrs,

View file

@ -1,6 +1,6 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
use crate::method::{FnArg, FnSpec, FnType};
use crate::method::{FnArg, FnSpec, FnType, PySelfType};
use crate::utils;
use proc_macro2::{Span, TokenStream};
use quote::quote;
@ -18,6 +18,12 @@ pub fn gen_py_method(
match spec.tp {
FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)),
FnType::PySelf(pyslf) => impl_py_method_def(
name,
doc,
&spec,
&impl_wrap_pyslf(cls, name, &spec, pyslf, true),
),
FnType::FnNew => impl_py_method_def_new(name, doc, &impl_wrap_new(cls, name, &spec)),
FnType::FnInit => impl_py_method_def_init(name, doc, &impl_wrap_init(cls, name, &spec)),
FnType::FnCall => impl_py_method_def_call(name, doc, &impl_wrap(cls, name, &spec, false)),
@ -48,7 +54,45 @@ pub fn impl_wrap(
noargs: bool,
) -> TokenStream {
let body = impl_call(cls, name, &spec);
let slf = quote! {
let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf);
};
impl_wrap_common(cls, name, spec, noargs, slf, body)
}
pub fn impl_wrap_pyslf(
cls: &syn::Type,
name: &syn::Ident,
spec: &FnSpec<'_>,
slftype: PySelfType,
noargs: bool,
) -> TokenStream {
let names = get_arg_names(spec);
let body = quote! {
#cls::#name(_slf, #(#names),*)
};
let slf = match slftype {
PySelfType::Py => quote! {
let _slf = pyo3::Py::<#cls>::from_borrowed_ptr(_slf);
},
PySelfType::PyRef => quote! {
let _slf = pyo3::PyRef::<#cls>::from_borrowed_ptr(_py, _slf);
},
PySelfType::PyRefMut => quote! {
let _slf = pyo3::PyRefMut::<#cls>::from_borrowed_ptr(_py, _slf);
},
};
impl_wrap_common(cls, name, spec, noargs, slf, body)
}
fn impl_wrap_common(
cls: &syn::Type,
name: &syn::Ident,
spec: &FnSpec<'_>,
noargs: bool,
slf: TokenStream,
body: TokenStream,
) -> TokenStream {
if spec.args.is_empty() && noargs {
quote! {
unsafe extern "C" fn __wrap(
@ -59,8 +103,7 @@ pub fn impl_wrap(
stringify!(#cls), ".", stringify!(#name), "()");
let _pool = pyo3::GILPool::new();
let _py = pyo3::Python::assume_gil_acquired();
let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf);
#slf
let _result = {
pyo3::derive_utils::IntoPyResult::into_py_result(#body)
};
@ -82,7 +125,7 @@ pub fn impl_wrap(
stringify!(#cls), ".", stringify!(#name), "()");
let _pool = pyo3::GILPool::new();
let _py = pyo3::Python::assume_gil_acquired();
let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf);
#slf
let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);

View file

@ -11,7 +11,7 @@ use crate::types::PyAny;
use crate::AsPyPointer;
use crate::IntoPyPointer;
use crate::Python;
use crate::{FromPyObject, IntoPyObject, ToPyObject};
use crate::{FromPyObject, IntoPyObject, PyTryFrom, ToPyObject};
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
@ -72,6 +72,12 @@ impl<'a, T: PyTypeInfo> PyRef<'a, T> {
pub(crate) fn from_ref(r: &'a T) -> Self {
PyRef(r, PhantomData)
}
pub unsafe fn from_owned_ptr(py: Python<'a>, ptr: *mut ffi::PyObject) -> Self {
Self::from_ref(py.from_owned_ptr(ptr))
}
pub unsafe fn from_borrowed_ptr(py: Python<'a>, ptr: *mut ffi::PyObject) -> Self {
Self::from_ref(py.from_borrowed_ptr(ptr))
}
}
impl<'a, T> PyRef<'a, T>
@ -105,6 +111,15 @@ impl<'a, T: PyTypeInfo> Deref for PyRef<'a, T> {
}
}
impl<'a, T> FromPyObject<'a> for PyRef<'a, T>
where
T: PyTypeInfo,
{
fn extract(ob: &'a PyAny) -> PyResult<PyRef<'a, T>> {
T::try_from(ob).map(PyRef::from_ref).map_err(Into::into)
}
}
/// Mutable version of [`PyRef`](struct.PyRef.html).
/// # Example
/// ```
@ -135,6 +150,12 @@ impl<'a, T: PyTypeInfo> PyRefMut<'a, T> {
pub(crate) fn from_mut(t: &'a mut T) -> Self {
PyRefMut(t, PhantomData)
}
pub unsafe fn from_owned_ptr(py: Python<'a>, ptr: *mut ffi::PyObject) -> Self {
Self::from_mut(py.mut_from_owned_ptr(ptr))
}
pub unsafe fn from_borrowed_ptr(py: Python<'a>, ptr: *mut ffi::PyObject) -> Self {
Self::from_mut(py.mut_from_borrowed_ptr(ptr))
}
}
impl<'a, T> PyRefMut<'a, T>
@ -174,6 +195,17 @@ impl<'a, T: PyTypeInfo> DerefMut for PyRefMut<'a, T> {
}
}
impl<'a, T> FromPyObject<'a> for PyRefMut<'a, T>
where
T: PyTypeInfo,
{
fn extract(ob: &'a PyAny) -> PyResult<PyRefMut<'a, T>> {
T::try_from_mut(ob)
.map(PyRefMut::from_mut)
.map_err(Into::into)
}
}
/// Trait implements object reference extraction from python managed pointer.
pub trait AsPyRef<T: PyTypeInfo>: Sized {
/// Return reference to object.

View file

@ -240,7 +240,6 @@ impl<'p> Python<'p> {
/// Register `ffi::PyObject` pointer in release pool,
/// and do unchecked downcast to specific type.
pub unsafe fn from_owned_ptr<T>(self, ptr: *mut ffi::PyObject) -> &'p T
where
T: PyTypeInfo,

78
tests/test_nested_iter.rs Normal file
View file

@ -0,0 +1,78 @@
//! Rust value -> Python Iterator
//! Inspired by https://github.com/jothan/cordoba, thanks.
use pyo3;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyString};
use pyo3::PyIterProtocol;
use std::collections::HashMap;
#[macro_use]
mod common;
/// Assumes it's a file reader or so.
#[pyclass]
struct Reader {
inner: HashMap<u8, String>,
}
#[pymethods]
impl Reader {
fn get_optional(&self, test: Option<i32>) -> PyResult<i32> {
Ok(test.unwrap_or(10))
}
fn get_iter(slf: PyRef<Self>, keys: Py<PyBytes>) -> PyResult<Iter> {
Ok(Iter {
reader: slf.into(),
keys: keys,
idx: 0,
})
}
}
#[pyclass]
struct Iter {
reader: Py<Reader>,
keys: Py<PyBytes>,
idx: usize,
}
#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyObject> {
let py = unsafe { Python::assume_gil_acquired() };
Ok(slf.to_object(py))
}
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyObject>> {
let py = unsafe { Python::assume_gil_acquired() };
match slf.keys.as_ref(py).as_bytes().get(slf.idx) {
Some(&b) => {
let res = slf
.reader
.as_ref(py)
.inner
.get(&b)
.map(|s| PyString::new(py, s).into());
slf.idx += 1;
Ok(res)
}
None => Ok(None),
}
}
}
#[test]
fn test_nested_iter() {
let gil = Python::acquire_gil();
let py = gil.python();
let reader = [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")];
let reader = Reader {
inner: reader.iter().map(|(k, v)| (*k, v.to_string())).collect(),
}
.into_object(py);
py_assert!(
py,
reader,
"list(reader.get_iter(bytes([3, 5, 2]))) == ['c', 'e', 'b']"
);
}