Allow `#[new]` to return existing instances

fixes #2384
This commit is contained in:
Alex Gaynor 2023-07-02 17:26:31 -04:00
parent 1a0c9bec61
commit 0b78bb851e
Failed to extract signature
4 changed files with 88 additions and 7 deletions

View File

@ -114,6 +114,9 @@ impl Nonzero {
} }
``` ```
If you want to return an existing object (for example, because your `new`
method caches the values it returns), `new` can return `pyo3::Py<Self>`.
As you can see, the Rust method name is not important here; this way you can As you can see, the Rust method name is not important here; this way you can
still, use `new()` for a Rust-level constructor. still, use `new()` for a Rust-level constructor.

View File

@ -0,0 +1 @@
`#[new]` methods may now return `Py<Self>` in order to return existing instances

View File

@ -1,7 +1,7 @@
//! Contains initialization utilities for `#[pyclass]`. //! Contains initialization utilities for `#[pyclass]`.
use crate::callback::IntoPyCallbackOutput; use crate::callback::IntoPyCallbackOutput;
use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef}; use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef};
use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python}; use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python};
use crate::{ use crate::{
ffi::PyTypeObject, ffi::PyTypeObject,
pycell::{ pycell::{
@ -134,9 +134,14 @@ impl<T: PyTypeInfo> PyObjectInit<T> for PyNativeTypeInitializer<T> {
/// ); /// );
/// }); /// });
/// ``` /// ```
pub struct PyClassInitializer<T: PyClass> { pub struct PyClassInitializer<T: PyClass>(PyClassInitializerImpl<T>);
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer, enum PyClassInitializerImpl<T: PyClass> {
Existing(Py<T>),
New {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
},
} }
impl<T: PyClass> PyClassInitializer<T> { impl<T: PyClass> PyClassInitializer<T> {
@ -144,7 +149,7 @@ impl<T: PyClass> PyClassInitializer<T> {
/// ///
/// It is recommended to use `add_subclass` instead of this method for most usage. /// It is recommended to use `add_subclass` instead of this method for most usage.
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self { pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
Self { init, super_init } Self(PyClassInitializerImpl::New { init, super_init })
} }
/// Constructs a new initializer from an initializer for the base class. /// Constructs a new initializer from an initializer for the base class.
@ -242,13 +247,18 @@ impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
contents: MaybeUninit<PyCellContents<T>>, contents: MaybeUninit<PyCellContents<T>>,
} }
let obj = self.super_init.into_new_object(py, subtype)?; let (init, super_init) = match self.0 {
PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()),
PyClassInitializerImpl::New { init, super_init } => (init, super_init),
};
let obj = super_init.into_new_object(py, subtype)?;
let cell: *mut PartiallyInitializedPyCell<T> = obj as _; let cell: *mut PartiallyInitializedPyCell<T> = obj as _;
std::ptr::write( std::ptr::write(
(*cell).contents.as_mut_ptr(), (*cell).contents.as_mut_ptr(),
PyCellContents { PyCellContents {
value: ManuallyDrop::new(UnsafeCell::new(self.init)), value: ManuallyDrop::new(UnsafeCell::new(init)),
borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(), borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
thread_checker: T::ThreadChecker::new(), thread_checker: T::ThreadChecker::new(),
dict: T::Dict::INIT, dict: T::Dict::INIT,
@ -284,6 +294,13 @@ where
} }
} }
impl<T: PyClass> From<Py<T>> for PyClassInitializer<T> {
#[inline]
fn from(value: Py<T>) -> PyClassInitializer<T> {
PyClassInitializer(PyClassInitializerImpl::Existing(value))
}
}
// Implementation used by proc macros to allow anything convertible to PyClassInitializer<T> to be // Implementation used by proc macros to allow anything convertible to PyClassInitializer<T> to be
// the return value of pyclass #[new] method (optionally wrapped in `Result<U, E>`). // the return value of pyclass #[new] method (optionally wrapped in `Result<U, E>`).
impl<T, U> IntoPyCallbackOutput<PyClassInitializer<T>> for U impl<T, U> IntoPyCallbackOutput<PyClassInitializer<T>> for U

View File

@ -2,6 +2,7 @@
use pyo3::exceptions::PyValueError; use pyo3::exceptions::PyValueError;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
#[pyclass] #[pyclass]
@ -204,3 +205,62 @@ fn new_with_custom_error() {
assert_eq!(err.to_string(), "ValueError: custom error"); assert_eq!(err.to_string(), "ValueError: custom error");
}); });
} }
#[pyclass]
struct NewExisting {
#[pyo3(get)]
num: usize,
}
#[pymethods]
impl NewExisting {
#[new]
fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py<NewExisting> {
static PRE_BUILT: GILOnceCell<[pyo3::Py<NewExisting>; 2]> = GILOnceCell::new();
let existing = PRE_BUILT.get_or_init(py, || {
[
pyo3::PyCell::new(py, NewExisting { num: 0 })
.unwrap()
.into(),
pyo3::PyCell::new(py, NewExisting { num: 1 })
.unwrap()
.into(),
]
});
if val < existing.len() {
return existing[val].clone_ref(py);
}
pyo3::PyCell::new(py, NewExisting { num: val })
.unwrap()
.into()
}
}
#[test]
fn test_new_existing() {
Python::with_gil(|py| {
let typeobj = py.get_type::<NewExisting>();
let obj1 = typeobj.call1((0,)).unwrap();
let obj2 = typeobj.call1((0,)).unwrap();
let obj3 = typeobj.call1((1,)).unwrap();
let obj4 = typeobj.call1((1,)).unwrap();
let obj5 = typeobj.call1((2,)).unwrap();
let obj6 = typeobj.call1((2,)).unwrap();
assert!(obj1.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj2.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj3.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj4.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj5.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
assert!(obj6.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
assert!(obj1.is(obj2));
assert!(obj3.is(obj4));
assert!(!obj1.is(obj3));
assert!(!obj1.is(obj5));
assert!(!obj5.is(obj6));
});
}