Merge pull request #2675 from PyO3/relax-extract-array

Also relax the PySequence check when extracting fixed-sized arrays.
This commit is contained in:
David Hewitt 2022-10-13 20:57:43 +01:00 committed by GitHub
commit 61fd70c476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 9 deletions

View File

@ -0,0 +1 @@
Fix regression of `impl FromPyObject for [T; N]` no longer accepting types passing `PySequence_Check`, e.g. NumPy arrays, since version 0.17.0. This the same fix that was applied `impl FromPyObject for Vec<T>` in version 0.17.1 extended to fixed-size arrays.

View File

@ -6,6 +6,11 @@ fn vec_to_vec_i32(vec: Vec<i32>) -> Vec<i32> {
vec
}
#[pyfunction]
fn array_to_array_i32(arr: [i32; 3]) -> [i32; 3] {
arr
}
#[pyfunction]
fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> {
vec
@ -14,6 +19,7 @@ fn vec_to_vec_pystring(vec: Vec<&PyString>) -> Vec<&PyString> {
#[pymodule]
pub fn sequence(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(vec_to_vec_i32, m)?)?;
m.add_function(wrap_pyfunction!(array_to_array_i32, m)?)?;
m.add_function(wrap_pyfunction!(vec_to_vec_pystring, m)?)?;
Ok(())
}

View File

@ -29,3 +29,13 @@ def test_vec_from_array():
import numpy
assert sequence.vec_to_vec_i32(numpy.array([1, 2, 3])) == [1, 2, 3]
@pytest.mark.skipif(
platform.system() != "Linux" or platform.python_implementation() != "CPython",
reason="Binary NumPy wheels are not available for all platforms and Python implementations",
)
def test_rust_array_from_array():
import numpy
assert sequence.array_to_array_i32(numpy.array([1, 2, 3])) == [1, 2, 3]

View File

@ -3,9 +3,11 @@ use crate::{exceptions, PyErr};
#[cfg(min_const_generics)]
mod min_const_generics {
use super::invalid_sequence_length;
use crate::conversion::IntoPyPointer;
use crate::conversion::{AsPyPointer, IntoPyPointer};
use crate::types::PySequence;
use crate::{
ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject,
ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom,
Python, ToPyObject,
};
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
@ -61,8 +63,16 @@ mod min_const_generics {
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
// Types that pass `PySequence_Check` usually implement enough of the sequence protocol
// to support this function and if not, we will only fail extraction safely.
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
<PySequence as PyTryFrom>::try_from_unchecked(obj)
} else {
return Err(PyDowncastError::new(obj, "Sequence").into());
}
};
let seq_len = seq.len()?;
if seq_len != N {
return Err(invalid_sequence_length(N, seq_len));
}
@ -174,9 +184,11 @@ mod min_const_generics {
#[cfg(not(min_const_generics))]
mod array_impls {
use super::invalid_sequence_length;
use crate::conversion::IntoPyPointer;
use crate::conversion::{AsPyPointer, IntoPyPointer};
use crate::types::PySequence;
use crate::{
ffi, FromPyObject, IntoPy, Py, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject,
ffi, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyObject, PyResult, PyTryFrom,
Python, ToPyObject,
};
use std::mem::{transmute_copy, ManuallyDrop};
@ -274,8 +286,16 @@ mod array_impls {
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let seq_len = seq.len()? as usize;
// Types that pass `PySequence_Check` usually implement enough of the sequence protocol
// to support this function and if not, we will only fail extraction safely.
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
<PySequence as PyTryFrom>::try_from_unchecked(obj)
} else {
return Err(PyDowncastError::new(obj, "Sequence").into());
}
};
let seq_len = seq.len()?;
if seq_len != slice.len() {
return Err(invalid_sequence_length(slice.len(), seq_len));
}
@ -348,6 +368,15 @@ mod tests {
});
}
#[test]
fn test_extract_non_iterable_to_array() {
Python::with_gil(|py| {
let v = py.eval("42", None, None).unwrap();
v.extract::<i32>().unwrap();
v.extract::<[i32; 1]>().unwrap_err();
});
}
#[cfg(feature = "macros")]
#[test]
fn test_pyclass_intopy_array_conversion() {

View File

@ -309,7 +309,7 @@ where
}
};
let mut v = Vec::with_capacity(seq.len().unwrap_or(0) as usize);
let mut v = Vec::with_capacity(seq.len().unwrap_or(0));
for item in seq.iter()? {
v.push(item?.extract::<T>()?);
}