From 04081bc8de8814519aebb59a0fc565629226c491 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Sat, 7 May 2016 23:31:46 +0200 Subject: [PATCH] __contains__: if extraction fails with TypeError, return False instead. --- src/py_class/py_class.rs | 3 +++ src/py_class/py_class_impl.py | 16 ++++++++++------ src/py_class/py_class_impl2.rs | 6 +++--- src/py_class/py_class_impl3.rs | 6 +++--- src/py_class/slots.rs | 24 +++++++++++++++++++++--- tests/test_class.rs | 1 + 6 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/py_class/py_class.rs b/src/py_class/py_class.rs index b8b5e6be..b75c9009 100644 --- a/src/py_class/py_class.rs +++ b/src/py_class/py_class.rs @@ -321,6 +321,9 @@ TODO: implement support for `__cmp__`, `__lt__`, `__le__`, `__gt__`, `__ge__`, ` For mapping types, this should consider the keys of the mapping rather than the values or the key-item pairs. + If extraction of the `item` parameter fails with `TypeError`, + `__contains__` will return `Ok(false)`. + ## Other Special Methods * `def __bool__(&self) -> PyResult` diff --git a/src/py_class/py_class_impl.py b/src/py_class/py_class_impl.py index 4c179c05..81b57bfb 100644 --- a/src/py_class/py_class_impl.py +++ b/src/py_class/py_class_impl.py @@ -506,6 +506,10 @@ def special_class_method(special_name, *args, **kwargs): generate_class_method(special_name=special_name, *args, **kwargs) Argument = namedtuple('Argument', ['name', 'default_type']) +class Argument(object): + def __init__(self, name, extract_err='passthrough'): + self.name = name + self.extract_err = 'py_class_extract_error_%s' % extract_err @special_method def operator(special_name, slot, @@ -535,8 +539,8 @@ def operator(special_name, slot, new_slots = [(slot, 'py_class_unary_slot!($class::%s, %s, %s)' % (special_name, res_ffi_type, res_conv))] elif len(args) == 1: - new_slots = [(slot, 'py_class_binary_slot!($class::%s, $%s_type, %s, %s)' - % (special_name, args[0].name, res_ffi_type, res_conv))] + new_slots = [(slot, 'py_class_binary_slot!($class::%s, $%s_type, %s, %s, %s)' + % (special_name, args[0].name, args[0].extract_err, res_ffi_type, res_conv))] elif len(args) == 2: new_slots = [(slot, 'py_class_ternary_slot!($class::%s, $%s_type, $%s_type, %s, %s)' % (special_name, args[0].name, args[1].name, res_ffi_type, res_conv))] @@ -615,23 +619,23 @@ special_names = { ]), '__length_hint__': normal_method(), '__getitem__': operator('mp_subscript', - args=[Argument('key', '&PyObject')], + args=[Argument('key')], additional_slots=[ ('sq_item', 'Some($crate::py_class::slots::sq_item)') ]), '__missing__': normal_method(), '__setitem__': operator('sdi_setitem', - args=[Argument('key', '&PyObject'), Argument('value', '&PyObject')], + args=[Argument('key'), Argument('value')], res_type='()'), '__delitem__': operator('sdi_delitem', - args=[Argument('key', '&PyObject')], + args=[Argument('key')], res_type='()'), '__iter__': operator('tp_iter'), '__next__': operator('tp_iternext', res_conv='$crate::py_class::slots::IterNextResultConverter'), '__reversed__': normal_method(), '__contains__': operator('sq_contains', - args=[Argument('item', '&PyObject')], + args=[Argument('item', extract_err='false')], res_type='bool'), # Emulating numeric types diff --git a/src/py_class/py_class_impl2.rs b/src/py_class/py_class_impl2.rs index 1709b71d..11c60641 100644 --- a/src/py_class/py_class_impl2.rs +++ b/src/py_class/py_class_impl2.rs @@ -430,7 +430,7 @@ macro_rules! py_class_impl { $type_slots $as_number /* as_sequence */ [ $( $sq_slot_name : $sq_slot_value, )* - sq_contains: py_class_binary_slot!($class::__contains__, $item_type, $crate::_detail::libc::c_int, $crate::py_class::slots::BoolConverter), + sq_contains: py_class_binary_slot!($class::__contains__, $item_type, py_class_extract_error_false, $crate::_detail::libc::c_int, $crate::py_class::slots::BoolConverter), ] $as_mapping $setdelitem } @@ -474,7 +474,7 @@ macro_rules! py_class_impl { $type_slots $as_number $as_sequence $as_mapping /* setdelitem */ [ sdi_setitem: $sdi_setitem_slot_value, - sdi_delitem: { py_class_binary_slot!($class::__delitem__, $key_type, $crate::_detail::libc::c_int, $crate::py_class::slots::UnitCallbackConverter) }, + sdi_delitem: { py_class_binary_slot!($class::__delitem__, $key_type, py_class_extract_error_passthrough, $crate::_detail::libc::c_int, $crate::py_class::slots::UnitCallbackConverter) }, ] } /* impl: */ { @@ -556,7 +556,7 @@ macro_rules! py_class_impl { ] /* as_mapping */ [ $( $mp_slot_name : $mp_slot_value, )* - mp_subscript: py_class_binary_slot!($class::__getitem__, $key_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + mp_subscript: py_class_binary_slot!($class::__getitem__, $key_type, py_class_extract_error_passthrough, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), ] $setdelitem } diff --git a/src/py_class/py_class_impl3.rs b/src/py_class/py_class_impl3.rs index ce6cef5a..ca9f8761 100644 --- a/src/py_class/py_class_impl3.rs +++ b/src/py_class/py_class_impl3.rs @@ -430,7 +430,7 @@ macro_rules! py_class_impl { $type_slots $as_number /* as_sequence */ [ $( $sq_slot_name : $sq_slot_value, )* - sq_contains: py_class_binary_slot!($class::__contains__, $item_type, $crate::_detail::libc::c_int, $crate::py_class::slots::BoolConverter), + sq_contains: py_class_binary_slot!($class::__contains__, $item_type, py_class_extract_error_false, $crate::_detail::libc::c_int, $crate::py_class::slots::BoolConverter), ] $as_mapping $setdelitem } @@ -474,7 +474,7 @@ macro_rules! py_class_impl { $type_slots $as_number $as_sequence $as_mapping /* setdelitem */ [ sdi_setitem: $sdi_setitem_slot_value, - sdi_delitem: { py_class_binary_slot!($class::__delitem__, $key_type, $crate::_detail::libc::c_int, $crate::py_class::slots::UnitCallbackConverter) }, + sdi_delitem: { py_class_binary_slot!($class::__delitem__, $key_type, py_class_extract_error_passthrough, $crate::_detail::libc::c_int, $crate::py_class::slots::UnitCallbackConverter) }, ] } /* impl: */ { @@ -556,7 +556,7 @@ macro_rules! py_class_impl { ] /* as_mapping */ [ $( $mp_slot_name : $mp_slot_value, )* - mp_subscript: py_class_binary_slot!($class::__getitem__, $key_type, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), + mp_subscript: py_class_binary_slot!($class::__getitem__, $key_type, py_class_extract_error_passthrough, *mut $crate::_detail::ffi::PyObject, $crate::_detail::PyObjectCallbackConverter), ] $setdelitem } diff --git a/src/py_class/slots.rs b/src/py_class/slots.rs index 10ecbd9a..9199a03e 100644 --- a/src/py_class/slots.rs +++ b/src/py_class/slots.rs @@ -253,7 +253,7 @@ macro_rules! py_class_unary_slot { #[macro_export] #[doc(hidden)] macro_rules! py_class_binary_slot { - ($class:ident :: $f:ident, $arg_type:ty, $res_type:ty, $conv:expr) => {{ + ($class:ident :: $f:ident, $arg_type:ty, $extract_err:ident, $res_type:ty, $conv:expr) => {{ unsafe extern "C" fn wrap_binary( slf: *mut $crate::_detail::ffi::PyObject, arg: *mut $crate::_detail::ffi::PyObject) @@ -269,10 +269,10 @@ macro_rules! py_class_binary_slot { Ok(prepared) => { match <$arg_type as $crate::ExtractPyObject>::extract(py, &prepared) { Ok(arg) => slf.$f(py, arg), - Err(e) => Err(e) + Err(e) => $extract_err!(py, e) } }, - Err(e) => Err(e) + Err(e) => $extract_err!(py, e) }; $crate::PyDrop::release_ref(arg, py); $crate::PyDrop::release_ref(slf, py); @@ -329,6 +329,24 @@ macro_rules! py_class_ternary_slot { }} } +#[macro_export] +#[doc(hidden)] +macro_rules! py_class_extract_error_passthrough { + ($py: ident, $e:ident) => (Err($e)); +} + +#[macro_export] +#[doc(hidden)] +macro_rules! py_class_extract_error_false { + ($py: ident, $e:ident) => { + if $e.matches($py, $py.get_type::<$crate::exc::TypeError>()) { + Ok(false) + } else { + Err($e) + } + }; +} + pub struct UnitCallbackConverter; impl CallbackConverter<()> for UnitCallbackConverter { diff --git a/tests/test_class.rs b/tests/test_class.rs index da45d274..8ad2d2ef 100644 --- a/tests/test_class.rs +++ b/tests/test_class.rs @@ -569,5 +569,6 @@ fn contains() { let c = Contains::create_instance(py).unwrap(); py_run!(py, c, "assert 1 in c"); py_run!(py, c, "assert -1 not in c"); + py_run!(py, c, "assert 'wrong type' not in c"); }