From cf965155f41fe2ca17e6430130051061d8f5ce85 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 21 Dec 2021 07:01:11 +0000 Subject: [PATCH] pymethods: support buffer protocol --- CHANGELOG.md | 1 + guide/src/class.md | 6 -- guide/src/class/protocols.md | 3 +- pyo3-macros-backend/src/method.rs | 2 + pyo3-macros-backend/src/pyclass.rs | 6 +- pyo3-macros-backend/src/pyfunction.rs | 1 + pyo3-macros-backend/src/pymethod.rs | 38 ++++++- pyo3-macros-backend/src/pyproto.rs | 27 ----- src/class/impl_.rs | 22 ---- src/pyclass.rs | 88 ++++++++------- tests/test_buffer_protocol.rs | 69 ++++++------ tests/test_buffer_protocol_pyproto.rs | 132 +++++++++++++++++++++++ tests/test_compile_error.rs | 2 + tests/ui/invalid_pymethods_buffer.rs | 18 ++++ tests/ui/invalid_pymethods_buffer.stderr | 11 ++ 15 files changed, 290 insertions(+), 136 deletions(-) create mode 100644 tests/test_buffer_protocol_pyproto.rs create mode 100644 tests/ui/invalid_pymethods_buffer.rs create mode 100644 tests/ui/invalid_pymethods_buffer.stderr diff --git a/CHANGELOG.md b/CHANGELOG.md index 693ea4f6..f60101ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - All PyO3 proc-macros except the deprecated `#[pyproto]` now accept a supplemental attribute `#[pyo3(crate = "some::path")]` that specifies where to find the `pyo3` crate, in case it has been renamed or is re-exported and not found at the crate root. [#2022](https://github.com/PyO3/pyo3/pull/2022) - Expose `pyo3-build-config` APIs for cross-compiling and Python configuration discovery for use in other projects. [#1996](https://github.com/PyO3/pyo3/pull/1996) +- Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067) - Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081) ### Changed diff --git a/guide/src/class.md b/guide/src/class.md index 651032ae..51c6e01d 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -899,12 +899,6 @@ impl pyo3::class::impl_::PyClassImpl for MyClass { visitor(collector.buffer_protocol_slots()); visitor(collector.methods_protocol_slots()); } - - fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> { - use pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.buffer_procs() - } } # Python::with_gil(|py| { # let cls = py.get_type::(); diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 4bdcaa47..bab3d3d0 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -175,7 +175,8 @@ TODO; see [#1884](https://github.com/PyO3/pyo3/issues/1884) #### Buffer objects -TODO; see [#1884](https://github.com/PyO3/pyo3/issues/1884) + - `__getbuffer__(, *mut ffi::Py_buffer, flags) -> ()` + - `__releasebuffer__(, *mut ffi::Py_buffer)` (no return value, not even `PyResult`) #### Garbage Collector Integration diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index e7d0a00e..4af70180 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -229,6 +229,7 @@ pub struct FnSpec<'a> { pub convention: CallingConvention, pub text_signature: Option, pub krate: syn::Path, + pub unsafety: Option, } pub fn get_return_info(output: &syn::ReturnType) -> syn::Type { @@ -316,6 +317,7 @@ impl<'a> FnSpec<'a> { deprecations, text_signature, krate, + unsafety: sig.unsafety, }) } diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index f7367eef..3a32f4be 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -847,12 +847,8 @@ impl<'a> PyClassImplsBuilder<'a> { #methods_protos } - fn get_buffer() -> ::std::option::Option<&'static _pyo3::class::impl_::PyBufferProcs> { - use _pyo3::class::impl_::*; - let collector = PyClassImplCollector::::new(); - collector.buffer_procs() - } #dict_offset + #weaklist_offset } diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 3e6bd3f7..df4e12ae 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -431,6 +431,7 @@ pub fn impl_wrap_pyfunction( deprecations: options.deprecations, text_signature: options.text_signature, krate: krate.clone(), + unsafety: func.sig.unsafety, }; let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name); diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index c8b4d607..b8cc764d 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -103,7 +103,7 @@ pub fn gen_py_method( ensure_no_forbidden_protocol_attributes(spec, &method.method_name)?; match proto_kind { PyMethodProtoKind::Slot(slot_def) => { - let slot = slot_def.generate_type_slot(cls, spec)?; + let slot = slot_def.generate_type_slot(cls, spec, &method.method_name)?; GeneratedPyMethod::Proto(slot) } PyMethodProtoKind::Call => { @@ -556,6 +556,14 @@ const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc") .arguments(&[Ty::Object]) .extract_error_mode(ExtractErrorMode::NotImplemented) .return_self(); +const __GETBUFFER__: SlotDef = SlotDef::new("Py_bf_getbuffer", "getbufferproc") + .arguments(&[Ty::PyBuffer, Ty::Int]) + .ret_ty(Ty::Int) + .require_unsafe(); +const __RELEASEBUFFER__: SlotDef = SlotDef::new("Py_bf_releasebuffer", "releasebufferproc") + .arguments(&[Ty::PyBuffer]) + .ret_ty(Ty::Void) + .require_unsafe(); fn pyproto(method_name: &str) -> Option<&'static SlotDef> { match method_name { @@ -594,6 +602,8 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> { "__iand__" => Some(&__IAND__), "__ixor__" => Some(&__IXOR__), "__ior__" => Some(&__IOR__), + "__getbuffer__" => Some(&__GETBUFFER__), + "__releasebuffer__" => Some(&__RELEASEBUFFER__), _ => None, } } @@ -608,6 +618,7 @@ enum Ty { PyHashT, PySsizeT, Void, + PyBuffer, } impl Ty { @@ -619,6 +630,7 @@ impl Ty { Ty::PyHashT => quote! { _pyo3::ffi::Py_hash_t }, Ty::PySsizeT => quote! { _pyo3::ffi::Py_ssize_t }, Ty::Void => quote! { () }, + Ty::PyBuffer => quote! { *mut _pyo3::ffi::Py_buffer }, } } @@ -680,7 +692,8 @@ impl Ty { let #ident = #extract; } } - Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(), + // Just pass other types through unmodified + Ty::PyBuffer | Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => quote! {}, } } } @@ -752,6 +765,7 @@ struct SlotDef { before_call_method: Option, extract_error_mode: ExtractErrorMode, return_mode: Option, + require_unsafe: bool, } const NO_ARGUMENTS: &[Ty] = &[]; @@ -766,6 +780,7 @@ impl SlotDef { before_call_method: None, extract_error_mode: ExtractErrorMode::Raise, return_mode: None, + require_unsafe: false, } } @@ -799,7 +814,17 @@ impl SlotDef { self } - fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result { + const fn require_unsafe(mut self) -> Self { + self.require_unsafe = true; + self + } + + fn generate_type_slot( + &self, + cls: &syn::Type, + spec: &FnSpec, + method_name: &str, + ) -> Result { let SlotDef { slot, func_ty, @@ -808,7 +833,14 @@ impl SlotDef { extract_error_mode, ret_ty, return_mode, + require_unsafe, } = self; + if *require_unsafe { + ensure_spanned!( + spec.unsafety.is_some(), + spec.name.span() => format!("`{}` must be `unsafe fn`", method_name) + ); + } let py = syn::Ident::new("_py", Span::call_site()); let method_arguments = generate_method_arguments(arguments); let ret_ty = ret_ty.ffi_type(); diff --git a/pyo3-macros-backend/src/pyproto.rs b/pyo3-macros-backend/src/pyproto.rs index 08b8527c..198fb517 100644 --- a/pyo3-macros-backend/src/pyproto.rs +++ b/pyo3-macros-backend/src/pyproto.rs @@ -134,31 +134,6 @@ fn impl_proto_methods( let slots_trait = proto.slots_trait(); let slots_trait_slots = proto.slots_trait_slots(); - let mut maybe_buffer_methods = None; - - let build_config = pyo3_build_config::get(); - const PY39: pyo3_build_config::PythonVersion = - pyo3_build_config::PythonVersion { major: 3, minor: 9 }; - - if build_config.version <= PY39 && proto.name == "Buffer" { - maybe_buffer_methods = Some(quote! { - impl _pyo3::class::impl_::PyBufferProtocolProcs<#ty> - for _pyo3::class::impl_::PyClassImplCollector<#ty> - { - fn buffer_procs( - self - ) -> ::std::option::Option<&'static _pyo3::class::impl_::PyBufferProcs> { - static PROCS: _pyo3::class::impl_::PyBufferProcs - = _pyo3::class::impl_::PyBufferProcs { - bf_getbuffer: ::std::option::Option::Some(_pyo3::class::buffer::getbuffer::<#ty>), - bf_releasebuffer: ::std::option::Option::Some(_pyo3::class::buffer::releasebuffer::<#ty>), - }; - ::std::option::Option::Some(&PROCS) - } - } - }); - } - let mut tokens = proto .slot_defs(method_names) .map(|def| { @@ -178,8 +153,6 @@ fn impl_proto_methods( } quote! { - #maybe_buffer_methods - impl _pyo3::class::impl_::#slots_trait<#ty> for _pyo3::class::impl_::PyClassImplCollector<#ty> { diff --git a/src/class/impl_.rs b/src/class/impl_.rs index 61d97018..8a88e343 100644 --- a/src/class/impl_.rs +++ b/src/class/impl_.rs @@ -85,9 +85,6 @@ pub trait PyClassImpl: Sized { fn get_free() -> Option { None } - fn get_buffer() -> Option<&'static PyBufferProcs> { - None - } #[inline] fn dict_offset() -> Option { None @@ -685,25 +682,6 @@ methods_trait!(PyDescrProtocolMethods, descr_protocol_methods); methods_trait!(PyMappingProtocolMethods, mapping_protocol_methods); methods_trait!(PyNumberProtocolMethods, number_protocol_methods); -// On Python < 3.9 setting the buffer protocol using slots doesn't work, so these procs are used -// on those versions to set the slots manually (on the limited API). - -#[cfg(not(Py_LIMITED_API))] -pub use ffi::PyBufferProcs; - -#[cfg(Py_LIMITED_API)] -pub struct PyBufferProcs; - -pub trait PyBufferProtocolProcs { - fn buffer_procs(self) -> Option<&'static PyBufferProcs>; -} - -impl PyBufferProtocolProcs for &'_ PyClassImplCollector { - fn buffer_procs(self) -> Option<&'static PyBufferProcs> { - None - } -} - // Thread checkers #[doc(hidden)] diff --git a/src/pyclass.rs b/src/pyclass.rs index be6d7549..b2e1b0fc 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -1,6 +1,6 @@ //! `PyClass` and related traits. use crate::{ - class::impl_::{fallback_new, tp_dealloc, PyBufferProcs, PyClassImpl}, + class::impl_::{fallback_new, tp_dealloc, PyClassImpl}, ffi, impl_::pyclass::{PyClassDict, PyClassWeakRef}, PyCell, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python, @@ -55,7 +55,6 @@ where &T::for_each_proto_slot, T::IS_GC, T::IS_BASETYPE, - T::get_buffer(), ) } { Ok(type_object) => type_object, @@ -81,7 +80,6 @@ unsafe fn create_type_object_impl( for_each_proto_slot: &dyn Fn(&mut dyn FnMut(&[ffi::PyType_Slot])), is_gc: bool, is_basetype: bool, - buffer_procs: Option<&PyBufferProcs>, ) -> PyResult<*mut ffi::PyTypeObject> { let mut slots = Vec::new(); @@ -130,10 +128,26 @@ unsafe fn create_type_object_impl( // protocol methods let mut has_gc_methods = false; + // Before Python 3.9, need to patch in buffer methods manually (they don't work in slots) + #[cfg(all(not(Py_3_9), not(Py_LIMITED_API)))] + let mut buffer_procs: ffi::PyBufferProcs = Default::default(); + for_each_proto_slot(&mut |proto_slots| { - has_gc_methods |= proto_slots - .iter() - .any(|slot| slot.slot == ffi::Py_tp_clear || slot.slot == ffi::Py_tp_traverse); + for slot in proto_slots { + has_gc_methods |= slot.slot == ffi::Py_tp_clear || slot.slot == ffi::Py_tp_traverse; + + #[cfg(all(not(Py_3_9), not(Py_LIMITED_API)))] + if slot.slot == ffi::Py_bf_getbuffer { + // Safety: slot.pfunc is a valid function pointer + buffer_procs.bf_getbuffer = Some(std::mem::transmute(slot.pfunc)); + } + + #[cfg(all(not(Py_3_9), not(Py_LIMITED_API)))] + if slot.slot == ffi::Py_bf_releasebuffer { + // Safety: slot.pfunc is a valid function pointer + buffer_procs.bf_releasebuffer = Some(std::mem::transmute(slot.pfunc)); + } + } slots.extend_from_slice(proto_slots); }); @@ -153,8 +167,11 @@ unsafe fn create_type_object_impl( tp_init_additional( type_object as _, tp_doc, - buffer_procs, + #[cfg(all(not(Py_3_9), not(Py_LIMITED_API)))] + &buffer_procs, + #[cfg(not(Py_3_9))] dict_offset, + #[cfg(not(Py_3_9))] weaklist_offset, ); Ok(type_object as _) @@ -169,12 +186,12 @@ fn type_object_creation_failed(py: Python, e: PyErr, name: &'static str) -> ! { /// Additional type initializations necessary before Python 3.10 #[cfg(all(not(Py_LIMITED_API), not(Py_3_10)))] -fn tp_init_additional( +unsafe fn tp_init_additional( type_object: *mut ffi::PyTypeObject, _tp_doc: &str, - _buffer_procs: Option<&PyBufferProcs>, - _dict_offset: Option, - _weaklist_offset: Option, + #[cfg(not(Py_3_9))] buffer_procs: &ffi::PyBufferProcs, + #[cfg(not(Py_3_9))] dict_offset: Option, + #[cfg(not(Py_3_9))] weaklist_offset: Option, ) { // Just patch the type objects for the things there's no // PyType_FromSpec API for... there's no reason this should work, @@ -184,16 +201,14 @@ fn tp_init_additional( #[cfg(all(not(PyPy), not(Py_3_10)))] { if _tp_doc != "\0" { - unsafe { - // Until CPython 3.10, tp_doc was treated specially for - // heap-types, and it removed the text_signature value from it. - // We go in after the fact and replace tp_doc with something - // that _does_ include the text_signature value! - ffi::PyObject_Free((*type_object).tp_doc as _); - let data = ffi::PyObject_Malloc(_tp_doc.len()); - data.copy_from(_tp_doc.as_ptr() as _, _tp_doc.len()); - (*type_object).tp_doc = data as _; - } + // Until CPython 3.10, tp_doc was treated specially for + // heap-types, and it removed the text_signature value from it. + // We go in after the fact and replace tp_doc with something + // that _does_ include the text_signature value! + ffi::PyObject_Free((*type_object).tp_doc as _); + let data = ffi::PyObject_Malloc(_tp_doc.len()); + data.copy_from(_tp_doc.as_ptr() as _, _tp_doc.len()); + (*type_object).tp_doc = data as _; } } @@ -201,23 +216,15 @@ fn tp_init_additional( // Python 3.9, so on older versions we must manually fixup the type object. #[cfg(not(Py_3_9))] { - if let Some(buffer) = _buffer_procs { - unsafe { - (*(*type_object).tp_as_buffer).bf_getbuffer = buffer.bf_getbuffer; - (*(*type_object).tp_as_buffer).bf_releasebuffer = buffer.bf_releasebuffer; - } + (*(*type_object).tp_as_buffer).bf_getbuffer = buffer_procs.bf_getbuffer; + (*(*type_object).tp_as_buffer).bf_releasebuffer = buffer_procs.bf_releasebuffer; + + if let Some(dict_offset) = dict_offset { + (*type_object).tp_dictoffset = dict_offset; } - if let Some(dict_offset) = _dict_offset { - unsafe { - (*type_object).tp_dictoffset = dict_offset; - } - } - - if let Some(weaklist_offset) = _weaklist_offset { - unsafe { - (*type_object).tp_weaklistoffset = weaklist_offset; - } + if let Some(weaklist_offset) = weaklist_offset { + (*type_object).tp_weaklistoffset = weaklist_offset; } } } @@ -226,9 +233,9 @@ fn tp_init_additional( fn tp_init_additional( _type_object: *mut ffi::PyTypeObject, _tp_doc: &str, - _buffer_procs: Option<&PyBufferProcs>, - _dict_offset: Option, - _weaklist_offset: Option, + #[cfg(all(not(Py_3_9), not(Py_LIMITED_API)))] _buffer_procs: &ffi::PyBufferProcs, + #[cfg(not(Py_3_9))] _dict_offset: Option, + #[cfg(not(Py_3_9))] _weaklist_offset: Option, ) { } @@ -290,6 +297,7 @@ fn py_class_method_defs( }); if !defs.is_empty() { + // Safety: Python expects a zeroed entry to mark the end of the defs defs.push(unsafe { std::mem::zeroed() }); } @@ -329,6 +337,7 @@ fn py_class_members( } if !members.is_empty() { + // Safety: Python expects a zeroed entry to mark the end of the defs members.push(unsafe { std::mem::zeroed() }); } @@ -370,6 +379,7 @@ fn py_class_properties( push_dict_getset(&mut props, is_dummy); if !props.is_empty() { + // Safety: Python expects a zeroed entry to mark the end of the defs props.push(unsafe { std::mem::zeroed() }); } props diff --git a/tests/test_buffer_protocol.rs b/tests/test_buffer_protocol.rs index 81c5a7f0..bb2cb253 100644 --- a/tests/test_buffer_protocol.rs +++ b/tests/test_buffer_protocol.rs @@ -2,13 +2,12 @@ #![cfg(not(Py_LIMITED_API))] use pyo3::buffer::PyBuffer; -use pyo3::class::PyBufferProtocol; use pyo3::exceptions::PyBufferError; use pyo3::ffi; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use pyo3::AsPyPointer; -use std::ffi::CStr; +use std::ffi::CString; use std::os::raw::{c_int, c_void}; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; @@ -22,9 +21,13 @@ struct TestBufferClass { drop_called: Arc, } -#[pyproto] -impl PyBufferProtocol for TestBufferClass { - fn bf_getbuffer(slf: PyRefMut, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()> { +#[pymethods] +impl TestBufferClass { + unsafe fn __getbuffer__( + mut slf: PyRefMut, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { if view.is_null() { return Err(PyBufferError::new_err("View is null")); } @@ -33,43 +36,43 @@ impl PyBufferProtocol for TestBufferClass { return Err(PyBufferError::new_err("Object is not writable")); } - unsafe { - (*view).obj = ffi::_Py_NewRef(slf.as_ptr()); - } + (*view).obj = ffi::_Py_NewRef(slf.as_ptr()); - let bytes = &slf.vec; + (*view).buf = slf.vec.as_mut_ptr() as *mut c_void; + (*view).len = slf.vec.len() as isize; + (*view).readonly = 1; + (*view).itemsize = 1; - unsafe { - (*view).buf = bytes.as_ptr() as *mut c_void; - (*view).len = bytes.len() as isize; - (*view).readonly = 1; - (*view).itemsize = 1; + (*view).format = if (flags & ffi::PyBUF_FORMAT) == ffi::PyBUF_FORMAT { + let msg = CString::new("B").unwrap(); + msg.into_raw() + } else { + ptr::null_mut() + }; - (*view).format = ptr::null_mut(); - if (flags & ffi::PyBUF_FORMAT) == ffi::PyBUF_FORMAT { - let msg = CStr::from_bytes_with_nul(b"B\0").unwrap(); - (*view).format = msg.as_ptr() as *mut _; - } + (*view).ndim = 1; + (*view).shape = if (flags & ffi::PyBUF_ND) == ffi::PyBUF_ND { + &mut (*view).len + } else { + ptr::null_mut() + }; - (*view).ndim = 1; - (*view).shape = ptr::null_mut(); - if (flags & ffi::PyBUF_ND) == ffi::PyBUF_ND { - (*view).shape = &mut (*view).len; - } + (*view).strides = if (flags & ffi::PyBUF_STRIDES) == ffi::PyBUF_STRIDES { + &mut (*view).itemsize + } else { + ptr::null_mut() + }; - (*view).strides = ptr::null_mut(); - if (flags & ffi::PyBUF_STRIDES) == ffi::PyBUF_STRIDES { - (*view).strides = &mut (*view).itemsize; - } - - (*view).suboffsets = ptr::null_mut(); - (*view).internal = ptr::null_mut(); - } + (*view).suboffsets = ptr::null_mut(); + (*view).internal = ptr::null_mut(); Ok(()) } - fn bf_releasebuffer(_slf: PyRefMut, _view: *mut ffi::Py_buffer) {} + unsafe fn __releasebuffer__(&self, view: *mut ffi::Py_buffer) { + // Release memory held by the format string + drop(CString::from_raw((*view).format)); + } } impl Drop for TestBufferClass { diff --git a/tests/test_buffer_protocol_pyproto.rs b/tests/test_buffer_protocol_pyproto.rs new file mode 100644 index 00000000..81c5a7f0 --- /dev/null +++ b/tests/test_buffer_protocol_pyproto.rs @@ -0,0 +1,132 @@ +#![cfg(feature = "macros")] +#![cfg(not(Py_LIMITED_API))] + +use pyo3::buffer::PyBuffer; +use pyo3::class::PyBufferProtocol; +use pyo3::exceptions::PyBufferError; +use pyo3::ffi; +use pyo3::prelude::*; +use pyo3::types::IntoPyDict; +use pyo3::AsPyPointer; +use std::ffi::CStr; +use std::os::raw::{c_int, c_void}; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +mod common; + +#[pyclass] +struct TestBufferClass { + vec: Vec, + drop_called: Arc, +} + +#[pyproto] +impl PyBufferProtocol for TestBufferClass { + fn bf_getbuffer(slf: PyRefMut, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()> { + if view.is_null() { + return Err(PyBufferError::new_err("View is null")); + } + + if (flags & ffi::PyBUF_WRITABLE) == ffi::PyBUF_WRITABLE { + return Err(PyBufferError::new_err("Object is not writable")); + } + + unsafe { + (*view).obj = ffi::_Py_NewRef(slf.as_ptr()); + } + + let bytes = &slf.vec; + + unsafe { + (*view).buf = bytes.as_ptr() as *mut c_void; + (*view).len = bytes.len() as isize; + (*view).readonly = 1; + (*view).itemsize = 1; + + (*view).format = ptr::null_mut(); + if (flags & ffi::PyBUF_FORMAT) == ffi::PyBUF_FORMAT { + let msg = CStr::from_bytes_with_nul(b"B\0").unwrap(); + (*view).format = msg.as_ptr() as *mut _; + } + + (*view).ndim = 1; + (*view).shape = ptr::null_mut(); + if (flags & ffi::PyBUF_ND) == ffi::PyBUF_ND { + (*view).shape = &mut (*view).len; + } + + (*view).strides = ptr::null_mut(); + if (flags & ffi::PyBUF_STRIDES) == ffi::PyBUF_STRIDES { + (*view).strides = &mut (*view).itemsize; + } + + (*view).suboffsets = ptr::null_mut(); + (*view).internal = ptr::null_mut(); + } + + Ok(()) + } + + fn bf_releasebuffer(_slf: PyRefMut, _view: *mut ffi::Py_buffer) {} +} + +impl Drop for TestBufferClass { + fn drop(&mut self) { + print!("dropped"); + self.drop_called.store(true, Ordering::Relaxed); + } +} + +#[test] +fn test_buffer() { + let drop_called = Arc::new(AtomicBool::new(false)); + + { + let gil = Python::acquire_gil(); + let py = gil.python(); + let instance = Py::new( + py, + TestBufferClass { + vec: vec![b' ', b'2', b'3'], + drop_called: drop_called.clone(), + }, + ) + .unwrap(); + let env = [("ob", instance)].into_py_dict(py); + py_assert!(py, *env, "bytes(ob) == b' 23'"); + } + + assert!(drop_called.load(Ordering::Relaxed)); +} + +#[test] +fn test_buffer_referenced() { + let drop_called = Arc::new(AtomicBool::new(false)); + + let buf = { + let input = vec![b' ', b'2', b'3']; + let gil = Python::acquire_gil(); + let py = gil.python(); + let instance: PyObject = TestBufferClass { + vec: input.clone(), + drop_called: drop_called.clone(), + } + .into_py(py); + + let buf = PyBuffer::::get(instance.as_ref(py)).unwrap(); + assert_eq!(buf.to_vec(py).unwrap(), input); + drop(instance); + buf + }; + + assert!(!drop_called.load(Ordering::Relaxed)); + + { + let _py = Python::acquire_gil().python(); + drop(buf); + } + + assert!(drop_called.load(Ordering::Relaxed)); +} diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index daade52d..fb3ad668 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -25,6 +25,8 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_pyclass_item.rs"); t.compile_fail("tests/ui/invalid_pyfunctions.rs"); t.compile_fail("tests/ui/invalid_pymethods.rs"); + #[cfg(not(Py_LIMITED_API))] + t.compile_fail("tests/ui/invalid_pymethods_buffer.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/missing_clone.rs"); diff --git a/tests/ui/invalid_pymethods_buffer.rs b/tests/ui/invalid_pymethods_buffer.rs new file mode 100644 index 00000000..4eb7c28e --- /dev/null +++ b/tests/ui/invalid_pymethods_buffer.rs @@ -0,0 +1,18 @@ +use pyo3::prelude::*; + +#[pyclass] +struct MyClass {} + +#[pymethods] +impl MyClass { + #[pyo3(name = "__getbuffer__")] + fn getbuffer_must_be_unsafe(&self, _view: *mut pyo3::ffi::Py_buffer, _flags: std::os::raw::c_int) {} +} + +#[pymethods] +impl MyClass { + #[pyo3(name = "__releasebuffer__")] + fn releasebuffer_must_be_unsafe(&self, _view: *mut pyo3::ffi::Py_buffer) {} +} + +fn main() {} diff --git a/tests/ui/invalid_pymethods_buffer.stderr b/tests/ui/invalid_pymethods_buffer.stderr new file mode 100644 index 00000000..3480848f --- /dev/null +++ b/tests/ui/invalid_pymethods_buffer.stderr @@ -0,0 +1,11 @@ +error: `__getbuffer__` must be `unsafe fn` + --> tests/ui/invalid_pymethods_buffer.rs:9:8 + | +9 | fn getbuffer_must_be_unsafe(&self, _view: *mut pyo3::ffi::Py_buffer, _flags: std::os::raw::c_int) {} + | ^^^^^^^^^^^^^^^^^^^^^^^^ + +error: `__releasebuffer__` must be `unsafe fn` + --> tests/ui/invalid_pymethods_buffer.rs:15:8 + | +15 | fn releasebuffer_must_be_unsafe(&self, _view: *mut pyo3::ffi::Py_buffer) {} + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^