#1064: Comparisons with __eq__ should not raise TypeError (#1072)

* Add (failing) tests for issue #1064

* Return NotImplemented when richcmp doesn't match the expected type.

* Fix tests that expect TypeError when richcmp returns NotImplemented.

- The python code 'class Other: pass; c2 {} Other()' was raising a NameError:
  c2 not found

- eq and ne never raise a TypeError, so I split the those cases.

* Return NotImplemented for number-like binary operations.

* Add dummy impl PyNumberProtocol for the test struct.

* Rework tests of NotImplemented.

* Make py_ternary_num_func return NotImplemented when type mismatches.

* Return NotImplement for type mismatches in binary inplace operators.

* Reduce boilerplate with `extract_or_return_not_implemented!`

* Extract common definition 'Other' into a function.

* Test explicitly for NotImplemented in the __ipow__ test.

* Add entry in CHANGELOG for PR #1072.

* Add the section 'Emulating numeric types' to the guide.

* Ensure we're returning NotImplemented in tests.

* Simplify the tests: only test we return NotImplemented.

Our previous test were rather indirect: were relying that Python
behaves correctly when we return NotImplemented.

Now we only test that calling a pyclass dunder method returns NotImplemented
when the argument doesn't match the type signature.  This is the expected
behavior.

* Remove reverse operators in tests of NotImplemented

The won't be used because of #844.

* Apply suggestions from code review

Co-authored-by: Yuji Kanagawa <yuji.kngw.80s.revive@gmail.com>

* Add a note about #844 below the list of reflected operations.

Co-authored-by: Yuji Kanagawa <yuji.kngw.80s.revive@gmail.com>
This commit is contained in:
Manuel Vázquez Acosta 2020-08-05 09:53:16 -04:00 committed by GitHub
parent 3ac327c8e0
commit f2ba3e6da7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 307 additions and 15 deletions

View File

