pymethods: support gc protocol

This commit is contained in:
David Hewitt 2022-02-10 21:28:42 +00:00
parent 7851e865ae
commit 676295b8de
7 changed files with 157 additions and 121 deletions

View File

@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091) - `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)
- Add `PyAny::contains` method (`in` operator for `PyAny`). [#2115](https://github.com/PyO3/pyo3/pull/2115) - Add `PyAny::contains` method (`in` operator for `PyAny`). [#2115](https://github.com/PyO3/pyo3/pull/2115)
- Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133) - Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133)
- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159)
### Changed ### Changed

View File

@ -175,7 +175,8 @@ given signatures should be interpreted as follows:
#### Garbage Collector Integration #### Garbage Collector Integration
TODO; see [#1884](https://github.com/PyO3/pyo3/issues/1884) - `__traverse__(<self>, visit: pyo3::class::gc::PyVisit) -> Result<(), pyo3::class::gc::PyTraverseError>`
- `__clear__(<self>) -> ()`
### `#[pyproto]` traits ### `#[pyproto]` traits

View File

@ -757,7 +757,6 @@ impl<'a> PyClassImplsBuilder<'a> {
self.impl_into_py(), self.impl_into_py(),
self.impl_pyclassimpl(), self.impl_pyclassimpl(),
self.impl_freelist(), self.impl_freelist(),
self.impl_gc(),
] ]
.into_iter() .into_iter()
.collect() .collect()
@ -981,26 +980,6 @@ impl<'a> PyClassImplsBuilder<'a> {
Vec::new() Vec::new()
} }
} }
/// Enforce at compile time that PyGCProtocol is implemented
fn impl_gc(&self) -> TokenStream {
let cls = self.cls;
let attr = self.attr;
if attr.is_gc {
let closure_name = format!("__assertion_closure_{}", cls);
let closure_token = syn::Ident::new(&closure_name, Span::call_site());
quote! {
fn #closure_token() {
use _pyo3::class;
fn _assert_implements_protocol<'p, T: _pyo3::class::PyGCProtocol<'p>>() {}
_assert_implements_protocol::<#cls>();
}
}
} else {
quote! {}
}
}
} }
fn define_inventory_class(inventory_class_name: &syn::Ident) -> TokenStream { fn define_inventory_class(inventory_class_name: &syn::Ident) -> TokenStream {

View File

@ -36,14 +36,86 @@ enum PyMethodKind {
impl PyMethodKind { impl PyMethodKind {
fn from_name(name: &str) -> Self { fn from_name(name: &str) -> Self {
if let Some(slot_def) = pyproto(name) { match name {
PyMethodKind::Proto(PyMethodProtoKind::Slot(slot_def)) // Protocol implemented through slots
} else if name == "__call__" { "__getattr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETATTR__)),
PyMethodKind::Proto(PyMethodProtoKind::Call) "__str__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__STR__)),
} else if let Some(slot_fragment_def) = pyproto_fragment(name) { "__repr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPR__)),
PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(slot_fragment_def)) "__hash__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__HASH__)),
} else { "__richcmp__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RICHCMP__)),
PyMethodKind::Fn "__get__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GET__)),
"__iter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITER__)),
"__next__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEXT__)),
"__await__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AWAIT__)),
"__aiter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AITER__)),
"__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)),
"__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)),
"__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)),
"__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)),
"__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)),
"__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)),
"__abs__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ABS__)),
"__invert__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INVERT__)),
"__index__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INDEX__)),
"__int__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INT__)),
"__float__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__FLOAT__)),
"__bool__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__BOOL__)),
"__iadd__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IADD__)),
"__isub__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ISUB__)),
"__imul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMUL__)),
"__imatmul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMATMUL__)),
"__itruediv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITRUEDIV__)),
"__ifloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IFLOORDIV__)),
"__imod__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMOD__)),
"__ipow__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IPOW__)),
"__ilshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ILSHIFT__)),
"__irshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IRSHIFT__)),
"__iand__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IAND__)),
"__ixor__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IXOR__)),
"__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)),
"__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)),
"__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)),
// Protocols implemented through traits
"__setattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETATTR__)),
"__delattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELATTR__)),
"__set__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SET__)),
"__delete__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELETE__)),
"__setitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETITEM__)),
"__delitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELITEM__)),
"__add__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ADD__)),
"__radd__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RADD__)),
"__sub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SUB__)),
"__rsub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSUB__)),
"__mul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MUL__)),
"__rmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMUL__)),
"__matmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MATMUL__)),
"__rmatmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMATMUL__)),
"__floordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__FLOORDIV__)),
"__rfloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RFLOORDIV__)),
"__truediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__TRUEDIV__)),
"__rtruediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RTRUEDIV__)),
"__divmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DIVMOD__)),
"__rdivmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RDIVMOD__)),
"__mod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MOD__)),
"__rmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMOD__)),
"__lshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LSHIFT__)),
"__rlshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RLSHIFT__)),
"__rshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSHIFT__)),
"__rrshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RRSHIFT__)),
"__and__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__AND__)),
"__rand__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RAND__)),
"__xor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__XOR__)),
"__rxor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RXOR__)),
"__or__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__OR__)),
"__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)),
"__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)),
"__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)),
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
// Not a proto
_ => PyMethodKind::Fn,
} }
} }
} }
@ -51,6 +123,7 @@ impl PyMethodKind {
enum PyMethodProtoKind { enum PyMethodProtoKind {
Slot(&'static SlotDef), Slot(&'static SlotDef),
Call, Call,
Traverse,
SlotFragment(&'static SlotFragmentDef), SlotFragment(&'static SlotFragmentDef),
} }
@ -108,6 +181,9 @@ pub fn gen_py_method(
PyMethodProtoKind::Call => { PyMethodProtoKind::Call => {
GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?) GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?)
} }
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, method.spec)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => { PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?; let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?;
GeneratedPyMethod::SlotTraitImpl(method.method_name, proto) GeneratedPyMethod::SlotTraitImpl(method.method_name, proto)
@ -220,6 +296,36 @@ fn impl_call_slot(cls: &syn::Type, mut spec: FnSpec) -> Result<TokenStream> {
}}) }})
} }
fn impl_traverse_slot(cls: &syn::Type, spec: FnSpec) -> Result<TokenStream> {
let ident = spec.name;
Ok(quote! {{
pub unsafe extern "C" fn __wrap_(
slf: *mut _pyo3::ffi::PyObject,
visit: _pyo3::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int
{
let pool = _pyo3::GILPool::new();
let py = pool.python();
_pyo3::callback::abort_on_traverse_panic(::std::panic::catch_unwind(move || {
let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf);
let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py);
let borrow = slf.try_borrow();
if let ::std::result::Result::Ok(borrow) = borrow {
_pyo3::class::gc::unwrap_traverse_result(borrow.#ident(visit))
} else {
0
}
}))
}
_pyo3::ffi::PyType_Slot {
slot: _pyo3::ffi::Py_tp_traverse,
pfunc: __wrap_ as _pyo3::ffi::traverseproc as _
}
}})
}
fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec) -> TokenStream { fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec) -> TokenStream {
let name = &spec.name; let name = &spec.name;
let deprecations = &spec.deprecations; let deprecations = &spec.deprecations;
@ -567,49 +673,9 @@ const __RELEASEBUFFER__: SlotDef = SlotDef::new("Py_bf_releasebuffer", "releaseb
.arguments(&[Ty::PyBuffer]) .arguments(&[Ty::PyBuffer])
.ret_ty(Ty::Void) .ret_ty(Ty::Void)
.require_unsafe(); .require_unsafe();
const __CLEAR__: SlotDef = SlotDef::new("Py_tp_clear", "inquiry")
fn pyproto(method_name: &str) -> Option<&'static SlotDef> { .arguments(&[])
match method_name { .ret_ty(Ty::Int);
"__getattr__" => Some(&__GETATTR__),
"__str__" => Some(&__STR__),
"__repr__" => Some(&__REPR__),
"__hash__" => Some(&__HASH__),
"__richcmp__" => Some(&__RICHCMP__),
"__get__" => Some(&__GET__),
"__iter__" => Some(&__ITER__),
"__next__" => Some(&__NEXT__),
"__await__" => Some(&__AWAIT__),
"__aiter__" => Some(&__AITER__),
"__anext__" => Some(&__ANEXT__),
"__len__" => Some(&__LEN__),
"__contains__" => Some(&__CONTAINS__),
"__getitem__" => Some(&__GETITEM__),
"__pos__" => Some(&__POS__),
"__neg__" => Some(&__NEG__),
"__abs__" => Some(&__ABS__),
"__invert__" => Some(&__INVERT__),
"__index__" => Some(&__INDEX__),
"__int__" => Some(&__INT__),
"__float__" => Some(&__FLOAT__),
"__bool__" => Some(&__BOOL__),
"__iadd__" => Some(&__IADD__),
"__isub__" => Some(&__ISUB__),
"__imul__" => Some(&__IMUL__),
"__imatmul__" => Some(&__IMATMUL__),
"__itruediv__" => Some(&__ITRUEDIV__),
"__ifloordiv__" => Some(&__IFLOORDIV__),
"__imod__" => Some(&__IMOD__),
"__ipow__" => Some(&__IPOW__),
"__ilshift__" => Some(&__ILSHIFT__),
"__irshift__" => Some(&__IRSHIFT__),
"__iand__" => Some(&__IAND__),
"__ixor__" => Some(&__IXOR__),
"__ior__" => Some(&__IOR__),
"__getbuffer__" => Some(&__GETBUFFER__),
"__releasebuffer__" => Some(&__RELEASEBUFFER__),
_ => None,
}
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
enum Ty { enum Ty {
@ -1045,46 +1111,6 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object,
.extract_error_mode(ExtractErrorMode::NotImplemented) .extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object); .ret_ty(Ty::Object);
fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
match method_name {
"__setattr__" => Some(&__SETATTR__),
"__delattr__" => Some(&__DELATTR__),
"__set__" => Some(&__SET__),
"__delete__" => Some(&__DELETE__),
"__setitem__" => Some(&__SETITEM__),
"__delitem__" => Some(&__DELITEM__),
"__add__" => Some(&__ADD__),
"__radd__" => Some(&__RADD__),
"__sub__" => Some(&__SUB__),
"__rsub__" => Some(&__RSUB__),
"__mul__" => Some(&__MUL__),
"__rmul__" => Some(&__RMUL__),
"__matmul__" => Some(&__MATMUL__),
"__rmatmul__" => Some(&__RMATMUL__),
"__floordiv__" => Some(&__FLOORDIV__),
"__rfloordiv__" => Some(&__RFLOORDIV__),
"__truediv__" => Some(&__TRUEDIV__),
"__rtruediv__" => Some(&__RTRUEDIV__),
"__divmod__" => Some(&__DIVMOD__),
"__rdivmod__" => Some(&__RDIVMOD__),
"__mod__" => Some(&__MOD__),
"__rmod__" => Some(&__RMOD__),
"__lshift__" => Some(&__LSHIFT__),
"__rlshift__" => Some(&__RLSHIFT__),
"__rshift__" => Some(&__RSHIFT__),
"__rrshift__" => Some(&__RRSHIFT__),
"__and__" => Some(&__AND__),
"__rand__" => Some(&__RAND__),
"__xor__" => Some(&__XOR__),
"__rxor__" => Some(&__RXOR__),
"__or__" => Some(&__OR__),
"__ror__" => Some(&__ROR__),
"__pow__" => Some(&__POW__),
"__rpow__" => Some(&__RPOW__),
_ => None,
}
}
fn extract_proto_arguments( fn extract_proto_arguments(
cls: &syn::Type, cls: &syn::Type,
py: &syn::Ident, py: &syn::Ident,

View File

@ -266,3 +266,18 @@ where
R::ERR_VALUE R::ERR_VALUE
}) })
} }
/// Aborts if panic has occurred. Used inside `__traverse__` implementations, where panicking is not possible.
#[doc(hidden)]
#[inline]
pub fn abort_on_traverse_panic(
panic_result: Result<c_int, Box<dyn Any + Send + 'static>>,
) -> c_int {
match panic_result {
Ok(traverse_result) => traverse_result,
Err(_payload) => {
eprintln!("FATAL: panic inside __traverse__ handler; aborting.");
::std::process::abort()
}
}
}

