implement buffer protocol with proc macro

This commit is contained in:
Nikolay Kim 2017-05-14 12:52:30 -07:00
parent 9eae1523cb
commit 3e20979f3f
13 changed files with 312 additions and 31 deletions

2
.gitignore vendored
View File

@ -2,6 +2,8 @@
/Cargo.lock /Cargo.lock
/doc /doc
/gh-pages /gh-pages
/pyo3cls/target
/pyo3cls/Cargo.lock
*.so *.so
*.out *.out

View File

@ -5,6 +5,7 @@ description = "Bindings to Python"
authors = ["PyO3 Project and Contributors <https://github.com/PyO3"] authors = ["PyO3 Project and Contributors <https://github.com/PyO3"]
readme = "README.md" readme = "README.md"
keywords = [ keywords = [
"pyo3",
"python", "python",
"cpython", "cpython",
] ]
@ -16,20 +17,21 @@ license = "MIT/APACHE-2"
exclude = [ exclude = [
".gitignore", ".gitignore",
".travis.yml", ".travis.yml",
"appveyor.yml",
".cargo/config", ".cargo/config",
"appveyor.yml",
] ]
build = "build.rs" build = "build.rs"
[dependencies] [dependencies]
libc = "0.2" libc = "0.2"
num-traits = "0.1" num-traits = "0.1"
pyo3cls = { path = "pyo3cls" }
[build-dependencies] [build-dependencies]
regex = "0.1" regex = "0.1"
[features] [features]
default = [] default = ["nightly"]
# Enable additional features that require nightly rust # Enable additional features that require nightly rust
nightly = [] nightly = []

View File

