further refactor num-bigint conversion
This commit is contained in:
parent
d1f0561036
commit
9604957c72
|
@ -51,54 +51,10 @@ use crate::{
|
||||||
ffi, types::*, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
|
ffi, types::*, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
|
||||||
};
|
};
|
||||||
|
|
||||||
use num_bigint::{BigInt, BigUint, Sign};
|
use num_bigint::{BigInt, BigUint};
|
||||||
use std::os::raw::c_int;
|
|
||||||
|
|
||||||
#[cfg(not(Py_LIMITED_API))]
|
#[cfg(not(Py_LIMITED_API))]
|
||||||
use std::os::raw::c_uchar;
|
use num_bigint::Sign;
|
||||||
|
|
||||||
#[cfg(Py_LIMITED_API)]
|
|
||||||
use std::slice;
|
|
||||||
|
|
||||||
#[cfg(not(Py_LIMITED_API))]
|
|
||||||
#[inline]
|
|
||||||
unsafe fn extract(ob: &PyLong, length: usize, is_signed: c_int) -> PyResult<Vec<u32>> {
|
|
||||||
let mut buffer = Vec::<u32>::with_capacity(length);
|
|
||||||
crate::err::error_on_minusone(
|
|
||||||
ob.py(),
|
|
||||||
ffi::_PyLong_AsByteArray(
|
|
||||||
ob.as_ptr() as *mut ffi::PyLongObject,
|
|
||||||
buffer.as_mut_ptr() as *mut u8,
|
|
||||||
length * 4,
|
|
||||||
1,
|
|
||||||
is_signed,
|
|
||||||
),
|
|
||||||
)?;
|
|
||||||
buffer.set_len(length);
|
|
||||||
|
|
||||||
Ok(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(Py_LIMITED_API)]
|
|
||||||
#[inline]
|
|
||||||
unsafe fn extract(ob: &PyLong, length: usize, is_signed: c_int) -> PyResult<Vec<u32>> {
|
|
||||||
use crate::intern;
|
|
||||||
let py = ob.py();
|
|
||||||
let kwargs = if is_signed != 0 {
|
|
||||||
let kwargs = PyDict::new(py);
|
|
||||||
kwargs.set_item(intern!(py, "signed"), true)?;
|
|
||||||
Some(kwargs)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let bytes_obj = ob
|
|
||||||
.getattr(intern!(py, "to_bytes"))?
|
|
||||||
.call((length * 4, intern!(py, "little")), kwargs)?;
|
|
||||||
let bytes: &PyBytes = bytes_obj.downcast_unchecked();
|
|
||||||
let bytes_u32 = slice::from_raw_parts(bytes.as_bytes().as_ptr().cast(), length);
|
|
||||||
|
|
||||||
Ok(bytes_u32.to_vec())
|
|
||||||
}
|
|
||||||
|
|
||||||
// for identical functionality between BigInt and BigUint
|
// for identical functionality between BigInt and BigUint
|
||||||
macro_rules! bigint_conversion {
|
macro_rules! bigint_conversion {
|
||||||
|
@ -110,7 +66,7 @@ macro_rules! bigint_conversion {
|
||||||
let bytes = $to_bytes(self);
|
let bytes = $to_bytes(self);
|
||||||
unsafe {
|
unsafe {
|
||||||
let obj = ffi::_PyLong_FromByteArray(
|
let obj = ffi::_PyLong_FromByteArray(
|
||||||
bytes.as_ptr() as *const c_uchar,
|
bytes.as_ptr().cast(),
|
||||||
bytes.len(),
|
bytes.len(),
|
||||||
1,
|
1,
|
||||||
$is_signed,
|
$is_signed,
|
||||||
|
@ -153,41 +109,46 @@ bigint_conversion!(BigInt, 1, BigInt::to_signed_bytes_le);
|
||||||
impl<'source> FromPyObject<'source> for BigInt {
|
impl<'source> FromPyObject<'source> for BigInt {
|
||||||
fn extract(ob: &'source PyAny) -> PyResult<BigInt> {
|
fn extract(ob: &'source PyAny) -> PyResult<BigInt> {
|
||||||
let py = ob.py();
|
let py = ob.py();
|
||||||
unsafe {
|
// fast path - checking for subclass of `int` just checks a bit in the type object
|
||||||
let num: Py<PyLong> = Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))?;
|
let num_owned: Py<PyLong>;
|
||||||
let n_bits = {
|
let num = if let Ok(long) = ob.downcast::<PyLong>() {
|
||||||
cfg_if::cfg_if! {
|
long
|
||||||
if #[cfg(not(Py_LIMITED_API))] {
|
} else {
|
||||||
// fast path
|
num_owned = unsafe { Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))? };
|
||||||
let n_bits = ffi::_PyLong_NumBits(num.as_ptr());
|
num_owned.as_ref(py)
|
||||||
if n_bits == (-1isize as usize) {
|
};
|
||||||
return Err(crate::PyErr::fetch(py));
|
let n_bits = int_n_bits(num)?;
|
||||||
}
|
if n_bits == 0 {
|
||||||
n_bits
|
return Ok(BigInt::from(0isize));
|
||||||
} else {
|
}
|
||||||
// slow path
|
#[cfg(not(Py_LIMITED_API))]
|
||||||
let n_bits_obj = num.getattr(py, crate::intern!(py, "bit_length"))?.call0(py)?;
|
{
|
||||||
let n_bits_int: &PyLong = n_bits_obj.downcast_unchecked(py);
|
let mut buffer = int_to_u32_vec(num, (n_bits + 32) / 32, true)?;
|
||||||
n_bits_int.extract::<usize>()?
|
let sign = if buffer.last().copied().map_or(false, |last| last >> 31 != 0) {
|
||||||
|
// BigInt::new takes an unsigned array, so need to convert from two's complement
|
||||||
|
// flip all bits, 'subtract' 1 (by adding one to the unsigned array)
|
||||||
|
let mut elements = buffer.iter_mut();
|
||||||
|
for element in elements.by_ref() {
|
||||||
|
*element = (!*element).wrapping_add(1);
|
||||||
|
if *element != 0 {
|
||||||
|
// if the element didn't wrap over, no need to keep adding further ...
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
// ... so just two's complement the rest
|
||||||
|
for element in elements {
|
||||||
if n_bits == 0 {
|
*element = !*element;
|
||||||
return Ok(BigInt::from(0isize));
|
}
|
||||||
}
|
Sign::Minus
|
||||||
let n_digits = (n_bits + 32) / 32;
|
|
||||||
let mut buffer = extract(num.as_ref(py), n_digits, 1)?;
|
|
||||||
buffer
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|chunk| *chunk = u32::from_le(*chunk));
|
|
||||||
|
|
||||||
Ok(if buffer.last().unwrap() >> 31 != 0 {
|
|
||||||
buffer.iter_mut().for_each(|element| *element = !*element);
|
|
||||||
BigInt::new(Sign::Minus, buffer) - 1
|
|
||||||
} else {
|
} else {
|
||||||
BigInt::new(Sign::Plus, buffer)
|
Sign::Plus
|
||||||
})
|
};
|
||||||
|
Ok(BigInt::new(sign, buffer))
|
||||||
|
}
|
||||||
|
#[cfg(Py_LIMITED_API)]
|
||||||
|
{
|
||||||
|
let bytes = int_to_py_bytes(num, (n_bits + 8) / 8, true)?;
|
||||||
|
Ok(BigInt::from_signed_bytes_le(bytes.as_bytes()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,37 +157,92 @@ impl<'source> FromPyObject<'source> for BigInt {
|
||||||
impl<'source> FromPyObject<'source> for BigUint {
|
impl<'source> FromPyObject<'source> for BigUint {
|
||||||
fn extract(ob: &'source PyAny) -> PyResult<BigUint> {
|
fn extract(ob: &'source PyAny) -> PyResult<BigUint> {
|
||||||
let py = ob.py();
|
let py = ob.py();
|
||||||
unsafe {
|
// fast path - checking for subclass of `int` just checks a bit in the type object
|
||||||
let num: Py<PyLong> = Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))?;
|
let num_owned: Py<PyLong>;
|
||||||
let n_bits = {
|
let num = if let Ok(long) = ob.downcast::<PyLong>() {
|
||||||
cfg_if::cfg_if! {
|
long
|
||||||
if #[cfg(not(Py_LIMITED_API))] {
|
} else {
|
||||||
// fast path
|
num_owned = unsafe { Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))? };
|
||||||
let n_bits = ffi::_PyLong_NumBits(num.as_ptr());
|
num_owned.as_ref(py)
|
||||||
if n_bits == (-1isize as usize) {
|
};
|
||||||
return Err(crate::PyErr::fetch(py));
|
let n_bits = int_n_bits(num)?;
|
||||||
}
|
if n_bits == 0 {
|
||||||
n_bits
|
return Ok(BigUint::from(0usize));
|
||||||
} else {
|
}
|
||||||
// slow path
|
#[cfg(not(Py_LIMITED_API))]
|
||||||
let n_bits_obj = num.getattr(py, crate::intern!(py, "bit_length"))?.call0(py)?;
|
{
|
||||||
let n_bits_int: &PyLong = n_bits_obj.downcast_unchecked(py);
|
let buffer = int_to_u32_vec(num, (n_bits + 31) / 32, false)?;
|
||||||
n_bits_int.extract::<usize>()?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if n_bits == 0 {
|
|
||||||
return Ok(BigUint::from(0usize));
|
|
||||||
}
|
|
||||||
let n_digits = (n_bits + 31) / 32;
|
|
||||||
let mut buffer = extract(num.as_ref(py), n_digits, 0)?;
|
|
||||||
buffer
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|chunk| *chunk = u32::from_le(*chunk));
|
|
||||||
|
|
||||||
Ok(BigUint::new(buffer))
|
Ok(BigUint::new(buffer))
|
||||||
}
|
}
|
||||||
|
#[cfg(Py_LIMITED_API)]
|
||||||
|
{
|
||||||
|
let bytes = int_to_py_bytes(num, (n_bits + 7) / 8, false)?;
|
||||||
|
Ok(BigUint::from_bytes_le(bytes.as_bytes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(Py_LIMITED_API))]
|
||||||
|
#[inline]
|
||||||
|
fn int_to_u32_vec(long: &PyLong, n_digits: usize, is_signed: bool) -> PyResult<Vec<u32>> {
|
||||||
|
let mut buffer = Vec::with_capacity(n_digits);
|
||||||
|
unsafe {
|
||||||
|
crate::err::error_on_minusone(
|
||||||
|
long.py(),
|
||||||
|
ffi::_PyLong_AsByteArray(
|
||||||
|
long.as_ptr().cast(),
|
||||||
|
buffer.as_mut_ptr() as *mut u8,
|
||||||
|
n_digits * 4,
|
||||||
|
1,
|
||||||
|
is_signed.into(),
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
buffer.set_len(n_digits)
|
||||||
|
};
|
||||||
|
buffer
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|chunk| *chunk = u32::from_le(*chunk));
|
||||||
|
|
||||||
|
Ok(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(Py_LIMITED_API)]
|
||||||
|
fn int_to_py_bytes(long: &PyLong, n_bytes: usize, is_signed: bool) -> PyResult<&PyBytes> {
|
||||||
|
use crate::intern;
|
||||||
|
let py = long.py();
|
||||||
|
let kwargs = if is_signed {
|
||||||
|
let kwargs = PyDict::new(py);
|
||||||
|
kwargs.set_item(intern!(py, "signed"), true)?;
|
||||||
|
Some(kwargs)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let bytes = long.call_method(
|
||||||
|
intern!(py, "to_bytes"),
|
||||||
|
(n_bytes, intern!(py, "little")),
|
||||||
|
kwargs,
|
||||||
|
)?;
|
||||||
|
Ok(bytes.downcast()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn int_n_bits(long: &PyLong) -> PyResult<usize> {
|
||||||
|
let py = long.py();
|
||||||
|
#[cfg(not(Py_LIMITED_API))]
|
||||||
|
{
|
||||||
|
// fast path
|
||||||
|
let n_bits = unsafe { ffi::_PyLong_NumBits(long.as_ptr()) };
|
||||||
|
if n_bits == (-1isize as usize) {
|
||||||
|
return Err(crate::PyErr::fetch(py));
|
||||||
|
}
|
||||||
|
Ok(n_bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(Py_LIMITED_API)]
|
||||||
|
{
|
||||||
|
// slow path
|
||||||
|
long.call_method0(crate::intern!(py, "bit_length"))
|
||||||
|
.and_then(PyAny::extract)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,82 +252,62 @@ mod tests {
|
||||||
use crate::types::{PyDict, PyModule};
|
use crate::types::{PyDict, PyModule};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
|
||||||
fn python_fib(py: Python<'_>) -> &PyModule {
|
fn rust_fib<T>() -> impl Iterator<Item = T>
|
||||||
let fib_code = indoc!(
|
|
||||||
r#"
|
|
||||||
def fib(n):
|
|
||||||
f0, f1 = 0, 1
|
|
||||||
for _ in range(n):
|
|
||||||
f0, f1 = f1, f0 + f1
|
|
||||||
return f0
|
|
||||||
|
|
||||||
def fib_neg(n):
|
|
||||||
return -fib(n)
|
|
||||||
"#
|
|
||||||
);
|
|
||||||
PyModule::from_code(py, fib_code, "fib.py", "fib").unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rust_fib<T>(n: usize) -> T
|
|
||||||
where
|
where
|
||||||
T: From<u16>,
|
T: From<u16>,
|
||||||
for<'a> &'a T: std::ops::Add<Output = T>,
|
for<'a> &'a T: std::ops::Add<Output = T>,
|
||||||
{
|
{
|
||||||
let mut f0: T = T::from(0);
|
let mut f0: T = T::from(1);
|
||||||
let mut f1: T = T::from(1);
|
let mut f1: T = T::from(1);
|
||||||
for _ in 0..n {
|
std::iter::from_fn(move || {
|
||||||
let f2 = &f0 + &f1;
|
let f2 = &f0 + &f1;
|
||||||
f0 = std::mem::replace(&mut f1, f2);
|
Some(std::mem::replace(&mut f0, std::mem::replace(&mut f1, f2)))
|
||||||
}
|
})
|
||||||
f0
|
}
|
||||||
|
|
||||||
|
fn python_fib(py: Python<'_>) -> impl Iterator<Item = PyObject> + '_ {
|
||||||
|
let mut f0 = 1.to_object(py);
|
||||||
|
let mut f1 = 1.to_object(py);
|
||||||
|
std::iter::from_fn(move || {
|
||||||
|
let f2 = f0.call_method1(py, "__add__", (f1.as_ref(py),)).unwrap();
|
||||||
|
Some(std::mem::replace(&mut f0, std::mem::replace(&mut f1, f2)))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn convert_biguint() {
|
fn convert_biguint() {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let rs_result: BigUint = rust_fib(400);
|
// check the first 2000 numbers in the fibonacci sequence
|
||||||
let fib = python_fib(py);
|
for (py_result, rs_result) in python_fib(py).zip(rust_fib::<BigUint>()).take(2000) {
|
||||||
let locals = PyDict::new(py);
|
// Python -> Rust
|
||||||
locals.set_item("rs_result", &rs_result).unwrap();
|
assert_eq!(py_result.extract::<BigUint>(py).unwrap(), rs_result);
|
||||||
locals.set_item("fib", fib).unwrap();
|
// Rust -> Python
|
||||||
// Checks if Rust BigUint -> Python Long conversion is correct
|
assert!(py_result.as_ref(py).eq(rs_result).unwrap());
|
||||||
py.run("assert fib.fib(400) == rs_result", None, Some(locals))
|
}
|
||||||
.unwrap();
|
|
||||||
// Checks if Python Long -> Rust BigUint conversion is correct if N is small
|
|
||||||
let py_result: BigUint =
|
|
||||||
FromPyObject::extract(fib.getattr("fib").unwrap().call1((400,)).unwrap()).unwrap();
|
|
||||||
assert_eq!(rs_result, py_result);
|
|
||||||
// Checks if Python Long -> Rust BigUint conversion is correct if N is large
|
|
||||||
let rs_result: BigUint = rust_fib(2000);
|
|
||||||
let py_result: BigUint =
|
|
||||||
FromPyObject::extract(fib.getattr("fib").unwrap().call1((2000,)).unwrap()).unwrap();
|
|
||||||
assert_eq!(rs_result, py_result);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn convert_bigint() {
|
fn convert_bigint() {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let rs_result = rust_fib::<BigInt>(400) * -1;
|
// check the first 2000 numbers in the fibonacci sequence
|
||||||
let fib = python_fib(py);
|
for (py_result, rs_result) in python_fib(py).zip(rust_fib::<BigInt>()).take(2000) {
|
||||||
let locals = PyDict::new(py);
|
// Python -> Rust
|
||||||
locals.set_item("rs_result", &rs_result).unwrap();
|
assert_eq!(py_result.extract::<BigInt>(py).unwrap(), rs_result);
|
||||||
locals.set_item("fib", fib).unwrap();
|
// Rust -> Python
|
||||||
// Checks if Rust BigInt -> Python Long conversion is correct
|
assert!(py_result.as_ref(py).eq(&rs_result).unwrap());
|
||||||
py.run("assert fib.fib_neg(400) == rs_result", None, Some(locals))
|
|
||||||
.unwrap();
|
// negate
|
||||||
// Checks if Python Long -> Rust BigInt conversion is correct if N is small
|
|
||||||
let py_result: BigInt =
|
let rs_result = rs_result * -1;
|
||||||
FromPyObject::extract(fib.getattr("fib_neg").unwrap().call1((400,)).unwrap())
|
let py_result = py_result.call_method0(py, "__neg__").unwrap();
|
||||||
.unwrap();
|
|
||||||
assert_eq!(rs_result, py_result);
|
// Python -> Rust
|
||||||
// Checks if Python Long -> Rust BigInt conversion is correct if N is large
|
assert_eq!(py_result.extract::<BigInt>(py).unwrap(), rs_result);
|
||||||
let rs_result = rust_fib::<BigInt>(2000) * -1;
|
// Rust -> Python
|
||||||
let py_result: BigInt =
|
assert!(py_result.as_ref(py).eq(rs_result).unwrap());
|
||||||
FromPyObject::extract(fib.getattr("fib_neg").unwrap().call1((2000,)).unwrap())
|
}
|
||||||
.unwrap();
|
});
|
||||||
assert_eq!(rs_result, py_result);
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn python_index_class(py: Python<'_>) -> &PyModule {
|
fn python_index_class(py: Python<'_>) -> &PyModule {
|
||||||
|
@ -341,9 +337,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn handle_zero() {
|
fn handle_zero() {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let fib = python_fib(py);
|
let zero: BigInt = 0.to_object(py).extract(py).unwrap();
|
||||||
let zero: BigInt =
|
|
||||||
FromPyObject::extract(fib.getattr("fib").unwrap().call1((0,)).unwrap()).unwrap();
|
|
||||||
assert_eq!(zero, BigInt::from(0));
|
assert_eq!(zero, BigInt::from(0));
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue