Merge pull request #624 from kngwyu/seq-setitem

Fix PySequenceProtocol::set_item
This commit is contained in:
Yuji Kanagawa 2019-10-19 13:44:49 +09:00 committed by GitHub
commit f6f607ef68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 64 deletions

View file

@ -212,60 +212,6 @@ where
}
}
trait PySequenceSetItemProtocolImpl {
fn sq_ass_item() -> Option<ffi::ssizeobjargproc>;
}
impl<'p, T> PySequenceSetItemProtocolImpl for T
where
T: PySequenceProtocol<'p>,
{
default fn sq_ass_item() -> Option<ffi::ssizeobjargproc> {
None
}
}
impl<T> PySequenceSetItemProtocolImpl for T
where
T: for<'p> PySequenceSetItemProtocol<'p>,
{
fn sq_ass_item() -> Option<ffi::ssizeobjargproc> {
unsafe extern "C" fn wrap<T>(
slf: *mut ffi::PyObject,
key: ffi::Py_ssize_t,
value: *mut ffi::PyObject,
) -> c_int
where
T: for<'p> PySequenceSetItemProtocol<'p>,
{
let py = Python::assume_gil_acquired();
let _pool = crate::GILPool::new(py);
let slf = py.mut_from_borrowed_ptr::<T>(slf);
let result = if value.is_null() {
Err(PyErr::new::<exceptions::NotImplementedError, _>(format!(
"Item deletion not supported by {:?}",
stringify!(T)
)))
} else {
let value = py.from_borrowed_ptr::<PyAny>(value);
match value.extract() {
Ok(value) => slf.__setitem__(key.into(), value).into(),
Err(e) => Err(e),
}
};
match result {
Ok(_) => 0,
Err(e) => {
e.restore(py);
-1
}
}
}
Some(wrap::<T>)
}
}
/// It can be possible to delete and set items (PySequenceSetItemProtocol and
/// PySequenceDelItemProtocol implemented), only to delete (PySequenceDelItemProtocol implemented)
/// or no deleting or setting is possible
@ -286,11 +232,68 @@ mod sq_ass_item_impl {
Some(del_set_item)
} else if let Some(del_item) = T::del_item() {
Some(del_item)
} else if let Some(set_item) = T::set_item() {
Some(set_item)
} else {
None
}
}
trait SetItem {
fn set_item() -> Option<ffi::ssizeobjargproc>;
}
impl<'p, T> SetItem for T
where
T: PySequenceProtocol<'p>,
{
default fn set_item() -> Option<ffi::ssizeobjargproc> {
None
}
}
impl<T> SetItem for T
where
T: for<'p> PySequenceSetItemProtocol<'p>,
{
fn set_item() -> Option<ffi::ssizeobjargproc> {
unsafe extern "C" fn wrap<T>(
slf: *mut ffi::PyObject,
key: ffi::Py_ssize_t,
value: *mut ffi::PyObject,
) -> c_int
where
T: for<'p> PySequenceSetItemProtocol<'p>,
{
let py = Python::assume_gil_acquired();
let _pool = crate::GILPool::new(py);
let slf = py.mut_from_borrowed_ptr::<T>(slf);
let result = if value.is_null() {
Err(PyErr::new::<exceptions::NotImplementedError, _>(format!(
"Item deletion is not supported by {:?}",
stringify!(T)
)))
} else {
let value = py.from_borrowed_ptr::<PyAny>(value);
match value.extract() {
Ok(value) => slf.__setitem__(key.into(), value).into(),
Err(e) => Err(e),
}
};
match result {
Ok(_) => 0,
Err(e) => {
e.restore(py);
-1
}
}
}
Some(wrap::<T>)
}
}
trait DelItem {
fn del_item() -> Option<ffi::ssizeobjargproc>;
}

View file

@ -9,6 +9,7 @@ use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::types::{IntoPyDict, PyAny, PyBytes, PySlice, PyType};
use pyo3::AsPyPointer;
use std::convert::TryFrom;
use std::{isize, iter};
mod common;
@ -147,20 +148,44 @@ fn comparisons() {
}
#[pyclass]
struct Sequence {}
#[derive(Debug)]
struct Sequence {
fields: Vec<String>,
}
impl Default for Sequence {
fn default() -> Sequence {
let mut fields = vec![];
for s in &["A", "B", "C", "D", "E", "F", "G"] {
fields.push(s.to_string());
}
Sequence { fields }
}
}
#[pyproto]
impl PySequenceProtocol for Sequence {
fn __len__(&self) -> PyResult<usize> {
Ok(5)
Ok(self.fields.len())
}
fn __getitem__(&self, key: isize) -> PyResult<usize> {
let idx: usize = std::convert::TryFrom::try_from(key)?;
if idx == 5 {
return Err(PyErr::new::<IndexError, _>(()));
fn __getitem__(&self, key: isize) -> PyResult<String> {
let idx = usize::try_from(key)?;
if let Some(s) = self.fields.get(idx) {
Ok(s.clone())
} else {
Err(PyErr::new::<IndexError, _>(()))
}
}
fn __setitem__(&mut self, idx: isize, value: String) -> PyResult<()> {
let idx = usize::try_from(idx)?;
if let Some(elem) = self.fields.get_mut(idx) {
*elem = value;
Ok(())
} else {
Err(PyErr::new::<IndexError, _>(()))
}
Ok(idx)
}
}
@ -169,9 +194,17 @@ fn sequence() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = Py::new(py, Sequence {}).unwrap();
py_assert!(py, c, "list(c) == [0, 1, 2, 3, 4]");
py_assert!(py, c, "c[-1] == 4");
let c = Py::new(py, Sequence::default()).unwrap();
py_assert!(py, c, "list(c) == ['A', 'B', 'C', 'D', 'E', 'F', 'G']");
py_assert!(py, c, "c[-1] == 'G'");
py_run!(
py,
c,
r#"
c[0] = 'H'
assert c[0] == 'H'
"#
);
py_expect_exception!(py, c, "c['abc']", TypeError);
}