Merge #3157
3157: Add support for `#[new]` which is also a `#[classmethod]` r=davidhewitt a=stuhood Fixes #3077. Co-authored-by: Stu Hood <stuhood@gmail.com>
This commit is contained in:
commit
3b4c7d38c7
|
@ -607,6 +607,27 @@ Declares a class method callable from Python.
|
|||
* For details on `parameter-list`, see the documentation of `Method arguments` section.
|
||||
* The return type must be `PyResult<T>` or `T` for some `T` that implements `IntoPy<PyObject>`.
|
||||
|
||||
### Constructors which accept a class argument
|
||||
|
||||
To create a constructor which takes a positional class argument, you can combine the `#[classmethod]` and `#[new]` modifiers:
|
||||
```rust
|
||||
# use pyo3::prelude::*;
|
||||
# use pyo3::types::PyType;
|
||||
# #[pyclass]
|
||||
# struct BaseClass(PyObject);
|
||||
#
|
||||
#[pymethods]
|
||||
impl BaseClass {
|
||||
#[new]
|
||||
#[classmethod]
|
||||
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
|
||||
// Get an abstract attribute (presumably) declared on a subclass of this class.
|
||||
let subclass_attr = cls.getattr("a_class_attr")?;
|
||||
Ok(Self(subclass_attr.to_object(py)))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Static methods
|
||||
|
||||
To create a static method for a custom class, the method needs to be annotated with the
|
||||
|
|
1
newsfragments/3157.added.md
Normal file
1
newsfragments/3157.added.md
Normal file
|
@ -0,0 +1 @@
|
|||
Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument.
|
|
@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
|
|||
pub enum MethodTypeAttribute {
|
||||
/// `#[new]`
|
||||
New,
|
||||
/// `#[new]` && `#[classmethod]`
|
||||
NewClassMethod,
|
||||
/// `#[classmethod]`
|
||||
ClassMethod,
|
||||
/// `#[classattr]`
|
||||
|
@ -102,6 +104,7 @@ pub enum FnType {
|
|||
Setter(SelfType),
|
||||
Fn(SelfType),
|
||||
FnNew,
|
||||
FnNewClass,
|
||||
FnClass,
|
||||
FnStatic,
|
||||
FnModule,
|
||||
|
@ -122,7 +125,7 @@ impl FnType {
|
|||
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
|
||||
quote!()
|
||||
}
|
||||
FnType::FnClass => {
|
||||
FnType::FnClass | FnType::FnNewClass => {
|
||||
quote! {
|
||||
let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject);
|
||||
}
|
||||
|
@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
|
|||
let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
|
||||
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
|
||||
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
|
||||
Some(MethodTypeAttribute::New) => {
|
||||
Some(MethodTypeAttribute::New) | Some(MethodTypeAttribute::NewClassMethod) => {
|
||||
if let Some(name) = &python_name {
|
||||
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
|
||||
}
|
||||
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
|
||||
(FnType::FnNew, false, Some(CallingConvention::TpNew))
|
||||
if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) {
|
||||
(FnType::FnNew, false, Some(CallingConvention::TpNew))
|
||||
} else {
|
||||
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
|
||||
}
|
||||
}
|
||||
Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None),
|
||||
Some(MethodTypeAttribute::Getter) => {
|
||||
|
@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
|
|||
}
|
||||
CallingConvention::TpNew => {
|
||||
let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?;
|
||||
let call = quote! { #rust_name(#(#args),*) };
|
||||
let call = match &self.tp {
|
||||
FnType::FnNew => quote! { #rust_name(#(#args),*) },
|
||||
FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) },
|
||||
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
|
||||
};
|
||||
quote! {
|
||||
unsafe fn #ident(
|
||||
#py: _pyo3::Python<'_>,
|
||||
|
@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
|
|||
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
|
||||
FnType::Fn(_) => Some("self"),
|
||||
FnType::FnModule => Some("module"),
|
||||
FnType::FnClass => Some("cls"),
|
||||
FnType::FnClass | FnType::FnNewClass => Some("cls"),
|
||||
FnType::FnStatic | FnType::FnNew => None,
|
||||
};
|
||||
|
||||
|
@ -637,11 +648,22 @@ fn parse_method_attributes(
|
|||
let mut deprecated_args = None;
|
||||
let mut ty: Option<MethodTypeAttribute> = None;
|
||||
|
||||
macro_rules! set_compound_ty {
|
||||
($new_ty:expr, $ident:expr) => {
|
||||
ty = match (ty, $new_ty) {
|
||||
(None, new_ty) => Some(new_ty),
|
||||
(Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod),
|
||||
(Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod),
|
||||
(Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"),
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! set_ty {
|
||||
($new_ty:expr, $ident:expr) => {
|
||||
ensure_spanned!(
|
||||
ty.replace($new_ty).is_none(),
|
||||
$ident.span() => "cannot specify a second method type"
|
||||
$ident.span() => "cannot combine these method types"
|
||||
);
|
||||
};
|
||||
}
|
||||
|
@ -650,13 +672,13 @@ fn parse_method_attributes(
|
|||
match attr.parse_meta() {
|
||||
Ok(syn::Meta::Path(name)) => {
|
||||
if name.is_ident("new") || name.is_ident("__new__") {
|
||||
set_ty!(MethodTypeAttribute::New, name);
|
||||
set_compound_ty!(MethodTypeAttribute::New, name);
|
||||
} else if name.is_ident("init") || name.is_ident("__init__") {
|
||||
bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0");
|
||||
} else if name.is_ident("call") || name.is_ident("__call__") {
|
||||
bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0");
|
||||
} else if name.is_ident("classmethod") {
|
||||
set_ty!(MethodTypeAttribute::ClassMethod, name);
|
||||
set_compound_ty!(MethodTypeAttribute::ClassMethod, name);
|
||||
} else if name.is_ident("staticmethod") {
|
||||
set_ty!(MethodTypeAttribute::StaticMethod, name);
|
||||
} else if name.is_ident("classattr") {
|
||||
|
|
|
@ -234,7 +234,9 @@ pub fn gen_py_method(
|
|||
Some(quote!(_pyo3::ffi::METH_STATIC)),
|
||||
)?),
|
||||
// special prototypes
|
||||
(_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?),
|
||||
(_, FnType::FnNew) | (_, FnType::FnNewClass) => {
|
||||
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
|
||||
}
|
||||
|
||||
(_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def(
|
||||
cls,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::iter::IterNextOutput;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyType;
|
||||
|
||||
#[pyclass]
|
||||
struct EmptyClass {}
|
||||
|
@ -35,9 +37,30 @@ impl PyClassIter {
|
|||
}
|
||||
}
|
||||
|
||||
/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
|
||||
#[pyclass(subclass)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct AssertingBaseClass;
|
||||
|
||||
#[pymethods]
|
||||
impl AssertingBaseClass {
|
||||
#[new]
|
||||
#[classmethod]
|
||||
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
|
||||
if !cls.is(expected_type) {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"{:?} != {:?}",
|
||||
cls, expected_type
|
||||
)));
|
||||
}
|
||||
Ok(Self)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<EmptyClass>()?;
|
||||
m.add_class::<PyClassIter>()?;
|
||||
m.add_class::<AssertingBaseClass>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -25,3 +25,14 @@ def test_iter():
|
|||
with pytest.raises(StopIteration) as excinfo:
|
||||
next(i)
|
||||
assert excinfo.value.value == "Ended"
|
||||
|
||||
|
||||
class AssertingSubClass(pyclasses.AssertingBaseClass):
|
||||
pass
|
||||
|
||||
|
||||
def test_new_classmethod():
|
||||
# The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass.
|
||||
_ = AssertingSubClass(expected_type=AssertingSubClass)
|
||||
with pytest.raises(ValueError):
|
||||
_ = AssertingSubClass(expected_type=str)
|
||||
|
|
|
@ -88,7 +88,7 @@ error: `signature` not allowed with `classattr`
|
|||
105 | #[pyo3(signature = ())]
|
||||
| ^^^^^^^^^
|
||||
|
||||
error: cannot specify a second method type
|
||||
error: cannot combine these method types
|
||||
--> tests/ui/invalid_pymethods.rs:112:7
|
||||
|
|
||||
112 | #[staticmethod]
|
||||
|
|
Loading…
Reference in a new issue