diff --git a/guide/src/class.md b/guide/src/class.md index dfe5c8ae..6ca582ed 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -691,7 +691,7 @@ This is the equivalent of the Python decorator `@classmethod`. #[pymethods] impl MyClass { #[classmethod] - fn cls_method(cls: &PyType) -> PyResult { + fn cls_method(cls: &Bound<'_, PyType>) -> PyResult { Ok(10) } } @@ -719,10 +719,10 @@ To create a constructor which takes a positional class argument, you can combine impl BaseClass { #[new] #[classmethod] - fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult { + fn py_new(cls: &Bound<'_, PyType>) -> PyResult { // Get an abstract attribute (presumably) declared on a subclass of this class. - let subclass_attr = cls.getattr("a_class_attr")?; - Ok(Self(subclass_attr.to_object(py))) + let subclass_attr: Bound<'_, PyAny> = cls.getattr("a_class_attr")?; + Ok(Self(subclass_attr.unbind())) } } ``` @@ -928,7 +928,7 @@ impl MyClass { // similarly for classmethod arguments, use $cls #[classmethod] #[pyo3(text_signature = "($cls, e, f)")] - fn my_class_method(cls: &PyType, e: i32, f: i32) -> i32 { + fn my_class_method(cls: &Bound<'_, PyType>, e: i32, f: i32) -> i32 { e + f } #[staticmethod] diff --git a/guide/src/function.md b/guide/src/function.md index 49ec716a..f3955ba5 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -83,10 +83,11 @@ The `#[pyo3]` attribute can be used to modify properties of the generated Python ```rust use pyo3::prelude::*; + use pyo3::types::PyString; #[pyfunction] #[pyo3(pass_module)] - fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { + fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult> { module.name() } diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index c158ec9f..f492a330 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -127,13 +127,21 @@ impl FnType { let slf: Ident = syn::Ident::new("_slf", Span::call_site()); quote_spanned! { *span => #[allow(clippy::useless_conversion)] - ::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())), + ::std::convert::Into::into( + _pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast()) + .downcast_unchecked::<_pyo3::types::PyType>() + ), } } FnType::FnModule(span) => { + let py = syn::Ident::new("py", Span::call_site()); + let slf: Ident = syn::Ident::new("_slf", Span::call_site()); quote_spanned! { *span => #[allow(clippy::useless_conversion)] - ::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)), + ::std::convert::Into::into( + _pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast()) + .downcast_unchecked::<_pyo3::types::PyModule>() + ), } } } @@ -409,7 +417,7 @@ impl<'a> FnSpec<'a> { // will error on incorrect type. Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(), Some(syn::FnArg::Receiver(_)) | None => bail_spanned!( - sig.paren_token.span.join() => "Expected `&PyType` or `Py` as the first argument to `#[classmethod]`" + sig.paren_token.span.join() => "Expected `&Bound` or `Py` as the first argument to `#[classmethod]`" ), }; FnType::FnClass(span) diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index 326893d1..9c7b2d25 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -44,6 +44,25 @@ struct AssertingBaseClass; #[pymethods] impl AssertingBaseClass { + #[new] + #[classmethod] + fn new(cls: &Bound<'_, PyType>, expected_type: Bound<'_, PyType>) -> PyResult { + if !cls.is(&expected_type) { + return Err(PyValueError::new_err(format!( + "{:?} != {:?}", + cls, expected_type + ))); + } + Ok(Self) + } +} + +#[pyclass(subclass)] +#[derive(Clone, Debug)] +struct AssertingBaseClassGilRef; + +#[pymethods] +impl AssertingBaseClassGilRef { #[new] #[classmethod] fn new(cls: &PyType, expected_type: &PyType) -> PyResult { @@ -65,6 +84,7 @@ pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/pytests/tests/test_pyclasses.py b/pytests/tests/test_pyclasses.py index 54828628..9a9b44b5 100644 --- a/pytests/tests/test_pyclasses.py +++ b/pytests/tests/test_pyclasses.py @@ -41,6 +41,17 @@ def test_new_classmethod(): _ = AssertingSubClass(expected_type=str) +def test_new_classmethod_gil_ref(): + class AssertingSubClass(pyclasses.AssertingBaseClassGilRef): + pass + + # The `AssertingBaseClass` constructor errors if it is not passed the + # relevant subclass. + _ = AssertingSubClass(expected_type=AssertingSubClass) + with pytest.raises(ValueError): + _ = AssertingSubClass(expected_type=str) + + class ClassWithoutConstructorPy: def __new__(cls): raise TypeError("No constructor defined") diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index 4fac39fb..eef3b569 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -3,8 +3,10 @@ use crate::exceptions::PyStopAsyncIteration; use crate::gil::LockGIL; use crate::impl_::panic::PanicTrap; use crate::internal_tricks::extract_c_string; +use crate::types::{any::PyAnyMethods, PyModule, PyType}; use crate::{ - ffi, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, Python, + ffi, Bound, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, + Python, }; use std::borrow::Cow; use std::ffi::CStr; @@ -466,3 +468,52 @@ pub trait AsyncIterResultOptionKind { } impl AsyncIterResultOptionKind for Result, Error> {} + +/// Used in `#[classmethod]` to pass the class object to the method +/// and also in `#[pyfunction(pass_module)]`. +/// +/// This is a wrapper to avoid implementing `From` for GIL Refs. +/// +/// Once the GIL Ref API is fully removed, it should be possible to simplify +/// this to just `&'a Bound<'py, T>` and `From` implementations. +pub struct BoundRef<'a, 'py, T>(pub &'a Bound<'py, T>); + +impl<'a, 'py> BoundRef<'a, 'py, PyAny> { + pub unsafe fn ref_from_ptr(py: Python<'py>, ptr: &'a *mut ffi::PyObject) -> Self { + BoundRef(Bound::ref_from_ptr(py, ptr)) + } + + pub unsafe fn downcast_unchecked(self) -> BoundRef<'a, 'py, T> { + BoundRef(self.0.downcast_unchecked::()) + } +} + +// GIL Ref implementations for &'a T ran into trouble with orphan rules, +// so explicit implementations are used instead for the two relevant types. +impl<'a> From> for &'a PyType { + #[inline] + fn from(bound: BoundRef<'a, 'a, PyType>) -> Self { + bound.0.as_gil_ref() + } +} + +impl<'a> From> for &'a PyModule { + #[inline] + fn from(bound: BoundRef<'a, 'a, PyModule>) -> Self { + bound.0.as_gil_ref() + } +} + +impl<'a, 'py, T> From> for &'a Bound<'py, T> { + #[inline] + fn from(bound: BoundRef<'a, 'py, T>) -> Self { + bound.0 + } +} + +impl From> for Py { + #[inline] + fn from(bound: BoundRef<'_, '_, T>) -> Self { + bound.0.clone().unbind() + } +} diff --git a/src/instance.rs b/src/instance.rs index 54a93452..ff553c44 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -138,6 +138,24 @@ impl<'py> Bound<'py, PyAny> { ) -> PyResult { Py::from_owned_ptr_or_err(py, ptr).map(|obj| Self(py, ManuallyDrop::new(obj))) } + + /// This slightly strange method is used to obtain `&Bound` from a pointer in macro code + /// where we need to constrain the lifetime `'a` safely. + /// + /// Note that `'py` is required to outlive `'a` implicitly by the nature of the fact that + /// `&'a Bound<'py>` means that `Bound<'py>` exists for at least the lifetime `'a`. + /// + /// # Safety + /// - `ptr` must be a valid pointer to a Python object for the lifetime `'a`. The `ptr` can + /// be either a borrowed reference or an owned reference, it does not matter, as this is + /// just `&Bound` there will never be any ownership transfer. + #[inline] + pub(crate) unsafe fn ref_from_ptr<'a>( + _py: Python<'py>, + ptr: &'a *mut ffi::PyObject, + ) -> &'a Self { + &*(ptr as *const *mut ffi::PyObject).cast::>() + } } impl<'py, T> Bound<'py, T> diff --git a/src/tests/hygiene/pymethods.rs b/src/tests/hygiene/pymethods.rs index 15ea6759..a00e67d9 100644 --- a/src/tests/hygiene/pymethods.rs +++ b/src/tests/hygiene/pymethods.rs @@ -375,7 +375,7 @@ impl Dummy { #[staticmethod] fn staticmethod() {} #[classmethod] - fn clsmethod(_: &crate::types::PyType) {} + fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {} #[pyo3(signature = (*_args, **_kwds))] fn __call__( &self, @@ -770,7 +770,7 @@ impl Dummy { #[staticmethod] fn staticmethod() {} #[classmethod] - fn clsmethod(_: &crate::types::PyType) {} + fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {} #[pyo3(signature = (*_args, **_kwds))] fn __call__( &self, diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index bbf37a2d..ff3555c3 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -284,7 +284,7 @@ fn panic_unsendable_child() { test_unsendable::().unwrap(); } -fn get_length(obj: &PyAny) -> PyResult { +fn get_length(obj: &Bound<'_, PyAny>) -> PyResult { let length = obj.len()?; Ok(length) @@ -299,7 +299,18 @@ impl ClassWithFromPyWithMethods { argument } #[classmethod] - fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize { + fn classmethod( + _cls: &Bound<'_, PyType>, + #[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] argument: usize, + ) -> usize { + argument + } + + #[classmethod] + fn classmethod_gil_ref( + _cls: &PyType, + #[pyo3(from_py_with = "PyAny::len")] argument: usize, + ) -> usize { argument } @@ -322,6 +333,7 @@ fn test_pymethods_from_py_with() { assert instance.instance_method(arg) == 2 assert instance.classmethod(arg) == 2 + assert instance.classmethod_gil_ref(arg) == 2 assert instance.staticmethod(arg) == 2 "# ); diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 08356533..2114ead2 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -73,7 +73,13 @@ impl ClassMethod { #[classmethod] /// Test class method. - fn method(cls: &PyType) -> PyResult { + fn method(cls: &Bound<'_, PyType>) -> PyResult { + Ok(format!("{}.method()!", cls.as_gil_ref().qualname()?)) + } + + #[classmethod] + /// Test class method. + fn method_gil_ref(cls: &PyType) -> PyResult { Ok(format!("{}.method()!", cls.qualname()?)) } @@ -108,8 +114,12 @@ struct ClassMethodWithArgs {} #[pymethods] impl ClassMethodWithArgs { #[classmethod] - fn method(cls: &PyType, input: &PyString) -> PyResult { - Ok(format!("{}.method({})", cls.qualname()?, input)) + fn method(cls: &Bound<'_, PyType>, input: &PyString) -> PyResult { + Ok(format!( + "{}.method({})", + cls.as_gil_ref().qualname()?, + input + )) } } @@ -915,7 +925,7 @@ impl r#RawIdents { } #[classmethod] - pub fn r#class_method(_: &PyType, r#type: PyObject) -> PyObject { + pub fn r#class_method(_: &Bound<'_, PyType>, r#type: PyObject) -> PyObject { r#type } @@ -1082,7 +1092,7 @@ issue_1506!( #[classmethod] fn issue_1506_class( - _cls: &PyType, + _cls: &Bound<'_, PyType>, _py: Python<'_>, _arg: &PyAny, _args: &PyTuple, diff --git a/tests/test_module.rs b/tests/test_module.rs index 1bd976e0..9a59770e 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::py_run; +use pyo3::types::PyString; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; #[path = "../src/tests/common.rs"] @@ -344,47 +345,59 @@ fn test_module_with_constant() { #[pyfunction] #[pyo3(pass_module)] -fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { +fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult> { module.name() } #[pyfunction] #[pyo3(pass_module)] -fn pyfunction_with_module_owned(module: Py) -> PyResult { - Python::with_gil(|gil| module.as_ref(gil).name().map(Into::into)) -} - -#[pyfunction] -#[pyo3(pass_module)] -fn pyfunction_with_module_and_py<'a>( - module: &'a PyModule, - _python: Python<'a>, -) -> PyResult<&'a str> { +fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> { module.name() } #[pyfunction] #[pyo3(pass_module)] -fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> { +fn pyfunction_with_module_owned( + module: Py, + py: Python<'_>, +) -> PyResult> { + module.bind(py).name() +} + +#[pyfunction] +#[pyo3(pass_module)] +fn pyfunction_with_module_and_py<'py>( + module: &Bound<'py, PyModule>, + _python: Python<'py>, +) -> PyResult> { + module.name() +} + +#[pyfunction] +#[pyo3(pass_module)] +fn pyfunction_with_module_and_arg<'py>( + module: &Bound<'py, PyModule>, + string: String, +) -> PyResult<(Bound<'py, PyString>, String)> { module.name().map(|s| (s, string)) } #[pyfunction(signature = (string="foo"))] #[pyo3(pass_module)] -fn pyfunction_with_module_and_default_arg<'a>( - module: &'a PyModule, +fn pyfunction_with_module_and_default_arg<'py>( + module: &Bound<'py, PyModule>, string: &str, -) -> PyResult<(&'a str, String)> { +) -> PyResult<(Bound<'py, PyString>, String)> { module.name().map(|s| (s, string.into())) } #[pyfunction(signature = (*args, **kwargs))] #[pyo3(pass_module)] -fn pyfunction_with_module_and_args_kwargs<'a>( - module: &'a PyModule, - args: &PyTuple, - kwargs: Option<&PyDict>, -) -> PyResult<(&'a str, usize, Option)> { +fn pyfunction_with_module_and_args_kwargs<'py>( + module: &Bound<'py, PyModule>, + args: &Bound<'py, PyTuple>, + kwargs: Option<&Bound<'py, PyDict>>, +) -> PyResult<(Bound<'py, PyString>, usize, Option)> { module .name() .map(|s| (s, args.len(), kwargs.map(|d| d.len()))) @@ -399,6 +412,7 @@ fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str> #[pymodule] fn module_with_functions_with_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?; + m.add_function(wrap_pyfunction!(pyfunction_with_module_gil_ref, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_owned, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?; @@ -421,6 +435,11 @@ fn test_module_functions_with_module() { m, "m.pyfunction_with_module() == 'module_with_functions_with_module'" ); + py_assert!( + py, + m, + "m.pyfunction_with_module_gil_ref() == 'module_with_functions_with_module'" + ); py_assert!( py, m, diff --git a/tests/test_multiple_pymethods.rs b/tests/test_multiple_pymethods.rs index f78a9c60..13baeed3 100644 --- a/tests/test_multiple_pymethods.rs +++ b/tests/test_multiple_pymethods.rs @@ -35,7 +35,7 @@ impl PyClassWithMultiplePyMethods { #[pymethods] impl PyClassWithMultiplePyMethods { #[classmethod] - fn classmethod(_ty: &PyType) -> &'static str { + fn classmethod(_ty: &Bound<'_, PyType>) -> &'static str { "classmethod" } } diff --git a/tests/test_no_imports.rs b/tests/test_no_imports.rs index 1ff3dc94..4abdcdf2 100644 --- a/tests/test_no_imports.rs +++ b/tests/test_no_imports.rs @@ -60,7 +60,9 @@ impl BasicClass { /// Some documentation here #[classmethod] - fn classmethod(cls: &pyo3::types::PyType) -> &pyo3::types::PyType { + fn classmethod<'a, 'py>( + cls: &'a pyo3::Bound<'py, pyo3::types::PyType>, + ) -> &'a pyo3::Bound<'py, pyo3::types::PyType> { cls } @@ -132,8 +134,10 @@ struct NewClassMethod { impl NewClassMethod { #[new] #[classmethod] - fn new(cls: &pyo3::types::PyType) -> Self { - Self { cls: cls.into() } + fn new(cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self { + cls: cls.clone().into_any().unbind(), + } } } diff --git a/tests/test_text_signature.rs b/tests/test_text_signature.rs index 5b0491d9..9056ca21 100644 --- a/tests/test_text_signature.rs +++ b/tests/test_text_signature.rs @@ -115,7 +115,7 @@ fn test_auto_test_signature_function() { } #[pyfunction(pass_module)] - fn my_function_2(module: &PyModule, a: i32, b: i32, c: i32) { + fn my_function_2(module: &Bound<'_, PyModule>, a: i32, b: i32, c: i32) { let _ = (module, a, b, c); } @@ -232,7 +232,7 @@ fn test_auto_test_signature_method() { } #[classmethod] - fn classmethod(cls: &PyType, a: i32, b: i32, c: i32) { + fn classmethod(cls: &Bound<'_, PyType>, a: i32, b: i32, c: i32) { let _ = (cls, a, b, c); } } @@ -311,7 +311,7 @@ fn test_auto_test_signature_opt_out() { #[classmethod] #[pyo3(text_signature = None)] - fn classmethod(cls: &PyType, a: i32, b: i32, c: i32) { + fn classmethod(cls: &Bound<'_, PyType>, a: i32, b: i32, c: i32) { let _ = (cls, a, b, c); } } @@ -372,7 +372,7 @@ fn test_methods() { } #[classmethod] #[pyo3(text_signature = "($cls, c)")] - fn class_method(_cls: &PyType, c: i32) { + fn class_method(_cls: &Bound<'_, PyType>, c: i32) { let _ = c; } #[staticmethod] diff --git a/tests/ui/invalid_pyfunctions.stderr b/tests/ui/invalid_pyfunctions.stderr index a3fd845d..6576997a 100644 --- a/tests/ui/invalid_pyfunctions.stderr +++ b/tests/ui/invalid_pyfunctions.stderr @@ -35,11 +35,11 @@ error: expected `&PyModule` or `Py` as first argument with `pass_modul 19 | fn pass_module_but_no_arguments<'py>() {} | ^^ -error[E0277]: the trait bound `&str: From<&pyo3::prelude::PyModule>` is not satisfied +error[E0277]: the trait bound `&str: From>` is not satisfied --> tests/ui/invalid_pyfunctions.rs:22:43 | 22 | fn first_argument_not_module<'py>(string: &str, module: &'py PyModule) -> PyResult<&'py str> { - | ^ the trait `From<&pyo3::prelude::PyModule>` is not implemented for `&str` + | ^ the trait `From>` is not implemented for `&str` | = help: the following other types implement trait `From`: > @@ -48,4 +48,4 @@ error[E0277]: the trait bound `&str: From<&pyo3::prelude::PyModule>` is not sati > > > - = note: required for `&pyo3::prelude::PyModule` to implement `Into<&str>` + = note: required for `BoundRef<'_, '_, pyo3::prelude::PyModule>` to implement `Into<&str>` diff --git a/tests/ui/invalid_pymethods.stderr b/tests/ui/invalid_pymethods.stderr index 1a50c4da..6a8d6eca 100644 --- a/tests/ui/invalid_pymethods.stderr +++ b/tests/ui/invalid_pymethods.stderr @@ -22,13 +22,13 @@ error: unexpected receiver 26 | fn staticmethod_with_receiver(&self) {} | ^ -error: Expected `&PyType` or `Py` as the first argument to `#[classmethod]` +error: Expected `&Bound` or `Py` as the first argument to `#[classmethod]` --> tests/ui/invalid_pymethods.rs:32:33 | 32 | fn classmethod_with_receiver(&self) {} | ^^^^^^^ -error: Expected `&PyType` or `Py` as the first argument to `#[classmethod]` +error: Expected `&Bound` or `Py` as the first argument to `#[classmethod]` --> tests/ui/invalid_pymethods.rs:38:36 | 38 | fn classmethod_missing_argument() -> Self { @@ -179,11 +179,11 @@ error: macros cannot be used as items in `#[pymethods]` impl blocks 197 | macro_invocation!(); | ^^^^^^^^^^^^^^^^ -error[E0277]: the trait bound `i32: From<&PyType>` is not satisfied +error[E0277]: the trait bound `i32: From>` is not satisfied --> tests/ui/invalid_pymethods.rs:46:45 | 46 | fn classmethod_wrong_first_argument(_x: i32) -> Self { - | ^^^ the trait `From<&PyType>` is not implemented for `i32` + | ^^^ the trait `From>` is not implemented for `i32` | = help: the following other types implement trait `From`: > @@ -192,4 +192,4 @@ error[E0277]: the trait bound `i32: From<&PyType>` is not satisfied > > > - = note: required for `&PyType` to implement `Into` + = note: required for `BoundRef<'_, '_, PyType>` to implement `Into`