Some improvements to __richcmp__ on enums
- Implement __ne__ as well as __eq__. - Return NotImplemented when types cannot be converted, rather than throwing. - Compare the integer ids inside the __eq__/__ne__ implementation. Previously a match block was generated.
This commit is contained in:
parent
bdc1468baf
commit
26a9603519
|
@ -526,13 +526,6 @@ fn impl_enum_class(
|
|||
};
|
||||
|
||||
let (default_richcmp, default_richcmp_slot) = {
|
||||
let variants_eq = variants.iter().map(|variant| {
|
||||
let variant_name = variant.ident;
|
||||
quote! {
|
||||
(#cls::#variant_name, #cls::#variant_name) =>
|
||||
Ok(true.to_object(py)),
|
||||
}
|
||||
});
|
||||
let mut richcmp_impl: syn::ImplItemMethod = syn::parse_quote! {
|
||||
fn __pyo3__richcmp__(
|
||||
&self,
|
||||
|
@ -544,16 +537,26 @@ fn impl_enum_class(
|
|||
use ::core::result::Result::*;
|
||||
match op {
|
||||
_pyo3::basic::CompareOp::Eq => {
|
||||
let self_val = self.__pyo3__int__();
|
||||
if let Ok(i) = other.extract::<#repr_type>() {
|
||||
let self_val = self.__pyo3__int__();
|
||||
return Ok((self_val == i).to_object(py));
|
||||
}
|
||||
let other = other.extract::<_pyo3::PyRef<Self>>()?;
|
||||
let other = &*other;
|
||||
match (self, other) {
|
||||
#(#variants_eq)*
|
||||
_ => Ok(false.to_object(py)),
|
||||
if let Ok(other) = other.extract::<_pyo3::PyRef<Self>>() {
|
||||
return Ok((self_val == other.__pyo3__int__()).to_object(py));
|
||||
}
|
||||
|
||||
return Ok(py.NotImplemented());
|
||||
}
|
||||
_pyo3::basic::CompareOp::Ne => {
|
||||
let self_val = self.__pyo3__int__();
|
||||
if let Ok(i) = other.extract::<#repr_type>() {
|
||||
return Ok((self_val != i).to_object(py));
|
||||
}
|
||||
if let Ok(other) = other.extract::<_pyo3::PyRef<Self>>() {
|
||||
return Ok((self_val != other.__pyo3__int__()).to_object(py));
|
||||
}
|
||||
|
||||
return Ok(py.NotImplemented());
|
||||
}
|
||||
_ => Ok(py.NotImplemented()),
|
||||
}
|
||||
|
|
|
@ -52,23 +52,23 @@ fn test_enum_arg() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_enum_eq() {
|
||||
fn test_enum_eq_enum() {
|
||||
Python::with_gil(|py| {
|
||||
let var1 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
let var2 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
|
||||
py_assert!(py, var1 var2, "var1 == var2");
|
||||
py_assert!(py, var1 other_var, "var1 != other_var");
|
||||
py_assert!(py, var1 var2, "(var1 != var2) == False");
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_repr_correct() {
|
||||
fn test_enum_eq_incomparable() {
|
||||
Python::with_gil(|py| {
|
||||
let var1 = Py::new(py, MyEnum::Variant).unwrap();
|
||||
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
|
||||
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
|
||||
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
|
||||
py_assert!(py, var1, "(var1 == 'foo') == False");
|
||||
py_assert!(py, var1, "(var1 != 'foo') == True");
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue