pyo3/examples/sequential/tests/test.rs

152 lines
4.1 KiB
Rust

use core::ffi::{c_char, CStr};
use core::ptr;
use std::thread;
use pyo3_ffi::*;
use sequential::PyInit_sequential;
static COMMAND: &'static str = "
from sequential import Id
s = sum(int(Id()) for _ in range(12))
\0";
// Newtype to be able to pass it to another thread.
struct State(*mut PyThreadState);
unsafe impl Sync for State {}
unsafe impl Send for State {}
#[test]
fn lets_go_fast() -> Result<(), String> {
unsafe {
let ret = PyImport_AppendInittab(
"sequential\0".as_ptr().cast::<c_char>(),
Some(PyInit_sequential),
);
if ret == -1 {
return Err("could not add module to inittab".into());
}
Py_Initialize();
let main_state = PyThreadState_Swap(ptr::null_mut());
const NULL: State = State(ptr::null_mut());
let mut subs = [NULL; 12];
let config = PyInterpreterConfig {
use_main_obmalloc: 0,
allow_fork: 0,
allow_exec: 0,
allow_threads: 1,
allow_daemon_threads: 0,
check_multi_interp_extensions: 1,
gil: PyInterpreterConfig_OWN_GIL,
};
for State(state) in &mut subs {
let status = Py_NewInterpreterFromConfig(state, &config);
if PyStatus_IsError(status) == 1 {
let msg = if status.err_msg.is_null() {
"no error message".into()
} else {
CStr::from_ptr(status.err_msg).to_string_lossy()
};
PyThreadState_Swap(main_state);
Py_FinalizeEx();
return Err(format!("could not create new subinterpreter: {msg}"));
}
}
PyThreadState_Swap(main_state);
let main_state = PyEval_SaveThread(); // a PyInterpreterConfig with shared gil would deadlock otherwise
let ints: Vec<_> = thread::scope(move |s| {
let mut handles = vec![];
for state in subs {
let handle = s.spawn(move || {
let state = state;
PyEval_RestoreThread(state.0);
let ret = run_code();
Py_EndInterpreter(state.0);
ret
});
handles.push(handle);
}
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
PyEval_RestoreThread(main_state);
let ret = Py_FinalizeEx();
if ret == -1 {
return Err("could not finalize interpreter".into());
}
let mut sum: u64 = 0;
for i in ints {
let i = i?;
sum += i;
}
assert_eq!(sum, (0..).take(12 * 12).sum());
}
Ok(())
}
unsafe fn fetch() -> String {
let err = PyErr_GetRaisedException();
let err_repr = PyObject_Str(err);
if !err_repr.is_null() {
let mut size = 0;
let p = PyUnicode_AsUTF8AndSize(err_repr, &mut size);
if !p.is_null() {
let s = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
p.cast::<u8>(),
size as usize,
));
let s = String::from(s);
Py_DECREF(err_repr);
return s;
}
}
String::from("could not get error")
}
fn run_code() -> Result<u64, String> {
unsafe {
let code_obj = Py_CompileString(
COMMAND.as_ptr().cast::<c_char>(),
"program\0".as_ptr().cast::<c_char>(),
Py_file_input,
);
if code_obj.is_null() {
return Err(fetch());
}
let globals = PyDict_New();
let res_ptr = PyEval_EvalCode(code_obj, globals, ptr::null_mut());
Py_DECREF(code_obj);
if res_ptr.is_null() {
return Err(fetch());
} else {
Py_DECREF(res_ptr);
}
let sum = PyDict_GetItemString(globals, "s\0".as_ptr().cast::<c_char>()); /* borrowed reference */
if sum.is_null() {
Py_DECREF(globals);
return Err("globals did not have `s`".into());
}
let int = PyLong_AsUnsignedLongLong(sum) as u64;
Py_DECREF(globals);
Ok(int)
}
}