3157: Add support for `#[new]` which is also a `#[classmethod]` r=davidhewitt a=stuhood

Fixes #3077.

Co-authored-by: Stu Hood <stuhood@gmail.com>
This commit is contained in:
bors[bot] 2023-05-17 21:58:44 +00:00 committed by GitHub
commit 3b4c7d38c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 10 deletions

View file

@ -607,6 +607,27 @@ Declares a class method callable from Python.
* For details on `parameter-list`, see the documentation of `Method arguments` section.
* The return type must be `PyResult<T>` or `T` for some `T` that implements `IntoPy<PyObject>`.
### Constructors which accept a class argument
To create a constructor which takes a positional class argument, you can combine the `#[classmethod]` and `#[new]` modifiers:
```rust
# use pyo3::prelude::*;
# use pyo3::types::PyType;
# #[pyclass]
# struct BaseClass(PyObject);
#
#[pymethods]
impl BaseClass {
#[new]
#[classmethod]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
}
}
```
## Static methods
To create a static method for a custom class, the method needs to be annotated with the

View file

@ -0,0 +1 @@
Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument.

View file

@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
pub enum MethodTypeAttribute {
/// `#[new]`
New,
/// `#[new]` && `#[classmethod]`
NewClassMethod,
/// `#[classmethod]`
ClassMethod,
/// `#[classattr]`
@ -102,6 +104,7 @@ pub enum FnType {
Setter(SelfType),
Fn(SelfType),
FnNew,
FnNewClass,
FnClass,
FnStatic,
FnModule,
@ -122,7 +125,7 @@ impl FnType {
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
}
FnType::FnClass => {
FnType::FnClass | FnType::FnNewClass => {
quote! {
let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject);
}
@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
Some(MethodTypeAttribute::New) => {
Some(MethodTypeAttribute::New) | Some(MethodTypeAttribute::NewClassMethod) => {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
(FnType::FnNew, false, Some(CallingConvention::TpNew))
if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) {
(FnType::FnNew, false, Some(CallingConvention::TpNew))
} else {
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
}
}
Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None),
Some(MethodTypeAttribute::Getter) => {
@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?;
let call = quote! { #rust_name(#(#args),*) };
let call = match &self.tp {
FnType::FnNew => quote! { #rust_name(#(#args),*) },
FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) },
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
};
quote! {
unsafe fn #ident(
#py: _pyo3::Python<'_>,
@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
FnType::Fn(_) => Some("self"),
FnType::FnModule => Some("module"),
FnType::FnClass => Some("cls"),
FnType::FnClass | FnType::FnNewClass => Some("cls"),
FnType::FnStatic | FnType::FnNew => None,
};
@ -637,11 +648,22 @@ fn parse_method_attributes(
let mut deprecated_args = None;
let mut ty: Option<MethodTypeAttribute> = None;
macro_rules! set_compound_ty {
($new_ty:expr, $ident:expr) => {
ty = match (ty, $new_ty) {
(None, new_ty) => Some(new_ty),
(Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod),
(Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod),
(Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"),
};
};
}
macro_rules! set_ty {
($new_ty:expr, $ident:expr) => {
ensure_spanned!(
ty.replace($new_ty).is_none(),
$ident.span() => "cannot specify a second method type"
$ident.span() => "cannot combine these method types"
);
};
}
@ -650,13 +672,13 @@ fn parse_method_attributes(
match attr.parse_meta() {
Ok(syn::Meta::Path(name)) => {
if name.is_ident("new") || name.is_ident("__new__") {
set_ty!(MethodTypeAttribute::New, name);
set_compound_ty!(MethodTypeAttribute::New, name);
} else if name.is_ident("init") || name.is_ident("__init__") {
bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0");
} else if name.is_ident("call") || name.is_ident("__call__") {
bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0");
} else if name.is_ident("classmethod") {
set_ty!(MethodTypeAttribute::ClassMethod, name);
set_compound_ty!(MethodTypeAttribute::ClassMethod, name);
} else if name.is_ident("staticmethod") {
set_ty!(MethodTypeAttribute::StaticMethod, name);
} else if name.is_ident("classattr") {

View file

@ -234,7 +234,9 @@ pub fn gen_py_method(
Some(quote!(_pyo3::ffi::METH_STATIC)),
)?),
// special prototypes
(_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?),
(_, FnType::FnNew) | (_, FnType::FnNewClass) => {
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
}
(_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def(
cls,

View file

@ -1,5 +1,7 @@
use pyo3::exceptions::PyValueError;
use pyo3::iter::IterNextOutput;
use pyo3::prelude::*;
use pyo3::types::PyType;
#[pyclass]
struct EmptyClass {}
@ -35,9 +37,30 @@ impl PyClassIter {
}
}
/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
#[pyclass(subclass)]
#[derive(Clone, Debug)]
struct AssertingBaseClass;
#[pymethods]
impl AssertingBaseClass {
#[new]
#[classmethod]
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
if !cls.is(expected_type) {
return Err(PyValueError::new_err(format!(
"{:?} != {:?}",
cls, expected_type
)));
}
Ok(Self)
}
}
#[pymodule]
pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
Ok(())
}

View file

@ -25,3 +25,14 @@ def test_iter():
with pytest.raises(StopIteration) as excinfo:
next(i)
assert excinfo.value.value == "Ended"
class AssertingSubClass(pyclasses.AssertingBaseClass):
pass
def test_new_classmethod():
# The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)

View file

@ -88,7 +88,7 @@ error: `signature` not allowed with `classattr`
105 | #[pyo3(signature = ())]
| ^^^^^^^^^
error: cannot specify a second method type
error: cannot combine these method types
--> tests/ui/invalid_pymethods.rs:112:7
|
112 | #[staticmethod]