From 56b7c38e24215bcefb9b084f1ef83838f93160a4 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 11 Jul 2023 20:13:35 +0100 Subject: [PATCH] improve error span for mutable access to `#[pyclass(frozen)]` --- pyo3-macros-backend/src/method.rs | 76 ++++++++++++------- pyo3-macros-backend/src/pyclass.rs | 7 +- pyo3-macros-backend/src/pymethod.rs | 37 +++------ tests/ui/invalid_frozen_pyclass_borrow.rs | 11 +++ tests/ui/invalid_frozen_pyclass_borrow.stderr | 36 ++++++--- 5 files changed, 101 insertions(+), 66 deletions(-) diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 4ca8ec0e..79da8574 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -8,7 +8,7 @@ use quote::ToTokens; use quote::{quote, quote_spanned}; use syn::ext::IdentExt; use syn::spanned::Spanned; -use syn::Result; +use syn::{Result, Token}; #[derive(Clone, Debug)] pub struct FnArg<'a> { @@ -145,7 +145,7 @@ impl FnType { #[derive(Clone, Debug)] pub enum SelfType { - Receiver { mutable: bool }, + Receiver { mutable: bool, span: Span }, TryFromPyCell(Span), } @@ -155,38 +155,55 @@ pub enum ExtractErrorMode { Raise, } +impl ExtractErrorMode { + pub fn handle_error(self, py: &syn::Ident, extract: TokenStream) -> TokenStream { + match self { + ExtractErrorMode::Raise => quote! { #extract? }, + ExtractErrorMode::NotImplemented => quote! { + match #extract { + ::std::result::Result::Ok(value) => value, + ::std::result::Result::Err(_) => { return _pyo3::callback::convert(#py, #py.NotImplemented()); }, + } + }, + } + } +} + impl SelfType { pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream { - let cell = match error_mode { - ExtractErrorMode::Raise => { - quote! { _py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>()? } - } - ExtractErrorMode::NotImplemented => { - quote! { - match _py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>() { - ::std::result::Result::Ok(cell) => cell, - ::std::result::Result::Err(_) => return _pyo3::callback::convert(_py, _py.NotImplemented()), - } - } - } - }; + let py = syn::Ident::new("_py", Span::call_site()); + let _slf = syn::Ident::new("_slf", Span::call_site()); match self { - SelfType::Receiver { mutable: false } => { - quote! { - let _cell = #cell; - let _ref = _cell.try_borrow()?; - let _slf: &#cls = &*_ref; - } - } - SelfType::Receiver { mutable: true } => { - quote! { - let _cell = #cell; - let mut _ref = _cell.try_borrow_mut()?; - let _slf: &mut #cls = &mut *_ref; + SelfType::Receiver { span, mutable } => { + let (method, mutability) = if *mutable { + ( + quote_spanned! { *span => extract_pyclass_ref_mut }, + Some(Token![mut](*span)), + ) + } else { + (quote_spanned! { *span => extract_pyclass_ref }, None) + }; + let extract = error_mode.handle_error( + &py, + quote_spanned! { *span => + _pyo3::impl_::extract_argument::#method( + #py.from_borrowed_ptr::<_pyo3::PyAny>(#_slf), + &mut holder, + ) + }, + ); + quote_spanned! { *span => + let mut holder = _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT; + let #_slf: &#mutability #cls = #extract; } } SelfType::TryFromPyCell(span) => { - let _slf = quote! { _slf }; + let cell = error_mode.handle_error( + &py, + quote!{ + _py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>() + } + ); quote_spanned! { *span => let _cell = #cell; #[allow(clippy::useless_conversion)] // In case _slf is PyCell @@ -256,8 +273,9 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> Result { ) => { bail_spanned!(recv.span() => RECEIVER_BY_VALUE_ERR); } - syn::FnArg::Receiver(syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver { + syn::FnArg::Receiver(recv @ syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver { mutable: mutability.is_some(), + span: recv.span(), }), syn::FnArg::Typed(syn::PatType { ty, .. }) => { if let syn::Type::ImplTrait(_) = &**ty { diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index d944bff0..440e300d 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; +use crate::attributes::kw::frozen; use crate::attributes::{ self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute, ModuleAttribute, NameAttribute, NameLitStr, TextSignatureAttribute, @@ -355,7 +356,7 @@ fn impl_class( cls, args, methods_type, - descriptors_to_items(cls, field_options)?, + descriptors_to_items(cls, args.options.frozen, field_options)?, vec![], ) .doc(doc) @@ -674,6 +675,7 @@ fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result, field_options: Vec<(&syn::Field, FieldPyO3Options)>, ) -> syn::Result> { let ty = syn::parse_quote!(#cls); @@ -700,7 +702,8 @@ fn descriptors_to_items( items.push(getter); } - if options.set.is_some() { + if let Some(set) = options.set { + ensure_spanned!(frozen.is_none(), set.span() => "cannot use `#[pyo3(set)]` on a `frozen` class"); let setter = impl_py_setter_def( &ty, PropertyType::Descriptor { diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index b6b6f3e8..0e3861ea 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -523,9 +523,11 @@ pub fn impl_py_setter_def( }; let slf = match property_type { - PropertyType::Descriptor { .. } => { - SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise) + PropertyType::Descriptor { .. } => SelfType::Receiver { + mutable: true, + span: Span::call_site(), } + .receiver(cls, ExtractErrorMode::Raise), PropertyType::Function { self_type, .. } => { self_type.receiver(cls, ExtractErrorMode::Raise) } @@ -638,9 +640,11 @@ pub fn impl_py_getter_def( }; let slf = match property_type { - PropertyType::Descriptor { .. } => { - SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise) + PropertyType::Descriptor { .. } => SelfType::Receiver { + mutable: false, + span: Span::call_site(), } + .receiver(cls, ExtractErrorMode::Raise), PropertyType::Function { self_type, .. } => { self_type.receiver(cls, ExtractErrorMode::Raise) } @@ -949,8 +953,7 @@ impl Ty { #ident.to_borrowed_any(#py) }, ), - Ty::CompareOp => handle_error( - extract_error_mode, + Ty::CompareOp => extract_error_mode.handle_error( py, quote! { _pyo3::class::basic::CompareOp::from_raw(#ident) @@ -959,8 +962,7 @@ impl Ty { ), Ty::PySsizeT => { let ty = arg.ty; - handle_error( - extract_error_mode, + extract_error_mode.handle_error( py, quote! { ::std::convert::TryInto::<#ty>::try_into(#ident).map_err(|e| _pyo3::exceptions::PyValueError::new_err(e.to_string())) @@ -973,30 +975,13 @@ impl Ty { } } -fn handle_error( - extract_error_mode: ExtractErrorMode, - py: &syn::Ident, - extract: TokenStream, -) -> TokenStream { - match extract_error_mode { - ExtractErrorMode::Raise => quote! { #extract? }, - ExtractErrorMode::NotImplemented => quote! { - match #extract { - ::std::result::Result::Ok(value) => value, - ::std::result::Result::Err(_) => { return _pyo3::callback::convert(#py, #py.NotImplemented()); }, - } - }, - } -} - fn extract_object( extract_error_mode: ExtractErrorMode, py: &syn::Ident, name: &str, source: TokenStream, ) -> TokenStream { - handle_error( - extract_error_mode, + extract_error_mode.handle_error( py, quote! { _pyo3::impl_::extract_argument::extract_argument( diff --git a/tests/ui/invalid_frozen_pyclass_borrow.rs b/tests/ui/invalid_frozen_pyclass_borrow.rs index c7b2f27b..1f18eab6 100644 --- a/tests/ui/invalid_frozen_pyclass_borrow.rs +++ b/tests/ui/invalid_frozen_pyclass_borrow.rs @@ -6,6 +6,11 @@ pub struct Foo { field: u32, } +#[pymethods] +impl Foo { + fn mut_method(&mut self) {} +} + fn borrow_mut_fails(foo: Py, py: Python) { let borrow = foo.as_ref(py).borrow_mut(); } @@ -28,4 +33,10 @@ fn pyclass_get_of_mutable_class_fails(class: &PyCell) { class.get(); } +#[pyclass(frozen)] +pub struct SetOnFrozenClass { + #[pyo3(set)] + field: u32, +} + fn main() {} diff --git a/tests/ui/invalid_frozen_pyclass_borrow.stderr b/tests/ui/invalid_frozen_pyclass_borrow.stderr index b91d5c0c..e68c2a58 100644 --- a/tests/ui/invalid_frozen_pyclass_borrow.stderr +++ b/tests/ui/invalid_frozen_pyclass_borrow.stderr @@ -1,7 +1,25 @@ -error[E0271]: type mismatch resolving `::Frozen == False` - --> tests/ui/invalid_frozen_pyclass_borrow.rs:10:33 +error: cannot use `#[pyo3(set)]` on a `frozen` class + --> tests/ui/invalid_frozen_pyclass_borrow.rs:38:12 | -10 | let borrow = foo.as_ref(py).borrow_mut(); +38 | #[pyo3(set)] + | ^^^ + +error[E0271]: type mismatch resolving `::Frozen == False` + --> tests/ui/invalid_frozen_pyclass_borrow.rs:11:19 + | +11 | fn mut_method(&mut self) {} + | ^ expected `False`, found `True` + | +note: required by a bound in `extract_pyclass_ref_mut` + --> src/impl_/extract_argument.rs + | + | pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass>( + | ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut` + +error[E0271]: type mismatch resolving `::Frozen == False` + --> tests/ui/invalid_frozen_pyclass_borrow.rs:15:33 + | +15 | let borrow = foo.as_ref(py).borrow_mut(); | ^^^^^^^^^^ expected `False`, found `True` | note: required by a bound in `pyo3::PyCell::::borrow_mut` @@ -11,9 +29,9 @@ note: required by a bound in `pyo3::PyCell::::borrow_mut` | ^^^^^^^^^^^^^^ required by this bound in `PyCell::::borrow_mut` error[E0271]: type mismatch resolving `::Frozen == False` - --> tests/ui/invalid_frozen_pyclass_borrow.rs:20:35 + --> tests/ui/invalid_frozen_pyclass_borrow.rs:25:35 | -20 | let borrow = child.as_ref(py).borrow_mut(); +25 | let borrow = child.as_ref(py).borrow_mut(); | ^^^^^^^^^^ expected `False`, found `True` | note: required by a bound in `pyo3::PyCell::::borrow_mut` @@ -23,9 +41,9 @@ note: required by a bound in `pyo3::PyCell::::borrow_mut` | ^^^^^^^^^^^^^^ required by this bound in `PyCell::::borrow_mut` error[E0271]: type mismatch resolving `::Frozen == True` - --> tests/ui/invalid_frozen_pyclass_borrow.rs:24:11 + --> tests/ui/invalid_frozen_pyclass_borrow.rs:29:11 | -24 | class.get(); +29 | class.get(); | ^^^ expected `True`, found `False` | note: required by a bound in `pyo3::Py::::get` @@ -35,9 +53,9 @@ note: required by a bound in `pyo3::Py::::get` | ^^^^^^^^^^^^^ required by this bound in `Py::::get` error[E0271]: type mismatch resolving `::Frozen == True` - --> tests/ui/invalid_frozen_pyclass_borrow.rs:28:11 + --> tests/ui/invalid_frozen_pyclass_borrow.rs:33:11 | -28 | class.get(); +33 | class.get(); | ^^^ expected `True`, found `False` | note: required by a bound in `pyo3::PyCell::::get`