Merge pull request #3508 from mejrs/sub2

Create subinterpreter example
This commit is contained in:
David Hewitt 2023-11-26 10:42:31 +00:00 committed by GitHub
commit 0f34fcd4b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 510 additions and 0 deletions

View File

@ -11,6 +11,7 @@ Below is a brief description of each of these:
| `setuptools-rust-starter` | A template project which is configured to use [`setuptools_rust`](https://github.com/PyO3/setuptools-rust/) for development. |
| `word-count` | A quick performance comparison between word counter implementations written in each of Rust and Python. |
| `plugin` | Illustrates how to use Python as a scripting language within a Rust application |
| `sequential` | Illustrates how to use pyo3-ffi to write subinterpreter-safe modules |
## Creating new projects from these examples

View File

@ -0,0 +1,12 @@
[package]
authors = ["{{authors}}"]
name = "{{project-name}}"
version = "0.1.0"
edition = "2021"
[lib]
name = "sequential"
crate-type = ["cdylib", "lib"]
[dependencies]
pyo3-ffi = { version = "{{PYO3_VERSION}}", features = ["extension-module"] }

View File

@ -0,0 +1,4 @@
variable::set("PYO3_VERSION", "0.19.2");
file::rename(".template/Cargo.toml", "Cargo.toml");
file::rename(".template/pyproject.toml", "pyproject.toml");
file::delete(".template");

View File

@ -0,0 +1,7 @@
[build-system]
requires = ["maturin>=1,<2"]
build-backend = "maturin"
[project]
name = "{{project-name}}"
version = "0.1.0"

View File

@ -0,0 +1,13 @@
[package]
name = "sequential"
version = "0.1.0"
edition = "2021"
[lib]
name = "sequential"
crate-type = ["cdylib", "lib"]
[dependencies]
pyo3-ffi = { path = "../../pyo3-ffi", features = ["extension-module"] }
[workspace]

View File

@ -0,0 +1,2 @@
include pyproject.toml Cargo.toml
recursive-include src *

View File

@ -0,0 +1,36 @@
# sequential
A project built using only `pyo3_ffi`, without any of PyO3's safe api. It can be executed by subinterpreters that have their own GIL.
## Building and Testing
To build this package, first install `maturin`:
```shell
pip install maturin
```
To build and test use `maturin develop`:
```shell
pip install -r requirements-dev.txt
maturin develop
pytest
```
Alternatively, install nox and run the tests inside an isolated environment:
```shell
nox
```
## Copying this example
Use [`cargo-generate`](https://crates.io/crates/cargo-generate):
```bash
$ cargo install cargo-generate
$ cargo generate --git https://github.com/PyO3/pyo3 examples/sequential
```
(`cargo generate` will take a little while to clone the PyO3 repo first; be patient when waiting for the command to run.)

View File

@ -0,0 +1,5 @@
[template]
ignore = [".nox"]
[hooks]
pre = [".template/pre-script.rhai"]

View File

@ -0,0 +1,11 @@
import sys
import nox
@nox.session
def python(session):
if sys.version_info < (3, 12):
session.skip("Python 3.12+ is required")
session.env["MATURIN_PEP517_ARGS"] = "--profile=dev"
session.install(".[dev]")
session.run("pytest")

View File

@ -0,0 +1,20 @@
[build-system]
requires = ["maturin>=1,<2"]
build-backend = "maturin"
[project]
name = "sequential"
version = "0.1.0"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python",
"Programming Language :: Rust",
"Operating System :: POSIX",
"Operating System :: MacOS :: MacOS X",
]
requires-python = ">=3.12"
[project.optional-dependencies]
dev = ["pytest"]

View File

@ -0,0 +1,131 @@
use core::sync::atomic::{AtomicU64, Ordering};
use core::{mem, ptr};
use std::os::raw::{c_char, c_int, c_uint, c_ulonglong, c_void};
use pyo3_ffi::*;
#[repr(C)]
pub struct PyId {
_ob_base: PyObject,
id: Id,
}
static COUNT: AtomicU64 = AtomicU64::new(0);
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)]
pub struct Id(u64);
impl Id {
fn new() -> Self {
Id(COUNT.fetch_add(1, Ordering::Relaxed))
}
}
unsafe extern "C" fn id_new(
subtype: *mut PyTypeObject,
args: *mut PyObject,
kwds: *mut PyObject,
) -> *mut PyObject {
if PyTuple_Size(args) != 0 || !kwds.is_null() {
PyErr_SetString(
PyExc_TypeError,
"Id() takes no arguments\0".as_ptr().cast::<c_char>(),
);
return ptr::null_mut();
}
let f: allocfunc = (*subtype).tp_alloc.unwrap_or(PyType_GenericAlloc);
let slf = f(subtype, 0);
if slf.is_null() {
return ptr::null_mut();
} else {
let id = Id::new();
let slf = slf.cast::<PyId>();
ptr::addr_of_mut!((*slf).id).write(id);
}
slf
}
unsafe extern "C" fn id_repr(slf: *mut PyObject) -> *mut PyObject {
let slf = slf.cast::<PyId>();
let id = (*slf).id.0;
let string = format!("Id({})", id);
PyUnicode_FromStringAndSize(string.as_ptr().cast::<c_char>(), string.len() as Py_ssize_t)
}
unsafe extern "C" fn id_int(slf: *mut PyObject) -> *mut PyObject {
let slf = slf.cast::<PyId>();
let id = (*slf).id.0;
PyLong_FromUnsignedLongLong(id as c_ulonglong)
}
unsafe extern "C" fn id_richcompare(
slf: *mut PyObject,
other: *mut PyObject,
op: c_int,
) -> *mut PyObject {
let pytype = Py_TYPE(slf); // guaranteed to be `sequential.Id`
if Py_TYPE(other) != pytype {
return Py_NewRef(Py_NotImplemented());
}
let slf = (*slf.cast::<PyId>()).id;
let other = (*other.cast::<PyId>()).id;
let cmp = match op {
pyo3_ffi::Py_LT => slf < other,
pyo3_ffi::Py_LE => slf <= other,
pyo3_ffi::Py_EQ => slf == other,
pyo3_ffi::Py_NE => slf != other,
pyo3_ffi::Py_GT => slf > other,
pyo3_ffi::Py_GE => slf >= other,
unrecognized => {
let msg = format!("unrecognized richcompare opcode {}\0", unrecognized);
PyErr_SetString(PyExc_SystemError, msg.as_ptr().cast::<c_char>());
return ptr::null_mut();
}
};
if cmp {
Py_NewRef(Py_True())
} else {
Py_NewRef(Py_False())
}
}
static mut SLOTS: &[PyType_Slot] = &[
PyType_Slot {
slot: Py_tp_new,
pfunc: id_new as *mut c_void,
},
PyType_Slot {
slot: Py_tp_doc,
pfunc: "An id that is increased every time an instance is created\0".as_ptr()
as *mut c_void,
},
PyType_Slot {
slot: Py_tp_repr,
pfunc: id_repr as *mut c_void,
},
PyType_Slot {
slot: Py_nb_int,
pfunc: id_int as *mut c_void,
},
PyType_Slot {
slot: Py_tp_richcompare,
pfunc: id_richcompare as *mut c_void,
},
PyType_Slot {
slot: 0,
pfunc: ptr::null_mut(),
},
];
pub static mut ID_SPEC: PyType_Spec = PyType_Spec {
name: "sequential.Id\0".as_ptr().cast::<c_char>(),
basicsize: mem::size_of::<PyId>() as c_int,
itemsize: 0,
flags: (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE) as c_uint,
slots: unsafe { SLOTS as *const [PyType_Slot] as *mut PyType_Slot },
};

