fix incorrect `__richcmp__` for `eq_int` only simple enums (#4224)

* fix incorrect `__richcmp__` for `eq_int` only simple enums

* add tests for deprecated simple enum eq behavior

* only emit deprecation warning if neither `eq` nor `eq_int` were given

* require `eq` for `eq_int`
This commit is contained in:
Icxolu 2024-06-03 20:49:36 +02:00 committed by GitHub
parent 36cdeb29c1
commit 7e5884c40b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 126 additions and 29 deletions

View File

@ -808,7 +808,7 @@ fn impl_simple_enum(
};
let (default_richcmp, default_richcmp_slot) =
pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx);
pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx)?;
let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;
let mut default_slots = vec![default_repr_slot, default_int_slot];
@ -1670,14 +1670,16 @@ fn pyclass_richcmp_arms(options: &PyClassPyO3Options, ctx: &Ctx) -> TokenStream
let eq_arms = options
.eq
.map(|eq| {
quote_spanned! { eq.span() =>
.map(|eq| eq.span)
.or(options.eq_int.map(|eq_int| eq_int.span))
.map(|span| {
quote_spanned! { span =>
#pyo3_path::pyclass::CompareOp::Eq => {
::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val == other, py))
},
#pyo3_path::pyclass::CompareOp::Ne => {
::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val != other, py))
},
},
}
})
.unwrap_or_default();
@ -1692,15 +1694,15 @@ fn pyclass_richcmp_simple_enum(
cls: &syn::Type,
repr_type: &syn::Ident,
ctx: &Ctx,
) -> (Option<syn::ImplItemFn>, Option<MethodAndSlotDef>) {
) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
let Ctx { pyo3_path } = ctx;
let arms = pyclass_richcmp_arms(options, ctx);
if let Some(eq_int) = options.eq_int {
ensure_spanned!(options.eq.is_some(), eq_int.span() => "The `eq_int` option requires the `eq` option.");
}
let deprecation = options
.eq_int
.map(|_| TokenStream::new())
.unwrap_or_else(|| {
let deprecation = (options.eq_int.is_none() && options.eq.is_none())
.then(|| {
quote! {
#[deprecated(
since = "0.22.0",
@ -1709,15 +1711,20 @@ fn pyclass_richcmp_simple_enum(
const DEPRECATION: () = ();
const _: () = DEPRECATION;
}
});
})
.unwrap_or_default();
let mut options = options.clone();
options.eq_int = Some(parse_quote!(eq_int));
if options.eq.is_none() {
options.eq_int = Some(parse_quote!(eq_int));
}
if options.eq.is_none() && options.eq_int.is_none() {
return (None, None);
return Ok((None, None));
}
let arms = pyclass_richcmp_arms(&options, ctx);
let eq = options.eq.map(|eq| {
quote_spanned! { eq.span() =>
let self_val = self;
@ -1766,7 +1773,7 @@ fn pyclass_richcmp_simple_enum(
} else {
generate_default_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, ctx).unwrap()
};
(Some(richcmp_impl), Some(richcmp_slot))
Ok((Some(richcmp_impl), Some(richcmp_slot)))
}
fn pyclass_richcmp(

View File

@ -239,6 +239,24 @@ fn test_custom_module() {
});
}
#[pyclass(eq)]
#[derive(Debug, Clone, PartialEq)]
pub enum EqOnly {
VariantA,
VariantB,
}
#[test]
fn test_simple_enum_eq_only() {
Python::with_gil(|py| {
let var1 = Py::new(py, EqOnly::VariantA).unwrap();
let var2 = Py::new(py, EqOnly::VariantA).unwrap();
let var3 = Py::new(py, EqOnly::VariantB).unwrap();
py_assert!(py, var1 var2, "var1 == var2");
py_assert!(py, var1 var3, "var1 != var3");
})
}
#[pyclass(frozen, eq, eq_int, hash)]
#[derive(PartialEq, Hash)]
enum SimpleEnumWithHash {
@ -298,3 +316,63 @@ fn test_complex_enum_with_hash() {
py_assert!(py, *env, "hash(obj) == hsh");
});
}
#[allow(deprecated)]
mod deprecated {
use crate::py_assert;
use pyo3::prelude::*;
use pyo3::py_run;
#[pyclass]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum MyEnum {
Variant,
OtherVariant,
}
#[test]
fn test_enum_eq_enum() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::Variant).unwrap();
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1 var2, "var1 == var2");
py_assert!(py, var1 other_var, "var1 != other_var");
py_assert!(py, var1 var2, "(var1 != var2) == False");
})
}
#[test]
fn test_enum_eq_incomparable() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
py_assert!(py, var1, "(var1 == 'foo') == False");
py_assert!(py, var1, "(var1 != 'foo') == True");
})
}
#[pyclass]
enum CustomDiscriminant {
One = 1,
Two = 2,
}
#[test]
fn test_custom_discriminant() {
Python::with_gil(|py| {
#[allow(non_snake_case)]
let CustomDiscriminant = py.get_type_bound::<CustomDiscriminant>();
let one = Py::new(py, CustomDiscriminant::One).unwrap();
let two = Py::new(py, CustomDiscriminant::Two).unwrap();
py_run!(py, CustomDiscriminant one two, r#"
assert CustomDiscriminant.One == one
assert CustomDiscriminant.Two == two
assert CustomDiscriminant.One == 1
assert CustomDiscriminant.Two == 2
assert one != two
assert CustomDiscriminant.One != 2
assert CustomDiscriminant.Two != 1
"#);
})
}
}

