pymethods: more tests for magic methods

This commit is contained in:
David Hewitt 2022-02-15 22:51:37 +00:00
parent 9704c862fa
commit 6af47c78f1
4 changed files with 109 additions and 20 deletions

View file

@ -51,6 +51,8 @@ impl PyMethodKind {
"__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)),
"__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)),
"__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)),
"__concat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONCAT__)),
"__repeat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPEAT__)),
"__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)),
"__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)),
"__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)),
@ -602,6 +604,8 @@ const __LEN__: SlotDef = SlotDef::new("Py_mp_length", "lenfunc").ret_ty(Ty::PySs
const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc")
.arguments(&[Ty::Object])
.ret_ty(Ty::Int);
const __CONCAT__: SlotDef = SlotDef::new("Py_sq_concat", "binaryfunc").arguments(&[Ty::Object]);
const __REPEAT__: SlotDef = SlotDef::new("Py_sq_repeat", "ssizeargfunc").arguments(&[Ty::PySsizeT]);
const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]);
const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc");

View file

@ -52,6 +52,39 @@ fn unary_arithmetic() {
py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'");
}
#[pyclass]
struct Indexable(i32);
#[pymethods]
impl Indexable {
fn __index__(&self) -> i32 {
self.0
}
fn __int__(&self) -> i32 {
self.0
}
fn __float__(&self) -> f64 {
f64::from(self.0)
}
fn __invert__(&self) -> Self {
Self(!self.0)
}
}
#[test]
fn indexable() {
Python::with_gil(|py| {
let i = PyCell::new(py, Indexable(5)).unwrap();
py_run!(py, i, "assert int(i) == 5");
py_run!(py, i, "assert [0, 1, 2, 3, 4, 5][i] == 5");
py_run!(py, i, "assert float(i) == 5.0");
py_run!(py, i, "assert int(~i) == -6");
})
}
#[pyclass]
struct InPlaceOperations {
value: u32,

View file

@ -608,6 +608,7 @@ fn getattr_doesnt_override_member() {
/// Wraps a Python future and yield it once.
#[pyclass]
#[derive(Debug)]
struct OnceFuture {
future: PyObject,
polled: bool,
@ -645,24 +646,20 @@ fn test_await() {
let gil = Python::acquire_gil();
let py = gil.python();
let once = py.get_type::<OnceFuture>();
let source = pyo3::indoc::indoc!(
r#"
let source = r#"
import asyncio
import sys
async def main():
res = await Once(await asyncio.sleep(0.1))
return res
assert res is None
# For an odd error similar to https://bugs.python.org/issue38563
if sys.platform == "win32" and sys.version_info >= (3, 8, 0):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# get_event_loop can raise an error: https://github.com/PyO3/pyo3/pull/961#issuecomment-645238579
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
assert loop.run_until_complete(main()) is None
loop.close()
"#
);
asyncio.run(main())
"#;
let globals = PyModule::import(py, "__main__").unwrap().dict();
globals.set_item("Once", once).unwrap();
py.run(source, Some(globals), None)
@ -670,6 +667,62 @@ loop.close()
.unwrap();
}
#[pyclass]
struct AsyncIterator {
future: Option<Py<OnceFuture>>,
}
#[pymethods]
impl AsyncIterator {
#[new]
fn new(future: Py<OnceFuture>) -> Self {
Self {
future: Some(future),
}
}
fn __aiter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __anext__(&mut self) -> Option<Py<OnceFuture>> {
self.future.take()
}
}
#[test]
fn test_anext_aiter() {
let gil = Python::acquire_gil();
let py = gil.python();
let once = py.get_type::<OnceFuture>();
let source = r#"
import asyncio
import sys
async def main():
count = 0
async for result in AsyncIterator(Once(await asyncio.sleep(0.1))):
# The Once is awaited as part of the `async for` and produces None
assert result is None
count +=1
assert count == 1
# For an odd error similar to https://bugs.python.org/issue38563
if sys.platform == "win32" and sys.version_info >= (3, 8, 0):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())
"#;
let globals = PyModule::import(py, "__main__").unwrap().dict();
globals.set_item("Once", once).unwrap();
globals
.set_item("AsyncIterator", py.get_type::<AsyncIterator>())
.unwrap();
py.run(source, Some(globals), None)
.map_err(|e| e.print(py))
.unwrap();
}
/// Increment the count when `__get__` is called.
#[pyclass]
struct DescrCounter {

View file

@ -1,7 +1,5 @@
#![cfg(feature = "macros")]
#![cfg(feature = "pyproto")] // FIXME: change this to use #[pymethods] once supports sequence protocol
use pyo3::class::PySequenceProtocol;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyList};
@ -32,10 +30,7 @@ impl ByteSequence {
})
}
}
}
#[pyproto]
impl PySequenceProtocol for ByteSequence {
fn __len__(&self) -> usize {
self.elements.len()
}
@ -51,8 +46,12 @@ impl PySequenceProtocol for ByteSequence {
self.elements[idx as usize] = value;
}
fn __delitem__(&mut self, idx: isize) -> PyResult<()> {
if (idx < self.elements.len() as isize) && (idx >= 0) {
fn __delitem__(&mut self, mut idx: isize) -> PyResult<()> {
let self_len = self.elements.len() as isize;
if idx < 0 {
idx += self_len;
}
if (idx < self_len) && (idx >= 0) {
self.elements.remove(idx as usize);
Ok(())
} else {
@ -67,7 +66,7 @@ impl PySequenceProtocol for ByteSequence {
}
}
fn __concat__(&self, other: PyRef<'p, Self>) -> Self {
fn __concat__(&self, other: PyRef<Self>) -> Self {
let mut elements = self.elements.clone();
elements.extend_from_slice(&other.elements);
Self { elements }
@ -274,8 +273,8 @@ struct OptionList {
items: Vec<Option<i64>>,
}
#[pyproto]
impl PySequenceProtocol for OptionList {
#[pymethods]
impl OptionList {
fn __getitem__(&self, idx: isize) -> PyResult<Option<i64>> {
match self.items.get(idx as usize) {
Some(x) => Ok(*x),