View File

@ -0,0 +1,14 @@
use std::ptr;
use pyo3_ffi::*;
mod id;
mod module;
use crate::module::MODULE_DEF;
// The module initialization function, which must be named `PyInit_<your_module>`.
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "C" fn PyInit_sequential() -> *mut PyObject {
PyModuleDef_Init(ptr::addr_of_mut!(MODULE_DEF))
}

View File

@ -0,0 +1,82 @@
use core::{mem, ptr};
use pyo3_ffi::*;
use std::os::raw::{c_char, c_int, c_void};
pub static mut MODULE_DEF: PyModuleDef = PyModuleDef {
m_base: PyModuleDef_HEAD_INIT,
m_name: "sequential\0".as_ptr().cast::<c_char>(),
m_doc: "A library for generating sequential ids, written in Rust.\0"
.as_ptr()
.cast::<c_char>(),
m_size: mem::size_of::<sequential_state>() as Py_ssize_t,
m_methods: std::ptr::null_mut(),
m_slots: unsafe { SEQUENTIAL_SLOTS as *const [PyModuleDef_Slot] as *mut PyModuleDef_Slot },
m_traverse: Some(sequential_traverse),
m_clear: Some(sequential_clear),
m_free: Some(sequential_free),
};
static mut SEQUENTIAL_SLOTS: &[PyModuleDef_Slot] = &[
PyModuleDef_Slot {
slot: Py_mod_exec,
value: sequential_exec as *mut c_void,
},
PyModuleDef_Slot {
slot: Py_mod_multiple_interpreters,
value: Py_MOD_PER_INTERPRETER_GIL_SUPPORTED,
},
PyModuleDef_Slot {
slot: 0,
value: ptr::null_mut(),
},
];
unsafe extern "C" fn sequential_exec(module: *mut PyObject) -> c_int {
let state: *mut sequential_state = PyModule_GetState(module).cast();
let id_type = PyType_FromModuleAndSpec(
module,
ptr::addr_of_mut!(crate::id::ID_SPEC),
ptr::null_mut(),
);
if id_type.is_null() {
PyErr_SetString(
PyExc_SystemError,
"cannot locate type object\0".as_ptr().cast::<c_char>(),
);
return -1;
}
(*state).id_type = id_type.cast::<PyTypeObject>();
PyModule_AddObjectRef(module, "Id\0".as_ptr().cast::<c_char>(), id_type)
}
unsafe extern "C" fn sequential_traverse(
module: *mut PyObject,
visit: visitproc,
arg: *mut c_void,
) -> c_int {
let state: *mut sequential_state = PyModule_GetState(module.cast()).cast();
let id_type: *mut PyObject = (*state).id_type.cast();
if id_type.is_null() {
0
} else {
(visit)(id_type, arg)
}
}
unsafe extern "C" fn sequential_clear(module: *mut PyObject) -> c_int {
let state: *mut sequential_state = PyModule_GetState(module.cast()).cast();
Py_CLEAR(ptr::addr_of_mut!((*state).id_type).cast());
0
}
unsafe extern "C" fn sequential_free(module: *mut c_void) {
sequential_clear(module.cast());
}
#[repr(C)]
struct sequential_state {
id_type: *mut PyTypeObject,
}

