use pyo3::class::{ PyAsyncProtocol, PyContextProtocol, PyDescrProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol, PySequenceProtocol, }; use pyo3::exceptions::{IndexError, ValueError}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyBytes, PySlice, PyType}; use pyo3::{ffi, py_run, AsPyPointer, PyCell}; use std::convert::TryFrom; use std::{isize, iter}; mod common; #[pyclass] pub struct Len { l: usize, } #[pyproto] impl PyMappingProtocol for Len { fn __len__(&self) -> PyResult { Ok(self.l) } } #[test] fn len() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = Py::new(py, Len { l: 10 }).unwrap(); py_assert!(py, inst, "len(inst) == 10"); unsafe { assert_eq!(ffi::PyObject_Size(inst.as_ptr()), 10); assert_eq!(ffi::PyMapping_Size(inst.as_ptr()), 10); } let inst = Py::new( py, Len { l: (isize::MAX as usize) + 1, }, ) .unwrap(); py_expect_exception!(py, inst, "len(inst)", OverflowError); } #[pyclass] struct Iterator { iter: Box + Send>, } #[pyproto] impl<'p> PyIterProtocol for Iterator { fn __iter__(slf: PyRef<'p, Self>) -> PyResult> { Ok(slf.into()) } fn __next__(mut slf: PyRefMut<'p, Self>) -> PyResult> { Ok(slf.iter.next()) } } #[test] fn iterator() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = Py::new( py, Iterator { iter: Box::new(5..8), }, ) .unwrap(); py_assert!(py, inst, "iter(inst) is inst"); py_assert!(py, inst, "list(inst) == [5, 6, 7]"); } #[pyclass] struct StringMethods {} #[pyproto] impl<'p> PyObjectProtocol<'p> for StringMethods { fn __str__(&self) -> PyResult<&'static str> { Ok("str") } fn __repr__(&self) -> PyResult<&'static str> { Ok("repr") } fn __format__(&self, format_spec: String) -> PyResult { Ok(format!("format({})", format_spec)) } fn __bytes__(&self) -> PyResult { let gil = GILGuard::acquire(); Ok(PyBytes::new(gil.python(), b"bytes").into()) } } #[test] fn string_methods() { let gil = Python::acquire_gil(); let py = gil.python(); let obj = Py::new(py, StringMethods {}).unwrap(); py_assert!(py, obj, "str(obj) == 'str'"); py_assert!(py, obj, "repr(obj) == 'repr'"); py_assert!(py, obj, "'{0:x}'.format(obj) == 'format(x)'"); py_assert!(py, obj, "bytes(obj) == b'bytes'"); } #[pyclass] struct Comparisons { val: i32, } #[pyproto] impl PyObjectProtocol for Comparisons { fn __hash__(&self) -> PyResult { Ok(self.val as isize) } fn __bool__(&self) -> PyResult { Ok(self.val != 0) } } #[test] fn comparisons() { let gil = Python::acquire_gil(); let py = gil.python(); let zero = Py::new(py, Comparisons { val: 0 }).unwrap(); let one = Py::new(py, Comparisons { val: 1 }).unwrap(); let ten = Py::new(py, Comparisons { val: 10 }).unwrap(); let minus_one = Py::new(py, Comparisons { val: -1 }).unwrap(); py_assert!(py, one, "hash(one) == 1"); py_assert!(py, ten, "hash(ten) == 10"); py_assert!(py, minus_one, "hash(minus_one) == -2"); py_assert!(py, one, "bool(one) is True"); py_assert!(py, zero, "not zero"); } #[pyclass] #[derive(Debug)] struct Sequence { fields: Vec, } impl Default for Sequence { fn default() -> Sequence { let mut fields = vec![]; for &s in &["A", "B", "C", "D", "E", "F", "G"] { fields.push(s.to_string()); } Sequence { fields } } } #[pyproto] impl PySequenceProtocol for Sequence { fn __len__(&self) -> PyResult { Ok(self.fields.len()) } fn __getitem__(&self, key: isize) -> PyResult { let idx = usize::try_from(key)?; if let Some(s) = self.fields.get(idx) { Ok(s.clone()) } else { Err(PyErr::new::(())) } } fn __setitem__(&mut self, idx: isize, value: String) -> PyResult<()> { let idx = usize::try_from(idx)?; if let Some(elem) = self.fields.get_mut(idx) { *elem = value; Ok(()) } else { Err(PyErr::new::(())) } } } #[test] fn sequence() { let gil = Python::acquire_gil(); let py = gil.python(); let c = Py::new(py, Sequence::default()).unwrap(); py_assert!(py, c, "list(c) == ['A', 'B', 'C', 'D', 'E', 'F', 'G']"); py_assert!(py, c, "c[-1] == 'G'"); py_run!( py, c, r#" c[0] = 'H' assert c[0] == 'H' "# ); py_expect_exception!(py, c, "c['abc']", TypeError); } #[pyclass] struct Callable {} #[pymethods] impl Callable { #[__call__] fn __call__(&self, arg: i32) -> PyResult { Ok(arg * 6) } } #[test] fn callable() { let gil = Python::acquire_gil(); let py = gil.python(); let c = Py::new(py, Callable {}).unwrap(); py_assert!(py, c, "callable(c)"); py_assert!(py, c, "c(7) == 42"); let nc = Py::new(py, Comparisons { val: 0 }).unwrap(); py_assert!(py, nc, "not callable(nc)"); } #[pyclass] #[derive(Debug)] struct SetItem { key: i32, val: i32, } #[pyproto] impl PyMappingProtocol<'a> for SetItem { fn __setitem__(&mut self, key: i32, val: i32) -> PyResult<()> { self.key = key; self.val = val; Ok(()) } } #[test] fn setitem() { let gil = Python::acquire_gil(); let py = gil.python(); let c = PyCell::new(py, SetItem { key: 0, val: 0 }).unwrap(); py_run!(py, c, "c[1] = 2"); { let c = c.borrow(); assert_eq!(c.key, 1); assert_eq!(c.val, 2); } py_expect_exception!(py, c, "del c[1]", NotImplementedError); } #[pyclass] struct DelItem { key: i32, } #[pyproto] impl PyMappingProtocol<'a> for DelItem { fn __delitem__(&mut self, key: i32) -> PyResult<()> { self.key = key; Ok(()) } } #[test] fn delitem() { let gil = Python::acquire_gil(); let py = gil.python(); let c = PyCell::new(py, DelItem { key: 0 }).unwrap(); py_run!(py, c, "del c[1]"); { let c = c.borrow(); assert_eq!(c.key, 1); } py_expect_exception!(py, c, "c[1] = 2", NotImplementedError); } #[pyclass] struct SetDelItem { val: Option, } #[pyproto] impl PyMappingProtocol for SetDelItem { fn __setitem__(&mut self, _key: i32, val: i32) -> PyResult<()> { self.val = Some(val); Ok(()) } fn __delitem__(&mut self, _key: i32) -> PyResult<()> { self.val = None; Ok(()) } } #[test] fn setdelitem() { let gil = Python::acquire_gil(); let py = gil.python(); let c = PyCell::new(py, SetDelItem { val: None }).unwrap(); py_run!(py, c, "c[1] = 2"); { let c = c.borrow(); assert_eq!(c.val, Some(2)); } py_run!(py, c, "del c[1]"); let c = c.borrow(); assert_eq!(c.val, None); } #[pyclass] struct Reversed {} #[pyproto] impl PyMappingProtocol for Reversed { fn __reversed__(&self) -> PyResult<&'static str> { Ok("I am reversed") } } #[test] fn reversed() { let gil = Python::acquire_gil(); let py = gil.python(); let c = Py::new(py, Reversed {}).unwrap(); py_run!(py, c, "assert reversed(c) == 'I am reversed'"); } #[pyclass] struct Contains {} #[pyproto] impl PySequenceProtocol for Contains { fn __contains__(&self, item: i32) -> PyResult { Ok(item >= 0) } } #[test] fn contains() { let gil = Python::acquire_gil(); let py = gil.python(); let c = Py::new(py, Contains {}).unwrap(); py_run!(py, c, "assert 1 in c"); py_run!(py, c, "assert -1 not in c"); py_expect_exception!(py, c, "assert 'wrong type' not in c", TypeError); } #[pyclass] struct ContextManager { exit_called: bool, } #[pyproto] impl<'p> PyContextProtocol<'p> for ContextManager { fn __enter__(&mut self) -> PyResult { Ok(42) } fn __exit__( &mut self, ty: Option<&'p PyType>, _value: Option<&'p PyAny>, _traceback: Option<&'p PyAny>, ) -> PyResult { let gil = GILGuard::acquire(); self.exit_called = true; if ty == Some(gil.python().get_type::()) { Ok(true) } else { Ok(false) } } } #[test] fn context_manager() { let gil = Python::acquire_gil(); let py = gil.python(); let c = PyCell::new(py, ContextManager { exit_called: false }).unwrap(); py_run!(py, c, "with c as x: assert x == 42"); { let mut c = c.borrow_mut(); assert!(c.exit_called); c.exit_called = false; } py_run!(py, c, "with c as x: raise ValueError"); { let mut c = c.borrow_mut(); assert!(c.exit_called); c.exit_called = false; } py_expect_exception!( py, c, "with c as x: raise NotImplementedError", NotImplementedError ); let c = c.borrow(); assert!(c.exit_called); } #[test] fn test_basics() { let gil = Python::acquire_gil(); let py = gil.python(); let v = PySlice::new(py, 1, 10, 2); let indices = v.indices(100).unwrap(); assert_eq!(1, indices.start); assert_eq!(10, indices.stop); assert_eq!(2, indices.step); assert_eq!(5, indices.slicelength); } #[pyclass] struct Test {} #[pyproto] impl<'p> PyMappingProtocol<'p> for Test { fn __getitem__(&self, idx: &PyAny) -> PyResult { let gil = GILGuard::acquire(); if let Ok(slice) = idx.cast_as::() { let indices = slice.indices(1000)?; if indices.start == 100 && indices.stop == 200 && indices.step == 1 { return Ok("slice".into_py(gil.python())); } } else if let Ok(idx) = idx.extract::() { if idx == 1 { return Ok("int".into_py(gil.python())); } } Err(PyErr::new::("error")) } } #[test] fn test_cls_impl() { let gil = Python::acquire_gil(); let py = gil.python(); let ob = Py::new(py, Test {}).unwrap(); let d = [("ob", ob)].into_py_dict(py); py.run("assert ob[1] == 'int'", None, Some(d)).unwrap(); py.run("assert ob[100:200:1] == 'slice'", None, Some(d)) .unwrap(); } #[pyclass(dict)] struct DunderDictSupport {} #[test] fn dunder_dict_support() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = PyCell::new(py, DunderDictSupport {}).unwrap(); py_run!( py, inst, r#" inst.a = 1 assert inst.a == 1 "# ); } #[test] fn access_dunder_dict() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = PyCell::new(py, DunderDictSupport {}).unwrap(); py_run!( py, inst, r#" inst.a = 1 assert inst.__dict__ == {'a': 1} "# ); } // If the base class has dict support, child class also has dict #[pyclass(extends=DunderDictSupport)] struct InheritDict { _value: usize, } #[test] fn inherited_dict() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = PyCell::new(py, (InheritDict { _value: 0 }, DunderDictSupport {})).unwrap(); py_run!( py, inst, r#" inst.a = 1 assert inst.__dict__ == {'a': 1} "# ); } #[pyclass(weakref, dict)] struct WeakRefDunderDictSupport {} #[test] fn weakref_dunder_dict_support() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = PyCell::new(py, WeakRefDunderDictSupport {}).unwrap(); py_run!( py, inst, "import weakref; assert weakref.ref(inst)() is inst; inst.a = 1; assert inst.a == 1" ); } #[pyclass] struct ClassWithGetAttr { #[pyo3(get, set)] data: u32, } #[pyproto] impl PyObjectProtocol for ClassWithGetAttr { fn __getattr__(&self, _name: &str) -> PyResult { Ok(self.data * 2) } } #[test] fn getattr_doesnt_override_member() { let gil = Python::acquire_gil(); let py = gil.python(); let inst = PyCell::new(py, ClassWithGetAttr { data: 4 }).unwrap(); py_assert!(py, inst, "inst.data == 4"); py_assert!(py, inst, "inst.a == 8"); } /// Wraps a Python future and yield it once. #[pyclass] struct OnceFuture { future: PyObject, polled: bool, } #[pymethods] impl OnceFuture { #[new] fn new(future: PyObject) -> Self { OnceFuture { future, polled: false, } } } #[pyproto] impl PyAsyncProtocol for OnceFuture { fn __await__(slf: PyRef<'p, Self>) -> PyResult> { Ok(slf) } } #[pyproto] impl PyIterProtocol for OnceFuture { fn __iter__(slf: PyRef<'p, Self>) -> PyResult> { Ok(slf) } fn __next__(mut slf: PyRefMut) -> PyResult> { if !slf.polled { slf.polled = true; Ok(Some(slf.future.clone())) } else { Ok(None) } } } #[test] fn test_await() { let gil = Python::acquire_gil(); let py = gil.python(); let once = py.get_type::(); let source = pyo3::indoc::indoc!( r#" import asyncio import sys async def main(): res = await Once(await asyncio.sleep(0.1)) return res # For an odd error similar to https://bugs.python.org/issue38563 if sys.platform == "win32" and sys.version_info >= (3, 8, 0): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) assert loop.run_until_complete(main()) is None loop.close() "# ); let globals = PyModule::import(py, "__main__").unwrap().dict(); globals.set_item("Once", once).unwrap(); py.run(source, Some(globals), None) .map_err(|e| e.print(py)) .unwrap(); } /// Increment the count when `__get__` is called. #[pyclass] struct DescrCounter { #[pyo3(get)] count: usize, } #[pymethods] impl DescrCounter { #[new] fn new() -> Self { DescrCounter { count: 0 } } } #[pyproto] impl PyDescrProtocol for DescrCounter { fn __get__( mut slf: PyRefMut<'p, Self>, _instance: &PyAny, _owner: Option<&'p PyType>, ) -> PyResult> { slf.count += 1; Ok(slf) } fn __set__( _slf: PyRef<'p, Self>, _instance: &PyAny, mut new_value: PyRefMut<'p, Self>, ) -> PyResult<()> { new_value.count = _slf.count; Ok(()) } } #[test] fn descr_getset() { let gil = Python::acquire_gil(); let py = gil.python(); let counter = py.get_type::(); let source = pyo3::indoc::indoc!( r#" class Class: counter = Counter() c = Class() c.counter # count += 1 assert c.counter.count == 2 c.counter = Counter() assert c.counter.count == 3 "# ); let globals = PyModule::import(py, "__main__").unwrap().dict(); globals.set_item("Counter", counter).unwrap(); py.run(source, Some(globals), None) .map_err(|e| e.print(py)) .unwrap(); }