Typed PyBuffer
This commit is contained in:
parent
75b2b62dd9
commit
688021315e
|
@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
### Changed
|
||||
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
|
||||
- Call `Py_Finalize` at exit to flush buffers, etc. [#943](https://github.com/PyO3/pyo3/pull/943)
|
||||
- PyBuffer is now typed. #[951](https://github.com/PyO3/pyo3/pull/951)
|
||||
|
||||
### Removed
|
||||
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)
|
||||
|
|
175
src/buffer.rs
175
src/buffer.rs
|
@ -18,8 +18,9 @@
|
|||
|
||||
//! `PyBuffer` implementation
|
||||
use crate::err::{self, PyResult};
|
||||
use crate::{exceptions, ffi, AsPyPointer, PyAny, Python};
|
||||
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
|
||||
use std::ffi::CStr;
|
||||
use std::marker::PhantomData;
|
||||
use std::os::raw;
|
||||
use std::pin::Pin;
|
||||
use std::{cell, mem, ptr, slice};
|
||||
|
@ -27,12 +28,12 @@ use std::{cell, mem, ptr, slice};
|
|||
/// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`.
|
||||
// use Pin<Box> because Python expects that the Py_buffer struct has a stable memory address
|
||||
#[repr(transparent)]
|
||||
pub struct PyBuffer(Pin<Box<ffi::Py_buffer>>);
|
||||
pub struct PyBuffer<T>(Pin<Box<ffi::Py_buffer>>, PhantomData<T>);
|
||||
|
||||
// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
|
||||
// Accessing the buffer contents is protected using the GIL.
|
||||
unsafe impl Send for PyBuffer {}
|
||||
unsafe impl Sync for PyBuffer {}
|
||||
unsafe impl<T> Send for PyBuffer<T> {}
|
||||
unsafe impl<T> Sync for PyBuffer<T> {}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum ElementType {
|
||||
|
@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool {
|
|||
}
|
||||
|
||||
/// Trait implemented for possible element types of `PyBuffer`.
|
||||
pub unsafe trait Element {
|
||||
pub unsafe trait Element: Copy {
|
||||
/// Gets whether the element specified in the format string is potentially compatible.
|
||||
/// Alignment and size are checked separately from this function.
|
||||
fn is_compatible_format(format: &CStr) -> bool;
|
||||
}
|
||||
|
||||
fn validate(b: &ffi::Py_buffer) {
|
||||
fn validate(b: &ffi::Py_buffer) -> PyResult<()> {
|
||||
// shape and stride information must be provided when we use PyBUF_FULL_RO
|
||||
assert!(!b.shape.is_null());
|
||||
assert!(!b.strides.is_null());
|
||||
if b.shape.is_null() {
|
||||
return Err(exceptions::BufferError::py_err("Shape is Null"));
|
||||
}
|
||||
if b.strides.is_null() {
|
||||
return Err(exceptions::BufferError::py_err("PyBuffer: Strides is Null"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl PyBuffer {
|
||||
impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
|
||||
fn extract(obj: &PyAny) -> PyResult<PyBuffer<T>> {
|
||||
Self::get(obj)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Element> PyBuffer<T> {
|
||||
/// Get the underlying buffer from the specified python object.
|
||||
pub fn get(py: Python, obj: &PyAny) -> PyResult<PyBuffer> {
|
||||
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
|
||||
unsafe {
|
||||
let mut buf = Box::pin(mem::zeroed::<ffi::Py_buffer>());
|
||||
let mut buf = Box::pin(ffi::Py_buffer::new());
|
||||
err::error_on_minusone(
|
||||
py,
|
||||
obj.py(),
|
||||
ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO),
|
||||
)?;
|
||||
validate(&buf);
|
||||
Ok(PyBuffer(buf))
|
||||
validate(&buf)?;
|
||||
let buf = PyBuffer(buf, PhantomData);
|
||||
// Type Check
|
||||
if mem::size_of::<T>() == buf.item_size()
|
||||
&& (buf.0.buf as usize) % mem::align_of::<T>() == 0
|
||||
&& T::is_compatible_format(buf.format())
|
||||
{
|
||||
Ok(buf)
|
||||
} else {
|
||||
Err(exceptions::BufferError::py_err(
|
||||
"Incompatible type as buffer",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -307,12 +330,8 @@ impl PyBuffer {
|
|||
///
|
||||
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
|
||||
/// to modify the values in the slice.
|
||||
pub fn as_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
|
||||
if mem::size_of::<T>() == self.item_size()
|
||||
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
|
||||
&& self.is_c_contiguous()
|
||||
&& T::is_compatible_format(self.format())
|
||||
{
|
||||
pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
|
||||
if self.is_c_contiguous() {
|
||||
unsafe {
|
||||
Some(slice::from_raw_parts(
|
||||
self.0.buf as *mut ReadOnlyCell<T>,
|
||||
|
@ -334,13 +353,8 @@ impl PyBuffer {
|
|||
///
|
||||
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
|
||||
/// to modify the values in the slice.
|
||||
pub fn as_mut_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
|
||||
if !self.readonly()
|
||||
&& mem::size_of::<T>() == self.item_size()
|
||||
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
|
||||
&& self.is_c_contiguous()
|
||||
&& T::is_compatible_format(self.format())
|
||||
{
|
||||
pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
|
||||
if !self.readonly() && self.is_c_contiguous() {
|
||||
unsafe {
|
||||
Some(slice::from_raw_parts(
|
||||
self.0.buf as *mut cell::Cell<T>,
|
||||
|
@ -361,15 +375,8 @@ impl PyBuffer {
|
|||
///
|
||||
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
|
||||
/// to modify the values in the slice.
|
||||
pub fn as_fortran_slice<'a, T: Element>(
|
||||
&'a self,
|
||||
_py: Python<'a>,
|
||||
) -> Option<&'a [ReadOnlyCell<T>]> {
|
||||
if mem::size_of::<T>() == self.item_size()
|
||||
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
|
||||
&& self.is_fortran_contiguous()
|
||||
&& T::is_compatible_format(self.format())
|
||||
{
|
||||
pub fn as_fortran_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
|
||||
if mem::size_of::<T>() == self.item_size() && self.is_fortran_contiguous() {
|
||||
unsafe {
|
||||
Some(slice::from_raw_parts(
|
||||
self.0.buf as *mut ReadOnlyCell<T>,
|
||||
|
@ -391,16 +398,8 @@ impl PyBuffer {
|
|||
///
|
||||
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
|
||||
/// to modify the values in the slice.
|
||||
pub fn as_fortran_mut_slice<'a, T: Element>(
|
||||
&'a self,
|
||||
_py: Python<'a>,
|
||||
) -> Option<&'a [cell::Cell<T>]> {
|
||||
if !self.readonly()
|
||||
&& mem::size_of::<T>() == self.item_size()
|
||||
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
|
||||
&& self.is_fortran_contiguous()
|
||||
&& T::is_compatible_format(self.format())
|
||||
{
|
||||
pub fn as_fortran_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
|
||||
if !self.readonly() && self.is_fortran_contiguous() {
|
||||
unsafe {
|
||||
Some(slice::from_raw_parts(
|
||||
self.0.buf as *mut cell::Cell<T>,
|
||||
|
@ -421,7 +420,7 @@ impl PyBuffer {
|
|||
/// To check whether the buffer format is compatible before calling this method,
|
||||
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
|
||||
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
|
||||
pub fn copy_to_slice<T: Element + Copy>(&self, py: Python, target: &mut [T]) -> PyResult<()> {
|
||||
pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
|
||||
self.copy_to_slice_impl(py, target, b'C')
|
||||
}
|
||||
|
||||
|
@ -434,28 +433,16 @@ impl PyBuffer {
|
|||
/// To check whether the buffer format is compatible before calling this method,
|
||||
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
|
||||
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
|
||||
pub fn copy_to_fortran_slice<T: Element + Copy>(
|
||||
&self,
|
||||
py: Python,
|
||||
target: &mut [T],
|
||||
) -> PyResult<()> {
|
||||
pub fn copy_to_fortran_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
|
||||
self.copy_to_slice_impl(py, target, b'F')
|
||||
}
|
||||
|
||||
fn copy_to_slice_impl<T: Element + Copy>(
|
||||
&self,
|
||||
py: Python,
|
||||
target: &mut [T],
|
||||
fort: u8,
|
||||
) -> PyResult<()> {
|
||||
fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
|
||||
if mem::size_of_val(target) != self.len_bytes() {
|
||||
return Err(exceptions::BufferError::py_err(
|
||||
"Slice length does not match buffer length.",
|
||||
));
|
||||
}
|
||||
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
|
||||
return incompatible_format_error();
|
||||
}
|
||||
unsafe {
|
||||
err::error_on_minusone(
|
||||
py,
|
||||
|
@ -473,7 +460,7 @@ impl PyBuffer {
|
|||
/// If the buffer is multi-dimensional, the elements are written in C-style order.
|
||||
///
|
||||
/// Fails if the buffer format is not compatible with type `T`.
|
||||
pub fn to_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
|
||||
pub fn to_vec(&self, py: Python) -> PyResult<Vec<T>> {
|
||||
self.to_vec_impl(py, b'C')
|
||||
}
|
||||
|
||||
|
@ -481,15 +468,11 @@ impl PyBuffer {
|
|||
/// If the buffer is multi-dimensional, the elements are written in Fortran-style order.
|
||||
///
|
||||
/// Fails if the buffer format is not compatible with type `T`.
|
||||
pub fn to_fortran_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
|
||||
pub fn to_fortran_vec(&self, py: Python) -> PyResult<Vec<T>> {
|
||||
self.to_vec_impl(py, b'F')
|
||||
}
|
||||
|
||||
fn to_vec_impl<T: Element + Copy>(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
|
||||
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
|
||||
incompatible_format_error()?;
|
||||
unreachable!();
|
||||
}
|
||||
fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
|
||||
let item_count = self.item_count();
|
||||
let mut vec: Vec<T> = Vec::with_capacity(item_count);
|
||||
unsafe {
|
||||
|
@ -520,7 +503,7 @@ impl PyBuffer {
|
|||
/// To check whether the buffer format is compatible before calling this method,
|
||||
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
|
||||
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
|
||||
pub fn copy_from_slice<T: Element + Copy>(&self, py: Python, source: &[T]) -> PyResult<()> {
|
||||
pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
|
||||
self.copy_from_slice_impl(py, source, b'C')
|
||||
}
|
||||
|
||||
|
@ -534,20 +517,11 @@ impl PyBuffer {
|
|||
/// To check whether the buffer format is compatible before calling this method,
|
||||
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
|
||||
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
|
||||
pub fn copy_from_fortran_slice<T: Element + Copy>(
|
||||
&self,
|
||||
py: Python,
|
||||
source: &[T],
|
||||
) -> PyResult<()> {
|
||||
pub fn copy_from_fortran_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
|
||||
self.copy_from_slice_impl(py, source, b'F')
|
||||
}
|
||||
|
||||
fn copy_from_slice_impl<T: Element + Copy>(
|
||||
&self,
|
||||
py: Python,
|
||||
source: &[T],
|
||||
fort: u8,
|
||||
) -> PyResult<()> {
|
||||
fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> {
|
||||
if self.readonly() {
|
||||
return buffer_readonly_error();
|
||||
}
|
||||
|
@ -556,9 +530,6 @@ impl PyBuffer {
|
|||
"Slice length does not match buffer length.",
|
||||
));
|
||||
}
|
||||
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
|
||||
return incompatible_format_error();
|
||||
}
|
||||
unsafe {
|
||||
err::error_on_minusone(
|
||||
py,
|
||||
|
@ -589,19 +560,14 @@ impl PyBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
fn incompatible_format_error() -> PyResult<()> {
|
||||
Err(exceptions::BufferError::py_err(
|
||||
"Slice type is incompatible with buffer format.",
|
||||
))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn buffer_readonly_error() -> PyResult<()> {
|
||||
Err(exceptions::BufferError::py_err(
|
||||
"Cannot write to read-only buffer.",
|
||||
))
|
||||
}
|
||||
|
||||
impl Drop for PyBuffer {
|
||||
impl<T> Drop for PyBuffer<T> {
|
||||
fn drop(&mut self) {
|
||||
let _gil_guard = Python::acquire_gil();
|
||||
unsafe { ffi::PyBuffer_Release(&mut *self.0) }
|
||||
|
@ -614,9 +580,9 @@ impl Drop for PyBuffer {
|
|||
/// The data cannot be modified through the reference, but other references may
|
||||
/// be modifying the data.
|
||||
#[repr(transparent)]
|
||||
pub struct ReadOnlyCell<T>(cell::UnsafeCell<T>);
|
||||
pub struct ReadOnlyCell<T: Element>(cell::UnsafeCell<T>);
|
||||
|
||||
impl<T: Copy> ReadOnlyCell<T> {
|
||||
impl<T: Element> ReadOnlyCell<T> {
|
||||
#[inline]
|
||||
pub fn get(&self) -> T {
|
||||
unsafe { *self.0.get() }
|
||||
|
@ -675,7 +641,7 @@ mod test {
|
|||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let bytes = py.eval("b'abcde'", None, None).unwrap();
|
||||
let buffer = PyBuffer::get(py, &bytes).unwrap();
|
||||
let buffer = PyBuffer::get(&bytes).unwrap();
|
||||
assert_eq!(buffer.dimensions(), 1);
|
||||
assert_eq!(buffer.item_count(), 5);
|
||||
assert_eq!(buffer.format().to_str().unwrap(), "B");
|
||||
|
@ -684,26 +650,18 @@ mod test {
|
|||
assert!(buffer.is_c_contiguous());
|
||||
assert!(buffer.is_fortran_contiguous());
|
||||
|
||||
assert!(buffer.as_slice::<f64>(py).is_none());
|
||||
assert!(buffer.as_slice::<i8>(py).is_none());
|
||||
|
||||
let slice = buffer.as_slice::<u8>(py).unwrap();
|
||||
let slice = buffer.as_slice(py).unwrap();
|
||||
assert_eq!(slice.len(), 5);
|
||||
assert_eq!(slice[0].get(), b'a');
|
||||
assert_eq!(slice[2].get(), b'c');
|
||||
|
||||
assert!(buffer.as_mut_slice::<u8>(py).is_none());
|
||||
|
||||
assert!(buffer.copy_to_slice(py, &mut [0u8]).is_err());
|
||||
let mut arr = [0; 5];
|
||||
buffer.copy_to_slice(py, &mut arr).unwrap();
|
||||
assert_eq!(arr, b"abcde" as &[u8]);
|
||||
|
||||
assert!(buffer.copy_from_slice(py, &[0u8; 5]).is_err());
|
||||
|
||||
assert!(buffer.to_vec::<i8>(py).is_err());
|
||||
assert!(buffer.to_vec::<u16>(py).is_err());
|
||||
assert_eq!(buffer.to_vec::<u8>(py).unwrap(), b"abcde");
|
||||
assert_eq!(buffer.to_vec(py).unwrap(), b"abcde");
|
||||
}
|
||||
|
||||
#[allow(clippy::float_cmp)] // The test wants to ensure that no precision was lost on the Python round-trip
|
||||
|
@ -716,21 +674,18 @@ mod test {
|
|||
.unwrap()
|
||||
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
|
||||
.unwrap();
|
||||
let buffer = PyBuffer::get(py, array).unwrap();
|
||||
let buffer = PyBuffer::get(array).unwrap();
|
||||
assert_eq!(buffer.dimensions(), 1);
|
||||
assert_eq!(buffer.item_count(), 4);
|
||||
assert_eq!(buffer.format().to_str().unwrap(), "f");
|
||||
assert_eq!(buffer.shape(), [4]);
|
||||
|
||||
assert!(buffer.as_slice::<f64>(py).is_none());
|
||||
assert!(buffer.as_slice::<i32>(py).is_none());
|
||||
|
||||
let slice = buffer.as_slice::<f32>(py).unwrap();
|
||||
let slice = buffer.as_slice(py).unwrap();
|
||||
assert_eq!(slice.len(), 4);
|
||||
assert_eq!(slice[0].get(), 1.0);
|
||||
assert_eq!(slice[3].get(), 2.5);
|
||||
|
||||
let mut_slice = buffer.as_mut_slice::<f32>(py).unwrap();
|
||||
let mut_slice = buffer.as_mut_slice(py).unwrap();
|
||||
assert_eq!(mut_slice.len(), 4);
|
||||
assert_eq!(mut_slice[0].get(), 1.0);
|
||||
mut_slice[3].set(2.75);
|
||||
|
@ -741,6 +696,6 @@ mod test {
|
|||
.unwrap();
|
||||
assert_eq!(slice[2].get(), 12.0);
|
||||
|
||||
assert_eq!(buffer.to_vec::<f32>(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
|
||||
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -133,8 +133,8 @@ pub type objobjargproc =
|
|||
#[cfg(not(Py_LIMITED_API))]
|
||||
mod bufferinfo {
|
||||
use crate::ffi::pyport::Py_ssize_t;
|
||||
use std::mem;
|
||||
use std::os::raw::{c_char, c_int, c_void};
|
||||
use std::ptr;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
|
@ -152,10 +152,21 @@ mod bufferinfo {
|
|||
pub internal: *mut c_void,
|
||||
}
|
||||
|
||||
impl Default for Py_buffer {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
unsafe { mem::zeroed() }
|
||||
impl Py_buffer {
|
||||
pub const fn new() -> Self {
|
||||
Py_buffer {
|
||||
buf: ptr::null_mut(),
|
||||
obj: ptr::null_mut(),
|
||||
len: 0,
|
||||
itemsize: 0,
|
||||
readonly: 0,
|
||||
ndim: 0,
|
||||
format: ptr::null_mut(),
|
||||
shape: ptr::null_mut(),
|
||||
strides: ptr::null_mut(),
|
||||
suboffsets: ptr::null_mut(),
|
||||
internal: ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -279,7 +279,7 @@ macro_rules! array_impls {
|
|||
fn extract(obj: &'source PyAny) -> PyResult<Self> {
|
||||
let mut array = [T::default(); $N];
|
||||
// first try buffer protocol
|
||||
if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) {
|
||||
if let Ok(buf) = buffer::PyBuffer::get(obj) {
|
||||
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
|
||||
buf.release(obj.py());
|
||||
return Ok(array);
|
||||
|
@ -315,9 +315,9 @@ where
|
|||
{
|
||||
fn extract(obj: &'source PyAny) -> PyResult<Self> {
|
||||
// first try buffer protocol
|
||||
if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) {
|
||||
if let Ok(buf) = buffer::PyBuffer::get(obj) {
|
||||
if buf.dimensions() == 1 {
|
||||
if let Ok(v) = buf.to_vec::<T>(obj.py()) {
|
||||
if let Ok(v) = buf.to_vec(obj.py()) {
|
||||
buf.release(obj.py());
|
||||
return Ok(v);
|
||||
}
|
||||
|
|
|
@ -114,8 +114,8 @@ fn test_buffer_referenced() {
|
|||
}
|
||||
.into_py(py);
|
||||
|
||||
let buf = PyBuffer::get(py, instance.as_ref(py)).unwrap();
|
||||
assert_eq!(buf.to_vec::<u8>(py).unwrap(), input);
|
||||
let buf = PyBuffer::<u8>::get(instance.as_ref(py)).unwrap();
|
||||
assert_eq!(buf.to_vec(py).unwrap(), input);
|
||||
drop(instance);
|
||||
buf
|
||||
};
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use pyo3::buffer::PyBuffer;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
|
||||
|
@ -20,3 +21,45 @@ fn test_optional_bool() {
|
|||
py_assert!(py, f, "f(False) == 'Some(false)'");
|
||||
py_assert!(py, f, "f(None) == 'None'");
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn buffer_inplace_add(py: Python, x: PyBuffer<i32>, y: PyBuffer<i32>) {
|
||||
let x = x.as_mut_slice(py).unwrap();
|
||||
let y = y.as_slice(py).unwrap();
|
||||
for (xi, yi) in x.iter().zip(y) {
|
||||
let xi_plus_yi = xi.get() + yi.get();
|
||||
xi.set(xi_plus_yi);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_buffer_add() {
|
||||
// Regression test for issue #932
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let f = wrap_pyfunction!(buffer_inplace_add)(py);
|
||||
|
||||
py_expect_exception!(
|
||||
py,
|
||||
f,
|
||||
r#"
|
||||
import array
|
||||
a = array.array("i", [0, 1, 2, 3])
|
||||
b = array.array("I", [0, 1, 2, 3])
|
||||
f(a, b)
|
||||
"#,
|
||||
BufferError
|
||||
);
|
||||
|
||||
pyo3::py_run!(
|
||||
py,
|
||||
f,
|
||||
r#"
|
||||
import array
|
||||
a = array.array("i", [0, 1, 2, 3])
|
||||
b = array.array("i", [2, 3, 4, 5])
|
||||
f(a, b)
|
||||
assert a, array.array("i", [2, 4, 6, 8])
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue