Typed PyBuffer

This commit is contained in:
kngwyu 2020-06-04 22:00:47 +09:00
parent 75b2b62dd9
commit 688021315e
6 changed files with 130 additions and 120 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Changed
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
- Call `Py_Finalize` at exit to flush buffers, etc. [#943](https://github.com/PyO3/pyo3/pull/943)
- PyBuffer is now typed. #[951](https://github.com/PyO3/pyo3/pull/951)
### Removed
- Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930)

View File

@ -18,8 +18,9 @@
//! `PyBuffer` implementation
use crate::err::{self, PyResult};
use crate::{exceptions, ffi, AsPyPointer, PyAny, Python};
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw;
use std::pin::Pin;
use std::{cell, mem, ptr, slice};
@ -27,12 +28,12 @@ use std::{cell, mem, ptr, slice};
/// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`.
// use Pin<Box> because Python expects that the Py_buffer struct has a stable memory address
#[repr(transparent)]
pub struct PyBuffer(Pin<Box<ffi::Py_buffer>>);
pub struct PyBuffer<T>(Pin<Box<ffi::Py_buffer>>, PhantomData<T>);
// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
// Accessing the buffer contents is protected using the GIL.
unsafe impl Send for PyBuffer {}
unsafe impl Sync for PyBuffer {}
unsafe impl<T> Send for PyBuffer<T> {}
unsafe impl<T> Sync for PyBuffer<T> {}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum ElementType {
@ -146,29 +147,51 @@ fn is_matching_endian(c: u8) -> bool {
}
/// Trait implemented for possible element types of `PyBuffer`.
pub unsafe trait Element {
pub unsafe trait Element: Copy {
/// Gets whether the element specified in the format string is potentially compatible.
/// Alignment and size are checked separately from this function.
fn is_compatible_format(format: &CStr) -> bool;
}
fn validate(b: &ffi::Py_buffer) {
fn validate(b: &ffi::Py_buffer) -> PyResult<()> {
// shape and stride information must be provided when we use PyBUF_FULL_RO
assert!(!b.shape.is_null());
assert!(!b.strides.is_null());
if b.shape.is_null() {
return Err(exceptions::BufferError::py_err("Shape is Null"));
}
if b.strides.is_null() {
return Err(exceptions::BufferError::py_err("PyBuffer: Strides is Null"));
}
Ok(())
}
impl PyBuffer {
impl<'source, T: Element> FromPyObject<'source> for PyBuffer<T> {
fn extract(obj: &PyAny) -> PyResult<PyBuffer<T>> {
Self::get(obj)
}
}
impl<T: Element> PyBuffer<T> {
/// Get the underlying buffer from the specified python object.
pub fn get(py: Python, obj: &PyAny) -> PyResult<PyBuffer> {
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
unsafe {
let mut buf = Box::pin(mem::zeroed::<ffi::Py_buffer>());
let mut buf = Box::pin(ffi::Py_buffer::new());
err::error_on_minusone(
py,
obj.py(),
ffi::PyObject_GetBuffer(obj.as_ptr(), &mut *buf, ffi::PyBUF_FULL_RO),
)?;
validate(&buf);
Ok(PyBuffer(buf))
validate(&buf)?;
let buf = PyBuffer(buf, PhantomData);
// Type Check
if mem::size_of::<T>() == buf.item_size()
&& (buf.0.buf as usize) % mem::align_of::<T>() == 0
&& T::is_compatible_format(buf.format())
{
Ok(buf)
} else {
Err(exceptions::BufferError::py_err(
"Incompatible type as buffer",
))
}
}
}
@ -307,12 +330,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_c_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if self.is_c_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut ReadOnlyCell<T>,
@ -334,13 +353,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_mut_slice<'a, T: Element>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly()
&& mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_c_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly() && self.is_c_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut cell::Cell<T>,
@ -361,15 +375,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_fortran_slice<'a, T: Element>(
&'a self,
_py: Python<'a>,
) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_fortran_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_fortran_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell<T>]> {
if mem::size_of::<T>() == self.item_size() && self.is_fortran_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut ReadOnlyCell<T>,
@ -391,16 +398,8 @@ impl PyBuffer {
///
/// The returned slice uses type `Cell<T>` because it's theoretically possible for any call into the Python runtime
/// to modify the values in the slice.
pub fn as_fortran_mut_slice<'a, T: Element>(
&'a self,
_py: Python<'a>,
) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly()
&& mem::size_of::<T>() == self.item_size()
&& (self.0.buf as usize) % mem::align_of::<T>() == 0
&& self.is_fortran_contiguous()
&& T::is_compatible_format(self.format())
{
pub fn as_fortran_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell<T>]> {
if !self.readonly() && self.is_fortran_contiguous() {
unsafe {
Some(slice::from_raw_parts(
self.0.buf as *mut cell::Cell<T>,
@ -421,7 +420,7 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_to_slice<T: Element + Copy>(&self, py: Python, target: &mut [T]) -> PyResult<()> {
pub fn copy_to_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
self.copy_to_slice_impl(py, target, b'C')
}
@ -434,28 +433,16 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// you can use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_to_fortran_slice<T: Element + Copy>(
&self,
py: Python,
target: &mut [T],
) -> PyResult<()> {
pub fn copy_to_fortran_slice(&self, py: Python, target: &mut [T]) -> PyResult<()> {
self.copy_to_slice_impl(py, target, b'F')
}
fn copy_to_slice_impl<T: Element + Copy>(
&self,
py: Python,
target: &mut [T],
fort: u8,
) -> PyResult<()> {
fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
if mem::size_of_val(target) != self.len_bytes() {
return Err(exceptions::BufferError::py_err(
"Slice length does not match buffer length.",
));
}
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
return incompatible_format_error();
}
unsafe {
err::error_on_minusone(
py,
@ -473,7 +460,7 @@ impl PyBuffer {
/// If the buffer is multi-dimensional, the elements are written in C-style order.
///
/// Fails if the buffer format is not compatible with type `T`.
pub fn to_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
pub fn to_vec(&self, py: Python) -> PyResult<Vec<T>> {
self.to_vec_impl(py, b'C')
}
@ -481,15 +468,11 @@ impl PyBuffer {
/// If the buffer is multi-dimensional, the elements are written in Fortran-style order.
///
/// Fails if the buffer format is not compatible with type `T`.
pub fn to_fortran_vec<T: Element + Copy>(&self, py: Python) -> PyResult<Vec<T>> {
pub fn to_fortran_vec(&self, py: Python) -> PyResult<Vec<T>> {
self.to_vec_impl(py, b'F')
}
fn to_vec_impl<T: Element + Copy>(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
incompatible_format_error()?;
unreachable!();
}
fn to_vec_impl(&self, py: Python, fort: u8) -> PyResult<Vec<T>> {
let item_count = self.item_count();
let mut vec: Vec<T> = Vec::with_capacity(item_count);
unsafe {
@ -520,7 +503,7 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_from_slice<T: Element + Copy>(&self, py: Python, source: &[T]) -> PyResult<()> {
pub fn copy_from_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
self.copy_from_slice_impl(py, source, b'C')
}
@ -534,20 +517,11 @@ impl PyBuffer {
/// To check whether the buffer format is compatible before calling this method,
/// use `<T as buffer::Element>::is_compatible_format(buf.format())`.
/// Alternatively, `match buffer::ElementType::from_format(buf.format())`.
pub fn copy_from_fortran_slice<T: Element + Copy>(
&self,
py: Python,
source: &[T],
) -> PyResult<()> {
pub fn copy_from_fortran_slice(&self, py: Python, source: &[T]) -> PyResult<()> {
self.copy_from_slice_impl(py, source, b'F')
}
fn copy_from_slice_impl<T: Element + Copy>(
&self,
py: Python,
source: &[T],
fort: u8,
) -> PyResult<()> {
fn copy_from_slice_impl(&self, py: Python, source: &[T], fort: u8) -> PyResult<()> {
if self.readonly() {
return buffer_readonly_error();
}
@ -556,9 +530,6 @@ impl PyBuffer {
"Slice length does not match buffer length.",
));
}
if !T::is_compatible_format(self.format()) || mem::size_of::<T>() != self.item_size() {
return incompatible_format_error();
}
unsafe {
err::error_on_minusone(
py,
@ -589,19 +560,14 @@ impl PyBuffer {
}
}
fn incompatible_format_error() -> PyResult<()> {
Err(exceptions::BufferError::py_err(
"Slice type is incompatible with buffer format.",
))
}
#[inline(always)]
fn buffer_readonly_error() -> PyResult<()> {
Err(exceptions::BufferError::py_err(
"Cannot write to read-only buffer.",
))
}
impl Drop for PyBuffer {
impl<T> Drop for PyBuffer<T> {
fn drop(&mut self) {
let _gil_guard = Python::acquire_gil();
unsafe { ffi::PyBuffer_Release(&mut *self.0) }
@ -614,9 +580,9 @@ impl Drop for PyBuffer {
/// The data cannot be modified through the reference, but other references may
/// be modifying the data.
#[repr(transparent)]
pub struct ReadOnlyCell<T>(cell::UnsafeCell<T>);
pub struct ReadOnlyCell<T: Element>(cell::UnsafeCell<T>);
impl<T: Copy> ReadOnlyCell<T> {
impl<T: Element> ReadOnlyCell<T> {
#[inline]
pub fn get(&self) -> T {
unsafe { *self.0.get() }
@ -675,7 +641,7 @@ mod test {
let gil = Python::acquire_gil();
let py = gil.python();
let bytes = py.eval("b'abcde'", None, None).unwrap();
let buffer = PyBuffer::get(py, &bytes).unwrap();
let buffer = PyBuffer::get(&bytes).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 5);
assert_eq!(buffer.format().to_str().unwrap(), "B");
@ -684,26 +650,18 @@ mod test {
assert!(buffer.is_c_contiguous());
assert!(buffer.is_fortran_contiguous());
assert!(buffer.as_slice::<f64>(py).is_none());
assert!(buffer.as_slice::<i8>(py).is_none());
let slice = buffer.as_slice::<u8>(py).unwrap();
let slice = buffer.as_slice(py).unwrap();
assert_eq!(slice.len(), 5);
assert_eq!(slice[0].get(), b'a');
assert_eq!(slice[2].get(), b'c');
assert!(buffer.as_mut_slice::<u8>(py).is_none());
assert!(buffer.copy_to_slice(py, &mut [0u8]).is_err());
let mut arr = [0; 5];
buffer.copy_to_slice(py, &mut arr).unwrap();
assert_eq!(arr, b"abcde" as &[u8]);
assert!(buffer.copy_from_slice(py, &[0u8; 5]).is_err());
assert!(buffer.to_vec::<i8>(py).is_err());
assert!(buffer.to_vec::<u16>(py).is_err());
assert_eq!(buffer.to_vec::<u8>(py).unwrap(), b"abcde");
assert_eq!(buffer.to_vec(py).unwrap(), b"abcde");
}
#[allow(clippy::float_cmp)] // The test wants to ensure that no precision was lost on the Python round-trip
@ -716,21 +674,18 @@ mod test {
.unwrap()
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
.unwrap();
let buffer = PyBuffer::get(py, array).unwrap();
let buffer = PyBuffer::get(array).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 4);
assert_eq!(buffer.format().to_str().unwrap(), "f");
assert_eq!(buffer.shape(), [4]);
assert!(buffer.as_slice::<f64>(py).is_none());
assert!(buffer.as_slice::<i32>(py).is_none());
let slice = buffer.as_slice::<f32>(py).unwrap();
let slice = buffer.as_slice(py).unwrap();
assert_eq!(slice.len(), 4);
assert_eq!(slice[0].get(), 1.0);
assert_eq!(slice[3].get(), 2.5);
let mut_slice = buffer.as_mut_slice::<f32>(py).unwrap();
let mut_slice = buffer.as_mut_slice(py).unwrap();
assert_eq!(mut_slice.len(), 4);
assert_eq!(mut_slice[0].get(), 1.0);
mut_slice[3].set(2.75);
@ -741,6 +696,6 @@ mod test {
.unwrap();
assert_eq!(slice[2].get(), 12.0);
assert_eq!(buffer.to_vec::<f32>(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
}
}

View File

@ -133,8 +133,8 @@ pub type objobjargproc =
#[cfg(not(Py_LIMITED_API))]
mod bufferinfo {
use crate::ffi::pyport::Py_ssize_t;
use std::mem;
use std::os::raw::{c_char, c_int, c_void};
use std::ptr;
#[repr(C)]
#[derive(Copy, Clone)]
@ -152,10 +152,21 @@ mod bufferinfo {
pub internal: *mut c_void,
}
impl Default for Py_buffer {
#[inline]
fn default() -> Self {
unsafe { mem::zeroed() }
impl Py_buffer {
pub const fn new() -> Self {
Py_buffer {
buf: ptr::null_mut(),
obj: ptr::null_mut(),
len: 0,
itemsize: 0,
readonly: 0,
ndim: 0,
format: ptr::null_mut(),
shape: ptr::null_mut(),
strides: ptr::null_mut(),
suboffsets: ptr::null_mut(),
internal: ptr::null_mut(),
}
}
}

View File

@ -279,7 +279,7 @@ macro_rules! array_impls {
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
// first try buffer protocol
if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) {
if let Ok(buf) = buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
@ -315,9 +315,9 @@ where
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
// first try buffer protocol
if let Ok(buf) = buffer::PyBuffer::get(obj.py(), obj) {
if let Ok(buf) = buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 {
if let Ok(v) = buf.to_vec::<T>(obj.py()) {
if let Ok(v) = buf.to_vec(obj.py()) {
buf.release(obj.py());
return Ok(v);
}

View File

@ -114,8 +114,8 @@ fn test_buffer_referenced() {
}
.into_py(py);
let buf = PyBuffer::get(py, instance.as_ref(py)).unwrap();
assert_eq!(buf.to_vec::<u8>(py).unwrap(), input);
let buf = PyBuffer::<u8>::get(instance.as_ref(py)).unwrap();
assert_eq!(buf.to_vec(py).unwrap(), input);
drop(instance);
buf
};

View File

@ -1,3 +1,4 @@
use pyo3::buffer::PyBuffer;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
@ -20,3 +21,45 @@ fn test_optional_bool() {
py_assert!(py, f, "f(False) == 'Some(false)'");
py_assert!(py, f, "f(None) == 'None'");
}
#[pyfunction]
fn buffer_inplace_add(py: Python, x: PyBuffer<i32>, y: PyBuffer<i32>) {
let x = x.as_mut_slice(py).unwrap();
let y = y.as_slice(py).unwrap();
for (xi, yi) in x.iter().zip(y) {
let xi_plus_yi = xi.get() + yi.get();
xi.set(xi_plus_yi);
}
}
#[test]
fn test_buffer_add() {
// Regression test for issue #932
let gil = Python::acquire_gil();
let py = gil.python();
let f = wrap_pyfunction!(buffer_inplace_add)(py);
py_expect_exception!(
py,
f,
r#"
import array
a = array.array("i", [0, 1, 2, 3])
b = array.array("I", [0, 1, 2, 3])
f(a, b)
"#,
BufferError
);
pyo3::py_run!(
py,
f,
r#"
import array
a = array.array("i", [0, 1, 2, 3])
b = array.array("i", [2, 3, 4, 5])
f(a, b)
assert a, array.array("i", [2, 4, 6, 8])
"#
);
}