Merge pull request #3587 from wyfo/classmethod_into
feat: allow classmethods to receive `Py<PyType>`
This commit is contained in:
commit
3f0dfa9698
2
newsfragments/3587.added.md
Normal file
2
newsfragments/3587.added.md
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
- Classmethods can now receive `Py<PyType>` as their first argument
|
||||||
|
- Function annotated with `pass_module` can now receive `Py<PyModule>` as their first argument
|
|
@ -113,12 +113,14 @@ impl FnType {
|
||||||
}
|
}
|
||||||
FnType::FnClass | FnType::FnNewClass => {
|
FnType::FnClass | FnType::FnNewClass => {
|
||||||
quote! {
|
quote! {
|
||||||
_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject),
|
#[allow(clippy::useless_conversion)]
|
||||||
|
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FnType::FnModule => {
|
FnType::FnModule => {
|
||||||
quote! {
|
quote! {
|
||||||
py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf),
|
#[allow(clippy::useless_conversion)]
|
||||||
|
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,7 +199,8 @@ pub fn impl_wrap_pyfunction(
|
||||||
.collect::<syn::Result<Vec<_>>>()?;
|
.collect::<syn::Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let tp = if pass_module.is_some() {
|
let tp = if pass_module.is_some() {
|
||||||
const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`";
|
const PASS_MODULE_ERR: &str =
|
||||||
|
"expected &PyModule or Py<PyModule> as first argument with `pass_module`";
|
||||||
ensure_spanned!(
|
ensure_spanned!(
|
||||||
!arguments.is_empty(),
|
!arguments.is_empty(),
|
||||||
func.span() => PASS_MODULE_ERR
|
func.span() => PASS_MODULE_ERR
|
||||||
|
@ -271,18 +272,32 @@ pub fn impl_wrap_pyfunction(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn type_is_pymodule(ty: &syn::Type) -> bool {
|
fn type_is_pymodule(ty: &syn::Type) -> bool {
|
||||||
if let syn::Type::Reference(tyref) = ty {
|
let is_pymodule = |typath: &syn::TypePath| {
|
||||||
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
|
typath
|
||||||
if typath
|
.path
|
||||||
.path
|
.segments
|
||||||
.segments
|
.last()
|
||||||
.last()
|
.map_or(false, |seg| seg.ident == "PyModule")
|
||||||
.map(|seg| seg.ident == "PyModule")
|
};
|
||||||
.unwrap_or(false)
|
match ty {
|
||||||
{
|
syn::Type::Reference(tyref) => {
|
||||||
return true;
|
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
|
||||||
|
return is_pymodule(typath);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
syn::Type::Path(typath) => {
|
||||||
|
if let Some(syn::PathSegment {
|
||||||
|
arguments: syn::PathArguments::AngleBracketed(args),
|
||||||
|
..
|
||||||
|
}) = typath.path.segments.last()
|
||||||
|
{
|
||||||
|
if args.args.len() != 1 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return matches!(args.args.first().unwrap(), syn::GenericArgument::Type(syn::Type::Path(typath)) if is_pymodule(typath));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,14 @@ impl ClassMethod {
|
||||||
fn method(cls: &PyType) -> PyResult<String> {
|
fn method(cls: &PyType) -> PyResult<String> {
|
||||||
Ok(format!("{}.method()!", cls.name()?))
|
Ok(format!("{}.method()!", cls.name()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[classmethod]
|
||||||
|
fn method_owned(cls: Py<PyType>) -> PyResult<String> {
|
||||||
|
Ok(format!(
|
||||||
|
"{}.method_owned()!",
|
||||||
|
Python::with_gil(|gil| cls.as_ref(gil).name().map(ToString::to_string))?
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -84,6 +92,11 @@ fn class_method() {
|
||||||
let d = [("C", py.get_type::<ClassMethod>())].into_py_dict(py);
|
let d = [("C", py.get_type::<ClassMethod>())].into_py_dict(py);
|
||||||
py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'");
|
py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'");
|
||||||
py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'");
|
py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'");
|
||||||
|
py_assert!(
|
||||||
|
py,
|
||||||
|
*d,
|
||||||
|
"C().method_owned() == 'ClassMethod.method_owned()!'"
|
||||||
|
);
|
||||||
py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'");
|
py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'");
|
||||||
py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'");
|
py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'");
|
||||||
});
|
});
|
||||||
|
|
|
@ -348,6 +348,12 @@ fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
|
||||||
module.name()
|
module.name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(pass_module)]
|
||||||
|
fn pyfunction_with_module_owned(module: Py<PyModule>) -> PyResult<String> {
|
||||||
|
Python::with_gil(|gil| module.as_ref(gil).name().map(Into::into))
|
||||||
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(pass_module)]
|
#[pyo3(pass_module)]
|
||||||
fn pyfunction_with_module_and_py<'a>(
|
fn pyfunction_with_module_and_py<'a>(
|
||||||
|
@ -393,6 +399,7 @@ fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str>
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
|
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(pyfunction_with_module_owned, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?;
|
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?;
|
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?;
|
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?;
|
||||||
|
@ -401,6 +408,7 @@ fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<
|
||||||
pyfunction_with_pass_module_in_attribute,
|
pyfunction_with_pass_module_in_attribute,
|
||||||
m
|
m
|
||||||
)?)?;
|
)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,6 +421,11 @@ fn test_module_functions_with_module() {
|
||||||
m,
|
m,
|
||||||
"m.pyfunction_with_module() == 'module_with_functions_with_module'"
|
"m.pyfunction_with_module() == 'module_with_functions_with_module'"
|
||||||
);
|
);
|
||||||
|
py_assert!(
|
||||||
|
py,
|
||||||
|
m,
|
||||||
|
"m.pyfunction_with_module_owned() == 'module_with_functions_with_module'"
|
||||||
|
);
|
||||||
py_assert!(
|
py_assert!(
|
||||||
py,
|
py,
|
||||||
m,
|
m,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
error: expected &PyModule as first argument with `pass_module`
|
error: expected &PyModule or Py<PyModule> as first argument with `pass_module`
|
||||||
--> tests/ui/invalid_need_module_arg_position.rs:6:21
|
--> tests/ui/invalid_need_module_arg_position.rs:6:21
|
||||||
|
|
|
|
||||||
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
|
6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> {
|
||||||
|
|
Loading…
Reference in a new issue