From 9648d595a5a9339f52a62821ad31f77706a09b2f Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Sun, 16 Jun 2024 09:19:21 +0100 Subject: [PATCH] implement `PartialEq` for `Bound<'py, PyString>` (#4245) * implement `PartialEq` for `Bound<'py, PyString>` * fixup conditional code * document equality semantics for `Bound<'_, PyString>` * fix doc example --- newsfragments/4245.added.md | 1 + pyo3-ffi/src/unicodeobject.rs | 9 ++ src/instance.rs | 8 +- src/types/bytearray.rs | 13 +-- src/types/module.rs | 10 +- src/types/string.rs | 176 +++++++++++++++++++++++++++++++++- tests/test_proto_methods.rs | 7 +- 7 files changed, 194 insertions(+), 30 deletions(-) create mode 100644 newsfragments/4245.added.md diff --git a/newsfragments/4245.added.md b/newsfragments/4245.added.md new file mode 100644 index 00000000..692fb277 --- /dev/null +++ b/newsfragments/4245.added.md @@ -0,0 +1 @@ +Implement `PartialEq` for `Bound<'py, PyString>`. diff --git a/pyo3-ffi/src/unicodeobject.rs b/pyo3-ffi/src/unicodeobject.rs index 087160a1..519bbf26 100644 --- a/pyo3-ffi/src/unicodeobject.rs +++ b/pyo3-ffi/src/unicodeobject.rs @@ -328,6 +328,15 @@ extern "C" { pub fn PyUnicode_Compare(left: *mut PyObject, right: *mut PyObject) -> c_int; #[cfg_attr(PyPy, link_name = "PyPyUnicode_CompareWithASCIIString")] pub fn PyUnicode_CompareWithASCIIString(left: *mut PyObject, right: *const c_char) -> c_int; + #[cfg(Py_3_13)] + pub fn PyUnicode_EqualToUTF8(unicode: *mut PyObject, string: *const c_char) -> c_int; + #[cfg(Py_3_13)] + pub fn PyUnicode_EqualToUTF8AndSize( + unicode: *mut PyObject, + string: *const c_char, + size: Py_ssize_t, + ) -> c_int; + pub fn PyUnicode_RichCompare( left: *mut PyObject, right: *mut PyObject, diff --git a/src/instance.rs b/src/instance.rs index bc9b68ef..4703dd12 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -2010,9 +2010,7 @@ impl PyObject { #[cfg(test)] mod tests { use super::{Bound, Py, PyObject}; - use crate::types::any::PyAnyMethods; - use crate::types::{dict::IntoPyDict, PyDict, PyString}; - use crate::types::{PyCapsule, PyStringMethods}; + use crate::types::{dict::IntoPyDict, PyAnyMethods, PyCapsule, PyDict, PyString}; use crate::{ffi, Borrowed, PyAny, PyResult, Python, ToPyObject}; #[test] @@ -2021,7 +2019,7 @@ mod tests { let obj = py.get_type_bound::().to_object(py); let assert_repr = |obj: &Bound<'_, PyAny>, expected: &str| { - assert_eq!(obj.repr().unwrap().to_cow().unwrap(), expected); + assert_eq!(obj.repr().unwrap(), expected); }; assert_repr(obj.call0(py).unwrap().bind(py), "{}"); @@ -2221,7 +2219,7 @@ a = A() let obj_unbound: Py = obj.unbind(); let obj: Bound<'_, PyString> = obj_unbound.into_bound(py); - assert_eq!(obj.to_cow().unwrap(), "hello world"); + assert_eq!(obj, "hello world"); }); } diff --git a/src/types/bytearray.rs b/src/types/bytearray.rs index 1a66c71b..c411e830 100644 --- a/src/types/bytearray.rs +++ b/src/types/bytearray.rs @@ -515,12 +515,8 @@ impl<'py> TryFrom<&Bound<'py, PyAny>> for Bound<'py, PyByteArray> { #[cfg(test)] mod tests { - use crate::types::any::PyAnyMethods; - use crate::types::bytearray::PyByteArrayMethods; - use crate::types::string::PyStringMethods; - use crate::types::PyByteArray; - use crate::{exceptions, Bound, PyAny}; - use crate::{PyObject, Python}; + use crate::types::{PyAnyMethods, PyByteArray, PyByteArrayMethods}; + use crate::{exceptions, Bound, PyAny, PyObject, Python}; #[test] fn test_len() { @@ -555,10 +551,7 @@ mod tests { slice[0..5].copy_from_slice(b"Hi..."); - assert_eq!( - bytearray.str().unwrap().to_cow().unwrap(), - "bytearray(b'Hi... Python')" - ); + assert_eq!(bytearray.str().unwrap(), "bytearray(b'Hi... Python')"); }); } diff --git a/src/types/module.rs b/src/types/module.rs index 20f8305a..e866ec9c 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -37,7 +37,7 @@ impl PyModule { /// Python::with_gil(|py| -> PyResult<()> { /// let module = PyModule::new_bound(py, "my_module")?; /// - /// assert_eq!(module.name()?.to_cow()?, "my_module"); + /// assert_eq!(module.name()?, "my_module"); /// Ok(()) /// })?; /// # Ok(())} @@ -728,7 +728,7 @@ fn __name__(py: Python<'_>) -> &Bound<'_, PyString> { #[cfg(test)] mod tests { use crate::{ - types::{module::PyModuleMethods, string::PyStringMethods, PyModule}, + types::{module::PyModuleMethods, PyModule}, Python, }; @@ -736,15 +736,13 @@ mod tests { fn module_import_and_name() { Python::with_gil(|py| { let builtins = PyModule::import_bound(py, "builtins").unwrap(); - assert_eq!( - builtins.name().unwrap().to_cow().unwrap().as_ref(), - "builtins" - ); + assert_eq!(builtins.name().unwrap(), "builtins"); }) } #[test] fn module_filename() { + use crate::types::string::PyStringMethods; Python::with_gil(|py| { let site = PyModule::import_bound(py, "site").unwrap(); assert!(site diff --git a/src/types/string.rs b/src/types/string.rs index 0582a900..8556e41d 100644 --- a/src/types/string.rs +++ b/src/types/string.rs @@ -123,7 +123,33 @@ impl<'a> PyStringData<'a> { /// Represents a Python `string` (a Unicode string object). /// -/// This type is immutable. +/// This type is only seen inside PyO3's smart pointers as [`Py`], [`Bound<'py, PyString>`], +/// and [`Borrowed<'a, 'py, PyString>`]. +/// +/// All functionality on this type is implemented through the [`PyStringMethods`] trait. +/// +/// # Equality +/// +/// For convenience, [`Bound<'py, PyString>`] implements [`PartialEq`] to allow comparing the +/// data in the Python string to a Rust UTF-8 string slice. +/// +/// This is not always the most appropriate way to compare Python strings, as Python string subclasses +/// may have different equality semantics. In situations where subclasses overriding equality might be +/// relevant, use [`PyAnyMethods::eq`], at cost of the additional overhead of a Python method call. +/// +/// ```rust +/// # use pyo3::prelude::*; +/// use pyo3::types::PyString; +/// +/// # Python::with_gil(|py| { +/// let py_string = PyString::new_bound(py, "foo"); +/// // via PartialEq +/// assert_eq!(py_string, "foo"); +/// +/// // via Python equality +/// assert!(py_string.as_any().eq("foo").unwrap()); +/// # }); +/// ``` #[repr(transparent)] pub struct PyString(PyAny); @@ -490,6 +516,118 @@ impl IntoPy> for &'_ Py { } } +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq for Bound<'_, PyString> { + #[inline] + fn eq(&self, other: &str) -> bool { + self.as_borrowed() == *other + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq<&'_ str> for Bound<'_, PyString> { + #[inline] + fn eq(&self, other: &&str) -> bool { + self.as_borrowed() == **other + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq> for str { + #[inline] + fn eq(&self, other: &Bound<'_, PyString>) -> bool { + *self == other.as_borrowed() + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq<&'_ Bound<'_, PyString>> for str { + #[inline] + fn eq(&self, other: &&Bound<'_, PyString>) -> bool { + *self == other.as_borrowed() + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq> for &'_ str { + #[inline] + fn eq(&self, other: &Bound<'_, PyString>) -> bool { + **self == other.as_borrowed() + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq for &'_ Bound<'_, PyString> { + #[inline] + fn eq(&self, other: &str) -> bool { + self.as_borrowed() == other + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq for Borrowed<'_, '_, PyString> { + #[inline] + fn eq(&self, other: &str) -> bool { + #[cfg(not(Py_3_13))] + { + self.to_cow().map_or(false, |s| s == other) + } + + #[cfg(Py_3_13)] + unsafe { + ffi::PyUnicode_EqualToUTF8AndSize( + self.as_ptr(), + other.as_ptr().cast(), + other.len() as _, + ) == 1 + } + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq<&str> for Borrowed<'_, '_, PyString> { + #[inline] + fn eq(&self, other: &&str) -> bool { + *self == **other + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq> for str { + #[inline] + fn eq(&self, other: &Borrowed<'_, '_, PyString>) -> bool { + other == self + } +} + +/// Compares whether the data in the Python string is equal to the given UTF8. +/// +/// In some cases Python equality might be more appropriate; see the note on [`PyString`]. +impl PartialEq> for &'_ str { + #[inline] + fn eq(&self, other: &Borrowed<'_, '_, PyString>) -> bool { + other == self + } +} + #[cfg(test)] mod tests { use super::*; @@ -708,15 +846,15 @@ mod tests { fn test_intern_string() { Python::with_gil(|py| { let py_string1 = PyString::intern_bound(py, "foo"); - assert_eq!(py_string1.to_cow().unwrap(), "foo"); + assert_eq!(py_string1, "foo"); let py_string2 = PyString::intern_bound(py, "foo"); - assert_eq!(py_string2.to_cow().unwrap(), "foo"); + assert_eq!(py_string2, "foo"); assert_eq!(py_string1.as_ptr(), py_string2.as_ptr()); let py_string3 = PyString::intern_bound(py, "bar"); - assert_eq!(py_string3.to_cow().unwrap(), "bar"); + assert_eq!(py_string3, "bar"); assert_ne!(py_string1.as_ptr(), py_string3.as_ptr()); }); @@ -762,4 +900,34 @@ mod tests { assert_eq!(py_string.to_string_lossy(py), "🐈 Hello ���World"); }) } + + #[test] + fn test_comparisons() { + Python::with_gil(|py| { + let s = "hello, world"; + let py_string = PyString::new_bound(py, s); + + assert_eq!(py_string, "hello, world"); + + assert_eq!(py_string, s); + assert_eq!(&py_string, s); + assert_eq!(s, py_string); + assert_eq!(s, &py_string); + + assert_eq!(py_string, *s); + assert_eq!(&py_string, *s); + assert_eq!(*s, py_string); + assert_eq!(*s, &py_string); + + let py_string = py_string.as_borrowed(); + + assert_eq!(py_string, s); + assert_eq!(&py_string, s); + assert_eq!(s, py_string); + assert_eq!(s, &py_string); + + assert_eq!(py_string, *s); + assert_eq!(*s, py_string); + }) + } } diff --git a/tests/test_proto_methods.rs b/tests/test_proto_methods.rs index 5f0fa105..06e0d45e 100644 --- a/tests/test_proto_methods.rs +++ b/tests/test_proto_methods.rs @@ -131,7 +131,7 @@ fn test_delattr() { fn test_str() { Python::with_gil(|py| { let example_py = make_example(py); - assert_eq!(example_py.str().unwrap().to_cow().unwrap(), "5"); + assert_eq!(example_py.str().unwrap(), "5"); }) } @@ -139,10 +139,7 @@ fn test_str() { fn test_repr() { Python::with_gil(|py| { let example_py = make_example(py); - assert_eq!( - example_py.repr().unwrap().to_cow().unwrap(), - "ExampleClass(value=5)" - ); + assert_eq!(example_py.repr().unwrap(), "ExampleClass(value=5)"); }) }