diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e466334..507d3d4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091) - Add `PyAny::contains` method (`in` operator for `PyAny`). [#2115](https://github.com/PyO3/pyo3/pull/2115) - Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133) +- Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159) ### Changed diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index d0d570de..61a452ce 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -175,7 +175,8 @@ given signatures should be interpreted as follows: #### Garbage Collector Integration -TODO; see [#1884](https://github.com/PyO3/pyo3/issues/1884) + - `__traverse__(, visit: pyo3::class::gc::PyVisit) -> Result<(), pyo3::class::gc::PyTraverseError>` + - `__clear__() -> ()` ### `#[pyproto]` traits diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 74f8b945..9d909b3d 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -757,7 +757,6 @@ impl<'a> PyClassImplsBuilder<'a> { self.impl_into_py(), self.impl_pyclassimpl(), self.impl_freelist(), - self.impl_gc(), ] .into_iter() .collect() @@ -981,26 +980,6 @@ impl<'a> PyClassImplsBuilder<'a> { Vec::new() } } - - /// Enforce at compile time that PyGCProtocol is implemented - fn impl_gc(&self) -> TokenStream { - let cls = self.cls; - let attr = self.attr; - if attr.is_gc { - let closure_name = format!("__assertion_closure_{}", cls); - let closure_token = syn::Ident::new(&closure_name, Span::call_site()); - quote! { - fn #closure_token() { - use _pyo3::class; - - fn _assert_implements_protocol<'p, T: _pyo3::class::PyGCProtocol<'p>>() {} - _assert_implements_protocol::<#cls>(); - } - } - } else { - quote! {} - } - } } fn define_inventory_class(inventory_class_name: &syn::Ident) -> TokenStream { diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 4ec87956..297f206b 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -36,14 +36,86 @@ enum PyMethodKind { impl PyMethodKind { fn from_name(name: &str) -> Self { - if let Some(slot_def) = pyproto(name) { - PyMethodKind::Proto(PyMethodProtoKind::Slot(slot_def)) - } else if name == "__call__" { - PyMethodKind::Proto(PyMethodProtoKind::Call) - } else if let Some(slot_fragment_def) = pyproto_fragment(name) { - PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(slot_fragment_def)) - } else { - PyMethodKind::Fn + match name { + // Protocol implemented through slots + "__getattr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETATTR__)), + "__str__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__STR__)), + "__repr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPR__)), + "__hash__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__HASH__)), + "__richcmp__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RICHCMP__)), + "__get__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GET__)), + "__iter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITER__)), + "__next__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEXT__)), + "__await__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AWAIT__)), + "__aiter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AITER__)), + "__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)), + "__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)), + "__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)), + "__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)), + "__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)), + "__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)), + "__abs__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ABS__)), + "__invert__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INVERT__)), + "__index__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INDEX__)), + "__int__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INT__)), + "__float__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__FLOAT__)), + "__bool__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__BOOL__)), + "__iadd__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IADD__)), + "__isub__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ISUB__)), + "__imul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMUL__)), + "__imatmul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMATMUL__)), + "__itruediv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITRUEDIV__)), + "__ifloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IFLOORDIV__)), + "__imod__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMOD__)), + "__ipow__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IPOW__)), + "__ilshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ILSHIFT__)), + "__irshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IRSHIFT__)), + "__iand__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IAND__)), + "__ixor__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IXOR__)), + "__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)), + "__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)), + "__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)), + "__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)), + // Protocols implemented through traits + "__setattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETATTR__)), + "__delattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELATTR__)), + "__set__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SET__)), + "__delete__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELETE__)), + "__setitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETITEM__)), + "__delitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELITEM__)), + "__add__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ADD__)), + "__radd__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RADD__)), + "__sub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SUB__)), + "__rsub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSUB__)), + "__mul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MUL__)), + "__rmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMUL__)), + "__matmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MATMUL__)), + "__rmatmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMATMUL__)), + "__floordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__FLOORDIV__)), + "__rfloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RFLOORDIV__)), + "__truediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__TRUEDIV__)), + "__rtruediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RTRUEDIV__)), + "__divmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DIVMOD__)), + "__rdivmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RDIVMOD__)), + "__mod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MOD__)), + "__rmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMOD__)), + "__lshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LSHIFT__)), + "__rlshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RLSHIFT__)), + "__rshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSHIFT__)), + "__rrshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RRSHIFT__)), + "__and__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__AND__)), + "__rand__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RAND__)), + "__xor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__XOR__)), + "__rxor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RXOR__)), + "__or__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__OR__)), + "__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)), + "__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)), + "__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)), + // Some tricky protocols which don't fit the pattern of the rest + "__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call), + "__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse), + // Not a proto + _ => PyMethodKind::Fn, } } } @@ -51,6 +123,7 @@ impl PyMethodKind { enum PyMethodProtoKind { Slot(&'static SlotDef), Call, + Traverse, SlotFragment(&'static SlotFragmentDef), } @@ -108,6 +181,9 @@ pub fn gen_py_method( PyMethodProtoKind::Call => { GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?) } + PyMethodProtoKind::Traverse => { + GeneratedPyMethod::Proto(impl_traverse_slot(cls, method.spec)?) + } PyMethodProtoKind::SlotFragment(slot_fragment_def) => { let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?; GeneratedPyMethod::SlotTraitImpl(method.method_name, proto) @@ -220,6 +296,36 @@ fn impl_call_slot(cls: &syn::Type, mut spec: FnSpec) -> Result { }}) } +fn impl_traverse_slot(cls: &syn::Type, spec: FnSpec) -> Result { + let ident = spec.name; + Ok(quote! {{ + pub unsafe extern "C" fn __wrap_( + slf: *mut _pyo3::ffi::PyObject, + visit: _pyo3::ffi::visitproc, + arg: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int + { + let pool = _pyo3::GILPool::new(); + let py = pool.python(); + _pyo3::callback::abort_on_traverse_panic(::std::panic::catch_unwind(move || { + let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf); + + let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py); + let borrow = slf.try_borrow(); + if let ::std::result::Result::Ok(borrow) = borrow { + _pyo3::class::gc::unwrap_traverse_result(borrow.#ident(visit)) + } else { + 0 + } + })) + } + _pyo3::ffi::PyType_Slot { + slot: _pyo3::ffi::Py_tp_traverse, + pfunc: __wrap_ as _pyo3::ffi::traverseproc as _ + } + }}) +} + fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec) -> TokenStream { let name = &spec.name; let deprecations = &spec.deprecations; @@ -567,49 +673,9 @@ const __RELEASEBUFFER__: SlotDef = SlotDef::new("Py_bf_releasebuffer", "releaseb .arguments(&[Ty::PyBuffer]) .ret_ty(Ty::Void) .require_unsafe(); - -fn pyproto(method_name: &str) -> Option<&'static SlotDef> { - match method_name { - "__getattr__" => Some(&__GETATTR__), - "__str__" => Some(&__STR__), - "__repr__" => Some(&__REPR__), - "__hash__" => Some(&__HASH__), - "__richcmp__" => Some(&__RICHCMP__), - "__get__" => Some(&__GET__), - "__iter__" => Some(&__ITER__), - "__next__" => Some(&__NEXT__), - "__await__" => Some(&__AWAIT__), - "__aiter__" => Some(&__AITER__), - "__anext__" => Some(&__ANEXT__), - "__len__" => Some(&__LEN__), - "__contains__" => Some(&__CONTAINS__), - "__getitem__" => Some(&__GETITEM__), - "__pos__" => Some(&__POS__), - "__neg__" => Some(&__NEG__), - "__abs__" => Some(&__ABS__), - "__invert__" => Some(&__INVERT__), - "__index__" => Some(&__INDEX__), - "__int__" => Some(&__INT__), - "__float__" => Some(&__FLOAT__), - "__bool__" => Some(&__BOOL__), - "__iadd__" => Some(&__IADD__), - "__isub__" => Some(&__ISUB__), - "__imul__" => Some(&__IMUL__), - "__imatmul__" => Some(&__IMATMUL__), - "__itruediv__" => Some(&__ITRUEDIV__), - "__ifloordiv__" => Some(&__IFLOORDIV__), - "__imod__" => Some(&__IMOD__), - "__ipow__" => Some(&__IPOW__), - "__ilshift__" => Some(&__ILSHIFT__), - "__irshift__" => Some(&__IRSHIFT__), - "__iand__" => Some(&__IAND__), - "__ixor__" => Some(&__IXOR__), - "__ior__" => Some(&__IOR__), - "__getbuffer__" => Some(&__GETBUFFER__), - "__releasebuffer__" => Some(&__RELEASEBUFFER__), - _ => None, - } -} +const __CLEAR__: SlotDef = SlotDef::new("Py_tp_clear", "inquiry") + .arguments(&[]) + .ret_ty(Ty::Int); #[derive(Clone, Copy)] enum Ty { @@ -1045,46 +1111,6 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object, .extract_error_mode(ExtractErrorMode::NotImplemented) .ret_ty(Ty::Object); -fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> { - match method_name { - "__setattr__" => Some(&__SETATTR__), - "__delattr__" => Some(&__DELATTR__), - "__set__" => Some(&__SET__), - "__delete__" => Some(&__DELETE__), - "__setitem__" => Some(&__SETITEM__), - "__delitem__" => Some(&__DELITEM__), - "__add__" => Some(&__ADD__), - "__radd__" => Some(&__RADD__), - "__sub__" => Some(&__SUB__), - "__rsub__" => Some(&__RSUB__), - "__mul__" => Some(&__MUL__), - "__rmul__" => Some(&__RMUL__), - "__matmul__" => Some(&__MATMUL__), - "__rmatmul__" => Some(&__RMATMUL__), - "__floordiv__" => Some(&__FLOORDIV__), - "__rfloordiv__" => Some(&__RFLOORDIV__), - "__truediv__" => Some(&__TRUEDIV__), - "__rtruediv__" => Some(&__RTRUEDIV__), - "__divmod__" => Some(&__DIVMOD__), - "__rdivmod__" => Some(&__RDIVMOD__), - "__mod__" => Some(&__MOD__), - "__rmod__" => Some(&__RMOD__), - "__lshift__" => Some(&__LSHIFT__), - "__rlshift__" => Some(&__RLSHIFT__), - "__rshift__" => Some(&__RSHIFT__), - "__rrshift__" => Some(&__RRSHIFT__), - "__and__" => Some(&__AND__), - "__rand__" => Some(&__RAND__), - "__xor__" => Some(&__XOR__), - "__rxor__" => Some(&__RXOR__), - "__or__" => Some(&__OR__), - "__ror__" => Some(&__ROR__), - "__pow__" => Some(&__POW__), - "__rpow__" => Some(&__RPOW__), - _ => None, - } -} - fn extract_proto_arguments( cls: &syn::Type, py: &syn::Ident, diff --git a/src/callback.rs b/src/callback.rs index 4b54de81..e32def03 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -266,3 +266,18 @@ where R::ERR_VALUE }) } + +/// Aborts if panic has occurred. Used inside `__traverse__` implementations, where panicking is not possible. +#[doc(hidden)] +#[inline] +pub fn abort_on_traverse_panic( + panic_result: Result>, +) -> c_int { + match panic_result { + Ok(traverse_result) => traverse_result, + Err(_payload) => { + eprintln!("FATAL: panic inside __traverse__ handler; aborting."); + ::std::process::abort() + } + } +} diff --git a/src/class/gc.rs b/src/class/gc.rs index 2641a2b3..81e761fe 100644 --- a/src/class/gc.rs +++ b/src/class/gc.rs @@ -83,4 +83,20 @@ impl<'p> PyVisit<'p> { Err(PyTraverseError(r)) } } + + /// Creates the PyVisit from the arguments to tp_traverse + #[doc(hidden)] + pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, _py: Python<'p>) -> Self { + Self { visit, arg, _py } + } +} + +/// Unwraps the result of __traverse__ for tp_traverse +#[doc(hidden)] +#[inline] +pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int { + match result { + Ok(()) => 0, + Err(PyTraverseError(value)) => value, + } } diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 92854473..d137d04b 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -1,7 +1,5 @@ #![cfg(feature = "macros")] -#![cfg(feature = "pyproto")] // FIXME: #[pymethods] to support gc protocol -use pyo3::class::PyGCProtocol; use pyo3::class::PyTraverseError; use pyo3::class::PyVisit; use pyo3::prelude::*; @@ -90,8 +88,8 @@ struct GcIntegration { dropped: TestDropCall, } -#[pyproto] -impl PyGCProtocol for GcIntegration { +#[pymethods] +impl GcIntegration { fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> { visit.call(&self.self_ref) } @@ -133,8 +131,8 @@ fn gc_integration() { #[pyclass(gc)] struct GcIntegration2 {} -#[pyproto] -impl PyGCProtocol for GcIntegration2 { +#[pymethods] +impl GcIntegration2 { fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> { Ok(()) } @@ -230,8 +228,8 @@ impl TraversableClass { } } -#[pyproto] -impl PyGCProtocol for TraversableClass { +#[pymethods] +impl TraversableClass { fn __clear__(&mut self) {} fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> { self.traversed.store(true, Ordering::Relaxed);