fix incorrect `__richcmp__` for `eq_int` only simple enums (#4224)
* fix incorrect `__richcmp__` for `eq_int` only simple enums * add tests for deprecated simple enum eq behavior * only emit deprecation warning if neither `eq` nor `eq_int` were given * require `eq` for `eq_int`
This commit is contained in:
parent
36cdeb29c1
commit
7e5884c40b
|
@ -808,7 +808,7 @@ fn impl_simple_enum(
|
|||
};
|
||||
|
||||
let (default_richcmp, default_richcmp_slot) =
|
||||
pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx);
|
||||
pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx)?;
|
||||
let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?;
|
||||
|
||||
let mut default_slots = vec![default_repr_slot, default_int_slot];
|
||||
|
@ -1670,14 +1670,16 @@ fn pyclass_richcmp_arms(options: &PyClassPyO3Options, ctx: &Ctx) -> TokenStream
|
|||
|
||||
let eq_arms = options
|
||||
.eq
|
||||
.map(|eq| {
|
||||
quote_spanned! { eq.span() =>
|
||||
.map(|eq| eq.span)
|
||||
.or(options.eq_int.map(|eq_int| eq_int.span))
|
||||
.map(|span| {
|
||||
quote_spanned! { span =>
|
||||
#pyo3_path::pyclass::CompareOp::Eq => {
|
||||
::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val == other, py))
|
||||
},
|
||||
#pyo3_path::pyclass::CompareOp::Ne => {
|
||||
::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val != other, py))
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
@ -1692,15 +1694,15 @@ fn pyclass_richcmp_simple_enum(
|
|||
cls: &syn::Type,
|
||||
repr_type: &syn::Ident,
|
||||
ctx: &Ctx,
|
||||
) -> (Option<syn::ImplItemFn>, Option<MethodAndSlotDef>) {
|
||||
) -> Result<(Option<syn::ImplItemFn>, Option<MethodAndSlotDef>)> {
|
||||
let Ctx { pyo3_path } = ctx;
|
||||
|
||||
let arms = pyclass_richcmp_arms(options, ctx);
|
||||
if let Some(eq_int) = options.eq_int {
|
||||
ensure_spanned!(options.eq.is_some(), eq_int.span() => "The `eq_int` option requires the `eq` option.");
|
||||
}
|
||||
|
||||
let deprecation = options
|
||||
.eq_int
|
||||
.map(|_| TokenStream::new())
|
||||
.unwrap_or_else(|| {
|
||||
let deprecation = (options.eq_int.is_none() && options.eq.is_none())
|
||||
.then(|| {
|
||||
quote! {
|
||||
#[deprecated(
|
||||
since = "0.22.0",
|
||||
|
@ -1709,15 +1711,20 @@ fn pyclass_richcmp_simple_enum(
|
|||
const DEPRECATION: () = ();
|
||||
const _: () = DEPRECATION;
|
||||
}
|
||||
});
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut options = options.clone();
|
||||
options.eq_int = Some(parse_quote!(eq_int));
|
||||
if options.eq.is_none() {
|
||||
options.eq_int = Some(parse_quote!(eq_int));
|
||||
}
|
||||
|
||||
if options.eq.is_none() && options.eq_int.is_none() {
|
||||
return (None, None);
|
||||
return Ok((None, None));
|
||||
}
|
||||
|
||||
let arms = pyclass_richcmp_arms(&options, ctx);
|
||||
|
||||
let eq = options.eq.map(|eq| {
|
||||
quote_spanned! { eq.span() =>
|
||||
let self_val = self;
|
||||
|
@ -1766,7 +1773,7 @@ fn pyclass_richcmp_simple_enum(
|
|||
} else {
|
||||
generate_default_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, ctx).unwrap()
|
||||
};
|
||||
(Some(richcmp_impl), Some(richcmp_slot))
|
||||
Ok((Some(richcmp_impl), Some(richcmp_slot)))
|
||||
}
|
||||
|
||||
fn pyclass_richcmp(
|
||||
|
|
|
@ -239,6 +239,24 @@ fn test_custom_module() {
|
|||
});
|
||||
}
|
||||
|
||||
#[pyclass(eq)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum EqOnly {
|
||||
VariantA,
|
||||
VariantB,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_enum_eq_only() {
|
||||
Python::with_gil(|py| {
|
||||
let var1 = Py::new(py, EqOnly::VariantA).unwrap();
|
||||
let var2 = Py::new(py, EqOnly::VariantA).unwrap();
|
||||
let var3 = Py::new(py, EqOnly::VariantB).unwrap();
|
||||
py_assert!(py, var1 var2, "var1 == var2");
|
||||
py_assert!(py, var1 var3, "var1 != var3");
|
||||
})
|
||||
}
|
||||
|
||||
#[pyclass(frozen, eq, eq_int, hash)]
|
||||
#[derive(PartialEq, Hash)]
|
||||
enum SimpleEnumWithHash {
|
||||
|
@ -298,3 +316,63 @@ fn test_complex_enum_with_hash() {
|
|||
py_assert!(py, *env, "hash(obj) == hsh");
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
mod deprecated {
|
||||
use crate::py_assert;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::py_run;
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum MyEnum {
|
||||
Variant,
|
||||
OtherVariant,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enum_eq_enum() {
|
||||
Python::with_gil(|py| {
|
||||
let var1 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
let var2 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
|
||||
py_assert!(py, var1 var2, "var1 == var2");
|
||||
py_assert!(py, var1 other_var, "var1 != other_var");
|
||||
py_assert!(py, var1 var2, "(var1 != var2) == False");
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enum_eq_incomparable() {
|
||||
Python::with_gil(|py| {
|
||||
let var1 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
py_assert!(py, var1, "(var1 == 'foo') == False");
|
||||
py_assert!(py, var1, "(var1 != 'foo') == True");
|
||||
})
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
enum CustomDiscriminant {
|
||||
One = 1,
|
||||
Two = 2,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_discriminant() {
|
||||
Python::with_gil(|py| {
|
||||
#[allow(non_snake_case)]
|
||||
let CustomDiscriminant = py.get_type_bound::<CustomDiscriminant>();
|
||||
let one = Py::new(py, CustomDiscriminant::One).unwrap();
|
||||
let two = Py::new(py, CustomDiscriminant::Two).unwrap();
|
||||
py_run!(py, CustomDiscriminant one two, r#"
|
||||
assert CustomDiscriminant.One == one
|
||||
assert CustomDiscriminant.Two == two
|
||||
assert CustomDiscriminant.One == 1
|
||||
assert CustomDiscriminant.Two == 2
|
||||
assert one != two
|
||||
assert CustomDiscriminant.One != 2
|
||||
assert CustomDiscriminant.Two != 1
|
||||
"#);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,12 @@ enum ComplexEqOptRequiresPartialEq {
|
|||
B { msg: String },
|
||||
}
|
||||
|
||||
#[pyclass(eq_int)]
|
||||
enum SimpleEqIntWithoutEq {
|
||||
A,
|
||||
B,
|
||||
}
|
||||
|
||||
#[pyclass(eq_int)]
|
||||
enum NoEqInt {
|
||||
A(i32),
|
||||
|
|
|
@ -30,28 +30,34 @@ error: `constructor` can't be used on a simple enum variant
|
|||
26 | #[pyo3(constructor = (a, b))]
|
||||
| ^^^^^^^^^^^
|
||||
|
||||
error: `eq_int` can only be used on simple enums.
|
||||
error: The `eq_int` option requires the `eq` option.
|
||||
--> tests/ui/invalid_pyclass_enum.rs:43:11
|
||||
|
|
||||
43 | #[pyclass(eq_int)]
|
||||
| ^^^^^^
|
||||
|
||||
error: `eq_int` can only be used on simple enums.
|
||||
--> tests/ui/invalid_pyclass_enum.rs:49:11
|
||||
|
|
||||
49 | #[pyclass(eq_int)]
|
||||
| ^^^^^^
|
||||
|
||||
error: The `hash` option requires the `frozen` option.
|
||||
--> tests/ui/invalid_pyclass_enum.rs:63:11
|
||||
--> tests/ui/invalid_pyclass_enum.rs:69:11
|
||||
|
|
||||
63 | #[pyclass(hash)]
|
||||
69 | #[pyclass(hash)]
|
||||
| ^^^^
|
||||
|
||||
error: The `hash` option requires the `eq` option.
|
||||
--> tests/ui/invalid_pyclass_enum.rs:63:11
|
||||
--> tests/ui/invalid_pyclass_enum.rs:69:11
|
||||
|
|
||||
63 | #[pyclass(hash)]
|
||||
69 | #[pyclass(hash)]
|
||||
| ^^^^
|
||||
|
||||
error: The `hash` option requires the `eq` option.
|
||||
--> tests/ui/invalid_pyclass_enum.rs:70:11
|
||||
--> tests/ui/invalid_pyclass_enum.rs:76:11
|
||||
|
|
||||
70 | #[pyclass(hash)]
|
||||
76 | #[pyclass(hash)]
|
||||
| ^^^^
|
||||
|
||||
error[E0369]: binary operation `==` cannot be applied to type `&SimpleEqOptRequiresPartialEq`
|
||||
|
@ -123,25 +129,25 @@ help: consider annotating `ComplexEqOptRequiresPartialEq` with `#[derive(Partial
|
|||
|
|
||||
|
||||
error[E0277]: the trait bound `SimpleHashOptRequiresHash: Hash` is not satisfied
|
||||
--> tests/ui/invalid_pyclass_enum.rs:49:31
|
||||
--> tests/ui/invalid_pyclass_enum.rs:55:31
|
||||
|
|
||||
49 | #[pyclass(frozen, eq, eq_int, hash)]
|
||||
55 | #[pyclass(frozen, eq, eq_int, hash)]
|
||||
| ^^^^ the trait `Hash` is not implemented for `SimpleHashOptRequiresHash`
|
||||
|
|
||||
help: consider annotating `SimpleHashOptRequiresHash` with `#[derive(Hash)]`
|
||||
|
|
||||
51 + #[derive(Hash)]
|
||||
52 | enum SimpleHashOptRequiresHash {
|
||||
57 + #[derive(Hash)]
|
||||
58 | enum SimpleHashOptRequiresHash {
|
||||
|
|
||||
|
||||
error[E0277]: the trait bound `ComplexHashOptRequiresHash: Hash` is not satisfied
|
||||
--> tests/ui/invalid_pyclass_enum.rs:56:23
|
||||
--> tests/ui/invalid_pyclass_enum.rs:62:23
|
||||
|
|
||||
56 | #[pyclass(frozen, eq, hash)]
|
||||
62 | #[pyclass(frozen, eq, hash)]
|
||||
| ^^^^ the trait `Hash` is not implemented for `ComplexHashOptRequiresHash`
|
||||
|
|
||||
help: consider annotating `ComplexHashOptRequiresHash` with `#[derive(Hash)]`
|
||||
|
|
||||
58 + #[derive(Hash)]
|
||||
59 | enum ComplexHashOptRequiresHash {
|
||||
64 + #[derive(Hash)]
|
||||
65 | enum ComplexHashOptRequiresHash {
|
||||
|
|
||||
|
|
Loading…
Reference in New Issue