Enable setting the module name of a class

This is relevant for pickling objects.
This commit is contained in:
Alexander Niederbühl 2019-03-19 21:45:54 +01:00
parent c4c75bbf81
commit 299d325375
3 changed files with 69 additions and 9 deletions

View file

@ -256,7 +256,7 @@ where
let gil = Python::acquire_gil();
let py = gil.python();
initialize_type::<Self>(py).unwrap_or_else(|_| {
initialize_type::<Self>(py, None).unwrap_or_else(|_| {
panic!("An error occurred while initializing class {}", Self::NAME)
});
}
@ -290,21 +290,15 @@ pub trait PyTypeCreate: PyObjectAlloc + PyTypeObject + Sized {
impl<T> PyTypeCreate for T where T: PyObjectAlloc + PyTypeObject + Sized {}
/// Register new type in python object system.
///
/// Currently, module_name is always None, so it defaults to pyo3_extension
#[cfg(not(Py_LIMITED_API))]
pub fn initialize_type<T>(py: Python) -> PyResult<*mut ffi::PyTypeObject>
pub fn initialize_type<T>(py: Python, module_name: Option<&str>) -> PyResult<*mut ffi::PyTypeObject>
where
T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol,
{
let type_name = CString::new(T::NAME).expect("class name must not contain NUL byte");
let type_object: &mut ffi::PyTypeObject = unsafe { T::type_object() };
let base_type_object: &mut ffi::PyTypeObject =
unsafe { <T::BaseType as PyTypeInfo>::type_object() };
type_object.tp_name = type_name.into_raw();
// PyPy will segfault if passed only a nul terminator as `tp_doc`.
// ptr::null() is OK though.
if T::DESCRIPTION == "\0" {
@ -315,6 +309,13 @@ where
type_object.tp_base = base_type_object;
let name = match module_name {
Some(module_name) => format!("{}.{}", module_name, T::NAME),
None => T::NAME.to_string(),
};
let name = CString::new(name).expect("Module name/type name must not contain NUL byte");
type_object.tp_name = name.into_raw();
// dealloc
type_object.tp_dealloc = Some(tp_dealloc_callback::<T>);

View file

@ -1,4 +1,5 @@
use pyo3::prelude::*;
use pyo3::type_object::initialize_type;
#[macro_use]
mod common;
@ -71,4 +72,9 @@ fn empty_class_in_module() {
// We currently have no way of determining a canonical module, so builtins is better
// than using whatever calls init first.
assert_eq!(module, "builtins");
// The module name can also be set manually by calling `initialize_type`.
initialize_type::<EmptyClassInModule>(py, Some("test_module.nested")).unwrap();
let module: String = ty.getattr("__module__").unwrap().extract().unwrap();
assert_eq!(module, "test_module.nested");
}

View file

@ -1,6 +1,7 @@
use pyo3::prelude::*;
use pyo3::type_object::initialize_type;
use pyo3::types::IntoPyDict;
use pyo3::types::PyTuple;
use pyo3::types::{PyDict, PyTuple};
use pyo3::wrap_pyfunction;
use std::isize;
@ -117,3 +118,55 @@ fn pytuple_pyclass_iter() {
py_assert!(py, tup, "type(tup[0]).__name__ == type(tup[0]).__name__");
py_assert!(py, tup, "tup[0] != tup[1]");
}
#[pyclass(dict)]
struct PickleSupport {}
#[pymethods]
impl PickleSupport {
#[new]
fn new(obj: &PyRawObject) {
obj.init({ PickleSupport {} });
}
pub fn __reduce__(slf: PyRef<Self>) -> PyResult<(PyObject, Py<PyTuple>, PyObject)> {
let gil = Python::acquire_gil();
let py = gil.python();
let cls = slf.to_object(py).getattr(py, "__class__")?;
let dict = slf.to_object(py).getattr(py, "__dict__")?;
Ok((cls, PyTuple::empty(py), dict))
}
}
fn add_module(py: Python, module: &PyModule) -> PyResult<()> {
py.import("sys")?
.dict()
.get_item("modules")
.unwrap()
.downcast_mut::<PyDict>()?
.set_item(module.name()?, module)
}
#[test]
fn test_pickle() {
let gil = Python::acquire_gil();
let py = gil.python();
let module = PyModule::new(py, "test_module").unwrap();
module.add_class::<PickleSupport>().unwrap();
add_module(py, module).unwrap();
initialize_type::<PickleSupport>(py, Some("test_module")).unwrap();
let inst = PyRef::new(py, PickleSupport {}).unwrap();
py_run!(
py,
inst,
r#"
inst.a = 1
assert inst.__dict__ == {'a': 1}
import pickle
inst2 = pickle.loads(pickle.dumps(inst))
assert inst2.__dict__ == {'a': 1}
"#
);
}