From cd6f7295a1e993b8e90cef3bb7b1ee5c2059ebdd Mon Sep 17 00:00:00 2001 From: Paul Ganssle Date: Thu, 9 Aug 2018 14:17:34 -0400 Subject: [PATCH] Add type checking FFI bindings --- src/ffi3/datetime.rs | 29 ++++- tests/rustapi_module/tests/test_datetime.py | 8 ++ tests/test_datetime.rs | 126 ++++++++++++++++++++ 3 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 tests/test_datetime.rs diff --git a/src/ffi3/datetime.rs b/src/ffi3/datetime.rs index f09ea3da..f90e0347 100644 --- a/src/ffi3/datetime.rs +++ b/src/ffi3/datetime.rs @@ -181,14 +181,19 @@ pub unsafe fn PyDate_Check(op: *mut PyObject) -> c_int { PyObject_TypeCheck(op, PyDateTimeAPI.DateType) as c_int } +#[inline(always)] +pub unsafe fn PyDate_CheckExact(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == PyDateTimeAPI.DateType) as c_int +} + #[inline(always)] pub unsafe fn PyDateTime_Check(op: *mut PyObject) -> c_int { PyObject_TypeCheck(op, PyDateTimeAPI.DateTimeType) as c_int } #[inline(always)] -pub unsafe fn PyTZInfo_Check(op: *mut PyObject) -> c_int { - PyObject_TypeCheck(op, PyDateTimeAPI.TZInfoType) as c_int +pub unsafe fn PyDateTime_CheckExact(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == PyDateTimeAPI.DateTimeType) as c_int } #[inline(always)] @@ -196,11 +201,31 @@ pub unsafe fn PyTime_Check(op: *mut PyObject) -> c_int { PyObject_TypeCheck(op, PyDateTimeAPI.TimeType) as c_int } +#[inline(always)] +pub unsafe fn PyTime_CheckExact(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == PyDateTimeAPI.TimeType) as c_int +} + #[inline(always)] pub unsafe fn PyDelta_Check(op: *mut PyObject) -> c_int { PyObject_TypeCheck(op, PyDateTimeAPI.DeltaType) as c_int } +#[inline(always)] +pub unsafe fn PyDelta_CheckExact(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == PyDateTimeAPI.DeltaType) as c_int +} + +#[inline(always)] +pub unsafe fn PyTZInfo_Check(op: *mut PyObject) -> c_int { + PyObject_TypeCheck(op, PyDateTimeAPI.TZInfoType) as c_int +} + +#[inline(always)] +pub unsafe fn PyTZInfo_CheckExact(op: *mut PyObject) -> c_int { + (Py_TYPE(op) == PyDateTimeAPI.TZInfoType) as c_int +} + // // Accessor functions // diff --git a/tests/rustapi_module/tests/test_datetime.py b/tests/rustapi_module/tests/test_datetime.py index 285ecb1f..93764f31 100644 --- a/tests/rustapi_module/tests/test_datetime.py +++ b/tests/rustapi_module/tests/test_datetime.py @@ -25,6 +25,13 @@ def get_timestamp(dt): return dt.timestamp() + + + + + + + # Tests def test_date(): assert rdt.make_date(2017, 9, 1) == pdt.date(2017, 9, 1) @@ -224,3 +231,4 @@ def test_delta_accessors(td): def test_delta_err(args, err_type): with pytest.raises(err_type): rdt.make_delta(*args) + diff --git a/tests/test_datetime.rs b/tests/test_datetime.rs new file mode 100644 index 00000000..deadb49e --- /dev/null +++ b/tests/test_datetime.rs @@ -0,0 +1,126 @@ +#![feature(concat_idents)] + +extern crate pyo3; + +use pyo3::prelude::*; + +use pyo3::ffi::*; + +#[cfg(Py_3)] +fn _get_subclasses<'p>(py: &'p Python, py_type: &str, args: &str) -> + (&'p PyObjectRef, &'p PyObjectRef, &'p PyObjectRef) { + macro_rules! unwrap_py { + ($e:expr) => { ($e).map_err(|e| e.print(*py)).unwrap() } + }; + + // Import the class from Python and create some subclasses + let datetime = unwrap_py!(py.import("datetime")); + + let locals = PyDict::new(*py); + locals.set_item(py_type, datetime.get(py_type).unwrap()) + .unwrap(); + + let make_subclass_py = + format!("class Subklass({}):\n pass", py_type); + + let make_sub_subclass_py = + "class SubSubklass(Subklass):\n pass"; + + unwrap_py!(py.run(&make_subclass_py, None, Some(&locals))); + unwrap_py!(py.run(&make_sub_subclass_py, None, Some(&locals))); + + // Construct an instance of the base class + let obj = unwrap_py!( + py.eval(&format!("{}({})", py_type, args), None, Some(&locals)) + ); + + // Construct an instance of the subclass + let sub_obj = unwrap_py!( + py.eval(&format!("Subklass({})", args), None, Some(&locals)) + ); + + // Construct an instance of the sub-subclass + let sub_sub_obj = unwrap_py!( + py.eval(&format!("SubSubklass({})", args), None, Some(&locals)) + ); + + (obj, sub_obj, sub_sub_obj) +} + +#[cfg(Py_3)] +macro_rules! assert_check_exact { + ($check_func:ident, $obj: expr) => { + unsafe { + assert!($check_func(($obj).as_ptr()) != 0); + assert!(concat_idents!($check_func, Exact)(($obj).as_ptr()) != 0); + } + } +} + +#[cfg(Py_3)] +macro_rules! assert_check_only { + ($check_func:ident, $obj: expr) => { + unsafe { + assert!($check_func(($obj).as_ptr()) != 0); + assert!(concat_idents!($check_func, Exact)(($obj).as_ptr()) == 0); + } + } +} + + +#[test] +#[cfg(Py_3)] +fn test_date_check() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let (obj, sub_obj, sub_sub_obj) = _get_subclasses(&py, + "date", "2018, 1, 1" + ); + + assert_check_exact!(PyDate_Check, obj); + assert_check_only!(PyDate_Check, sub_obj); + assert_check_only!(PyDate_Check, sub_sub_obj); +} + +#[test] +#[cfg(Py_3)] +fn test_time_check() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let (obj, sub_obj, sub_sub_obj) = _get_subclasses(&py, + "time", "12, 30, 15" + ); + + assert_check_exact!(PyTime_Check, obj); + assert_check_only!(PyTime_Check, sub_obj); + assert_check_only!(PyTime_Check, sub_sub_obj); +} + +#[test] +#[cfg(Py_3)] +fn test_datetime_check() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let (obj, sub_obj, sub_sub_obj) = _get_subclasses(&py, + "datetime", "2018, 1, 1, 13, 30, 15" + ); + + assert_check_only!(PyDate_Check, obj); + assert_check_exact!(PyDateTime_Check, obj); + assert_check_only!(PyDateTime_Check, sub_obj); + assert_check_only!(PyDateTime_Check, sub_sub_obj); +} + +#[test] +#[cfg(Py_3)] +fn test_delta_check() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let (obj, sub_obj, sub_sub_obj) = _get_subclasses(&py, + "timedelta", "1, -3" + ); + + assert_check_exact!(PyDelta_Check, obj); + assert_check_only!(PyDelta_Check, sub_obj); + assert_check_only!(PyDelta_Check, sub_sub_obj); +}