Enable setting the module name of a class
This is relevant for pickling objects.
This commit is contained in:
parent
c4c75bbf81
commit
299d325375
|
@ -256,7 +256,7 @@ where
|
|||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
initialize_type::<Self>(py).unwrap_or_else(|_| {
|
||||
initialize_type::<Self>(py, None).unwrap_or_else(|_| {
|
||||
panic!("An error occurred while initializing class {}", Self::NAME)
|
||||
});
|
||||
}
|
||||
|
@ -290,21 +290,15 @@ pub trait PyTypeCreate: PyObjectAlloc + PyTypeObject + Sized {
|
|||
impl<T> PyTypeCreate for T where T: PyObjectAlloc + PyTypeObject + Sized {}
|
||||
|
||||
/// Register new type in python object system.
|
||||
///
|
||||
/// Currently, module_name is always None, so it defaults to pyo3_extension
|
||||
#[cfg(not(Py_LIMITED_API))]
|
||||
pub fn initialize_type<T>(py: Python) -> PyResult<*mut ffi::PyTypeObject>
|
||||
pub fn initialize_type<T>(py: Python, module_name: Option<&str>) -> PyResult<*mut ffi::PyTypeObject>
|
||||
where
|
||||
T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol,
|
||||
{
|
||||
let type_name = CString::new(T::NAME).expect("class name must not contain NUL byte");
|
||||
|
||||
let type_object: &mut ffi::PyTypeObject = unsafe { T::type_object() };
|
||||
let base_type_object: &mut ffi::PyTypeObject =
|
||||
unsafe { <T::BaseType as PyTypeInfo>::type_object() };
|
||||
|
||||
type_object.tp_name = type_name.into_raw();
|
||||
|
||||
// PyPy will segfault if passed only a nul terminator as `tp_doc`.
|
||||
// ptr::null() is OK though.
|
||||
if T::DESCRIPTION == "\0" {
|
||||
|
@ -315,6 +309,13 @@ where
|
|||
|
||||
type_object.tp_base = base_type_object;
|
||||
|
||||
let name = match module_name {
|
||||
Some(module_name) => format!("{}.{}", module_name, T::NAME),
|
||||
None => T::NAME.to_string(),
|
||||
};
|
||||
let name = CString::new(name).expect("Module name/type name must not contain NUL byte");
|
||||
type_object.tp_name = name.into_raw();
|
||||
|
||||
// dealloc
|
||||
type_object.tp_dealloc = Some(tp_dealloc_callback::<T>);
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use pyo3::prelude::*;
|
||||
use pyo3::type_object::initialize_type;
|
||||
|
||||
#[macro_use]
|
||||
mod common;
|
||||
|
@ -71,4 +72,9 @@ fn empty_class_in_module() {
|
|||
// We currently have no way of determining a canonical module, so builtins is better
|
||||
// than using whatever calls init first.
|
||||
assert_eq!(module, "builtins");
|
||||
|
||||
// The module name can also be set manually by calling `initialize_type`.
|
||||
initialize_type::<EmptyClassInModule>(py, Some("test_module.nested")).unwrap();
|
||||
let module: String = ty.getattr("__module__").unwrap().extract().unwrap();
|
||||
assert_eq!(module, "test_module.nested");
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use pyo3::prelude::*;
|
||||
use pyo3::type_object::initialize_type;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::types::{PyDict, PyTuple};
|
||||
use pyo3::wrap_pyfunction;
|
||||
use std::isize;
|
||||
|
||||
|
@ -117,3 +118,55 @@ fn pytuple_pyclass_iter() {
|
|||
py_assert!(py, tup, "type(tup[0]).__name__ == type(tup[0]).__name__");
|
||||
py_assert!(py, tup, "tup[0] != tup[1]");
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
struct PickleSupport {}
|
||||
|
||||
#[pymethods]
|
||||
impl PickleSupport {
|
||||
#[new]
|
||||
fn new(obj: &PyRawObject) {
|
||||
obj.init({ PickleSupport {} });
|
||||
}
|
||||
|
||||
pub fn __reduce__(slf: PyRef<Self>) -> PyResult<(PyObject, Py<PyTuple>, PyObject)> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let cls = slf.to_object(py).getattr(py, "__class__")?;
|
||||
let dict = slf.to_object(py).getattr(py, "__dict__")?;
|
||||
Ok((cls, PyTuple::empty(py), dict))
|
||||
}
|
||||
}
|
||||
|
||||
fn add_module(py: Python, module: &PyModule) -> PyResult<()> {
|
||||
py.import("sys")?
|
||||
.dict()
|
||||
.get_item("modules")
|
||||
.unwrap()
|
||||
.downcast_mut::<PyDict>()?
|
||||
.set_item(module.name()?, module)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pickle() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let module = PyModule::new(py, "test_module").unwrap();
|
||||
module.add_class::<PickleSupport>().unwrap();
|
||||
add_module(py, module).unwrap();
|
||||
initialize_type::<PickleSupport>(py, Some("test_module")).unwrap();
|
||||
let inst = PyRef::new(py, PickleSupport {}).unwrap();
|
||||
py_run!(
|
||||
py,
|
||||
inst,
|
||||
r#"
|
||||
inst.a = 1
|
||||
assert inst.__dict__ == {'a': 1}
|
||||
|
||||
import pickle
|
||||
inst2 = pickle.loads(pickle.dumps(inst))
|
||||
|
||||
assert inst2.__dict__ == {'a': 1}
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue