diff --git a/pyo3cls/src/func.rs b/pyo3cls/src/func.rs index f8a6f9a6..8c574df5 100644 --- a/pyo3cls/src/func.rs +++ b/pyo3cls/src/func.rs @@ -264,16 +264,38 @@ fn get_arg_ty(sig: &syn::MethodSig, idx: usize) -> syn::Ty { // Success fn get_res_success(ty: &syn::Ty) -> (Tokens, syn::Ty) { - let result; - let mut succ = match ty { + let mut result; + let mut succ; + + match ty { &syn::Ty::Path(_, ref path) => { if let Some(segment) = path.segments.last() { match segment.ident.as_ref() { - // check result type + // check for PyResult "PyResult" => match segment.parameters { syn::PathParameters::AngleBracketed(ref data) => { result = true; - data.types[0].clone() + succ = data.types[0].clone(); + + // check for PyResult> + match data.types[0] { + syn::Ty::Path(_, ref path) => + if let Some(segment) = path.segments.last() { + match segment.ident.as_ref() { + // get T from Option + "Option" => match segment.parameters { + syn::PathParameters::AngleBracketed(ref data) => + { + result = false; + succ = data.types[0].clone(); + }, + _ => (), + }, + _ => (), + } + }, + _ => () + } }, _ => panic!("fn result type is not supported"), }, diff --git a/src/class/iter.rs b/src/class/iter.rs index 75777e07..5052a651 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -11,7 +11,7 @@ use err::PyResult; use python::Python; use token::ToInstancePtr; use typeob::PyTypeInfo; -use callback::PyObjectCallbackConverter; +use callback::{PyObjectCallbackConverter, IterNextResultConverter}; /// Iterator protocol @@ -32,7 +32,7 @@ pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> { pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> { type Success: ::IntoPyObject; - type Result: Into>; + type Result: Into>>; } @@ -91,6 +91,7 @@ impl PyIterNextProtocolImpl for T where T: for<'p> PyIterNextProtocol<'p> + T { #[inline] fn tp_iternext() -> Option { - py_unary_func!(PyIterNextProtocol, T::__next__, T::Success, PyObjectCallbackConverter) + py_unary_func!(PyIterNextProtocol, T::__next__, + Option, IterNextResultConverter) } } diff --git a/tests/test_class.rs b/tests/test_class.rs index d86b4dfa..d7b05be0 100644 --- a/tests/test_class.rs +++ b/tests/test_class.rs @@ -4,7 +4,7 @@ extern crate pyo3; use pyo3::*; -use std::{mem, isize, iter}; +use std::{isize, iter}; use std::cell::RefCell; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -428,27 +428,35 @@ fn len() { py_expect_exception!(py, inst, "len(inst)", OverflowError); } -/*py_class!(class Iterator |py| { - data iter: RefCell + Send>>; +#[py::class] +struct Iterator{ + iter: Box + Send>, + token: PyToken, +} - def __iter__(&self) -> PyResult { - Ok(self.clone_ref(py)) +#[py::ptr(Iterator)] +struct IteratorPtr(PyPtr); + +#[py::proto] +impl PyIterProtocol for Iterator { + fn __iter__(&mut self, py: Python) -> PyResult { + Ok(self.to_inst_ptr()) } - def __next__(&self) -> PyResult> { - Ok(self.iter(py).borrow_mut().next()) + fn __next__(&mut self, py: Python) -> PyResult> { + Ok(self.iter.next()) } -}); +} #[test] fn iterator() { let gil = Python::acquire_gil(); let py = gil.python(); - let inst = Iterator::create_instance(py, RefCell::new(Box::new(5..8))).unwrap(); + let inst = py.init(|t| Iterator{iter: Box::new(5..8), token: t}).unwrap(); py_assert!(py, inst, "iter(inst) is inst"); py_assert!(py, inst, "list(inst) == [5, 6, 7]"); -}*/ +} #[py::class] struct StringMethods {token: PyToken}