Guard against PyUnicode_AsUTF8AndSize returning null

This commit is contained in:
Alexander Niederbühl 2019-10-23 00:38:13 +02:00
parent 45eb9f4b89
commit 7a4909bdc7
2 changed files with 56 additions and 17 deletions

View File

@ -7,8 +7,8 @@ use crate::instance::PyNativeType;
use crate::object::PyObject;
use crate::types::PyAny;
use crate::AsPyPointer;
use crate::IntoPy;
use crate::Python;
use crate::{exceptions, IntoPy};
use crate::{ffi, FromPy};
use std::borrow::Cow;
use std::ops::Index;
@ -59,29 +59,27 @@ impl PyString {
}
/// Get the Python string as a byte slice.
///
/// Returns a `UnicodeEncodeError` if the input is not valid unicode
/// (containing unpaired surrogates).
#[inline]
pub fn as_bytes(&self) -> &[u8] {
pub fn as_bytes(&self) -> PyResult<&[u8]> {
unsafe {
let mut size: ffi::Py_ssize_t = 0;
let data = ffi::PyUnicode_AsUTF8AndSize(self.0.as_ptr(), &mut size) as *const u8;
// PyUnicode_AsUTF8AndSize would return null if the pointer did not reference a valid
// unicode object, but because we have a valid PyString, assume success
debug_assert!(!data.is_null());
std::slice::from_raw_parts(data, size as usize)
if data.is_null() {
Err(PyErr::fetch(self.py()))
} else {
Ok(std::slice::from_raw_parts(data, size as usize))
}
}
}
/// Convert the `PyString` into a Rust string.
///
/// Returns a `UnicodeDecodeError` if the input is not valid unicode
/// (containing unpaired surrogates).
pub fn to_string(&self) -> PyResult<Cow<str>> {
match std::str::from_utf8(self.as_bytes()) {
Ok(s) => Ok(Cow::Borrowed(s)),
Err(e) => Err(PyErr::from_instance(
exceptions::UnicodeDecodeError::new_utf8(self.py(), self.as_bytes(), e)?,
)),
}
let bytes = self.as_bytes()?;
let string = std::str::from_utf8(bytes)?;
Ok(Cow::Borrowed(string))
}
/// Convert the `PyString` into a Rust string.
@ -89,7 +87,10 @@ impl PyString {
/// Unpaired surrogates invalid UTF-8 sequences are
/// replaced with U+FFFD REPLACEMENT CHARACTER.
pub fn to_string_lossy(&self) -> Cow<str> {
String::from_utf8_lossy(self.as_bytes())
// TODO: Handle error of `as_bytes`
// see https://github.com/PyO3/pyo3/pull/634
let bytes = self.as_bytes().unwrap();
String::from_utf8_lossy(bytes)
}
}
@ -273,7 +274,16 @@ mod test {
let s = "ascii 🐈";
let obj: PyObject = PyString::new(py, s).into();
let py_string = <PyString as PyTryFrom>::try_from(obj.as_ref(py)).unwrap();
assert_eq!(s.as_bytes(), py_string.as_bytes());
assert_eq!(s.as_bytes(), py_string.as_bytes().unwrap());
}
#[test]
fn test_as_bytes_surrogate() {
let gil = Python::acquire_gil();
let py = gil.python();
let obj: PyObject = py.eval(r#"'\ud800'"#, None, None).unwrap().into();
let py_string = <PyString as PyTryFrom>::try_from(obj.as_ref(py)).unwrap();
assert!(py_string.as_bytes().is_err());
}
#[test]

29
tests/test_string.rs Normal file
View File

@ -0,0 +1,29 @@
use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::wrap_pyfunction;
mod common;
#[pyfunction]
fn take_str(_s: &str) -> PyResult<()> {
Ok(())
}
#[test]
fn test_unicode_encode_error() {
let gil = Python::acquire_gil();
let py = gil.python();
let take_str = wrap_pyfunction!(take_str)(py);
py_run!(
py,
take_str,
r#"
try:
take_str('\ud800')
except UnicodeEncodeError as e:
error_msg = "'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed"
assert str(e) == error_msg
"#
);
}