From 7ea875fc496bc73554551e87bc02ae21351ab707 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Fri, 21 Sep 2018 12:48:42 +0900 Subject: [PATCH] Implement Add/Sub/Mul/Div for &PyComplex --- Cargo.toml | 1 + src/ffi3/complexobject.rs | 30 +++++++++++ src/lib.rs | 3 ++ src/objects/complex.rs | 107 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index a1b665be..1d7ec8b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ pyo3cls = { path = "pyo3cls", version = "=0.5.0-alpha.1" } mashup = "0.1.7" [dev-dependencies] +assert_approx_eq = "1.0" docmatic = "0.1.2" [build-dependencies] diff --git a/src/ffi3/complexobject.rs b/src/ffi3/complexobject.rs index 8b3ff0f5..3f817556 100644 --- a/src/ffi3/complexobject.rs +++ b/src/ffi3/complexobject.rs @@ -22,3 +22,33 @@ extern "C" { pub fn PyComplex_RealAsDouble(op: *mut PyObject) -> c_double; pub fn PyComplex_ImagAsDouble(op: *mut PyObject) -> c_double; } + +#[cfg(not(Py_LIMITED_API))] +#[repr(C)] +#[derive(Copy, Clone)] +pub struct Py_complex { + pub real: c_double, + pub imag: c_double, +} + +#[cfg(not(Py_LIMITED_API))] +#[repr(C)] +#[derive(Copy, Clone)] +pub struct PyComplexObject { + _ob_base: PyObject, + pub cval: Py_complex, +} + +#[cfg(not(Py_LIMITED_API))] +#[cfg_attr(windows, link(name = "pythonXY"))] +extern "C" { + pub fn _Py_c_sum(left: Py_complex, right: Py_complex) -> Py_complex; + pub fn _Py_c_diff(left: Py_complex, right: Py_complex) -> Py_complex; + pub fn _Py_c_neg(complex: Py_complex) -> Py_complex; + pub fn _Py_c_prod(left: Py_complex, right: Py_complex) -> Py_complex; + pub fn _Py_c_quot(dividend: Py_complex, divisor: Py_complex) -> Py_complex; + pub fn _Py_c_pow(num: Py_complex, exp: Py_complex) -> Py_complex; + pub fn _Py_c_abs(arg: Py_complex) -> c_double; + pub fn PyComplex_FromCComplex(v: Py_complex) -> *mut PyObject; + pub fn PyComplex_AsCComplex(op: *mut PyObject) -> Py_complex; +} diff --git a/src/lib.rs b/src/lib.rs index 20d02832..271bf2a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,6 +128,9 @@ extern crate spin; // We need that reexport for wrap_function #[doc(hidden)] pub extern crate mashup; +#[cfg(test)] +#[macro_use] +extern crate assert_approx_eq; /// Rust FFI declarations for Python pub mod ffi; diff --git a/src/objects/complex.rs b/src/objects/complex.rs index 49cf4466..cd0edc49 100644 --- a/src/objects/complex.rs +++ b/src/objects/complex.rs @@ -1,6 +1,9 @@ use ffi; +use instance::PyObjectWithToken; use object::PyObject; use python::{Python, ToPyPointer}; +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +use std::ops::*; use std::os::raw::c_double; /// Represents a Python `complex`. @@ -24,6 +27,62 @@ impl PyComplex { } } +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +#[inline(always)] +unsafe fn complex_operation( + l: &PyComplex, + r: &PyComplex, + operation: unsafe extern "C" fn(ffi::Py_complex, ffi::Py_complex) -> ffi::Py_complex, +) -> *mut ffi::PyObject { + let l_val = (*(l.as_ptr() as *mut ffi::PyComplexObject)).cval; + let r_val = (*(r.as_ptr() as *mut ffi::PyComplexObject)).cval; + ffi::PyComplex_FromCComplex(operation(l_val, r_val)) +} + +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +impl<'py> Add for &'py PyComplex { + type Output = &'py PyComplex; + fn add(self, other: &'py PyComplex) -> &'py PyComplex { + unsafe { + self.py() + .from_owned_ptr(complex_operation(self, other, ffi::_Py_c_sum)) + } + } +} + +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +impl<'py> Sub for &'py PyComplex { + type Output = &'py PyComplex; + fn sub(self, other: &'py PyComplex) -> &'py PyComplex { + unsafe { + self.py() + .from_owned_ptr(complex_operation(self, other, ffi::_Py_c_diff)) + } + } +} + +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +impl<'py> Mul for &'py PyComplex { + type Output = &'py PyComplex; + fn mul(self, other: &'py PyComplex) -> &'py PyComplex { + unsafe { + self.py() + .from_owned_ptr(complex_operation(self, other, ffi::_Py_c_prod)) + } + } +} + +#[cfg(any(not(Py_LIMITED_API), not(Py_3)))] +impl<'py> Div for &'py PyComplex { + type Output = &'py PyComplex; + fn div(self, other: &'py PyComplex) -> &'py PyComplex { + unsafe { + self.py() + .from_owned_ptr(complex_operation(self, other, ffi::_Py_c_quot)) + } + } +} + #[cfg(test)] mod test { use super::PyComplex; @@ -36,4 +95,52 @@ mod test { assert_eq!(complex.real(), 3.0); assert_eq!(complex.imag(), 1.2); } + + #[cfg(any(not(Py_LIMITED_API), not(Py_3)))] + #[test] + fn test_add() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let l = PyComplex::from_doubles(py, 3.0, 1.2); + let r = PyComplex::from_doubles(py, 1.0, 2.6); + let res = l + r; + assert_approx_eq!(res.real(), 4.0); + assert_approx_eq!(res.imag(), 3.8); + } + + #[cfg(any(not(Py_LIMITED_API), not(Py_3)))] + #[test] + fn test_sub() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let l = PyComplex::from_doubles(py, 3.0, 1.2); + let r = PyComplex::from_doubles(py, 1.0, 2.6); + let res = l - r; + assert_approx_eq!(res.real(), 2.0); + assert_approx_eq!(res.imag(), -1.4); + } + + #[cfg(any(not(Py_LIMITED_API), not(Py_3)))] + #[test] + fn test_mul() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let l = PyComplex::from_doubles(py, 3.0, 1.2); + let r = PyComplex::from_doubles(py, 1.0, 2.6); + let res = l * r; + assert_approx_eq!(res.real(), -0.12); + assert_approx_eq!(res.imag(), 9.0); + } + + #[cfg(any(not(Py_LIMITED_API), not(Py_3)))] + #[test] + fn test_div() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let l = PyComplex::from_doubles(py, 3.0, 1.2); + let r = PyComplex::from_doubles(py, 1.0, 2.6); + let res = l / r; + assert_approx_eq!(res.real(), 0.7886597938144329); + assert_approx_eq!(res.imag(), -0.8505154639175257); + } }