diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 255fc80c..1e7f29d8 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -808,7 +808,7 @@ fn impl_simple_enum( }; let (default_richcmp, default_richcmp_slot) = - pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx); + pyclass_richcmp_simple_enum(&args.options, &ty, repr_type, ctx)?; let (default_hash, default_hash_slot) = pyclass_hash(&args.options, &ty, ctx)?; let mut default_slots = vec![default_repr_slot, default_int_slot]; @@ -1670,14 +1670,16 @@ fn pyclass_richcmp_arms(options: &PyClassPyO3Options, ctx: &Ctx) -> TokenStream let eq_arms = options .eq - .map(|eq| { - quote_spanned! { eq.span() => + .map(|eq| eq.span) + .or(options.eq_int.map(|eq_int| eq_int.span)) + .map(|span| { + quote_spanned! { span => #pyo3_path::pyclass::CompareOp::Eq => { ::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val == other, py)) }, #pyo3_path::pyclass::CompareOp::Ne => { ::std::result::Result::Ok(#pyo3_path::conversion::IntoPy::into_py(self_val != other, py)) - }, + }, } }) .unwrap_or_default(); @@ -1692,15 +1694,15 @@ fn pyclass_richcmp_simple_enum( cls: &syn::Type, repr_type: &syn::Ident, ctx: &Ctx, -) -> (Option, Option) { +) -> Result<(Option, Option)> { let Ctx { pyo3_path } = ctx; - let arms = pyclass_richcmp_arms(options, ctx); + if let Some(eq_int) = options.eq_int { + ensure_spanned!(options.eq.is_some(), eq_int.span() => "The `eq_int` option requires the `eq` option."); + } - let deprecation = options - .eq_int - .map(|_| TokenStream::new()) - .unwrap_or_else(|| { + let deprecation = (options.eq_int.is_none() && options.eq.is_none()) + .then(|| { quote! { #[deprecated( since = "0.22.0", @@ -1709,15 +1711,20 @@ fn pyclass_richcmp_simple_enum( const DEPRECATION: () = (); const _: () = DEPRECATION; } - }); + }) + .unwrap_or_default(); let mut options = options.clone(); - options.eq_int = Some(parse_quote!(eq_int)); + if options.eq.is_none() { + options.eq_int = Some(parse_quote!(eq_int)); + } if options.eq.is_none() && options.eq_int.is_none() { - return (None, None); + return Ok((None, None)); } + let arms = pyclass_richcmp_arms(&options, ctx); + let eq = options.eq.map(|eq| { quote_spanned! { eq.span() => let self_val = self; @@ -1766,7 +1773,7 @@ fn pyclass_richcmp_simple_enum( } else { generate_default_protocol_slot(cls, &mut richcmp_impl, &__RICHCMP__, ctx).unwrap() }; - (Some(richcmp_impl), Some(richcmp_slot)) + Ok((Some(richcmp_impl), Some(richcmp_slot))) } fn pyclass_richcmp( diff --git a/tests/test_enum.rs b/tests/test_enum.rs index 1406ac7f..4b721435 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -239,6 +239,24 @@ fn test_custom_module() { }); } +#[pyclass(eq)] +#[derive(Debug, Clone, PartialEq)] +pub enum EqOnly { + VariantA, + VariantB, +} + +#[test] +fn test_simple_enum_eq_only() { + Python::with_gil(|py| { + let var1 = Py::new(py, EqOnly::VariantA).unwrap(); + let var2 = Py::new(py, EqOnly::VariantA).unwrap(); + let var3 = Py::new(py, EqOnly::VariantB).unwrap(); + py_assert!(py, var1 var2, "var1 == var2"); + py_assert!(py, var1 var3, "var1 != var3"); + }) +} + #[pyclass(frozen, eq, eq_int, hash)] #[derive(PartialEq, Hash)] enum SimpleEnumWithHash { @@ -298,3 +316,63 @@ fn test_complex_enum_with_hash() { py_assert!(py, *env, "hash(obj) == hsh"); }); } + +#[allow(deprecated)] +mod deprecated { + use crate::py_assert; + use pyo3::prelude::*; + use pyo3::py_run; + + #[pyclass] + #[derive(Debug, PartialEq, Eq, Clone)] + pub enum MyEnum { + Variant, + OtherVariant, + } + + #[test] + fn test_enum_eq_enum() { + Python::with_gil(|py| { + let var1 = Py::new(py, MyEnum::Variant).unwrap(); + let var2 = Py::new(py, MyEnum::Variant).unwrap(); + let other_var = Py::new(py, MyEnum::OtherVariant).unwrap(); + py_assert!(py, var1 var2, "var1 == var2"); + py_assert!(py, var1 other_var, "var1 != other_var"); + py_assert!(py, var1 var2, "(var1 != var2) == False"); + }) + } + + #[test] + fn test_enum_eq_incomparable() { + Python::with_gil(|py| { + let var1 = Py::new(py, MyEnum::Variant).unwrap(); + py_assert!(py, var1, "(var1 == 'foo') == False"); + py_assert!(py, var1, "(var1 != 'foo') == True"); + }) + } + + #[pyclass] + enum CustomDiscriminant { + One = 1, + Two = 2, + } + + #[test] + fn test_custom_discriminant() { + Python::with_gil(|py| { + #[allow(non_snake_case)] + let CustomDiscriminant = py.get_type_bound::(); + let one = Py::new(py, CustomDiscriminant::One).unwrap(); + let two = Py::new(py, CustomDiscriminant::Two).unwrap(); + py_run!(py, CustomDiscriminant one two, r#" + assert CustomDiscriminant.One == one + assert CustomDiscriminant.Two == two + assert CustomDiscriminant.One == 1 + assert CustomDiscriminant.Two == 2 + assert one != two + assert CustomDiscriminant.One != 2 + assert CustomDiscriminant.Two != 1 + "#); + }) + } +} diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs index f4b94a61..73bc9925 100644 --- a/tests/ui/invalid_pyclass_enum.rs +++ b/tests/ui/invalid_pyclass_enum.rs @@ -40,6 +40,12 @@ enum ComplexEqOptRequiresPartialEq { B { msg: String }, } +#[pyclass(eq_int)] +enum SimpleEqIntWithoutEq { + A, + B, +} + #[pyclass(eq_int)] enum NoEqInt { A(i32), diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr index d817e601..cfa3922e 100644 --- a/tests/ui/invalid_pyclass_enum.stderr +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -30,28 +30,34 @@ error: `constructor` can't be used on a simple enum variant 26 | #[pyo3(constructor = (a, b))] | ^^^^^^^^^^^ -error: `eq_int` can only be used on simple enums. +error: The `eq_int` option requires the `eq` option. --> tests/ui/invalid_pyclass_enum.rs:43:11 | 43 | #[pyclass(eq_int)] | ^^^^^^ +error: `eq_int` can only be used on simple enums. + --> tests/ui/invalid_pyclass_enum.rs:49:11 + | +49 | #[pyclass(eq_int)] + | ^^^^^^ + error: The `hash` option requires the `frozen` option. - --> tests/ui/invalid_pyclass_enum.rs:63:11 + --> tests/ui/invalid_pyclass_enum.rs:69:11 | -63 | #[pyclass(hash)] +69 | #[pyclass(hash)] | ^^^^ error: The `hash` option requires the `eq` option. - --> tests/ui/invalid_pyclass_enum.rs:63:11 + --> tests/ui/invalid_pyclass_enum.rs:69:11 | -63 | #[pyclass(hash)] +69 | #[pyclass(hash)] | ^^^^ error: The `hash` option requires the `eq` option. - --> tests/ui/invalid_pyclass_enum.rs:70:11 + --> tests/ui/invalid_pyclass_enum.rs:76:11 | -70 | #[pyclass(hash)] +76 | #[pyclass(hash)] | ^^^^ error[E0369]: binary operation `==` cannot be applied to type `&SimpleEqOptRequiresPartialEq` @@ -123,25 +129,25 @@ help: consider annotating `ComplexEqOptRequiresPartialEq` with `#[derive(Partial | error[E0277]: the trait bound `SimpleHashOptRequiresHash: Hash` is not satisfied - --> tests/ui/invalid_pyclass_enum.rs:49:31 + --> tests/ui/invalid_pyclass_enum.rs:55:31 | -49 | #[pyclass(frozen, eq, eq_int, hash)] +55 | #[pyclass(frozen, eq, eq_int, hash)] | ^^^^ the trait `Hash` is not implemented for `SimpleHashOptRequiresHash` | help: consider annotating `SimpleHashOptRequiresHash` with `#[derive(Hash)]` | -51 + #[derive(Hash)] -52 | enum SimpleHashOptRequiresHash { +57 + #[derive(Hash)] +58 | enum SimpleHashOptRequiresHash { | error[E0277]: the trait bound `ComplexHashOptRequiresHash: Hash` is not satisfied - --> tests/ui/invalid_pyclass_enum.rs:56:23 + --> tests/ui/invalid_pyclass_enum.rs:62:23 | -56 | #[pyclass(frozen, eq, hash)] +62 | #[pyclass(frozen, eq, hash)] | ^^^^ the trait `Hash` is not implemented for `ComplexHashOptRequiresHash` | help: consider annotating `ComplexHashOptRequiresHash` with `#[derive(Hash)]` | -58 + #[derive(Hash)] -59 | enum ComplexHashOptRequiresHash { +64 + #[derive(Hash)] +65 | enum ComplexHashOptRequiresHash { |