Give a better error message for Python in traverse

This commit is contained in:
mejrs 2023-06-19 23:25:51 +02:00
parent e664749d61
commit 51a6863440
6 changed files with 64 additions and 34 deletions

View File

@ -210,7 +210,7 @@ pub fn gen_py_method(
GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?)
}
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec.name))
GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?;
@ -404,7 +404,16 @@ fn impl_call_slot(cls: &syn::Type, mut spec: FnSpec<'_>) -> Result<MethodAndSlot
})
}
fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndSlotDef {
fn impl_traverse_slot(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<MethodAndSlotDef> {
if let (Some(py_arg), _) = split_off_python_arg(&spec.signature.arguments) {
return Err(syn::Error::new_spanned(py_arg.ty, "__traverse__ may not take `Python`. \
Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`. \
Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`, \
i.e. `Python::with_gil` will panic."));
}
let rust_fn_ident = spec.name;
let associated_method = quote! {
pub unsafe extern "C" fn __pymethod_traverse__(
slf: *mut _pyo3::ffi::PyObject,
@ -420,10 +429,10 @@ fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndS
pfunc: #cls::__pymethod_traverse__ as _pyo3::ffi::traverseproc as _
}
};
MethodAndSlotDef {
Ok(MethodAndSlotDef {
associated_method,
slot_def,
}
})
}
fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<MethodAndMethodDef> {

View File

@ -36,5 +36,5 @@ fn test_compile_errors() {
t.compile_fail("tests/ui/not_send.rs");
t.compile_fail("tests/ui/not_send2.rs");
t.compile_fail("tests/ui/get_set_all.rs");
t.compile_fail("tests/ui/traverse_bare_self.rs");
t.compile_fail("tests/ui/traverse.rs");
}

27
tests/ui/traverse.rs Normal file
View File

@ -0,0 +1,27 @@
use pyo3::prelude::*;
use pyo3::PyVisit;
use pyo3::PyTraverseError;
#[pyclass]
struct TraverseTriesToTakePyRef {}
#[pymethods]
impl TraverseTriesToTakePyRef {
fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
}
#[pyclass]
struct Class;
#[pymethods]
impl Class {
fn __traverse__(&self, py: Python<'_>, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(())
}
fn __clear__(&mut self) {
}
}
fn main() {}

23
tests/ui/traverse.stderr Normal file
View File

@ -0,0 +1,23 @@
error: __traverse__ may not take `Python`. Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`. Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`, i.e. `Python::with_gil` will panic.
--> tests/ui/traverse.rs:18:32
|
18 | fn __traverse__(&self, py: Python<'_>, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
| ^^^^^^^^^^
error[E0308]: mismatched types
--> tests/ui/traverse.rs:9:6
|
8 | #[pymethods]
| ------------ arguments to this function are incorrect
9 | impl TraverseTriesToTakePyRef {
| ______^
10 | | fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
| |___________________^ expected fn pointer, found fn item
|
= note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>`
found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}`
note: function defined here
--> src/impl_/pymethods.rs
|
| pub unsafe fn call_traverse_impl<T>(
| ^^^^^^^^^^^^^^^^^^

View File

@ -1,12 +0,0 @@
use pyo3::prelude::*;
use pyo3::PyVisit;
#[pyclass]
struct TraverseTriesToTakePyRef {}
#[pymethods]
impl TraverseTriesToTakePyRef {
fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
}
fn main() {}

View File

@ -1,17 +0,0 @@
error[E0308]: mismatched types
--> tests/ui/traverse_bare_self.rs:8:6
|
7 | #[pymethods]
| ------------ arguments to this function are incorrect
8 | impl TraverseTriesToTakePyRef {
| ______^
9 | | fn __traverse__(slf: PyRef<Self>, visit: PyVisit) {}
| |___________________^ expected fn pointer, found fn item
|
= note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>`
found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}`
note: function defined here
--> src/impl_/pymethods.rs
|
| pub unsafe fn call_traverse_impl<T>(
| ^^^^^^^^^^^^^^^^^^