@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `PyType::as_type_ptr` is no longer `unsafe`. [#1047](https://github.com/PyO3/pyo3/pull/1047)
- Change `PyIterator::from_object` to return `PyResult<PyIterator>` instead of `Result<PyIterator, PyDowncastError>`. [#1051](https://github.com/PyO3/pyo3/pull/1051)
- Implement `Send + Sync` for `PyErr`. `PyErr::new`, `PyErr::from_type`, `PyException::py_err` and `PyException::into` have had these bounds added to their arguments. [#1067](https://github.com/PyO3/pyo3/pull/1067)
- Change `#[pyproto]` to return NotImplemented for operators for which Python can try a reversed operation. [1072](https://github.com/PyO3/pyo3/pull/1072)
### Removed
- Remove `PyString::as_bytes`. [#1023](https://github.com/PyO3/pyo3/pull/1023)

View File

@ -756,6 +756,94 @@ Each method corresponds to Python's `self.attr`, `self.attr = value` and `del se
Determines the "truthyness" of the object.
### Emulating numeric types
The [`PyNumberProtocol`] trait allows [emulate numeric types](https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types).
* `fn __add__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __sub__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __mul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __matmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __truediv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __floordiv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __mod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __divmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __pow__(lhs: impl FromPyObject, rhs: impl FromPyObject, modulo: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`
* `fn __lshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __and__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __or__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __xor__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
These methods are called to implement the binary arithmetic operations
(`+`, `-`, `*`, `@`, `/`, `//`, `%`, `divmod()`, `pow()` and `**`, `<<`, `>>`, `&`, `^`, and `|`).
If `rhs` is not of the type specified in the signature, the generated code
will automatically `return NotImplemented`. This is not the case for `lhs`
which must match signature or else raise a TypeError.
The reflected operations are also available:
* `fn __radd__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rsub__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmatmul__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rtruediv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rfloordiv__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rdivmod__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rpow__(lhs: impl FromPyObject, rhs: impl FromPyObject, modulo: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`
* `fn __rlshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rrshift__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rand__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __ror__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
* `fn __rxor__(lhs: impl FromPyObject, rhs: impl FromPyObject) -> PyResult<impl ToPyObject>`
The code generated for these methods expect that all arguments match the
signature, or raise a TypeError.
*Note*: Currently implementing the method for a binary arithmetic operations
(e.g, `__add__`) shadows the reflected operation (e.g, `__radd__`). This is
being addressed in [#844](https://github.com/PyO3/pyo3/issues/844). to make
these methods
This trait also has support the augmented arithmetic assignments (`+=`, `-=`,
`*=`, `@=`, `/=`, `//=`, `%=`, `**=`, `<<=`, `>>=`, `&=`, `^=`, `|=`):
* `fn __iadd__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __isub__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imul__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imatmul__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __itruediv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ifloordiv__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __imod__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ipow__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ilshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __irshift__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __iand__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ior__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
* `fn __ixor__(&'p mut self, other: impl FromPyObject) -> PyResult<()>`
The following methods implement the unary arithmetic operations (`-`, `+`, `abs()` and `~`):
* `fn __neg__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __pos__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __abs__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __invert__(&'p self) -> PyResult<impl ToPyObject>`
Support for coercions:
* `fn __complex__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __int__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __float__(&'p self) -> PyResult<impl ToPyObject>`
Other:
* `fn __index__(&'p self) -> PyResult<impl ToPyObject>`
* `fn __round__(&'p self, ndigits: Option<impl FromPyObject>) -> PyResult<impl ToPyObject>`
### Garbage Collector Integration
If your type owns references to other Python objects, you will need to

View File

@ -275,10 +275,8 @@ where
{
crate::callback_body!(py, {
let slf = py.from_borrowed_ptr::<crate::PyCell<T>>(slf);
let arg = py.from_borrowed_ptr::<PyAny>(arg);
let arg = extract_or_return_not_implemented!(py, arg);
let op = extract_op(op)?;
let arg = arg.extract()?;
slf.try_borrow()?.__richcmp__(arg, op).convert(py)
})

View File

@ -91,9 +91,8 @@ macro_rules! py_binary_num_func {
{
$crate::callback_body!(py, {
let lhs = py.from_borrowed_ptr::<$crate::PyAny>(lhs);
let rhs = py.from_borrowed_ptr::<$crate::PyAny>(rhs);
$class::$f(lhs.extract()?, rhs.extract()?).convert(py)
let rhs = extract_or_return_not_implemented!(py, rhs);
$class::$f(lhs.extract()?, rhs).convert(py)
})
}
Some(wrap::<$class>)
@ -138,7 +137,7 @@ macro_rules! py_binary_self_func {
$crate::callback_body!(py, {
let slf_ = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let arg = py.from_borrowed_ptr::<$crate::PyAny>(arg);
call_mut!(slf_, $f, arg).convert(py)?;
call_operator_mut!(py, slf_, $f, arg).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
@ -222,13 +221,8 @@ macro_rules! py_ternary_num_func {
let arg1 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg1)
.extract()?;
let arg2 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg2)
.extract()?;
let arg3 = py
.from_borrowed_ptr::<$crate::types::PyAny>(arg3)
.extract()?;
let arg2 = extract_or_return_not_implemented!(py, arg2);
let arg3 = extract_or_return_not_implemented!(py, arg3);
$class::$f(arg1, arg2, arg3).convert(py)
})
}
@ -279,7 +273,7 @@ macro_rules! py_dummy_ternary_self_func {
$crate::callback_body!(py, {
let slf_cell = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let arg1 = py.from_borrowed_ptr::<$crate::PyAny>(arg1);
call_mut!(slf_cell, $f, arg1).convert(py)?;
call_operator_mut!(py, slf_cell, $f, arg1).convert(py)?;
ffi::Py_INCREF(slf);
Ok(slf)
})
@ -375,6 +369,18 @@ macro_rules! py_func_set_del {
}};
}
macro_rules! extract_or_return_not_implemented {
($py: ident, $arg: ident) => {
match $py
.from_borrowed_ptr::<$crate::types::PyAny>($arg)
.extract()
{
Ok(value) => value,
Err(_) => return $py.NotImplemented().convert($py),
}
};
}
macro_rules! _call_impl {
($slf: expr, $fn: ident $(; $args: expr)*) => {
$slf.$fn($($args,)*)
@ -382,6 +388,16 @@ macro_rules! _call_impl {
($slf: expr, $fn: ident, $raw_arg: expr $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!($slf, $fn $(,$raw_args)* $(;$args)* ;$raw_arg.extract()?)
};
(op $py:ident; $slf: expr, $fn: ident, $raw_arg: expr $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!(
$slf, $fn ;
(match $raw_arg.extract() {
Ok(res) => res,
_=> return Ok($py.NotImplemented().convert($py)?)
})
$(;$args)*
)
}
}
/// Call `slf.try_borrow()?.$fn(...)`
@ -397,3 +413,9 @@ macro_rules! call_mut {
_call_impl!($slf.try_borrow_mut()?, $fn $(,$raw_args)* $(;$args)*)
};
}
macro_rules! call_operator_mut {
($py:ident, $slf: expr, $fn: ident $(,$raw_args: expr)* $(; $args: expr)*) => {
_call_impl!(op $py; $slf.try_borrow_mut()?, $fn $(,$raw_args)* $(;$args)*)
};
}

View File

@ -169,6 +169,7 @@ pub mod buffer;
pub mod callback;
pub mod class;
pub mod conversion;
#[macro_use]
#[doc(hidden)]
pub mod derive_utils;
mod err;

View File

@ -423,3 +423,185 @@ fn rich_comparisons_python_3_type_error() {
py_expect_exception!(py, c2, "c2 >= 1", PyTypeError);
py_expect_exception!(py, c2, "1 >= c2", PyTypeError);
}
// Checks that binary operations for which the arguments don't match the
// required type, return NotImplemented.
mod return_not_implemented {
use super::*;
#[pyclass]
struct RichComparisonToSelf {}
#[pyproto]
impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf {
fn __repr__(&self) -> &'static str {
"RC_Self"
}
fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject {
other.py().None()
}
}
#[pyproto]
impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf {
fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option<u8>) -> &'p PyAny {
lhs
}
fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
lhs
}
// Inplace assignments
fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {}
fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {}
}
fn _test_binary_dunder(dunder: &str) {
let gil = Python::acquire_gil();
let py = gil.python();
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
py_run!(
py,
c2,
&format!(
"class Other: pass\nassert c2.__{}__(Other()) is NotImplemented",
dunder
)
);
}
fn _test_binary_operator(operator: &str, dunder: &str) {
_test_binary_dunder(dunder);
let gil = Python::acquire_gil();
let py = gil.python();
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
py_expect_exception!(
py,
c2,
&format!("class Other: pass\nc2 {} Other()", operator),
PyTypeError
)
}
fn _test_inplace_binary_operator(operator: &str, dunder: &str) {
_test_binary_operator(operator, dunder);
}
#[test]
fn equality() {
_test_binary_dunder("eq");
_test_binary_dunder("ne");
}
#[test]
fn ordering() {
_test_binary_operator("<", "lt");
_test_binary_operator("<=", "le");
_test_binary_operator(">", "gt");
_test_binary_operator(">=", "ge");
}
#[test]
fn bitwise() {
_test_binary_operator("&", "and");
_test_binary_operator("|", "or");
_test_binary_operator("^", "xor");
_test_binary_operator("<<", "lshift");
_test_binary_operator(">>", "rshift");
}
#[test]
fn arith() {
_test_binary_operator("+", "add");
_test_binary_operator("-", "sub");
_test_binary_operator("*", "mul");
_test_binary_operator("@", "matmul");
_test_binary_operator("/", "truediv");
_test_binary_operator("//", "floordiv");
_test_binary_operator("%", "mod");
_test_binary_operator("**", "pow");
}
#[test]
#[ignore]
fn reverse_arith() {
_test_binary_dunder("radd");
_test_binary_dunder("rsub");
_test_binary_dunder("rmul");
_test_binary_dunder("rmatmul");
_test_binary_dunder("rtruediv");
_test_binary_dunder("rfloordiv");
_test_binary_dunder("rmod");
_test_binary_dunder("rpow");
}
#[test]
fn inplace_bitwise() {
_test_inplace_binary_operator("&=", "iand");
_test_inplace_binary_operator("|=", "ior");
_test_inplace_binary_operator("^=", "ixor");
_test_inplace_binary_operator("<<=", "ilshift");
_test_inplace_binary_operator(">>=", "irshift");
}
#[test]
fn inplace_arith() {
_test_inplace_binary_operator("+=", "iadd");
_test_inplace_binary_operator("-=", "isub");
_test_inplace_binary_operator("*=", "imul");
_test_inplace_binary_operator("@=", "imatmul");
_test_inplace_binary_operator("/=", "itruediv");
_test_inplace_binary_operator("//=", "ifloordiv");
_test_inplace_binary_operator("%=", "imod");
_test_inplace_binary_operator("**=", "ipow");
}
}