View File

@ -0,0 +1,151 @@
use core::ffi::{c_char, CStr};
use core::ptr;
use std::thread;
use pyo3_ffi::*;
use sequential::PyInit_sequential;
static COMMAND: &'static str = "
from sequential import Id
s = sum(int(Id()) for _ in range(12))
\0";
// Newtype to be able to pass it to another thread.
struct State(*mut PyThreadState);
unsafe impl Sync for State {}
unsafe impl Send for State {}
#[test]
fn lets_go_fast() -> Result<(), String> {
unsafe {
let ret = PyImport_AppendInittab(
"sequential\0".as_ptr().cast::<c_char>(),
Some(PyInit_sequential),
);
if ret == -1 {
return Err("could not add module to inittab".into());
}
Py_Initialize();
let main_state = PyThreadState_Swap(ptr::null_mut());
const NULL: State = State(ptr::null_mut());
let mut subs = [NULL; 12];
let config = PyInterpreterConfig {
use_main_obmalloc: 0,
allow_fork: 0,
allow_exec: 0,
allow_threads: 1,
allow_daemon_threads: 0,
check_multi_interp_extensions: 1,
gil: PyInterpreterConfig_OWN_GIL,
};
for State(state) in &mut subs {
let status = Py_NewInterpreterFromConfig(state, &config);
if PyStatus_IsError(status) == 1 {
let msg = if status.err_msg.is_null() {
"no error message".into()
} else {
CStr::from_ptr(status.err_msg).to_string_lossy()
};
PyThreadState_Swap(main_state);
Py_FinalizeEx();
return Err(format!("could not create new subinterpreter: {msg}"));
}
}
PyThreadState_Swap(main_state);
let main_state = PyEval_SaveThread(); // a PyInterpreterConfig with shared gil would deadlock otherwise
let ints: Vec<_> = thread::scope(move |s| {
let mut handles = vec![];
for state in subs {
let handle = s.spawn(move || {
let state = state;
PyEval_RestoreThread(state.0);
let ret = run_code();
Py_EndInterpreter(state.0);
ret
});
handles.push(handle);
}
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
PyEval_RestoreThread(main_state);
let ret = Py_FinalizeEx();
if ret == -1 {
return Err("could not finalize interpreter".into());
}
let mut sum: u64 = 0;
for i in ints {
let i = i?;
sum += i;
}
assert_eq!(sum, (0..).take(12 * 12).sum());
}
Ok(())
}
unsafe fn fetch() -> String {
let err = PyErr_GetRaisedException();
let err_repr = PyObject_Str(err);
if !err_repr.is_null() {
let mut size = 0;
let p = PyUnicode_AsUTF8AndSize(err_repr, &mut size);
if !p.is_null() {
let s = std::str::from_utf8_unchecked(std::slice::from_raw_parts(
p.cast::<u8>(),
size as usize,
));
let s = String::from(s);
Py_DECREF(err_repr);
return s;
}
}
String::from("could not get error")
}
fn run_code() -> Result<u64, String> {
unsafe {
let code_obj = Py_CompileString(
COMMAND.as_ptr().cast::<c_char>(),
"program\0".as_ptr().cast::<c_char>(),
Py_file_input,
);
if code_obj.is_null() {
return Err(fetch());
}
let globals = PyDict_New();
let res_ptr = PyEval_EvalCode(code_obj, globals, ptr::null_mut());
Py_DECREF(code_obj);
if res_ptr.is_null() {
return Err(fetch());
} else {
Py_DECREF(res_ptr);
}
let sum = PyDict_GetItemString(globals, "s\0".as_ptr().cast::<c_char>()); /* borrowed reference */
if sum.is_null() {
Py_DECREF(globals);
return Err("globals did not have `s`".into());
}
let int = PyLong_AsUnsignedLongLong(sum) as u64;
Py_DECREF(globals);
Ok(int)
}
}

View File

@ -0,0 +1,21 @@
import pytest
from sequential import Id
def test_make_some():
for x in range(12):
i = Id()
assert x == int(i)
def test_args():
with pytest.raises(TypeError, match="Id\\(\\) takes no arguments"):
Id(3, 4)
def test_cmp():
a = Id()
b = Id()
assert a <= b
assert a < b
assert a == a