pyo3/tests/test_inheritance.rs
2023-07-24 22:14:55 +01:00

365 lines
8.5 KiB
Rust

#![cfg(feature = "macros")]
use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::types::IntoPyDict;
mod common;
#[pyclass(subclass)]
struct BaseClass {
#[pyo3(get)]
val1: usize,
}
#[pyclass(subclass)]
struct SubclassAble {}
#[test]
fn subclass() {
Python::with_gil(|py| {
let d = [("SubclassAble", py.get_type::<SubclassAble>())].into_py_dict(py);
py.run(
"class A(SubclassAble): pass\nassert issubclass(A, SubclassAble)",
None,
Some(d),
)
.map_err(|e| e.display(py))
.unwrap();
});
}
#[pymethods]
impl BaseClass {
#[new]
fn new() -> Self {
BaseClass { val1: 10 }
}
fn base_method(&self, x: usize) -> usize {
x * self.val1
}
fn base_set(&mut self, fn_: &pyo3::PyAny) -> PyResult<()> {
let value: usize = fn_.call0()?.extract()?;
self.val1 = value;
Ok(())
}
}
#[pyclass(extends=BaseClass)]
struct SubClass {
#[pyo3(get)]
val2: usize,
}
#[pymethods]
impl SubClass {
#[new]
fn new() -> (Self, BaseClass) {
(SubClass { val2: 5 }, BaseClass { val1: 10 })
}
fn sub_method(&self, x: usize) -> usize {
x * self.val2
}
fn sub_set_and_ret(&mut self, x: usize) -> usize {
self.val2 = x;
x
}
}
#[test]
fn inheritance_with_new_methods() {
Python::with_gil(|py| {
let typeobj = py.get_type::<SubClass>();
let inst = typeobj.call((), None).unwrap();
py_run!(py, inst, "assert inst.val1 == 10; assert inst.val2 == 5");
});
}
#[test]
fn call_base_and_sub_methods() {
Python::with_gil(|py| {
let obj = PyCell::new(py, SubClass::new()).unwrap();
py_run!(
py,
obj,
r#"
assert obj.base_method(10) == 100
assert obj.sub_method(10) == 50
"#
);
});
}
#[test]
fn mutation_fails() {
Python::with_gil(|py| {
let obj = PyCell::new(py, SubClass::new()).unwrap();
let global = Some([("obj", obj)].into_py_dict(py));
let e = py
.run("obj.base_set(lambda: obj.sub_set_and_ret(1))", global, None)
.unwrap_err();
assert_eq!(&e.to_string(), "RuntimeError: Already borrowed");
});
}
#[test]
fn is_subclass_and_is_instance() {
Python::with_gil(|py| {
let sub_ty = py.get_type::<SubClass>();
let base_ty = py.get_type::<BaseClass>();
assert!(sub_ty.is_subclass_of::<BaseClass>().unwrap());
assert!(sub_ty.is_subclass(base_ty).unwrap());
let obj = PyCell::new(py, SubClass::new()).unwrap();
assert!(obj.is_instance_of::<SubClass>());
assert!(obj.is_instance_of::<BaseClass>());
assert!(obj.is_instance(sub_ty).unwrap());
assert!(obj.is_instance(base_ty).unwrap());
});
}
#[pyclass(subclass)]
struct BaseClassWithResult {
_val: usize,
}
#[pymethods]
impl BaseClassWithResult {
#[new]
fn new(value: isize) -> PyResult<Self> {
Ok(Self {
_val: std::convert::TryFrom::try_from(value)?,
})
}
}
#[pyclass(extends=BaseClassWithResult)]
struct SubClass2 {}
#[pymethods]
impl SubClass2 {
#[new]
fn new(value: isize) -> PyResult<(Self, BaseClassWithResult)> {
let base = BaseClassWithResult::new(value)?;
Ok((Self {}, base))
}
}
#[test]
fn handle_result_in_new() {
Python::with_gil(|py| {
let subclass = py.get_type::<SubClass2>();
py_run!(
py,
subclass,
r#"
try:
subclass(-10)
assert Fals
except ValueError as e:
pass
except Exception as e:
raise e
"#
);
});
}
// Subclassing builtin types is not allowed in the LIMITED API.
#[cfg(not(Py_LIMITED_API))]
mod inheriting_native_type {
use super::*;
use pyo3::exceptions::PyException;
use pyo3::types::{IntoPyDict, PyDict};
#[cfg(not(PyPy))]
#[test]
fn inherit_set() {
use pyo3::types::PySet;
#[cfg(not(PyPy))]
#[pyclass(extends=PySet)]
#[derive(Debug)]
struct SetWithName {
#[pyo3(get, name = "name")]
_name: &'static str,
}
#[cfg(not(PyPy))]
#[pymethods]
impl SetWithName {
#[new]
fn new() -> Self {
SetWithName { _name: "Hello :)" }
}
}
Python::with_gil(|py| {
let set_sub = pyo3::PyCell::new(py, SetWithName::new()).unwrap();
py_run!(
py,
set_sub,
r#"set_sub.add(10); assert list(set_sub) == [10]; assert set_sub.name == "Hello :)""#
);
});
}
#[pyclass(extends=PyDict)]
#[derive(Debug)]
struct DictWithName {
#[pyo3(get, name = "name")]
_name: &'static str,
}
#[pymethods]
impl DictWithName {
#[new]
fn new() -> Self {
DictWithName { _name: "Hello :)" }
}
}
#[test]
fn inherit_dict() {
Python::with_gil(|py| {
let dict_sub = pyo3::PyCell::new(py, DictWithName::new()).unwrap();
py_run!(
py,
dict_sub,
r#"dict_sub[0] = 1; assert dict_sub[0] == 1; assert dict_sub.name == "Hello :)""#
);
});
}
#[test]
fn inherit_dict_drop() {
Python::with_gil(|py| {
let dict_sub = pyo3::Py::new(py, DictWithName::new()).unwrap();
assert_eq!(dict_sub.get_refcnt(py), 1);
let item = py.eval("object()", None, None).unwrap();
assert_eq!(item.get_refcnt(), 1);
dict_sub.as_ref(py).set_item("foo", item).unwrap();
assert_eq!(item.get_refcnt(), 2);
drop(dict_sub);
assert_eq!(item.get_refcnt(), 1);
})
}
#[pyclass(extends=PyException)]
struct CustomException {
#[pyo3(get)]
context: &'static str,
}
#[pymethods]
impl CustomException {
#[new]
fn new(_exc_arg: &PyAny) -> Self {
CustomException {
context: "Hello :)",
}
}
}
#[test]
fn custom_exception() {
Python::with_gil(|py| {
let cls = py.get_type::<CustomException>();
let dict = [("cls", cls)].into_py_dict(py);
let res = py.run(
"e = cls('hello'); assert str(e) == 'hello'; assert e.context == 'Hello :)'; raise e",
None,
Some(dict)
);
let err = res.unwrap_err();
assert!(err.matches(py, cls), "{}", err);
// catching the exception in Python also works:
py_run!(
py,
cls,
r#"
try:
raise cls("foo")
except cls:
pass
"#
)
})
}
}
#[pyclass(subclass)]
struct SimpleClass {}
#[pymethods]
impl SimpleClass {
#[new]
fn new() -> Self {
Self {}
}
}
#[test]
fn test_subclass_ref_counts() {
// regression test for issue #1363
Python::with_gil(|py| {
#[allow(non_snake_case)]
let SimpleClass = py.get_type::<SimpleClass>();
py_run!(
py,
SimpleClass,
r#"
import gc
import sys
class SubClass(SimpleClass):
pass
gc.collect()
count = sys.getrefcount(SubClass)
for i in range(1000):
c = SubClass()
del c
gc.collect()
after = sys.getrefcount(SubClass)
# depending on Python's GC the count may be either identical or exactly 1000 higher,
# both are expected values that are not representative of the issue.
#
# (With issue #1363 the count will be decreased.)
assert after == count or (after == count + 1000), f"{after} vs {count}"
"#
);
})
}
#[test]
#[cfg(not(Py_LIMITED_API))]
fn module_add_class_inherit_bool_fails() {
use pyo3::types::PyBool;
#[pyclass(extends = PyBool)]
struct ExtendsBool;
Python::with_gil(|py| {
let m = PyModule::new(py, "test_module").unwrap();
let err = m.add_class::<ExtendsBool>().unwrap_err();
assert_eq!(
err.to_string(),
"RuntimeError: An error occurred while initializing class ExtendsBool"
);
assert_eq!(
err.cause(py).unwrap().to_string(),
"TypeError: type 'bool' is not an acceptable base type"
);
})
}