Some improvements to __richcmp__ on enums

- Implement __ne__ as well as __eq__.
 - Return NotImplemented when types cannot be converted, rather than
   throwing.
 - Compare the integer ids inside the __eq__/__ne__ implementation.
   Previously a match block was generated.
This commit is contained in:
Jonathan Coates 2022-09-14 12:24:17 +01:00
parent bdc1468baf
commit 26a9603519
2 changed files with 21 additions and 18 deletions

View File

@ -526,13 +526,6 @@ fn impl_enum_class(
};
let (default_richcmp, default_richcmp_slot) = {
let variants_eq = variants.iter().map(|variant| {
let variant_name = variant.ident;
quote! {
(#cls::#variant_name, #cls::#variant_name) =>
Ok(true.to_object(py)),
}
});
let mut richcmp_impl: syn::ImplItemMethod = syn::parse_quote! {
fn __pyo3__richcmp__(
&self,
@ -544,16 +537,26 @@ fn impl_enum_class(
use ::core::result::Result::*;
match op {
_pyo3::basic::CompareOp::Eq => {
let self_val = self.__pyo3__int__();
if let Ok(i) = other.extract::<#repr_type>() {
let self_val = self.__pyo3__int__();
return Ok((self_val == i).to_object(py));
}
let other = other.extract::<_pyo3::PyRef<Self>>()?;
let other = &*other;
match (self, other) {
#(#variants_eq)*
_ => Ok(false.to_object(py)),
if let Ok(other) = other.extract::<_pyo3::PyRef<Self>>() {
return Ok((self_val == other.__pyo3__int__()).to_object(py));
}
return Ok(py.NotImplemented());
}
_pyo3::basic::CompareOp::Ne => {
let self_val = self.__pyo3__int__();
if let Ok(i) = other.extract::<#repr_type>() {
return Ok((self_val != i).to_object(py));
}
if let Ok(other) = other.extract::<_pyo3::PyRef<Self>>() {
return Ok((self_val != other.__pyo3__int__()).to_object(py));
}
return Ok(py.NotImplemented());
}
_ => Ok(py.NotImplemented()),
}

View File

@ -52,23 +52,23 @@ fn test_enum_arg() {
}
#[test]
fn test_enum_eq() {
fn test_enum_eq_enum() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::Variant).unwrap();
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1 var2, "var1 == var2");
py_assert!(py, var1 other_var, "var1 != other_var");
py_assert!(py, var1 var2, "(var1 != var2) == False");
})
}
#[test]
fn test_default_repr_correct() {
fn test_enum_eq_incomparable() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
py_assert!(py, var1, "(var1 == 'foo') == False");
py_assert!(py, var1, "(var1 != 'foo') == True");
})
}