Merge pull request #3587 from wyfo/classmethod_into

feat: allow classmethods to receive `Py<PyType>`
This commit is contained in:
David Hewitt 2023-11-22 19:34:19 +00:00 committed by GitHub
commit 3f0dfa9698
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 14 deletions

View file

@ -0,0 +1,2 @@
- Classmethods can now receive `Py<PyType>` as their first argument
- Function annotated with `pass_module` can now receive `Py<PyModule>` as their first argument

View file

@ -113,12 +113,14 @@ impl FnType {
} }
FnType::FnClass | FnType::FnNewClass => { FnType::FnClass | FnType::FnNewClass => {
quote! { quote! {
_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject), #[allow(clippy::useless_conversion)]
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject)),
} }
} }
FnType::FnModule => { FnType::FnModule => {
quote! { quote! {
py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf), #[allow(clippy::useless_conversion)]
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
} }
} }
} }

View file

@ -199,7 +199,8 @@ pub fn impl_wrap_pyfunction(
.collect::<syn::Result<Vec<_>>>()?; .collect::<syn::Result<Vec<_>>>()?;
let tp = if pass_module.is_some() { let tp = if pass_module.is_some() {
const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`"; const PASS_MODULE_ERR: &str =
"expected &PyModule or Py<PyModule> as first argument with `pass_module`";
ensure_spanned!( ensure_spanned!(
!arguments.is_empty(), !arguments.is_empty(),
func.span() => PASS_MODULE_ERR func.span() => PASS_MODULE_ERR
@ -271,18 +272,32 @@ pub fn impl_wrap_pyfunction(
} }
fn type_is_pymodule(ty: &syn::Type) -> bool { fn type_is_pymodule(ty: &syn::Type) -> bool {
if let syn::Type::Reference(tyref) = ty { let is_pymodule = |typath: &syn::TypePath| {
if let syn::Type::Path(typath) = tyref.elem.as_ref() { typath
if typath .path
.path .segments
.segments .last()
.last() .map_or(false, |seg| seg.ident == "PyModule")
.map(|seg| seg.ident == "PyModule") };
.unwrap_or(false) match ty {
{ syn::Type::Reference(tyref) => {
return true; if let syn::Type::Path(typath) = tyref.elem.as_ref() {
return is_pymodule(typath);
} }
} }
syn::Type::Path(typath) => {
if let Some(syn::PathSegment {
arguments: syn::PathArguments::AngleBracketed(args),
..
}) = typath.path.segments.last()
{
if args.args.len() != 1 {
return false;
}
return matches!(args.args.first().unwrap(), syn::GenericArgument::Type(syn::Type::Path(typath)) if is_pymodule(typath));
}
}
_ => {}
} }
false false
} }

View file

@ -76,6 +76,14 @@ impl ClassMethod {
fn method(cls: &PyType) -> PyResult<String> { fn method(cls: &PyType) -> PyResult<String> {
Ok(format!("{}.method()!", cls.name()?)) Ok(format!("{}.method()!", cls.name()?))
} }
#[classmethod]
fn method_owned(cls: Py<PyType>) -> PyResult<String> {
Ok(format!(
"{}.method_owned()!",
Python::with_gil(|gil| cls.as_ref(gil).name().map(ToString::to_string))?
))
}
} }
#[test] #[test]
@ -84,6 +92,11 @@ fn class_method() {
let d = [("C", py.get_type::<ClassMethod>())].into_py_dict(py); let d = [("C", py.get_type::<ClassMethod>())].into_py_dict(py);
py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'"); py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'");
py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'"); py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'");
py_assert!(
py,
*d,
"C().method_owned() == 'ClassMethod.method_owned()!'"
);
py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'"); py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'");
py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'"); py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'");
}); });

View file

@ -348,6 +348,12 @@ fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
module.name() module.name()
} }
#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module_owned(module: Py<PyModule>) -> PyResult<String> {
Python::with_gil(|gil| module.as_ref(gil).name().map(Into::into))
}
#[pyfunction] #[pyfunction]
#[pyo3(pass_module)] #[pyo3(pass_module)]
fn pyfunction_with_module_and_py<'a>( fn pyfunction_with_module_and_py<'a>(
@ -393,6 +399,7 @@ fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str>
#[pymodule] #[pymodule]
fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_owned, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?;
@ -401,6 +408,7 @@ fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<
pyfunction_with_pass_module_in_attribute, pyfunction_with_pass_module_in_attribute,
m m
)?)?; )?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
Ok(()) Ok(())
} }
@ -413,6 +421,11 @@ fn test_module_functions_with_module() {
m, m,
"m.pyfunction_with_module() == 'module_with_functions_with_module'" "m.pyfunction_with_module() == 'module_with_functions_with_module'"
); );
py_assert!(
py,
m,
"m.pyfunction_with_module_owned() == 'module_with_functions_with_module'"
);
py_assert!( py_assert!(
py, py,
m, m,

View file

@ -1,4 +1,4 @@
error: expected &PyModule as first argument with `pass_module` error: expected &PyModule or Py<PyModule> as first argument with `pass_module`
--> tests/ui/invalid_need_module_arg_position.rs:6:21 --> tests/ui/invalid_need_module_arg_position.rs:6:21
| |
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> { 6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {