From b70ee9a5ad6a8fc788c49a701725f9df1ed287ae Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sun, 21 Jun 2020 23:38:26 +0900 Subject: [PATCH] Use subclass correctly in tp_new --- CHANGELOG.md | 2 + examples/rustapi_module/src/subclassing.rs | 7 ++++ .../rustapi_module/tests/test_subclassing.py | 5 ++- pyo3-derive-backend/src/pymethod.rs | 7 ++-- src/freelist.rs | 15 ++++--- src/pycell.rs | 9 ++-- src/pyclass.rs | 40 ++++++++++-------- src/pyclass_init.rs | 21 ++++++++-- tests/test_class_new.rs | 41 +++++++++++++++++++ 9 files changed, 112 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ca6070c..e6756cb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Update `num-complex` optional dependendency from `0.2` to `0.3`. [#977](https://github.com/PyO3/pyo3/pull/977) - Update `num-bigint` optional dependendency from `0.2` to `0.3`. [#978](https://github.com/PyO3/pyo3/pull/978) - `#[pyproto]` is re-implemented without specialization. [#961](https://github.com/PyO3/pyo3/pull/961) +- `PyClassAlloc::alloc` is renamed to `PyClassAlloc::new`. [#990](https://github.com/PyO3/pyo3/pull/990) ### Removed - Remove `ManagedPyRef` (unused, and needs specialization) [#930](https://github.com/PyO3/pyo3/pull/930) @@ -32,6 +33,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Fixed - Fix passing explicit `None` to `Option` argument `#[pyfunction]` with a default value. [#936](https://github.com/PyO3/pyo3/pull/936) +- Fix `PyClass.__new__`'s behavior when inherited by a Python class. [#990](https://github.com/PyO3/pyo3/pull/990) ## [0.10.1] - 2020-05-14 ### Fixed diff --git a/examples/rustapi_module/src/subclassing.rs b/examples/rustapi_module/src/subclassing.rs index 61ead72f..20e7d431 100644 --- a/examples/rustapi_module/src/subclassing.rs +++ b/examples/rustapi_module/src/subclassing.rs @@ -13,6 +13,13 @@ impl Subclassable { } } +#[pyproto] +impl pyo3::PyObjectProtocol for Subclassable { + fn __str__(&self) -> PyResult<&'static str> { + Ok("Subclassable") + } +} + #[pymodule] fn subclassing(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; diff --git a/examples/rustapi_module/tests/test_subclassing.py b/examples/rustapi_module/tests/test_subclassing.py index ccd61aa7..18516201 100644 --- a/examples/rustapi_module/tests/test_subclassing.py +++ b/examples/rustapi_module/tests/test_subclassing.py @@ -6,10 +6,11 @@ PYPY = platform.python_implementation() == "PyPy" class SomeSubClass(Subclassable): - pass + def __str__(self): + return "Subclass" def test_subclassing(): if not PYPY: a = SomeSubClass() - _b = str(a) + repr(a) + assert str(a) == "Subclass" diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index a4ef8c49..fce1b702 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -193,7 +193,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { quote! { #[allow(unused_mut)] unsafe extern "C" fn __wrap( - _cls: *mut pyo3::ffi::PyTypeObject, + subcls: *mut pyo3::ffi::PyTypeObject, _args: *mut pyo3::ffi::PyObject, _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject { @@ -204,8 +204,9 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream { let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); - let _result = pyo3::derive_utils::IntoPyNewResult::into_pynew_result(#body); - let cell = pyo3::PyClassInitializer::from(_result?).create_cell(_py)?; + let _result = pyo3::derive_utils::IntoPyNewResult::into_pynew_result(#body)?; + let initializer = pyo3::PyClassInitializer::from(_result); + let cell = initializer.create_cell_from_subtype(_py, subcls)?; Ok(cell as *mut pyo3::ffi::PyObject) }) } diff --git a/src/freelist.rs b/src/freelist.rs index f7a6e4c6..6bd9e2a4 100644 --- a/src/freelist.rs +++ b/src/freelist.rs @@ -72,13 +72,16 @@ impl PyClassAlloc for T where T: PyTypeInfo + PyClassWithFreeList, { - unsafe fn alloc(py: Python) -> *mut Self::Layout { - if let Some(obj) = ::get_free_list().pop() { - ffi::PyObject_Init(obj, Self::type_object_raw(py) as *const _ as _); - obj as _ - } else { - crate::pyclass::default_alloc::(py) as _ + unsafe fn new(py: Python, subtype: *mut ffi::PyTypeObject) -> *mut Self::Layout { + let type_obj = Self::type_object_raw(py) as *const _ as _; + // if subtype is not equal to this type, we cannot use the freelist + if subtype == type_obj { + if let Some(obj) = ::get_free_list().pop() { + ffi::PyObject_Init(obj, subtype); + return obj as _; + } } + crate::pyclass::default_new::(py, subtype) as _ } unsafe fn dealloc(py: Python, self_: *mut Self::Layout) { diff --git a/src/pycell.rs b/src/pycell.rs index 15ef1035..2fa67288 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -334,13 +334,16 @@ impl PyCell { std::mem::swap(&mut *self.borrow_mut(), &mut *other.borrow_mut()) } - /// Allocates new PyCell without initilizing value. + /// Allocates a new PyCell given a type object `subtype`. Used by our `tp_new` implementation. /// Requires `T::BaseLayout: PyBorrowFlagLayout` to ensure `self` has a borrow flag. - pub(crate) unsafe fn internal_new(py: Python) -> PyResult<*mut Self> + pub(crate) unsafe fn internal_new( + py: Python, + subtype: *mut ffi::PyTypeObject, + ) -> PyResult<*mut Self> where T::BaseLayout: PyBorrowFlagLayout, { - let base = T::alloc(py); + let base = T::new(py, subtype); if base.is_null() { return Err(PyErr::fetch(py)); } diff --git a/src/pyclass.rs b/src/pyclass.rs index a92f7a3d..9b4cc184 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -11,27 +11,29 @@ use std::os::raw::c_void; use std::ptr; #[inline] -pub(crate) unsafe fn default_alloc(py: Python) -> *mut ffi::PyObject { - let type_obj = T::type_object_raw(py); +pub(crate) unsafe fn default_new( + py: Python, + subtype: *mut ffi::PyTypeObject, +) -> *mut ffi::PyObject { // if the class derives native types(e.g., PyDict), call special new if T::FLAGS & type_flags::EXTENDED != 0 && T::BaseLayout::IS_NATIVE_TYPE { let base_tp = T::BaseType::type_object_raw(py); if let Some(base_new) = base_tp.tp_new { - return base_new(type_obj as *const _ as _, ptr::null_mut(), ptr::null_mut()); + return base_new(subtype, ptr::null_mut(), ptr::null_mut()); } } - let alloc = type_obj.tp_alloc.unwrap_or(ffi::PyType_GenericAlloc); - alloc(type_obj as *const _ as _, 0) + let alloc = (*subtype).tp_alloc.unwrap_or(ffi::PyType_GenericAlloc); + alloc(subtype, 0) as _ } -/// This trait enables custom alloc/dealloc implementations for `T: PyClass`. +/// This trait enables custom `tp_new`/`tp_dealloc` implementations for `T: PyClass`. pub trait PyClassAlloc: PyTypeInfo + Sized { /// Allocate the actual field for `#[pyclass]`. /// /// # Safety /// This function must return a valid pointer to the Python heap. - unsafe fn alloc(py: Python) -> *mut Self::Layout { - default_alloc::(py) as _ + unsafe fn new(py: Python, subtype: *mut ffi::PyTypeObject) -> *mut Self::Layout { + default_new::(py, subtype) as _ } /// Deallocate `#[pyclass]` on the Python heap. @@ -52,6 +54,18 @@ pub trait PyClassAlloc: PyTypeInfo + Sized { } } +fn tp_dealloc() -> Option { + unsafe extern "C" fn dealloc(obj: *mut ffi::PyObject) + where + T: PyClassAlloc, + { + let pool = crate::GILPool::new(); + let py = pool.python(); + ::dealloc(py, (obj as *mut T::Layout) as _) + } + Some(dealloc::) +} + #[doc(hidden)] pub unsafe fn tp_free_fallback(obj: *mut ffi::PyObject) { let ty = ffi::Py_TYPE(obj); @@ -115,15 +129,7 @@ where }; // dealloc - unsafe extern "C" fn tp_dealloc_callback(obj: *mut ffi::PyObject) - where - T: PyClassAlloc, - { - let pool = crate::GILPool::new(); - let py = pool.python(); - ::dealloc(py, (obj as *mut T::Layout) as _) - } - type_object.tp_dealloc = Some(tp_dealloc_callback::); + type_object.tp_dealloc = tp_dealloc::(); // type size type_object.tp_basicsize = std::mem::size_of::() as ffi::Py_ssize_t; diff --git a/src/pyclass_init.rs b/src/pyclass_init.rs index 3be018f2..162690d9 100644 --- a/src/pyclass_init.rs +++ b/src/pyclass_init.rs @@ -114,14 +114,27 @@ impl PyClassInitializer { PyClassInitializer::new(subclass_value, self) } - // Create a new PyCell + initialize it - #[doc(hidden)] - pub unsafe fn create_cell(self, py: Python) -> PyResult<*mut PyCell> + // Create a new PyCell and initialize it. + pub(crate) unsafe fn create_cell(self, py: Python) -> PyResult<*mut PyCell> where T: PyClass, T::BaseLayout: PyBorrowFlagLayout, { - let cell = PyCell::internal_new(py)?; + self.create_cell_from_subtype(py, T::type_object_raw(py) as *const _ as _) + } + + /// Create a new PyCell and initialize it given a typeobject `subtype`. + /// Called by our `tp_new` generated by the `#[new]` attribute. + pub unsafe fn create_cell_from_subtype( + self, + py: Python, + subtype: *mut crate::ffi::PyTypeObject, + ) -> PyResult<*mut PyCell> + where + T: PyClass, + T::BaseLayout: PyBorrowFlagLayout, + { + let cell = PyCell::internal_new(py, subtype)?; self.init_class(&mut *cell); Ok(cell) } diff --git a/tests/test_class_new.rs b/tests/test_class_new.rs index c225608f..8d4d6068 100644 --- a/tests/test_class_new.rs +++ b/tests/test_class_new.rs @@ -79,3 +79,44 @@ fn new_with_two_args() { assert_eq!(obj_ref._data1, 10); assert_eq!(obj_ref._data2, 20); } + +#[pyclass(subclass)] +struct SuperClass { + #[pyo3(get)] + from_rust: bool, +} + +#[pymethods] +impl SuperClass { + #[new] + fn new() -> Self { + SuperClass { from_rust: true } + } +} + +/// Checks that `subclass.__new__` works correctly. +/// See https://github.com/PyO3/pyo3/issues/947 for the corresponding bug. +#[test] +fn subclass_new() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let super_cls = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +class Class(SuperClass): + def __new__(cls): + return super().__new__(cls) # This should return an instance of Class + + @property + def from_rust(self): + return False +c = Class() +assert c.from_rust is False +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + globals.set_item("SuperClass", super_cls).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); +}