From 20c56181602833b9e37ac987318401c367b8b86c Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Tue, 16 May 2023 11:39:53 -0700 Subject: [PATCH] Add support for combining the `#[new]` and `#[classmethod]` method types. --- guide/src/class.md | 21 ++++++++++++++++ newsfragments/3157.added.md | 1 + pyo3-macros-backend/src/method.rs | 38 +++++++++++++++++++++++------ pyo3-macros-backend/src/pymethod.rs | 4 ++- pytests/src/pyclasses.rs | 23 +++++++++++++++++ pytests/tests/test_pyclasses.py | 11 +++++++++ tests/ui/invalid_pymethods.stderr | 2 +- 7 files changed, 90 insertions(+), 10 deletions(-) create mode 100644 newsfragments/3157.added.md diff --git a/guide/src/class.md b/guide/src/class.md index e209c2e3..a5aadabe 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -607,6 +607,27 @@ Declares a class method callable from Python. * For details on `parameter-list`, see the documentation of `Method arguments` section. * The return type must be `PyResult` or `T` for some `T` that implements `IntoPy`. +### Constructors which accept a class argument + +To create a constructor which takes a positional class argument, you can combine the `#[classmethod]` and `#[new]` modifiers: +```rust +# use pyo3::prelude::*; +# use pyo3::types::PyType; +# #[pyclass] +# struct BaseClass(PyObject); +# +#[pymethods] +impl BaseClass { + #[new] + #[classmethod] + fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult { + // Get an abstract attribute (presumably) declared on a subclass of this class. + let subclass_attr = cls.getattr("a_class_attr")?; + Ok(Self(subclass_attr.to_object(py))) + } +} +``` + ## Static methods To create a static method for a custom class, the method needs to be annotated with the diff --git a/newsfragments/3157.added.md b/newsfragments/3157.added.md new file mode 100644 index 00000000..2719f081 --- /dev/null +++ b/newsfragments/3157.added.md @@ -0,0 +1 @@ +Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument. diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index b7f96747..0374df2f 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error { pub enum MethodTypeAttribute { /// `#[new]` New, + /// `#[new]` && `#[classmethod]` + NewClassMethod, /// `#[classmethod]` ClassMethod, /// `#[classattr]` @@ -102,6 +104,7 @@ pub enum FnType { Setter(SelfType), Fn(SelfType), FnNew, + FnNewClass, FnClass, FnStatic, FnModule, @@ -122,7 +125,7 @@ impl FnType { FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => { quote!() } - FnType::FnClass => { + FnType::FnClass | FnType::FnNewClass => { quote! { let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject); } @@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> { let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr { Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None), Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None), - Some(MethodTypeAttribute::New) => { + Some(MethodTypeAttribute::New) | Some(MethodTypeAttribute::NewClassMethod) => { if let Some(name) = &python_name { bail_spanned!(name.span() => "`name` not allowed with `#[new]`"); } *python_name = Some(syn::Ident::new("__new__", Span::call_site())); - (FnType::FnNew, false, Some(CallingConvention::TpNew)) + if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) { + (FnType::FnNew, false, Some(CallingConvention::TpNew)) + } else { + (FnType::FnNewClass, true, Some(CallingConvention::TpNew)) + } } Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None), Some(MethodTypeAttribute::Getter) => { @@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> { } CallingConvention::TpNew => { let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?; - let call = quote! { #rust_name(#(#args),*) }; + let call = match &self.tp { + FnType::FnNew => quote! { #rust_name(#(#args),*) }, + FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) }, + x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x), + }; quote! { unsafe fn #ident( #py: _pyo3::Python<'_>, @@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> { FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None, FnType::Fn(_) => Some("self"), FnType::FnModule => Some("module"), - FnType::FnClass => Some("cls"), + FnType::FnClass | FnType::FnNewClass => Some("cls"), FnType::FnStatic | FnType::FnNew => None, }; @@ -637,11 +648,22 @@ fn parse_method_attributes( let mut deprecated_args = None; let mut ty: Option = None; + macro_rules! set_compound_ty { + ($new_ty:expr, $ident:expr) => { + ty = match (ty, $new_ty) { + (None, new_ty) => Some(new_ty), + (Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod), + (Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod), + (Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"), + }; + }; + } + macro_rules! set_ty { ($new_ty:expr, $ident:expr) => { ensure_spanned!( ty.replace($new_ty).is_none(), - $ident.span() => "cannot specify a second method type" + $ident.span() => "cannot combine these method types" ); }; } @@ -650,13 +672,13 @@ fn parse_method_attributes( match attr.parse_meta() { Ok(syn::Meta::Path(name)) => { if name.is_ident("new") || name.is_ident("__new__") { - set_ty!(MethodTypeAttribute::New, name); + set_compound_ty!(MethodTypeAttribute::New, name); } else if name.is_ident("init") || name.is_ident("__init__") { bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0"); } else if name.is_ident("call") || name.is_ident("__call__") { bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0"); } else if name.is_ident("classmethod") { - set_ty!(MethodTypeAttribute::ClassMethod, name); + set_compound_ty!(MethodTypeAttribute::ClassMethod, name); } else if name.is_ident("staticmethod") { set_ty!(MethodTypeAttribute::StaticMethod, name); } else if name.is_ident("classattr") { diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 04fb0211..bca3dab7 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -234,7 +234,9 @@ pub fn gen_py_method( Some(quote!(_pyo3::ffi::METH_STATIC)), )?), // special prototypes - (_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?), + (_, FnType::FnNew) | (_, FnType::FnNewClass) => { + GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?) + } (_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def( cls, diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index 362ce4a7..3ee61b34 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -1,5 +1,7 @@ +use pyo3::exceptions::PyValueError; use pyo3::iter::IterNextOutput; use pyo3::prelude::*; +use pyo3::types::PyType; #[pyclass] struct EmptyClass {} @@ -35,9 +37,30 @@ impl PyClassIter { } } +/// Demonstrates a base class which can operate on the relevant subclass in its constructor. +#[pyclass(subclass)] +#[derive(Clone, Debug)] +struct AssertingBaseClass; + +#[pymethods] +impl AssertingBaseClass { + #[new] + #[classmethod] + fn new(cls: &PyType, expected_type: &PyType) -> PyResult { + if !cls.is(expected_type) { + return Err(PyValueError::new_err(format!( + "{:?} != {:?}", + cls, expected_type + ))); + } + Ok(Self) + } +} + #[pymodule] pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/pytests/tests/test_pyclasses.py b/pytests/tests/test_pyclasses.py index c8697b4e..4a45b413 100644 --- a/pytests/tests/test_pyclasses.py +++ b/pytests/tests/test_pyclasses.py @@ -25,3 +25,14 @@ def test_iter(): with pytest.raises(StopIteration) as excinfo: next(i) assert excinfo.value.value == "Ended" + + +class AssertingSubClass(pyclasses.AssertingBaseClass): + pass + + +def test_new_classmethod(): + # The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass. + _ = AssertingSubClass(expected_type=AssertingSubClass) + with pytest.raises(ValueError): + _ = AssertingSubClass(expected_type=str) diff --git a/tests/ui/invalid_pymethods.stderr b/tests/ui/invalid_pymethods.stderr index a4b14baa..29bd4cc9 100644 --- a/tests/ui/invalid_pymethods.stderr +++ b/tests/ui/invalid_pymethods.stderr @@ -88,7 +88,7 @@ error: `signature` not allowed with `classattr` 105 | #[pyo3(signature = ())] | ^^^^^^^^^ -error: cannot specify a second method type +error: cannot combine these method types --> tests/ui/invalid_pymethods.rs:112:7 | 112 | #[staticmethod]