convert `PyBuffer` to `Bound` API (#3836)

This commit is contained in:
Icxolu 2024-02-14 23:10:59 +01:00 committed by GitHub
parent 9902633116
commit f3ddd023c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 16 deletions

View File

@ -35,8 +35,8 @@ impl BytesExtractor {
}
#[staticmethod]
pub fn from_buffer(buf: &PyAny) -> PyResult<usize> {
let buf = PyBuffer::<u8>::get(buf)?;
pub fn from_buffer(buf: &Bound<'_, PyAny>) -> PyResult<usize> {
let buf = PyBuffer::<u8>::get_bound(buf)?;
Ok(buf.item_count())
}
}

View File

@ -18,8 +18,8 @@
// DEALINGS IN THE SOFTWARE.
//! `PyBuffer` implementation
use crate::instance::Bound;
use crate::{err, exceptions::PyBufferError, ffi, FromPyObject, PyAny, PyResult, Python};
use crate::{Bound, PyNativeType};
use std::marker::PhantomData;
use std::os::raw;
use std::pin::Pin;
@ -184,13 +184,25 @@ pub unsafe trait Element: Copy {
impl<'py, T: Element> FromPyObject<'py> for PyBuffer<T> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
Self::get(obj.as_gil_ref())
Self::get_bound(obj)
}
}
impl<T: Element> PyBuffer<T> {
/// Gets the underlying buffer from the specified python object.
/// Deprecated form of [`PyBuffer::get_bound`]
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyBuffer::get` will be replaced by `PyBuffer::get_bound` in a future PyO3 version"
)
)]
pub fn get(obj: &PyAny) -> PyResult<PyBuffer<T>> {
Self::get_bound(&obj.as_borrowed())
}
/// Gets the underlying buffer from the specified python object.
pub fn get_bound(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
// TODO: use nightly API Box::new_uninit() once stable
let mut buf = Box::new(mem::MaybeUninit::uninit());
let buf: Box<ffi::Py_buffer> = {
@ -696,7 +708,7 @@ mod tests {
fn test_debug() {
Python::with_gil(|py| {
let bytes = py.eval_bound("b'abcde'", None, None).unwrap();
let buffer: PyBuffer<u8> = PyBuffer::get(bytes.as_gil_ref()).unwrap();
let buffer: PyBuffer<u8> = PyBuffer::get_bound(&bytes).unwrap();
let expected = format!(
concat!(
"PyBuffer {{ buf: {:?}, obj: {:?}, ",
@ -859,7 +871,7 @@ mod tests {
fn test_bytes_buffer() {
Python::with_gil(|py| {
let bytes = py.eval_bound("b'abcde'", None, None).unwrap();
let buffer = PyBuffer::get(bytes.as_gil_ref()).unwrap();
let buffer = PyBuffer::get_bound(&bytes).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 5);
assert_eq!(buffer.format().to_str().unwrap(), "B");
@ -895,7 +907,7 @@ mod tests {
.unwrap()
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
.unwrap();
let buffer = PyBuffer::get(array.as_gil_ref()).unwrap();
let buffer = PyBuffer::get_bound(&array).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 4);
assert_eq!(buffer.format().to_str().unwrap(), "f");
@ -925,7 +937,7 @@ mod tests {
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);
// F-contiguous fns
let buffer = PyBuffer::get(array.as_gil_ref()).unwrap();
let buffer = PyBuffer::get_bound(&array).unwrap();
let slice = buffer.as_fortran_slice(py).unwrap();
assert_eq!(slice.len(), 4);
assert_eq!(slice[1].get(), 11.0);

View File

@ -96,11 +96,11 @@ fn test_get_buffer_errors() {
)
.unwrap();
assert!(PyBuffer::<u32>::get(instance.as_ref(py)).is_ok());
assert!(PyBuffer::<u32>::get_bound(instance.bind(py).as_any()).is_ok());
instance.borrow_mut(py).error = Some(TestGetBufferError::NullShape);
assert_eq!(
PyBuffer::<u32>::get(instance.as_ref(py))
PyBuffer::<u32>::get_bound(instance.bind(py).as_any())
.unwrap_err()
.to_string(),
"BufferError: shape is null"
@ -108,7 +108,7 @@ fn test_get_buffer_errors() {
instance.borrow_mut(py).error = Some(TestGetBufferError::NullStrides);
assert_eq!(
PyBuffer::<u32>::get(instance.as_ref(py))
PyBuffer::<u32>::get_bound(instance.bind(py).as_any())
.unwrap_err()
.to_string(),
"BufferError: strides is null"
@ -116,7 +116,7 @@ fn test_get_buffer_errors() {
instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectItemSize);
assert_eq!(
PyBuffer::<u32>::get(instance.as_ref(py))
PyBuffer::<u32>::get_bound(instance.bind(py).as_any())
.unwrap_err()
.to_string(),
"BufferError: buffer contents are not compatible with u32"
@ -124,7 +124,7 @@ fn test_get_buffer_errors() {
instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectFormat);
assert_eq!(
PyBuffer::<u32>::get(instance.as_ref(py))
PyBuffer::<u32>::get_bound(instance.bind(py).as_any())
.unwrap_err()
.to_string(),
"BufferError: buffer contents are not compatible with u32"
@ -132,7 +132,7 @@ fn test_get_buffer_errors() {
instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectAlignment);
assert_eq!(
PyBuffer::<u32>::get(instance.as_ref(py))
PyBuffer::<u32>::get_bound(instance.bind(py).as_any())
.unwrap_err()
.to_string(),
"BufferError: buffer contents are insufficiently aligned for u32"

View File

@ -77,7 +77,7 @@ fn test_buffer_referenced() {
}
.into_py(py);
let buf = PyBuffer::<u8>::get(instance.as_ref(py)).unwrap();
let buf = PyBuffer::<u8>::get_bound(instance.bind(py)).unwrap();
assert_eq!(buf.to_vec(py).unwrap(), input);
drop(instance);
buf