From 305bc4324d39f7f252d4ec6cfeb0d0428ff5cc6f Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Mon, 6 Jun 2016 16:08:48 -0400 Subject: [PATCH 1/2] Add support for overloading comparison operators with __richcompare__ --- src/lib.rs | 1 + src/py_class/mod.rs | 25 +++++++++++++++++++++ src/py_class/py_class.rs | 8 ++++++- src/py_class/py_class_impl.py | 20 +++++++++++------ src/py_class/py_class_impl2.rs | 40 +++++++++++++++++++++++++++++----- src/py_class/py_class_impl3.rs | 40 +++++++++++++++++++++++++++++----- src/py_class/slots.rs | 31 ++++++++++++++++++++++++++ 7 files changed, 145 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3e383353..d0dc357d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,7 @@ pub use objects::*; pub use python::{Python, PythonObject, PythonObjectWithCheckedDowncast, PythonObjectDowncastError, PythonObjectWithTypeObject, PyClone, PyDrop}; pub use pythonrun::{GILGuard, GILProtected, prepare_freethreaded_python}; pub use conversion::{FromPyObject, RefFromPyObject, ToPyObject}; +pub use py_class::{CompareOp}; pub use objectprotocol::{ObjectProtocol}; #[cfg(feature="python27-sys")] diff --git a/src/py_class/mod.rs b/src/py_class/mod.rs index 03f436b8..a1485777 100644 --- a/src/py_class/mod.rs +++ b/src/py_class/mod.rs @@ -32,6 +32,31 @@ use objects::{PyObject, PyType}; use err::{self, PyResult}; use ffi; +#[derive(Debug)] +pub enum CompareOp { + Lt = ffi::Py_LT as isize, + Le = ffi::Py_LE as isize, + Eq = ffi::Py_EQ as isize, + Ne = ffi::Py_NE as isize, + Gt = ffi::Py_GT as isize, + Ge = ffi::Py_GE as isize, + Other +} + +impl> From for CompareOp { + fn from(val: T) -> Self { + match val.into() as libc::c_int { + ffi::Py_LT => CompareOp::Lt, + ffi::Py_LE => CompareOp::Le, + ffi::Py_EQ => CompareOp::Eq, + ffi::Py_NE => CompareOp::Ne, + ffi::Py_GT => CompareOp::Gt, + ffi::Py_GE => CompareOp::Ge, + _ => CompareOp::Other + } + } +} + /// Trait implemented by the types produced by the `py_class!()` macro. pub trait PythonObjectFromPyClassMacro : python::PythonObjectWithTypeObject { fn initialize(py: Python) -> PyResult; diff --git a/src/py_class/py_class.rs b/src/py_class/py_class.rs index 8bac200e..3e87831d 100644 --- a/src/py_class/py_class.rs +++ b/src/py_class/py_class.rs @@ -278,7 +278,13 @@ py_class!(class MyIterator |py| { ## Comparison operators -TODO: implement support for `__cmp__`, `__lt__`, `__le__`, `__gt__`, `__ge__`, `__eq__`, `__ne__`. +TODO: implement support for `__cmp__` + + * `def __richcompare__(&self, other: impl ToPyObject, op: CompareOp) -> PyResult` + + Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`). The `op` + argument indicates the comparison operation being performed. + The return type will normally be `PyResult`, but any Python object can be returned. * `def __hash__(&self) -> PyResult` diff --git a/src/py_class/py_class_impl.py b/src/py_class/py_class_impl.py index 91de39e9..cd4cead5 100644 --- a/src/py_class/py_class_impl.py +++ b/src/py_class/py_class_impl.py @@ -536,6 +536,9 @@ def operator(special_name, slot, param_list.append('{{ ${0} : ${0}_type = {{}} }}'.format(arg.name)) if slot == 'sq_contains': new_slots = [(slot, 'py_class_contains_slot!($class::%s, $%s_type)' % (special_name, args[0].name))] + elif slot == 'tp_richcompare': + new_slots = [(slot, 'py_class_richcompare_slot!($class::%s, $%s_type, %s, %s)' + % (special_name, args[0].name, res_ffi_type, res_conv))] elif len(args) == 0: new_slots = [(slot, 'py_class_unary_slot!($class::%s, %s, %s)' % (special_name, res_ffi_type, res_conv))] @@ -594,13 +597,16 @@ special_names = { '__bytes__': normal_method(), '__format__': normal_method(), # Comparison Operators - '__lt__': unimplemented(), - '__le__': unimplemented(), - '__gt__': unimplemented(), - '__ge__': unimplemented(), - '__eq__': unimplemented(), - '__ne__': unimplemented(), + '__lt__': error('__lt__ is not supported by py_class! use __richcompare__ instead.'), + '__le__': error('__le__ is not supported by py_class! use __richcompare__ instead.'), + '__gt__': error('__gt__ is not supported by py_class! use __richcompare__ instead.'), + '__ge__': error('__ge__ is not supported by py_class! use __richcompare__ instead.'), + '__eq__': error('__eq__ is not supported by py_class! use __richcompare__ instead.'), + '__ne__': error('__ne__ is not supported by py_class! use __richcompare__ instead.'), '__cmp__': unimplemented(), + '__richcompare__': operator('tp_richcompare', + res_type='PyObject', + args=[Argument('other'), Argument('op')]), '__hash__': operator('tp_hash', res_conv='$crate::py_class::slots::HashConverter', res_ffi_type='$crate::Py_hash_t'), @@ -652,7 +658,7 @@ special_names = { res_conv='$crate::py_class::slots::IterNextResultConverter'), '__reversed__': normal_method(), '__contains__': operator('sq_contains', args=[Argument('item')]), - + # Emulating numeric types '__add__': binary_numeric_operator('nb_add'), '__sub__': binary_numeric_operator('nb_subtract'), diff --git a/src/py_class/py_class_impl2.rs b/src/py_class/py_class_impl2.rs index 52204af8..0f4b2d20 100644 --- a/src/py_class/py_class_impl2.rs +++ b/src/py_class/py_class_impl2.rs @@ -580,7 +580,7 @@ macro_rules! py_class_impl { }; { { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__eq__ is not supported by py_class! yet." } + py_error! { "__eq__ is not supported by py_class! use __richcompare__ instead." } }; { { def __float__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -592,7 +592,7 @@ macro_rules! py_class_impl { }; { { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ge__ is not supported by py_class! yet." } + py_error! { "__ge__ is not supported by py_class! use __richcompare__ instead." } }; { { def __get__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -643,7 +643,7 @@ macro_rules! py_class_impl { }; { { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__gt__ is not supported by py_class! yet." } + py_error! { "__gt__ is not supported by py_class! use __richcompare__ instead." } }; { { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -809,7 +809,7 @@ macro_rules! py_class_impl { }; { { def __le__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__le__ is not supported by py_class! yet." } + py_error! { "__le__ is not supported by py_class! use __richcompare__ instead." } }; { { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -882,7 +882,7 @@ macro_rules! py_class_impl { }; { { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__lt__ is not supported by py_class! yet." } + py_error! { "__lt__ is not supported by py_class! use __richcompare__ instead." } }; { { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -924,7 +924,7 @@ macro_rules! py_class_impl { }; { { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ne__ is not supported by py_class! yet." } + py_error! { "__ne__ is not supported by py_class! use __richcompare__ instead." } }; { { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -1151,6 +1151,34 @@ macro_rules! py_class_impl { { { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" } }; + { { def __richcompare__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } + $class:ident $py:ident $info:tt + /* slots: */ { + /* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ] + $as_number:tt $as_sequence:tt $as_mapping:tt $setdelitem:tt + } + { $( $imp:item )* } + $members:tt + } => { py_class_impl! { + { $($tail)* } + $class $py $info + /* slots: */ { + /* type_slots */ [ + $( $tp_slot_name : $tp_slot_value, )* + tp_richcompare: py_class_richcompare_slot!($class::__richcompare__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + ] + $as_number $as_sequence $as_mapping $setdelitem + } + /* impl: */ { + $($imp)* + py_class_impl_item! { $class, $py, __richcompare__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } + } + $members + }}; + + { { def __richcompare__ $($tail:tt)* } $( $stuff:tt )* } => { + py_error! { "Invalid signature for operator __richcompare__" } + }; { { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rlshift__ is not supported by py_class! Use __lshift__ instead!" } diff --git a/src/py_class/py_class_impl3.rs b/src/py_class/py_class_impl3.rs index d456aac8..0790a148 100644 --- a/src/py_class/py_class_impl3.rs +++ b/src/py_class/py_class_impl3.rs @@ -580,7 +580,7 @@ macro_rules! py_class_impl { }; { { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__eq__ is not supported by py_class! yet." } + py_error! { "__eq__ is not supported by py_class! use __richcompare__ instead." } }; { { def __float__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -592,7 +592,7 @@ macro_rules! py_class_impl { }; { { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ge__ is not supported by py_class! yet." } + py_error! { "__ge__ is not supported by py_class! use __richcompare__ instead." } }; { { def __get__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -643,7 +643,7 @@ macro_rules! py_class_impl { }; { { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__gt__ is not supported by py_class! yet." } + py_error! { "__gt__ is not supported by py_class! use __richcompare__ instead." } }; { { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -809,7 +809,7 @@ macro_rules! py_class_impl { }; { { def __le__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__le__ is not supported by py_class! yet." } + py_error! { "__le__ is not supported by py_class! use __richcompare__ instead." } }; { { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -882,7 +882,7 @@ macro_rules! py_class_impl { }; { { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__lt__ is not supported by py_class! yet." } + py_error! { "__lt__ is not supported by py_class! use __richcompare__ instead." } }; { { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -924,7 +924,7 @@ macro_rules! py_class_impl { }; { { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ne__ is not supported by py_class! yet." } + py_error! { "__ne__ is not supported by py_class! use __richcompare__ instead." } }; { { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -1151,6 +1151,34 @@ macro_rules! py_class_impl { { { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" } }; + { { def __richcompare__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } + $class:ident $py:ident $info:tt + /* slots: */ { + /* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ] + $as_number:tt $as_sequence:tt $as_mapping:tt $setdelitem:tt + } + { $( $imp:item )* } + $members:tt + } => { py_class_impl! { + { $($tail)* } + $class $py $info + /* slots: */ { + /* type_slots */ [ + $( $tp_slot_name : $tp_slot_value, )* + tp_richcompare: py_class_richcompare_slot!($class::__richcompare__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + ] + $as_number $as_sequence $as_mapping $setdelitem + } + /* impl: */ { + $($imp)* + py_class_impl_item! { $class, $py, __richcompare__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } + } + $members + }}; + + { { def __richcompare__ $($tail:tt)* } $( $stuff:tt )* } => { + py_error! { "Invalid signature for operator __richcompare__" } + }; { { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rlshift__ is not supported by py_class! Use __lshift__ instead!" } diff --git a/src/py_class/slots.rs b/src/py_class/slots.rs index 0da42927..5a286787 100644 --- a/src/py_class/slots.rs +++ b/src/py_class/slots.rs @@ -320,6 +320,37 @@ macro_rules! py_class_ternary_slot { }} } +// sq_richcompare is special-cased slot +#[macro_export] +#[doc(hidden)] +macro_rules! py_class_richcompare_slot { + ($class:ident :: $f:ident, $arg_type:ty, $res_type:ty, $conv:expr) => {{ + unsafe extern "C" fn tp_richcompare( + slf: *mut $crate::_detail::ffi::PyObject, + arg: *mut $crate::_detail::ffi::PyObject, + op: $crate::_detail::libc::c_int) + -> $res_type + { + const LOCATION: &'static str = concat!(stringify!($class), ".", stringify!($f), "()"); + $crate::_detail::handle_callback( + LOCATION, $conv, + |py| { + let slf = $crate::PyObject::from_borrowed_ptr(py, slf).unchecked_cast_into::<$class>(); + let arg = $crate::PyObject::from_borrowed_ptr(py, arg); + let op = $crate::py_class::CompareOp::from(op as isize); + let ret = match <$arg_type as $crate::FromPyObject>::extract(py, &arg) { + Ok(arg) => slf.$f(py, arg, op), + Err(e) => Err(e) + }; + $crate::PyDrop::release_ref(arg, py); + $crate::PyDrop::release_ref(slf, py); + ret + }) + } + Some(tp_richcompare) + }} +} + // sq_contains is special-cased slot because it converts type errors to Ok(false) #[macro_export] #[doc(hidden)] From 291e08e29ab6ecd0114fade14cf536a6d10c8d26 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Mon, 6 Jun 2016 20:43:16 -0400 Subject: [PATCH 2/2] Remove CompareOp::Other, change to __richcmp__, and add tests --- src/py_class/mod.rs | 17 +------ src/py_class/py_class.rs | 4 +- src/py_class/py_class_impl.py | 16 +++---- src/py_class/py_class_impl2.rs | 24 +++++----- src/py_class/py_class_impl3.rs | 24 +++++----- src/py_class/slots.rs | 25 ++++++++-- tests/test_class.rs | 84 ++++++++++++++++++++++++++++++++++ 7 files changed, 139 insertions(+), 55 deletions(-) diff --git a/src/py_class/mod.rs b/src/py_class/mod.rs index a1485777..5019a097 100644 --- a/src/py_class/mod.rs +++ b/src/py_class/mod.rs @@ -39,22 +39,7 @@ pub enum CompareOp { Eq = ffi::Py_EQ as isize, Ne = ffi::Py_NE as isize, Gt = ffi::Py_GT as isize, - Ge = ffi::Py_GE as isize, - Other -} - -impl> From for CompareOp { - fn from(val: T) -> Self { - match val.into() as libc::c_int { - ffi::Py_LT => CompareOp::Lt, - ffi::Py_LE => CompareOp::Le, - ffi::Py_EQ => CompareOp::Eq, - ffi::Py_NE => CompareOp::Ne, - ffi::Py_GT => CompareOp::Gt, - ffi::Py_GE => CompareOp::Ge, - _ => CompareOp::Other - } - } + Ge = ffi::Py_GE as isize } /// Trait implemented by the types produced by the `py_class!()` macro. diff --git a/src/py_class/py_class.rs b/src/py_class/py_class.rs index 3e87831d..8ceb0e5e 100644 --- a/src/py_class/py_class.rs +++ b/src/py_class/py_class.rs @@ -278,9 +278,7 @@ py_class!(class MyIterator |py| { ## Comparison operators -TODO: implement support for `__cmp__` - - * `def __richcompare__(&self, other: impl ToPyObject, op: CompareOp) -> PyResult` + * `def __richcmp__(&self, other: impl ToPyObject, op: CompareOp) -> PyResult` Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`). The `op` argument indicates the comparison operation being performed. diff --git a/src/py_class/py_class_impl.py b/src/py_class/py_class_impl.py index cd4cead5..55d3562d 100644 --- a/src/py_class/py_class_impl.py +++ b/src/py_class/py_class_impl.py @@ -597,14 +597,14 @@ special_names = { '__bytes__': normal_method(), '__format__': normal_method(), # Comparison Operators - '__lt__': error('__lt__ is not supported by py_class! use __richcompare__ instead.'), - '__le__': error('__le__ is not supported by py_class! use __richcompare__ instead.'), - '__gt__': error('__gt__ is not supported by py_class! use __richcompare__ instead.'), - '__ge__': error('__ge__ is not supported by py_class! use __richcompare__ instead.'), - '__eq__': error('__eq__ is not supported by py_class! use __richcompare__ instead.'), - '__ne__': error('__ne__ is not supported by py_class! use __richcompare__ instead.'), - '__cmp__': unimplemented(), - '__richcompare__': operator('tp_richcompare', + '__lt__': error('__lt__ is not supported by py_class! use __richcmp__ instead.'), + '__le__': error('__le__ is not supported by py_class! use __richcmp__ instead.'), + '__gt__': error('__gt__ is not supported by py_class! use __richcmp__ instead.'), + '__ge__': error('__ge__ is not supported by py_class! use __richcmp__ instead.'), + '__eq__': error('__eq__ is not supported by py_class! use __richcmp__ instead.'), + '__ne__': error('__ne__ is not supported by py_class! use __richcmp__ instead.'), + '__cmp__': error('__cmp__ is not supported by py_class! use __richcmp__ instead.'), + '__richcmp__': operator('tp_richcompare', res_type='PyObject', args=[Argument('other'), Argument('op')]), '__hash__': operator('tp_hash', diff --git a/src/py_class/py_class_impl2.rs b/src/py_class/py_class_impl2.rs index 0f4b2d20..5e8a4c0e 100644 --- a/src/py_class/py_class_impl2.rs +++ b/src/py_class/py_class_impl2.rs @@ -483,7 +483,7 @@ macro_rules! py_class_impl { }}; { { def __cmp__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__cmp__ is not supported by py_class! yet." } + py_error! { "__cmp__ is not supported by py_class! use __richcmp__ instead." } }; { { def __coerce__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -580,7 +580,7 @@ macro_rules! py_class_impl { }; { { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__eq__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__eq__ is not supported by py_class! use __richcmp__ instead." } }; { { def __float__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -592,7 +592,7 @@ macro_rules! py_class_impl { }; { { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ge__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__ge__ is not supported by py_class! use __richcmp__ instead." } }; { { def __get__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -643,7 +643,7 @@ macro_rules! py_class_impl { }; { { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__gt__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__gt__ is not supported by py_class! use __richcmp__ instead." } }; { { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -809,7 +809,7 @@ macro_rules! py_class_impl { }; { { def __le__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__le__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__le__ is not supported by py_class! use __richcmp__ instead." } }; { { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -882,7 +882,7 @@ macro_rules! py_class_impl { }; { { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__lt__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__lt__ is not supported by py_class! use __richcmp__ instead." } }; { { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -924,7 +924,7 @@ macro_rules! py_class_impl { }; { { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ne__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__ne__ is not supported by py_class! use __richcmp__ instead." } }; { { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -1151,7 +1151,7 @@ macro_rules! py_class_impl { { { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" } }; - { { def __richcompare__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } + { { def __richcmp__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt /* slots: */ { /* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ] @@ -1165,19 +1165,19 @@ macro_rules! py_class_impl { /* slots: */ { /* type_slots */ [ $( $tp_slot_name : $tp_slot_value, )* - tp_richcompare: py_class_richcompare_slot!($class::__richcompare__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + tp_richcompare: py_class_richcompare_slot!($class::__richcmp__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), ] $as_number $as_sequence $as_mapping $setdelitem } /* impl: */ { $($imp)* - py_class_impl_item! { $class, $py, __richcompare__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } + py_class_impl_item! { $class, $py, __richcmp__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } } $members }}; - { { def __richcompare__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "Invalid signature for operator __richcompare__" } + { { def __richcmp__ $($tail:tt)* } $( $stuff:tt )* } => { + py_error! { "Invalid signature for operator __richcmp__" } }; { { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => { diff --git a/src/py_class/py_class_impl3.rs b/src/py_class/py_class_impl3.rs index 0790a148..c90ed012 100644 --- a/src/py_class/py_class_impl3.rs +++ b/src/py_class/py_class_impl3.rs @@ -483,7 +483,7 @@ macro_rules! py_class_impl { }}; { { def __cmp__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__cmp__ is not supported by py_class! yet." } + py_error! { "__cmp__ is not supported by py_class! use __richcmp__ instead." } }; { { def __coerce__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -580,7 +580,7 @@ macro_rules! py_class_impl { }; { { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__eq__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__eq__ is not supported by py_class! use __richcmp__ instead." } }; { { def __float__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -592,7 +592,7 @@ macro_rules! py_class_impl { }; { { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ge__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__ge__ is not supported by py_class! use __richcmp__ instead." } }; { { def __get__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -643,7 +643,7 @@ macro_rules! py_class_impl { }; { { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__gt__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__gt__ is not supported by py_class! use __richcmp__ instead." } }; { { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -809,7 +809,7 @@ macro_rules! py_class_impl { }; { { def __le__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__le__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__le__ is not supported by py_class! use __richcmp__ instead." } }; { { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -882,7 +882,7 @@ macro_rules! py_class_impl { }; { { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__lt__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__lt__ is not supported by py_class! use __richcmp__ instead." } }; { { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => { @@ -924,7 +924,7 @@ macro_rules! py_class_impl { }; { { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "__ne__ is not supported by py_class! use __richcompare__ instead." } + py_error! { "__ne__ is not supported by py_class! use __richcmp__ instead." } }; { { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt @@ -1151,7 +1151,7 @@ macro_rules! py_class_impl { { { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => { py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" } }; - { { def __richcompare__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } + { { def __richcmp__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* } $class:ident $py:ident $info:tt /* slots: */ { /* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ] @@ -1165,19 +1165,19 @@ macro_rules! py_class_impl { /* slots: */ { /* type_slots */ [ $( $tp_slot_name : $tp_slot_value, )* - tp_richcompare: py_class_richcompare_slot!($class::__richcompare__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + tp_richcompare: py_class_richcompare_slot!($class::__richcmp__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), ] $as_number $as_sequence $as_mapping $setdelitem } /* impl: */ { $($imp)* - py_class_impl_item! { $class, $py, __richcompare__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } + py_class_impl_item! { $class, $py, __richcmp__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] } } $members }}; - { { def __richcompare__ $($tail:tt)* } $( $stuff:tt )* } => { - py_error! { "Invalid signature for operator __richcompare__" } + { { def __richcmp__ $($tail:tt)* } $( $stuff:tt )* } => { + py_error! { "Invalid signature for operator __richcmp__" } }; { { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => { diff --git a/src/py_class/slots.rs b/src/py_class/slots.rs index 5a286787..c7624abd 100644 --- a/src/py_class/slots.rs +++ b/src/py_class/slots.rs @@ -24,6 +24,7 @@ use conversion::ToPyObject; use objects::PyObject; use function::CallbackConverter; use err::{PyErr, PyResult}; +use py_class::{CompareOp}; use exc; use Py_hash_t; @@ -320,6 +321,20 @@ macro_rules! py_class_ternary_slot { }} } +pub fn extract_op(py: Python, op: c_int) -> PyResult { + match op { + ffi::Py_LT => Ok(CompareOp::Lt), + ffi::Py_LE => Ok(CompareOp::Le), + ffi::Py_EQ => Ok(CompareOp::Eq), + ffi::Py_NE => Ok(CompareOp::Ne), + ffi::Py_GT => Ok(CompareOp::Gt), + ffi::Py_GE => Ok(CompareOp::Ge), + _ => Err(PyErr::new_lazy_init( + py.get_type::(), + Some("tp_richcompare called with invalid comparison operator".to_py_object(py).into_object()))) + } +} + // sq_richcompare is special-cased slot #[macro_export] #[doc(hidden)] @@ -337,10 +352,12 @@ macro_rules! py_class_richcompare_slot { |py| { let slf = $crate::PyObject::from_borrowed_ptr(py, slf).unchecked_cast_into::<$class>(); let arg = $crate::PyObject::from_borrowed_ptr(py, arg); - let op = $crate::py_class::CompareOp::from(op as isize); - let ret = match <$arg_type as $crate::FromPyObject>::extract(py, &arg) { - Ok(arg) => slf.$f(py, arg, op), - Err(e) => Err(e) + let ret = match $crate::py_class::slots::extract_op(py, op) { + Ok(op) => match <$arg_type as $crate::FromPyObject>::extract(py, &arg) { + Ok(arg) => slf.$f(py, arg, op).map(|res| { res.into_py_object(py).into_object() }), + Err(_) => Ok(py.NotImplemented()) + }, + Err(_) => Ok(py.NotImplemented()) }; $crate::PyDrop::release_ref(arg, py); $crate::PyDrop::release_ref(slf, py); diff --git a/tests/test_class.rs b/tests/test_class.rs index 1cb09e5c..2e428ea6 100644 --- a/tests/test_class.rs +++ b/tests/test_class.rs @@ -665,6 +665,90 @@ fn binary_arithmetic() { py_run!(py, c, "assert 1 | c == '1 | BA'"); } +py_class!(class RichComparisons |py| { + def __repr__(&self) -> PyResult<&'static str> { + Ok("RC") + } + + def __richcmp__(&self, other: &PyObject, op: CompareOp) -> PyResult { + match op { + CompareOp::Lt => Ok(format!("{:?} < {:?}", self.as_object(), other)), + CompareOp::Le => Ok(format!("{:?} <= {:?}", self.as_object(), other)), + CompareOp::Eq => Ok(format!("{:?} == {:?}", self.as_object(), other)), + CompareOp::Ne => Ok(format!("{:?} != {:?}", self.as_object(), other)), + CompareOp::Gt => Ok(format!("{:?} > {:?}", self.as_object(), other)), + CompareOp::Ge => Ok(format!("{:?} >= {:?}", self.as_object(), other)) + } + } +}); + +py_class!(class RichComparisons2 |py| { + def __repr__(&self) -> PyResult<&'static str> { + Ok("RC2") + } + + def __richcmp__(&self, other: &PyObject, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(true.to_py_object(py).into_object()), + CompareOp::Ne => Ok(false.to_py_object(py).into_object()), + _ => Ok(py.NotImplemented()) + } + } +}); + +#[test] +fn rich_comparisons() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c = RichComparisons::create_instance(py).unwrap(); + py_run!(py, c, "assert (c < c) == 'RC < RC'"); + py_run!(py, c, "assert (c < 1) == 'RC < 1'"); + py_run!(py, c, "assert (1 < c) == 'RC > 1'"); + py_run!(py, c, "assert (c <= c) == 'RC <= RC'"); + py_run!(py, c, "assert (c <= 1) == 'RC <= 1'"); + py_run!(py, c, "assert (1 <= c) == 'RC >= 1'"); + py_run!(py, c, "assert (c == c) == 'RC == RC'"); + py_run!(py, c, "assert (c == 1) == 'RC == 1'"); + py_run!(py, c, "assert (1 == c) == 'RC == 1'"); + py_run!(py, c, "assert (c != c) == 'RC != RC'"); + py_run!(py, c, "assert (c != 1) == 'RC != 1'"); + py_run!(py, c, "assert (1 != c) == 'RC != 1'"); + py_run!(py, c, "assert (c > c) == 'RC > RC'"); + py_run!(py, c, "assert (c > 1) == 'RC > 1'"); + py_run!(py, c, "assert (1 > c) == 'RC < 1'"); + py_run!(py, c, "assert (c >= c) == 'RC >= RC'"); + py_run!(py, c, "assert (c >= 1) == 'RC >= 1'"); + py_run!(py, c, "assert (1 >= c) == 'RC <= 1'"); +} + +#[test] +#[cfg(feature="python3-sys")] +fn rich_comparisons_python_3_type_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let c2 = RichComparisons2::create_instance(py).unwrap(); + py_expect_exception!(py, c2, "c2 < c2", TypeError); + py_expect_exception!(py, c2, "c2 < 1", TypeError); + py_expect_exception!(py, c2, "1 < c2", TypeError); + py_expect_exception!(py, c2, "c2 <= c2", TypeError); + py_expect_exception!(py, c2, "c2 <= 1", TypeError); + py_expect_exception!(py, c2, "1 <= c2", TypeError); + py_run!(py, c2, "assert (c2 == c2) == True"); + py_run!(py, c2, "assert (c2 == 1) == True"); + py_run!(py, c2, "assert (1 == c2) == True"); + py_run!(py, c2, "assert (c2 != c2) == False"); + py_run!(py, c2, "assert (c2 != 1) == False"); + py_run!(py, c2, "assert (1 != c2) == False"); + py_expect_exception!(py, c2, "c2 > c2", TypeError); + py_expect_exception!(py, c2, "c2 > 1", TypeError); + py_expect_exception!(py, c2, "1 > c2", TypeError); + py_expect_exception!(py, c2, "c2 >= c2", TypeError); + py_expect_exception!(py, c2, "c2 >= 1", TypeError); + py_expect_exception!(py, c2, "1 >= c2", TypeError); +} + py_class!(class ContextManager |py| { data exit_called : Cell;