Fix `abi3` conversion of `__complex__` classes

Python classes that were not `complex` but implemented the `__complex__`
magic would have that method called via `PyComplex_AsCComplex` when
running against the full API, but the limited-API version
`PyComplex_RealAsDouble` does not attempt this conversion.  If the input
object is not already complex, we can call the magic before proceeding.
This commit is contained in:
Jake Lishman 2023-05-26 13:49:05 +01:00
parent 194d0c791f
commit 8d98b4248e
No known key found for this signature in database
GPG Key ID: F111E77FA4F6AF0D
2 changed files with 141 additions and 0 deletions

View File

@ -0,0 +1 @@
Fix conversion of classes implementing `__complex__` to `Complex` when using `abi3` or PyPy.

View File

@ -152,6 +152,18 @@ macro_rules! complex_conversion {
#[cfg(any(Py_LIMITED_API, PyPy))]
unsafe {
let obj = if obj.is_instance_of::<PyComplex>() {
obj
} else if let Some(method) =
obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
{
method.call0()?
} else {
// `obj` might still implement `__float__` or `__index__`, which will be
// handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
// errors if those methods don't exist / raise exceptions.
obj
};
let ptr = obj.as_ptr();
let real = ffi::PyComplex_RealAsDouble(ptr);
if real == -1.0 {
@ -172,6 +184,7 @@ complex_conversion!(f64);
#[cfg(test)]
mod tests {
use super::*;
use crate::types::PyModule;
#[test]
fn from_complex() {
@ -197,4 +210,131 @@ mod tests {
assert!(obj.extract::<Complex<f64>>(py).is_err());
});
}
#[test]
fn from_python_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class A:
def __complex__(self): return 3.0+1.2j
class B:
def __float__(self): return 3.0
class C:
def __index__(self): return 3
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
// Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_inherited_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class First: pass
class ComplexMixin:
def __complex__(self): return 3.0+1.2j
class FloatMixin:
def __float__(self): return 3.0
class IndexMixin:
def __index__(self): return 3
class A(First, ComplexMixin): pass
class B(First, FloatMixin): pass
class C(First, IndexMixin): pass
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_noncallable_descriptor_magic() {
// Functions and lambdas implement the descriptor protocol in a way that makes
// `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
// way the descriptor protocol might be implemented.
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class A:
@property
def __complex__(self):
return lambda: 3.0+1.2j
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
#[test]
fn from_python_nondescriptor_magic() {
// Magic methods don't need to implement the descriptor protocol, if they're callable.
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class MyComplex:
def __call__(self): return 3.0+1.2j
class A:
__complex__ = MyComplex()
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
}