Add several missing wrappers to PyAnyMethods (#4264)

This commit is contained in:
WÁNG Xuěruì 2024-06-20 16:16:06 +08:00 committed by GitHub
parent 0e142f05dd
commit e6b2216b04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 0 deletions

View File

@ -0,0 +1 @@
Added `PyAnyMethods::{bitnot, matmul, floor_div, rem, divmod}` for completeness.

View File

@ -1143,6 +1143,9 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed {
/// Equivalent to the Python expression `abs(self)`.
fn abs(&self) -> PyResult<Bound<'py, PyAny>>;
/// Computes `~self`.
fn bitnot(&self) -> PyResult<Bound<'py, PyAny>>;
/// Tests whether this object is less than another.
///
/// This is equivalent to the Python expression `self < other`.
@ -1200,11 +1203,31 @@ pub trait PyAnyMethods<'py>: crate::sealed::Sealed {
where
O: ToPyObject;
/// Computes `self @ other`.
fn matmul<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self / other`.
fn div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self // other`.
fn floor_div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self % other`.
fn rem<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `divmod(self, other)`.
fn divmod<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self << other`.
fn lshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
@ -1898,6 +1921,14 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
inner(self)
}
fn bitnot(&self) -> PyResult<Bound<'py, PyAny>> {
fn inner<'py>(any: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
unsafe { ffi::PyNumber_Invert(any.as_ptr()).assume_owned_or_err(any.py()) }
}
inner(self)
}
fn lt<O>(&self, other: O) -> PyResult<bool>
where
O: ToPyObject,
@ -1949,13 +1980,34 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
implement_binop!(add, PyNumber_Add, "+");
implement_binop!(sub, PyNumber_Subtract, "-");
implement_binop!(mul, PyNumber_Multiply, "*");
implement_binop!(matmul, PyNumber_MatrixMultiply, "@");
implement_binop!(div, PyNumber_TrueDivide, "/");
implement_binop!(floor_div, PyNumber_FloorDivide, "//");
implement_binop!(rem, PyNumber_Remainder, "%");
implement_binop!(lshift, PyNumber_Lshift, "<<");
implement_binop!(rshift, PyNumber_Rshift, ">>");
implement_binop!(bitand, PyNumber_And, "&");
implement_binop!(bitor, PyNumber_Or, "|");
implement_binop!(bitxor, PyNumber_Xor, "^");
/// Computes `divmod(self, other)`.
fn divmod<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe {
ffi::PyNumber_Divmod(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py())
}
}
let py = self.py();
inner(self, other.to_object(py).into_bound(py))
}
/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>

View File

@ -35,6 +35,10 @@ impl UnaryArithmetic {
Self::new(self.inner.abs())
}
fn __invert__(&self) -> Self {
Self::new(self.inner.recip())
}
#[pyo3(signature=(_ndigits=None))]
fn __round__(&self, _ndigits: Option<u32>) -> Self {
Self::new(self.inner.round())
@ -48,8 +52,18 @@ fn unary_arithmetic() {
py_run!(py, c, "assert repr(-c) == 'UA(-2.7)'");
py_run!(py, c, "assert repr(+c) == 'UA(2.7)'");
py_run!(py, c, "assert repr(abs(c)) == 'UA(2.7)'");
py_run!(py, c, "assert repr(~c) == 'UA(0.37037037037037035)'");
py_run!(py, c, "assert repr(round(c)) == 'UA(3)'");
py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'");
let c: Bound<'_, PyAny> = c.extract(py).unwrap();
assert_py_eq!(c.neg().unwrap().repr().unwrap().as_any(), "UA(-2.7)");
assert_py_eq!(c.pos().unwrap().repr().unwrap().as_any(), "UA(2.7)");
assert_py_eq!(c.abs().unwrap().repr().unwrap().as_any(), "UA(2.7)");
assert_py_eq!(
c.bitnot().unwrap().repr().unwrap().as_any(),
"UA(0.37037037037037035)"
);
});
}
@ -179,10 +193,26 @@ impl BinaryArithmetic {
format!("BA * {:?}", rhs)
}
fn __matmul__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("BA @ {:?}", rhs)
}
fn __truediv__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("BA / {:?}", rhs)
}
fn __floordiv__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("BA // {:?}", rhs)
}
fn __mod__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("BA % {:?}", rhs)
}
fn __divmod__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("divmod(BA, {:?})", rhs)
}
fn __lshift__(&self, rhs: &Bound<'_, PyAny>) -> String {
format!("BA << {:?}", rhs)
}
@ -217,6 +247,11 @@ fn binary_arithmetic() {
py_run!(py, c, "assert c + 1 == 'BA + 1'");
py_run!(py, c, "assert c - 1 == 'BA - 1'");
py_run!(py, c, "assert c * 1 == 'BA * 1'");
py_run!(py, c, "assert c @ 1 == 'BA @ 1'");
py_run!(py, c, "assert c / 1 == 'BA / 1'");
py_run!(py, c, "assert c // 1 == 'BA // 1'");
py_run!(py, c, "assert c % 1 == 'BA % 1'");
py_run!(py, c, "assert divmod(c, 1) == 'divmod(BA, 1)'");
py_run!(py, c, "assert c << 1 == 'BA << 1'");
py_run!(py, c, "assert c >> 1 == 'BA >> 1'");
py_run!(py, c, "assert c & 1 == 'BA & 1'");
@ -230,6 +265,11 @@ fn binary_arithmetic() {
py_expect_exception!(py, c, "1 + c", PyTypeError);
py_expect_exception!(py, c, "1 - c", PyTypeError);
py_expect_exception!(py, c, "1 * c", PyTypeError);
py_expect_exception!(py, c, "1 @ c", PyTypeError);
py_expect_exception!(py, c, "1 / c", PyTypeError);
py_expect_exception!(py, c, "1 // c", PyTypeError);
py_expect_exception!(py, c, "1 % c", PyTypeError);
py_expect_exception!(py, c, "divmod(1, c)", PyTypeError);
py_expect_exception!(py, c, "1 << c", PyTypeError);
py_expect_exception!(py, c, "1 >> c", PyTypeError);
py_expect_exception!(py, c, "1 & c", PyTypeError);
@ -243,7 +283,11 @@ fn binary_arithmetic() {
assert_py_eq!(c.add(&c).unwrap(), "BA + BA");
assert_py_eq!(c.sub(&c).unwrap(), "BA - BA");
assert_py_eq!(c.mul(&c).unwrap(), "BA * BA");
assert_py_eq!(c.matmul(&c).unwrap(), "BA @ BA");
assert_py_eq!(c.div(&c).unwrap(), "BA / BA");
assert_py_eq!(c.floor_div(&c).unwrap(), "BA // BA");
assert_py_eq!(c.rem(&c).unwrap(), "BA % BA");
assert_py_eq!(c.divmod(&c).unwrap(), "divmod(BA, BA)");
assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA");
assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA");
assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");