View File

@ -83,4 +83,20 @@ impl<'p> PyVisit<'p> {
Err(PyTraverseError(r)) Err(PyTraverseError(r))
} }
} }
/// Creates the PyVisit from the arguments to tp_traverse
#[doc(hidden)]
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, _py: Python<'p>) -> Self {
Self { visit, arg, _py }
}
}
/// Unwraps the result of __traverse__ for tp_traverse
#[doc(hidden)]
#[inline]
pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int {
match result {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
}
} }

View File

@ -1,7 +1,5 @@
#![cfg(feature = "macros")] #![cfg(feature = "macros")]
#![cfg(feature = "pyproto")] // FIXME: #[pymethods] to support gc protocol
use pyo3::class::PyGCProtocol;
use pyo3::class::PyTraverseError; use pyo3::class::PyTraverseError;
use pyo3::class::PyVisit; use pyo3::class::PyVisit;
use pyo3::prelude::*; use pyo3::prelude::*;
@ -90,8 +88,8 @@ struct GcIntegration {
dropped: TestDropCall, dropped: TestDropCall,
} }
#[pyproto] #[pymethods]
impl PyGCProtocol for GcIntegration { impl GcIntegration {
fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> { fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
visit.call(&self.self_ref) visit.call(&self.self_ref)
} }
@ -133,8 +131,8 @@ fn gc_integration() {
#[pyclass(gc)] #[pyclass(gc)]
struct GcIntegration2 {} struct GcIntegration2 {}
#[pyproto] #[pymethods]
impl PyGCProtocol for GcIntegration2 { impl GcIntegration2 {
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> { fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
Ok(()) Ok(())
} }
@ -230,8 +228,8 @@ impl TraversableClass {
} }
} }
#[pyproto] #[pymethods]
impl PyGCProtocol for TraversableClass { impl TraversableClass {
fn __clear__(&mut self) {} fn __clear__(&mut self) {}
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> { fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
self.traversed.store(true, Ordering::Relaxed); self.traversed.store(true, Ordering::Relaxed);