Merge pull request #49 from sciyoshi/rich-comparison

Add support for overloading comparison operators with __richcompare__
This commit is contained in:
Daniel Grunwald 2016-06-12 00:04:11 +02:00 committed by GitHub
commit 72e1e05835
8 changed files with 232 additions and 23 deletions

View File

@ -98,6 +98,7 @@ pub use objects::*;
pub use python::{Python, PythonObject, PythonObjectWithCheckedDowncast, PythonObjectDowncastError, PythonObjectWithTypeObject, PyClone, PyDrop};
pub use pythonrun::{GILGuard, GILProtected, prepare_freethreaded_python};
pub use conversion::{FromPyObject, RefFromPyObject, ToPyObject};
pub use py_class::{CompareOp};
pub use objectprotocol::{ObjectProtocol};
#[cfg(feature="python27-sys")]

View File

@ -32,6 +32,16 @@ use objects::{PyObject, PyType};
use err::{self, PyResult};
use ffi;
#[derive(Debug)]
pub enum CompareOp {
Lt = ffi::Py_LT as isize,
Le = ffi::Py_LE as isize,
Eq = ffi::Py_EQ as isize,
Ne = ffi::Py_NE as isize,
Gt = ffi::Py_GT as isize,
Ge = ffi::Py_GE as isize
}
/// Trait implemented by the types produced by the `py_class!()` macro.
pub trait PythonObjectFromPyClassMacro : python::PythonObjectWithTypeObject {
fn initialize(py: Python) -> PyResult<PyType>;

View File

@ -278,7 +278,11 @@ py_class!(class MyIterator |py| {
## Comparison operators
TODO: implement support for `__cmp__`, `__lt__`, `__le__`, `__gt__`, `__ge__`, `__eq__`, `__ne__`.
* `def __richcmp__(&self, other: impl ToPyObject, op: CompareOp) -> PyResult<impl ToPyObject>`
Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`). The `op`
argument indicates the comparison operation being performed.
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
* `def __hash__(&self) -> PyResult<impl PrimInt>`

View File

@ -536,6 +536,9 @@ def operator(special_name, slot,
param_list.append('{{ ${0} : ${0}_type = {{}} }}'.format(arg.name))
if slot == 'sq_contains':
new_slots = [(slot, 'py_class_contains_slot!($class::%s, $%s_type)' % (special_name, args[0].name))]
elif slot == 'tp_richcompare':
new_slots = [(slot, 'py_class_richcompare_slot!($class::%s, $%s_type, %s, %s)'
% (special_name, args[0].name, res_ffi_type, res_conv))]
elif len(args) == 0:
new_slots = [(slot, 'py_class_unary_slot!($class::%s, %s, %s)'
% (special_name, res_ffi_type, res_conv))]
@ -594,13 +597,16 @@ special_names = {
'__bytes__': normal_method(),
'__format__': normal_method(),
# Comparison Operators
'__lt__': unimplemented(),
'__le__': unimplemented(),
'__gt__': unimplemented(),
'__ge__': unimplemented(),
'__eq__': unimplemented(),
'__ne__': unimplemented(),
'__cmp__': unimplemented(),
'__lt__': error('__lt__ is not supported by py_class! use __richcmp__ instead.'),
'__le__': error('__le__ is not supported by py_class! use __richcmp__ instead.'),
'__gt__': error('__gt__ is not supported by py_class! use __richcmp__ instead.'),
'__ge__': error('__ge__ is not supported by py_class! use __richcmp__ instead.'),
'__eq__': error('__eq__ is not supported by py_class! use __richcmp__ instead.'),
'__ne__': error('__ne__ is not supported by py_class! use __richcmp__ instead.'),
'__cmp__': error('__cmp__ is not supported by py_class! use __richcmp__ instead.'),
'__richcmp__': operator('tp_richcompare',
res_type='PyObject',
args=[Argument('other'), Argument('op')]),
'__hash__': operator('tp_hash',
res_conv='$crate::py_class::slots::HashConverter',
res_ffi_type='$crate::Py_hash_t'),
@ -652,7 +658,7 @@ special_names = {
res_conv='$crate::py_class::slots::IterNextResultConverter'),
'__reversed__': normal_method(),
'__contains__': operator('sq_contains', args=[Argument('item')]),
# Emulating numeric types
'__add__': binary_numeric_operator('nb_add'),
'__sub__': binary_numeric_operator('nb_subtract'),

View File

@ -483,7 +483,7 @@ macro_rules! py_class_impl {
}};
{ { def __cmp__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__cmp__ is not supported by py_class! yet." }
py_error! { "__cmp__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __coerce__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -580,7 +580,7 @@ macro_rules! py_class_impl {
};
{ { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__eq__ is not supported by py_class! yet." }
py_error! { "__eq__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __float__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -592,7 +592,7 @@ macro_rules! py_class_impl {
};
{ { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__ge__ is not supported by py_class! yet." }
py_error! { "__ge__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __get__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -643,7 +643,7 @@ macro_rules! py_class_impl {
};
{ { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__gt__ is not supported by py_class! yet." }
py_error! { "__gt__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -809,7 +809,7 @@ macro_rules! py_class_impl {
};
{ { def __le__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__le__ is not supported by py_class! yet." }
py_error! { "__le__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -882,7 +882,7 @@ macro_rules! py_class_impl {
};
{ { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__lt__ is not supported by py_class! yet." }
py_error! { "__lt__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -924,7 +924,7 @@ macro_rules! py_class_impl {
};
{ { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__ne__ is not supported by py_class! yet." }
py_error! { "__ne__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -1151,6 +1151,34 @@ macro_rules! py_class_impl {
{ { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" }
};
{ { def __richcmp__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
/* slots: */ {
/* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ]
$as_number:tt $as_sequence:tt $as_mapping:tt $setdelitem:tt
}
{ $( $imp:item )* }
$members:tt
} => { py_class_impl! {
{ $($tail)* }
$class $py $info
/* slots: */ {
/* type_slots */ [
$( $tp_slot_name : $tp_slot_value, )*
tp_richcompare: py_class_richcompare_slot!($class::__richcmp__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter),
]
$as_number $as_sequence $as_mapping $setdelitem
}
/* impl: */ {
$($imp)*
py_class_impl_item! { $class, $py, __richcmp__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] }
}
$members
}};
{ { def __richcmp__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Invalid signature for operator __richcmp__" }
};
{ { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Reflected numeric operator __rlshift__ is not supported by py_class! Use __lshift__ instead!" }

View File

@ -483,7 +483,7 @@ macro_rules! py_class_impl {
}};
{ { def __cmp__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__cmp__ is not supported by py_class! yet." }
py_error! { "__cmp__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __coerce__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -580,7 +580,7 @@ macro_rules! py_class_impl {
};
{ { def __eq__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__eq__ is not supported by py_class! yet." }
py_error! { "__eq__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __float__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -592,7 +592,7 @@ macro_rules! py_class_impl {
};
{ { def __ge__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__ge__ is not supported by py_class! yet." }
py_error! { "__ge__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __get__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -643,7 +643,7 @@ macro_rules! py_class_impl {
};
{ { def __gt__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__gt__ is not supported by py_class! yet." }
py_error! { "__gt__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __hash__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -809,7 +809,7 @@ macro_rules! py_class_impl {
};
{ { def __le__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__le__ is not supported by py_class! yet." }
py_error! { "__le__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __len__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -882,7 +882,7 @@ macro_rules! py_class_impl {
};
{ { def __lt__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__lt__ is not supported by py_class! yet." }
py_error! { "__lt__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __matmul__ $($tail:tt)* } $( $stuff:tt )* } => {
@ -924,7 +924,7 @@ macro_rules! py_class_impl {
};
{ { def __ne__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "__ne__ is not supported by py_class! yet." }
py_error! { "__ne__ is not supported by py_class! use __richcmp__ instead." }
};
{ { def __neg__(&$slf:ident) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
@ -1151,6 +1151,34 @@ macro_rules! py_class_impl {
{ { def __rfloordiv__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Reflected numeric operator __rfloordiv__ is not supported by py_class! Use __floordiv__ instead!" }
};
{ { def __richcmp__(&$slf:ident, $other:ident : $other_type:ty, $op:ident : $op_type:ty) -> $res_type:ty { $($body:tt)* } $($tail:tt)* }
$class:ident $py:ident $info:tt
/* slots: */ {
/* type_slots */ [ $( $tp_slot_name:ident : $tp_slot_value:expr, )* ]
$as_number:tt $as_sequence:tt $as_mapping:tt $setdelitem:tt
}
{ $( $imp:item )* }
$members:tt
} => { py_class_impl! {
{ $($tail)* }
$class $py $info
/* slots: */ {
/* type_slots */ [
$( $tp_slot_name : $tp_slot_value, )*
tp_richcompare: py_class_richcompare_slot!($class::__richcmp__, $other_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter),
]
$as_number $as_sequence $as_mapping $setdelitem
}
/* impl: */ {
$($imp)*
py_class_impl_item! { $class, $py, __richcmp__(&$slf,) $res_type; { $($body)* } [{ $other : $other_type = {} } { $op : $op_type = {} }] }
}
$members
}};
{ { def __richcmp__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Invalid signature for operator __richcmp__" }
};
{ { def __rlshift__ $($tail:tt)* } $( $stuff:tt )* } => {
py_error! { "Reflected numeric operator __rlshift__ is not supported by py_class! Use __lshift__ instead!" }

View File

@ -24,6 +24,7 @@ use conversion::ToPyObject;
use objects::PyObject;
use function::CallbackConverter;
use err::{PyErr, PyResult};
use py_class::{CompareOp};
use exc;
use Py_hash_t;
@ -320,6 +321,53 @@ macro_rules! py_class_ternary_slot {
}}
}
pub fn extract_op(py: Python, op: c_int) -> PyResult<CompareOp> {
match op {
ffi::Py_LT => Ok(CompareOp::Lt),
ffi::Py_LE => Ok(CompareOp::Le),
ffi::Py_EQ => Ok(CompareOp::Eq),
ffi::Py_NE => Ok(CompareOp::Ne),
ffi::Py_GT => Ok(CompareOp::Gt),
ffi::Py_GE => Ok(CompareOp::Ge),
_ => Err(PyErr::new_lazy_init(
py.get_type::<exc::ValueError>(),
Some("tp_richcompare called with invalid comparison operator".to_py_object(py).into_object())))
}
}
// sq_richcompare is special-cased slot
#[macro_export]
#[doc(hidden)]
macro_rules! py_class_richcompare_slot {
($class:ident :: $f:ident, $arg_type:ty, $res_type:ty, $conv:expr) => {{
unsafe extern "C" fn tp_richcompare(
slf: *mut $crate::_detail::ffi::PyObject,
arg: *mut $crate::_detail::ffi::PyObject,
op: $crate::_detail::libc::c_int)
-> $res_type
{
const LOCATION: &'static str = concat!(stringify!($class), ".", stringify!($f), "()");
$crate::_detail::handle_callback(
LOCATION, $conv,
|py| {
let slf = $crate::PyObject::from_borrowed_ptr(py, slf).unchecked_cast_into::<$class>();
let arg = $crate::PyObject::from_borrowed_ptr(py, arg);
let ret = match $crate::py_class::slots::extract_op(py, op) {
Ok(op) => match <$arg_type as $crate::FromPyObject>::extract(py, &arg) {
Ok(arg) => slf.$f(py, arg, op).map(|res| { res.into_py_object(py).into_object() }),
Err(_) => Ok(py.NotImplemented())
},
Err(_) => Ok(py.NotImplemented())
};
$crate::PyDrop::release_ref(arg, py);
$crate::PyDrop::release_ref(slf, py);
ret
})
}
Some(tp_richcompare)
}}
}
// sq_contains is special-cased slot because it converts type errors to Ok(false)
#[macro_export]
#[doc(hidden)]

View File

@ -665,6 +665,90 @@ fn binary_arithmetic() {
py_run!(py, c, "assert 1 | c == '1 | BA'");
}
py_class!(class RichComparisons |py| {
def __repr__(&self) -> PyResult<&'static str> {
Ok("RC")
}
def __richcmp__(&self, other: &PyObject, op: CompareOp) -> PyResult<String> {
match op {
CompareOp::Lt => Ok(format!("{:?} < {:?}", self.as_object(), other)),
CompareOp::Le => Ok(format!("{:?} <= {:?}", self.as_object(), other)),
CompareOp::Eq => Ok(format!("{:?} == {:?}", self.as_object(), other)),
CompareOp::Ne => Ok(format!("{:?} != {:?}", self.as_object(), other)),
CompareOp::Gt => Ok(format!("{:?} > {:?}", self.as_object(), other)),
CompareOp::Ge => Ok(format!("{:?} >= {:?}", self.as_object(), other))
}
}
});
py_class!(class RichComparisons2 |py| {
def __repr__(&self) -> PyResult<&'static str> {
Ok("RC2")
}
def __richcmp__(&self, other: &PyObject, op: CompareOp) -> PyResult<PyObject> {
match op {
CompareOp::Eq => Ok(true.to_py_object(py).into_object()),
CompareOp::Ne => Ok(false.to_py_object(py).into_object()),
_ => Ok(py.NotImplemented())
}
}
});
#[test]
fn rich_comparisons() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = RichComparisons::create_instance(py).unwrap();
py_run!(py, c, "assert (c < c) == 'RC < RC'");
py_run!(py, c, "assert (c < 1) == 'RC < 1'");
py_run!(py, c, "assert (1 < c) == 'RC > 1'");
py_run!(py, c, "assert (c <= c) == 'RC <= RC'");
py_run!(py, c, "assert (c <= 1) == 'RC <= 1'");
py_run!(py, c, "assert (1 <= c) == 'RC >= 1'");
py_run!(py, c, "assert (c == c) == 'RC == RC'");
py_run!(py, c, "assert (c == 1) == 'RC == 1'");
py_run!(py, c, "assert (1 == c) == 'RC == 1'");
py_run!(py, c, "assert (c != c) == 'RC != RC'");
py_run!(py, c, "assert (c != 1) == 'RC != 1'");
py_run!(py, c, "assert (1 != c) == 'RC != 1'");
py_run!(py, c, "assert (c > c) == 'RC > RC'");
py_run!(py, c, "assert (c > 1) == 'RC > 1'");
py_run!(py, c, "assert (1 > c) == 'RC < 1'");
py_run!(py, c, "assert (c >= c) == 'RC >= RC'");
py_run!(py, c, "assert (c >= 1) == 'RC >= 1'");
py_run!(py, c, "assert (1 >= c) == 'RC <= 1'");
}
#[test]
#[cfg(feature="python3-sys")]
fn rich_comparisons_python_3_type_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let c2 = RichComparisons2::create_instance(py).unwrap();
py_expect_exception!(py, c2, "c2 < c2", TypeError);
py_expect_exception!(py, c2, "c2 < 1", TypeError);
py_expect_exception!(py, c2, "1 < c2", TypeError);
py_expect_exception!(py, c2, "c2 <= c2", TypeError);
py_expect_exception!(py, c2, "c2 <= 1", TypeError);
py_expect_exception!(py, c2, "1 <= c2", TypeError);
py_run!(py, c2, "assert (c2 == c2) == True");
py_run!(py, c2, "assert (c2 == 1) == True");
py_run!(py, c2, "assert (1 == c2) == True");
py_run!(py, c2, "assert (c2 != c2) == False");
py_run!(py, c2, "assert (c2 != 1) == False");
py_run!(py, c2, "assert (1 != c2) == False");
py_expect_exception!(py, c2, "c2 > c2", TypeError);
py_expect_exception!(py, c2, "c2 > 1", TypeError);
py_expect_exception!(py, c2, "1 > c2", TypeError);
py_expect_exception!(py, c2, "c2 >= c2", TypeError);
py_expect_exception!(py, c2, "c2 >= 1", TypeError);
py_expect_exception!(py, c2, "1 >= c2", TypeError);
}
py_class!(class ContextManager |py| {
data exit_called : Cell<bool>;