add PyAnyMethods for binary operators

also pow

fixes #3709
This commit is contained in:
Alex Gaynor 2023-12-29 16:09:30 -05:00
parent 6776b90e15
commit 339660c117
4 changed files with 132 additions and 0 deletions

View File

@ -0,0 +1 @@
Added methods to `PyAnyMethods` for binary operators (`add`, `sub`, etc.)

View File

@ -23,6 +23,13 @@ mod inner {
};
}
#[macro_export]
macro_rules! assert_py_eq {
($val:expr, $expected:expr) => {
assert!($val.eq($expected).unwrap());
};
}
#[macro_export]
macro_rules! py_expect_exception {
// Case1: idents & no err_msg

View File

@ -1208,6 +1208,58 @@ pub trait PyAnyMethods<'py> {
where
O: ToPyObject;
/// Computes `self + other`.
fn add<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self - other`.
fn sub<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self * other`.
fn mul<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self / other`.
fn div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self << other`.
fn lshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self >> other`.
fn rshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject;
/// Computes `self & other`.
fn bitand<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self | other`.
fn bitor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Computes `self ^ other`.
fn bitxor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject;
/// Determines whether this object appears callable.
///
/// This is equivalent to Python's [`callable()`][1] function.
@ -1680,6 +1732,26 @@ pub trait PyAnyMethods<'py> {
fn py_super(&self) -> PyResult<Bound<'py, PySuper>>;
}
macro_rules! implement_binop {
($name:ident, $c_api:ident, $op:expr) => {
#[doc = concat!("Computes `self ", $op, " other`.")]
fn $name<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
where
O: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) }
}
let py = self.py();
inner(self, other.to_object(py).into_bound(py))
}
};
}
impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
#[inline]
fn is<T: AsPyPointer>(&self, other: &T) -> bool {
@ -1855,6 +1927,42 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
.and_then(|any| any.is_truthy())
}
implement_binop!(add, PyNumber_Add, "+");
implement_binop!(sub, PyNumber_Subtract, "-");
implement_binop!(mul, PyNumber_Multiply, "*");
implement_binop!(div, PyNumber_TrueDivide, "/");
implement_binop!(lshift, PyNumber_Lshift, "<<");
implement_binop!(rshift, PyNumber_Rshift, ">>");
implement_binop!(bitand, PyNumber_And, "&");
implement_binop!(bitor, PyNumber_Or, "|");
implement_binop!(bitxor, PyNumber_Xor, "^");
/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
/// `py.None()` may be passed for the `modulus`.
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
where
O1: ToPyObject,
O2: ToPyObject,
{
fn inner<'py>(
any: &Bound<'py, PyAny>,
other: Bound<'_, PyAny>,
modulus: Bound<'_, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
unsafe {
ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr())
.assume_owned_or_err(any.py())
}
}
let py = self.py();
inner(
self,
other.to_object(py).into_bound(py),
modulus.to_object(py).into_bound(py),
)
}
fn is_callable(&self) -> bool {
unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 }
}

View File

@ -178,6 +178,10 @@ impl BinaryArithmetic {
format!("BA * {:?}", rhs)
}
fn __truediv__(&self, rhs: &PyAny) -> String {
format!("BA / {:?}", rhs)
}
fn __lshift__(&self, rhs: &PyAny) -> String {
format!("BA << {:?}", rhs)
}
@ -233,6 +237,18 @@ fn binary_arithmetic() {
py_expect_exception!(py, c, "1 ** c", PyTypeError);
py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");
let c: Bound<'_, PyAny> = c.extract().unwrap();
assert_py_eq!(c.add(&c).unwrap(), "BA + BA");
assert_py_eq!(c.sub(&c).unwrap(), "BA - BA");
assert_py_eq!(c.mul(&c).unwrap(), "BA * BA");
assert_py_eq!(c.div(&c).unwrap(), "BA / BA");
assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA");
assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA");
assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");
assert_py_eq!(c.bitor(&c).unwrap(), "BA | BA");
assert_py_eq!(c.bitxor(&c).unwrap(), "BA ^ BA");
assert_py_eq!(c.pow(&c, py.None()).unwrap(), "BA ** BA (mod: None)");
});
}