@ -257,7 +257,7 @@ fn find_interpreter_and_get_config() -> Result<(PythonVersion, String, Vec<Strin
.expect("Unable to get PYTHON_SYS_EXECUTABLE value"); .expect("Unable to get PYTHON_SYS_EXECUTABLE value");
let (interpreter_version, lines) = try!(get_config_from_interpreter(interpreter_path)); let (interpreter_version, lines) = try!(get_config_from_interpreter(interpreter_path));
if MIN_MINOR > interpreter_version.minor.unwrap_or(0) { if interpreter_version.major < 3 || MIN_MINOR > interpreter_version.minor.unwrap_or(0) {
return Err(format!("Unsupported python version in PYTHON_SYS_EXECUTABLE={}\n\ return Err(format!("Unsupported python version in PYTHON_SYS_EXECUTABLE={}\n\
\tmin version 3.4 != found {}", \tmin version 3.4 != found {}",
interpreter_path, interpreter_path,
@ -270,8 +270,9 @@ fn find_interpreter_and_get_config() -> Result<(PythonVersion, String, Vec<Strin
{ {
let interpreter_path = "python"; let interpreter_path = "python";
let (interpreter_version, lines) = try!(get_config_from_interpreter(interpreter_path)); let (interpreter_version, lines) = try!(get_config_from_interpreter(interpreter_path));
if MIN_MINOR <= interpreter_version.minor.unwrap_or(0) { if MIN_MINOR <= interpreter_version.minor.unwrap_or(0) &&
return Ok((interpreter_version, interpreter_path.to_owned(), lines)); interpreter_version.major == 3 {
return Ok((interpreter_version, interpreter_path.to_owned(), lines));
} }
} }
@ -318,7 +319,6 @@ fn configure_from_path() -> Result<(String, String), String> {
let mut flags = String::new(); let mut flags = String::new();
println!("test: {:?}", interpreter_version);
if let PythonVersion { major: 3, minor: some_minor} = interpreter_version { if let PythonVersion { major: 3, minor: some_minor} = interpreter_version {
if env::var_os("CARGO_FEATURE_PEP_384").is_some() { if env::var_os("CARGO_FEATURE_PEP_384").is_some() {
println!("cargo:rustc-cfg=Py_LIMITED_API"); println!("cargo:rustc-cfg=Py_LIMITED_API");
@ -346,7 +346,6 @@ fn main() {
// try using 'env' (sorry but this isn't our fault - it just has to // try using 'env' (sorry but this isn't our fault - it just has to
// match the pkg-config package name, which is going to have a . in it). // match the pkg-config package name, which is going to have a . in it).
let (python_interpreter_path, flags) = configure_from_path().unwrap(); let (python_interpreter_path, flags) = configure_from_path().unwrap();
println!('3');
let config_map = get_config_vars(&python_interpreter_path).unwrap(); let config_map = get_config_vars(&python_interpreter_path).unwrap();
for (key, val) in &config_map { for (key, val) in &config_map {
match cfg_line_for_var(key, val) { match cfg_line_for_var(key, val) {
@ -380,5 +379,6 @@ fn main() {
}) + flags.as_str(); }) + flags.as_str();
println!("cargo:python_flags={}", println!("cargo:python_flags={}",
if flags.len() > 0 { &flags[..flags.len()-1] } else { "" }); if flags.len() > 0 { &flags[..flags.len()-1] } else { "" });
} }

15
pyo3cls/Cargo.toml Normal file
View File

@ -0,0 +1,15 @@
[package]
name = "pyo3cls"
version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
[lib]
proc-macro = true
[dependencies]
quote="0.3"
[dependencies.syn]
version="0.11"
features=["full"]
#git = "https://github.com/dtolnay/syn.git"

30
pyo3cls/src/lib.rs Normal file
View File

@ -0,0 +1,30 @@
#![feature(proc_macro)]
extern crate proc_macro;
extern crate syn;
#[macro_use] extern crate quote;
use std::str::FromStr;
use proc_macro::TokenStream;
mod py_impl;
use py_impl::build_py_impl;
#[proc_macro_attribute]
pub fn py_impl(_: TokenStream, input: TokenStream) -> TokenStream {
// Construct a string representation of the type definition
let source = input.to_string();
// Parse the string representation into a syntax tree
//let ast: syn::Crate = source.parse().unwrap();
let ast = syn::parse_item(&source).unwrap();
// Build the output
let expanded = build_py_impl(&ast);
// Return the generated impl as a TokenStream
let s = source + expanded.as_str();
TokenStream::from_str(s.as_str()).unwrap()
}

62
pyo3cls/src/py_impl.rs Normal file
View File

@ -0,0 +1,62 @@
use syn;
use quote;
enum ImplType {
Buffer,
}
pub fn build_py_impl(ast: &syn::Item) -> quote::Tokens {
match ast.node {
syn::ItemKind::Impl(_, _, _, ref path, ref ty, ref impl_items) => {
if let &Some(ref path) = path {
match process_path(path) {
ImplType::Buffer => {
impl_protocol("PyBufferProtocolImpl", path.clone(), ty, impl_items)
}
}
} else {
//ImplType::Impl
unimplemented!()
}
},
_ => panic!("#[py_impl] can only be used with Impl blocks"),
}
}
fn process_path(path: &syn::Path) -> ImplType {
if let Some(segment) = path.segments.last() {
match segment.ident.as_ref() {
"PyBufferProtocol" => ImplType::Buffer,
_ => panic!("#[py_impl] can not be used with this block"),
}
} else {
panic!("#[py_impl] can not be used with this block");
}
}
fn impl_protocol(name: &'static str,
path: syn::Path, ty: &Box<syn::Ty>,
impls: &Vec<syn::ImplItem>) -> quote::Tokens {
// get method names in impl block
let mut meth = Vec::new();
for iimpl in impls.iter() {
meth.push(String::from(iimpl.ident.as_ref()))
}
// set trait name
let mut path = path;
{
let mut last = path.segments.last_mut().unwrap();
last.ident = syn::Ident::from(name);
}
quote! {
impl #path for #ty {
fn methods() -> &'static [&'static str] {
static METHODS: &'static [&'static str] = &[#(#meth,),*];
METHODS
}
}
}
}

85
src/class/buffer.rs Normal file
View File

@ -0,0 +1,85 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
use std;
use std::os::raw::c_int;
use ffi;
use err::{PyErr, PyResult};
use python::{self, Python, PythonObject};
use conversion::ToPyObject;
use objects::{PyObject, PyType, PyModule};
use py_class::slots::UnitCallbackConverter;
use function::handle_callback;
pub trait PyBufferProtocolImpl {
fn methods() -> &'static [&'static str];
}
impl<T> PyBufferProtocolImpl for T {
default fn methods() -> &'static [&'static str] {
static METHODS: &'static [&'static str] = &[];
METHODS
}
}
pub trait PyBufferProtocol {
fn bf_getbuffer(&self, py: Python, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()>;
fn bf_releasebuffer(&self, py: Python, view: *mut ffi::Py_buffer) -> PyResult<()>;
}
impl<T> PyBufferProtocol for T {
default fn bf_getbuffer(&self, _py: Python,
_view: *mut ffi::Py_buffer, _flags: c_int) -> PyResult<()> {
Ok(())
}
default fn bf_releasebuffer(&self, _py: Python,
_view: *mut ffi::Py_buffer) -> PyResult<()> {
Ok(())
}
}
impl ffi::PyBufferProcs {
pub fn new<T>() -> Option<ffi::PyBufferProcs>
where T: PyBufferProtocol + PyBufferProtocolImpl + PythonObject
{
let methods = T::methods();
if methods.is_empty() {
return None
}
let mut buf_procs: ffi::PyBufferProcs = ffi::PyBufferProcs_INIT;
for name in methods {
match name {
&"bf_getbuffer" => {
buf_procs.bf_getbuffer = {
unsafe extern "C" fn wrap<T>(slf: *mut ffi::PyObject, arg1: *mut ffi::Py_buffer, arg2: c_int) -> c_int
where T: PyBufferProtocol + PythonObject
{
const LOCATION: &'static str = concat!(stringify!(T), ".buffer_get::<PyBufferProtocol>()");
handle_callback(LOCATION, UnitCallbackConverter,
|py| {
let slf = PyObject::from_borrowed_ptr(py, slf).unchecked_cast_into::<T>();
let result = slf.bf_getbuffer(py, arg1, arg2);
::PyDrop::release_ref(slf, py);
result
}
)
}
Some(wrap::<T>)
}
},
_ => ()
}
}
Some(buf_procs)
}
}

4
src/class/mod.rs Normal file
View File

@ -0,0 +1,4 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
pub mod buffer;
pub use self::buffer::*;

View File

@ -95,7 +95,7 @@ mod bufferinfo {
#[derive(Copy)] #[derive(Copy)]
pub struct Py_buffer { pub struct Py_buffer {
pub buf: *mut c_void, pub buf: *mut c_void,
pub obj: *mut ::PyObject, pub obj: *mut ::ffi::PyObject,
pub len: Py_ssize_t, pub len: Py_ssize_t,
pub itemsize: Py_ssize_t, pub itemsize: Py_ssize_t,
pub readonly: c_int, pub readonly: c_int,
@ -114,12 +114,12 @@ mod bufferinfo {
} }
pub type getbufferproc = pub type getbufferproc =
unsafe extern "C" fn(arg1: *mut ::PyObject, unsafe extern "C" fn(arg1: *mut ::ffi::PyObject,
arg2: *mut Py_buffer, arg2: *mut Py_buffer,
arg3: c_int) arg3: c_int)
-> c_int; -> c_int;
pub type releasebufferproc = pub type releasebufferproc =
unsafe extern "C" fn(arg1: *mut ::PyObject, unsafe extern "C" fn(arg1: *mut ::ffi::PyObject,
arg2: *mut Py_buffer) -> (); arg2: *mut Py_buffer) -> ();
/// Maximum number of dimensions /// Maximum number of dimensions
@ -407,7 +407,7 @@ mod typeobject {
am_anext: None, am_anext: None,
}; };
#[repr(C)] #[repr(C)]
#[derive(Copy)] #[derive(Copy, Debug)]
pub struct PyBufferProcs { pub struct PyBufferProcs {
pub bf_getbuffer: Option<object::getbufferproc>, pub bf_getbuffer: Option<object::getbufferproc>,
pub bf_releasebuffer: Option<object::releasebufferproc>, pub bf_releasebuffer: Option<object::releasebufferproc>,

View File

@ -80,7 +80,12 @@
extern crate libc; extern crate libc;
#[macro_use] pub extern crate pyo3cls;
pub use pyo3cls as cls;
pub mod ffi; pub mod ffi;
pub mod class;
pub use ffi::Py_ssize_t; pub use ffi::Py_ssize_t;
pub use err::{PyErr, PyResult}; pub use err::{PyErr, PyResult};
pub use objects::*; pub use objects::*;
@ -89,6 +94,7 @@ pub use pythonrun::{GILGuard, GILProtected, prepare_freethreaded_python};
pub use conversion::{FromPyObject, RefFromPyObject, ToPyObject}; pub use conversion::{FromPyObject, RefFromPyObject, ToPyObject};
pub use py_class::{CompareOp}; pub use py_class::{CompareOp};
pub use objectprotocol::{ObjectProtocol}; pub use objectprotocol::{ObjectProtocol};
pub use class::*;
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub type Py_hash_t = ffi::Py_hash_t; pub type Py_hash_t = ffi::Py_hash_t;

View File

@ -79,7 +79,7 @@ unsafe fn make_shared(ptr: *mut ffi::PyObject) -> *mut ffi::PyObject {
#[inline] #[inline]
#[cfg(feature="nightly")] #[cfg(feature="nightly")]
fn unpack_shared(ptr: ptr::Shared<ffi::PyObject>) -> *mut ffi::PyObject { fn unpack_shared(ptr: ptr::Shared<ffi::PyObject>) -> *mut ffi::PyObject {
*ptr ptr.as_ptr()
} }
#[inline] #[inline]

View File

@ -26,6 +26,7 @@ use objects::PyObject;
use function::CallbackConverter; use function::CallbackConverter;
use err::{PyErr, PyResult}; use err::{PyErr, PyResult};
use py_class::{CompareOp}; use py_class::{CompareOp};
use class::PyBufferProtocol;
use exc; use exc;
use Py_hash_t; use Py_hash_t;
@ -49,7 +50,7 @@ macro_rules! py_class_type_object_static_init {
tp_flags: py_class_type_object_flags!($gc), tp_flags: py_class_type_object_flags!($gc),
tp_traverse: py_class_tp_traverse!($class_name, $gc), tp_traverse: py_class_tp_traverse!($class_name, $gc),
.. ..
$crate::_detail::ffi::PyTypeObject_INIT $crate::_detail::ffi::PyTypeObject_INIT
} }
); );
} }
@ -74,6 +75,8 @@ macro_rules! py_class_type_object_flags {
pub const TPFLAGS_DEFAULT : ::libc::c_ulong = ffi::Py_TPFLAGS_DEFAULT; pub const TPFLAGS_DEFAULT : ::libc::c_ulong = ffi::Py_TPFLAGS_DEFAULT;
use class::buffer::*;
#[macro_export] #[macro_export]
#[doc(hidden)] #[doc(hidden)]
macro_rules! py_class_type_object_dynamic_init { macro_rules! py_class_type_object_dynamic_init {
@ -94,15 +97,25 @@ macro_rules! py_class_type_object_dynamic_init {
$type_object.tp_basicsize = <$class as $crate::py_class::BaseObject>::size() $type_object.tp_basicsize = <$class as $crate::py_class::BaseObject>::size()
as $crate::_detail::ffi::Py_ssize_t; as $crate::_detail::ffi::Py_ssize_t;
} }
// call slot macros outside of unsafe block // call slot macros outside of unsafe block
*(unsafe { &mut $type_object.tp_as_async }) = py_class_as_async!($as_async); *(unsafe { &mut $type_object.tp_as_async }) = py_class_as_async!($as_async);
*(unsafe { &mut $type_object.tp_as_sequence }) = py_class_as_sequence!($as_sequence); *(unsafe { &mut $type_object.tp_as_sequence }) = py_class_as_sequence!($as_sequence);
*(unsafe { &mut $type_object.tp_as_number }) = py_class_as_number!($as_number); *(unsafe { &mut $type_object.tp_as_number }) = py_class_as_number!($as_number);
*(unsafe { &mut $type_object.tp_as_buffer }) = py_class_as_buffer!($as_buffer);
if let Some(buf) = $crate::ffi::PyBufferProcs::new::<$class>() {
static mut BUFFER_PROCS: $crate::ffi::PyBufferProcs = $crate::ffi::PyBufferProcs_INIT;
*(unsafe { &mut BUFFER_PROCS }) = buf;
*(unsafe { &mut $type_object.tp_as_buffer }) = unsafe { &mut BUFFER_PROCS };
} else {
*(unsafe { &mut $type_object.tp_as_buffer }) = 0 as *mut $crate::ffi::PyBufferProcs;
}
py_class_as_mapping!($type_object, $as_mapping, $setdelitem); py_class_as_mapping!($type_object, $as_mapping, $setdelitem);
} }
} }
pub fn build_tp_name(module_name: Option<&str>, type_name: &str) -> *mut c_char { pub fn build_tp_name(module_name: Option<&str>, type_name: &str) -> *mut c_char {
let name = match module_name { let name = match module_name {
Some(module_name) => CString::new(format!("{}.{}", module_name, type_name)), Some(module_name) => CString::new(format!("{}.{}", module_name, type_name)),
@ -196,22 +209,6 @@ macro_rules! py_class_as_async {
} }
#[macro_export]
#[doc(hidden)]
macro_rules! py_class_as_buffer {
([]) => (0 as *mut $crate::_detail::ffi::PyBufferProcs);
([$( $slot_name:ident : $slot_value:expr ,)+]) => {{
static mut BUFFER_PROCS : $crate::_detail::ffi::PyBufferProcs
= $crate::_detail::ffi::PyBufferProcs {
$( $slot_name : $slot_value, )*
..
$crate::_detail::ffi::PyBufferProcs_INIT
};
unsafe { &mut BUFFER_PROCS }
}}
}
#[macro_export] #[macro_export]
#[doc(hidden)] #[doc(hidden)]
macro_rules! py_class_as_mapping { macro_rules! py_class_as_mapping {

View File

@ -0,0 +1,78 @@
#![allow(dead_code, unused_variables)]
#![feature(proc_macro, specialization)]
#[macro_use] extern crate pyo3;
use std::ptr;
use std::os::raw::{c_int, c_void};
use pyo3::*;
use pyo3::cls;
py_class!(class TestClass |py| {
data vec: Vec<u8>;
});
#[cls::py_impl]
impl class::PyBufferProtocol for TestClass {
fn bf_getbuffer(&self, py: Python, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()> {
if view == ptr::null_mut() {
return Err(PyErr::new::<exc::BufferError, _>(py, "View is null"))
}
unsafe {
(*view).obj = ptr::null_mut();
}
if (flags & ffi::PyBUF_WRITABLE) == ffi::PyBUF_WRITABLE {
return Err(PyErr::new::<exc::BufferError, _>(py, "Object is not writable"))
}
let bytes = self.vec(py);
unsafe {
(*view).buf = bytes.as_ptr() as *mut c_void;
(*view).len = bytes.len() as isize;
(*view).readonly = 1;
(*view).itemsize = 1;
(*view).format = ptr::null_mut();
if (flags & ffi::PyBUF_FORMAT) == ffi::PyBUF_FORMAT {
let msg = ::std::ffi::CStr::from_ptr("B\0".as_ptr() as *const _);
(*view).format = msg.as_ptr() as *mut _;
}
(*view).ndim = 1;
(*view).shape = ptr::null_mut();
if (flags & ffi::PyBUF_ND) == ffi::PyBUF_ND {
(*view).shape = (&((*view).len)) as *const _ as *mut _;
}
(*view).strides = ptr::null_mut();
if (flags & ffi::PyBUF_STRIDES) == ffi::PyBUF_STRIDES {
(*view).strides = &((*view).itemsize) as *const _ as *mut _;
}
(*view).suboffsets = ptr::null_mut();
(*view).internal = ptr::null_mut();
}
Ok(())
}
}
#[test]
fn test_buffer() {
let gil = Python::acquire_gil();
let py = gil.python();
let t = TestClass::create_instance(py, vec![b' ', b'2', b'3']).unwrap();
let d = PyDict::new(py);
let _ = d.set_item(py, "ob", t);
py.run("assert bytes(ob) == b' 23'", None, Some(&d)).unwrap();
}