View File

@ -40,6 +40,12 @@ enum ComplexEqOptRequiresPartialEq {
B { msg: String },
}
#[pyclass(eq_int)]
enum SimpleEqIntWithoutEq {
A,
B,
}
#[pyclass(eq_int)]
enum NoEqInt {
A(i32),

View File

@ -30,28 +30,34 @@ error: `constructor` can't be used on a simple enum variant
26 | #[pyo3(constructor = (a, b))]
| ^^^^^^^^^^^
error: `eq_int` can only be used on simple enums.
error: The `eq_int` option requires the `eq` option.
--> tests/ui/invalid_pyclass_enum.rs:43:11
|
43 | #[pyclass(eq_int)]
| ^^^^^^
error: `eq_int` can only be used on simple enums.
--> tests/ui/invalid_pyclass_enum.rs:49:11
|
49 | #[pyclass(eq_int)]
| ^^^^^^
error: The `hash` option requires the `frozen` option.
--> tests/ui/invalid_pyclass_enum.rs:63:11
--> tests/ui/invalid_pyclass_enum.rs:69:11
|
63 | #[pyclass(hash)]
69 | #[pyclass(hash)]
| ^^^^
error: The `hash` option requires the `eq` option.
--> tests/ui/invalid_pyclass_enum.rs:63:11
--> tests/ui/invalid_pyclass_enum.rs:69:11
|
63 | #[pyclass(hash)]
69 | #[pyclass(hash)]
| ^^^^
error: The `hash` option requires the `eq` option.
--> tests/ui/invalid_pyclass_enum.rs:70:11
--> tests/ui/invalid_pyclass_enum.rs:76:11
|
70 | #[pyclass(hash)]
76 | #[pyclass(hash)]
| ^^^^
error[E0369]: binary operation `==` cannot be applied to type `&SimpleEqOptRequiresPartialEq`
@ -123,25 +129,25 @@ help: consider annotating `ComplexEqOptRequiresPartialEq` with `#[derive(Partial
|
error[E0277]: the trait bound `SimpleHashOptRequiresHash: Hash` is not satisfied
--> tests/ui/invalid_pyclass_enum.rs:49:31
--> tests/ui/invalid_pyclass_enum.rs:55:31
|
49 | #[pyclass(frozen, eq, eq_int, hash)]
55 | #[pyclass(frozen, eq, eq_int, hash)]
| ^^^^ the trait `Hash` is not implemented for `SimpleHashOptRequiresHash`
|
help: consider annotating `SimpleHashOptRequiresHash` with `#[derive(Hash)]`
|
51 + #[derive(Hash)]
52 | enum SimpleHashOptRequiresHash {
57 + #[derive(Hash)]
58 | enum SimpleHashOptRequiresHash {
|
error[E0277]: the trait bound `ComplexHashOptRequiresHash: Hash` is not satisfied
--> tests/ui/invalid_pyclass_enum.rs:56:23
--> tests/ui/invalid_pyclass_enum.rs:62:23
|
56 | #[pyclass(frozen, eq, hash)]
62 | #[pyclass(frozen, eq, hash)]
| ^^^^ the trait `Hash` is not implemented for `ComplexHashOptRequiresHash`
|
help: consider annotating `ComplexHashOptRequiresHash` with `#[derive(Hash)]`
|
58 + #[derive(Hash)]
59 | enum ComplexHashOptRequiresHash {
64 + #[derive(Hash)]
65 | enum ComplexHashOptRequiresHash {
|