From 53b83cccbfd4fece76978c4bb626dc8c691516e6 Mon Sep 17 00:00:00 2001 From: Georg Brandl Date: Tue, 21 Jun 2022 15:36:20 +0200 Subject: [PATCH] add `CompareOp::matches` (#2460) --- CHANGELOG.md | 2 ++ guide/src/class/object.md | 30 +++++++++++++++++++--- guide/src/class/protocols.md | 8 ++++-- src/pyclass.rs | 48 ++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9850359a..7fc75ba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add FFI definitions `Py_fstring_input`, `sendfunc`, and `_PyErr_StackItem`. [#2423](https://github.com/PyO3/pyo3/pull/2423) - Add `PyDateTime::new_with_fold`, `PyTime::new_with_fold`, `PyTime::get_fold`, `PyDateTime::get_fold` for PyPy. [#2428](https://github.com/PyO3/pyo3/pull/2428) - Allow `#[classattr]` take `Python` argument. [#2383](https://github.com/PyO3/pyo3/issues/2383) +- Add `CompareOp::matches` to easily implement `__richcmp__` as the result of a + Rust `std::cmp::Ordering` comparison. [#2460](https://github.com/PyO3/pyo3/pull/2460) ### Changed diff --git a/guide/src/class/object.md b/guide/src/class/object.md index ade91da8..fa1f4ef7 100644 --- a/guide/src/class/object.md +++ b/guide/src/class/object.md @@ -128,15 +128,15 @@ impl Number { Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`, `__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`. This method will be called with a value of `CompareOp` depending on the operation. - + ```rust use pyo3::class::basic::CompareOp; # use pyo3::prelude::*; -# +# # #[pyclass] # struct Number(i32); -# +# #[pymethods] impl Number { fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { @@ -152,6 +152,28 @@ impl Number { } ``` +If you obtain the result by comparing two Rust values, as in this example, you +can take a shortcut using `CompareOp::matches`: + +```rust +use pyo3::class::basic::CompareOp; + +# use pyo3::prelude::*; +# +# #[pyclass] +# struct Number(i32); +# +#[pymethods] +impl Number { + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.0.cmp(&other.0)) + } +} +``` + +It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches +the given `CompareOp`. + ### Truthyness We'll consider `Number` to be `True` if it is nonzero: @@ -229,4 +251,4 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { [`Hash`]: https://doc.rust-lang.org/std/hash/trait.Hash.html [`Hasher`]: https://doc.rust-lang.org/std/hash/trait.Hasher.html [`DefaultHasher`]: https://doc.rust-lang.org/std/collections/hash_map/struct.DefaultHasher.html -[SipHash]: https://en.wikipedia.org/wiki/SipHash \ No newline at end of file +[SipHash]: https://en.wikipedia.org/wiki/SipHash diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 716f45d5..6866ab67 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -70,8 +70,11 @@ given signatures should be interpreted as follows:
Return type The return type will normally be `PyResult`, but any Python object can be returned. - If the `object` is not of the type specified in the signature, the generated code will - automatically `return NotImplemented`. + If the second argument `object` is not of the type specified in the + signature, the generated code will automatically `return NotImplemented`. + + You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result + to the requested comparison.
- `__getattr__(, object) -> object` @@ -611,3 +614,4 @@ For details, look at the `#[pymethods]` regarding GC methods. [`PySequenceProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/sequence/trait.PySequenceProtocol.html [`PyIterProtocol`]: {{#PYO3_DOCS_URL}}/pyo3/class/iter/trait.PyIterProtocol.html [`PySequence`]: {{#PYO3_DOCS_URL}}/pyo3/types/struct.PySequence.html +[`CompareOp::matches`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.CompareOp.html#method.matches diff --git a/src/pyclass.rs b/src/pyclass.rs index f7bbb0cd..3a537157 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -10,6 +10,7 @@ use crate::{ IntoPy, IntoPyPointer, PyCell, PyErr, PyMethodDefType, PyObject, PyResult, PyTypeInfo, Python, }; use std::{ + cmp::Ordering, convert::TryInto, ffi::{CStr, CString}, os::raw::{c_char, c_int, c_uint, c_void}, @@ -452,6 +453,7 @@ pub enum CompareOp { } impl CompareOp { + /// Conversion from the C enum. pub fn from_raw(op: c_int) -> Option { match op { ffi::Py_LT => Some(CompareOp::Lt), @@ -463,6 +465,37 @@ impl CompareOp { _ => None, } } + + /// Returns if a Rust [`std::cmp::Ordering`] matches this ordering query. + /// + /// Usage example: + /// + /// ```rust + /// # use pyo3::prelude::*; + /// # use pyo3::class::basic::CompareOp; + /// + /// #[pyclass] + /// struct Size { + /// size: usize + /// } + /// + /// #[pymethods] + /// impl Size { + /// fn __richcmp__(&self, other: &Size, op: CompareOp) -> bool { + /// op.matches(self.size.cmp(&other.size)) + /// } + /// } + /// ``` + pub fn matches(&self, result: Ordering) -> bool { + match self { + CompareOp::Eq => result == Ordering::Equal, + CompareOp::Ne => result != Ordering::Equal, + CompareOp::Lt => result == Ordering::Less, + CompareOp::Le => result != Ordering::Greater, + CompareOp::Gt => result == Ordering::Greater, + CompareOp::Ge => result != Ordering::Less, + } + } } /// Output of `__next__` which can either `yield` the next value in the iteration, or @@ -597,3 +630,18 @@ pub trait Frozen: boolean_struct::private::Boolean {} impl Frozen for boolean_struct::True {} impl Frozen for boolean_struct::False {} + +mod tests { + #[test] + fn test_compare_op_matches() { + use super::CompareOp; + use std::cmp::Ordering; + + assert!(CompareOp::Eq.matches(Ordering::Equal)); + assert!(CompareOp::Ne.matches(Ordering::Less)); + assert!(CompareOp::Ge.matches(Ordering::Greater)); + assert!(CompareOp::Gt.matches(Ordering::Greater)); + assert!(CompareOp::Le.matches(Ordering::Equal)); + assert!(CompareOp::Lt.matches(Ordering::Less)); + } +}