From 7ec1fed7984ac7248d2fd6d842d3b9849438c582 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sun, 14 Mar 2021 16:25:04 +0900 Subject: [PATCH 1/2] Extend py_run! macro to take dict as *d syntax --- src/lib.rs | 55 +++++++++++++++++++++++++++++++++++-------------- tests/common.rs | 44 ++++++++++++++++++++++----------------- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9b7bce50..1c14aa19 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -289,10 +289,10 @@ macro_rules! wrap_pymodule { /// # Example /// ``` /// use pyo3::{prelude::*, py_run, types::PyList}; -/// let gil = Python::acquire_gil(); -/// let py = gil.python(); -/// let list = PyList::new(py, &[1, 2, 3]); -/// py_run!(py, list, "assert list == [1, 2, 3]"); +/// Python::with_gil(|py| { +/// let list = PyList::new(py, &[1, 2, 3]); +/// py_run!(py, list, "assert list == [1, 2, 3]"); +/// }); /// ``` /// /// You can use this macro to test pyfunctions or pyclasses quickly. @@ -320,15 +320,32 @@ macro_rules! wrap_pymodule { /// (self.hour, self.minute, self.second) /// } /// } -/// let gil = Python::acquire_gil(); -/// let py = gil.python(); -/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap(); -/// let time_as_tuple = (8, 43, 16); -/// py_run!(py, time time_as_tuple, r#" -/// assert time.hour == 8 -/// assert time.repl_japanese() == "8時43分16秒" -/// assert time.as_tuple() == time_as_tuple -/// "#); +/// Python::with_gil(|py| { +/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap(); +/// let time_as_tuple = (8, 43, 16); +/// py_run!(py, time time_as_tuple, r#" +/// assert time.hour == 8 +/// assert time.repl_japanese() == "8時43分16秒" +/// assert time.as_tuple() == time_as_tuple +/// "#); +/// }); +/// ``` +/// +/// If you need to prepare the `locals` dict by yourself, you can pass it by `*locals`. +/// +/// ``` +/// # use pyo3::prelude::*; +/// #[pyclass] +/// struct MyClass {} +/// #[pymethods] +/// impl MyClass { +/// #[new] +/// fn new() -> Self { MyClass {} } +/// } +/// Python::with_gil(|py| { +/// let locals = [("C", py.get_type::())]; +/// pyo3::py_run!(py, *locals, "c = C()"); +/// }); /// ``` /// /// **Note** @@ -345,6 +362,12 @@ macro_rules! py_run { ($py:expr, $($val:ident)+, $code:expr) => {{ $crate::py_run_impl!($py, $($val)+, &$crate::unindent::unindent($code)) }}; + ($py:expr, *$dict:expr, $code:literal) => {{ + $crate::py_run_impl!($py, *$dict, $crate::indoc::indoc!($code)) + }}; + ($py:expr, *$dict:expr, $code:expr) => {{ + $crate::py_run_impl!($py, *$dict, &$crate::unindent::unindent($code)) + }}; } #[macro_export] @@ -355,8 +378,10 @@ macro_rules! py_run_impl { use $crate::types::IntoPyDict; use $crate::ToPyObject; let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py); - - if let Err(e) = $py.run($code, None, Some(d)) { + $crate::py_run_impl!($py, *d, $code) + }}; + ($py:expr, *$dict:expr, $code:expr) => {{ + if let Err(e) = $py.run($code, None, Some($dict)) { e.print($py); // So when this c api function the last line called printed the error to stderr, // the output is only written into a buffer which is never flushed because we diff --git a/tests/common.rs b/tests/common.rs index c110f5d4..75a246a9 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,37 +1,43 @@ -//! Useful tips for writing tests: -//! - Tests are run in parallel; There's still a race condition in test_owned with some other test -//! - You need to use flush=True to get any output from print +//! Some common macros for tests #[macro_export] macro_rules! py_assert { - ($py:expr, $val:ident, $assertion:expr) => { - pyo3::py_run!($py, $val, concat!("assert ", $assertion)) + ($py:expr, $($val:ident)+, $assertion:literal) => { + pyo3::py_run!($py, $($val)+, concat!("assert ", $assertion)) + }; + ($py:expr, *$dict:expr, $assertion:literal) => { + pyo3::py_run!($py, *$dict, concat!("assert ", $assertion)) }; } #[macro_export] macro_rules! py_expect_exception { - ($py:expr, $val:ident, $code:expr, $err:ident) => {{ + // Case1: idents & no err_msg + ($py:expr, $($val:ident)+, $code:expr, $err:ident) => {{ use pyo3::types::IntoPyDict; - let d = [(stringify!($val), &$val)].into_py_dict($py); - - let res = $py.run($code, None, Some(d)); + let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py); + py_expect_exception!($py, *d, $code, $err) + }}; + // Case2: dict & no err_msg + ($py:expr, *$dict:expr, $code:expr, $err:ident) => {{ + let res = $py.run($code, None, Some($dict)); let err = res.expect_err(&format!("Did not raise {}", stringify!($err))); if !err.matches($py, $py.get_type::()) { panic!("Expected {} but got {:?}", stringify!($err), err) } err }}; - ($py:expr, $val:ident, $code:expr, $err:ident, $err_msg:expr) => {{ - let err = py_expect_exception!($py, $val, $code, $err); - assert_eq!( - err.instance($py) - .str() - .expect("error str() failed") - .to_str() - .expect("message was not valid utf8"), - $err_msg - ); + // Case3: idents & err_msg + ($py:expr, $($val:ident)+, $code:expr, $err:ident, $err_msg:literal) => {{ + let err = py_expect_exception!($py, $($val)+, $code, $err); + // Suppose that the error message looks like 'TypeError: ~' + assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg)); + err + }}; + // Case4: dict & err_msg + ($py:expr, *$dict:expr, $code:expr, $err:ident, $err_msg:literal) => {{ + let err = py_expect_exception!($py, *$dict, $code, $err); + assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg)); err }}; } From 9b88a452e2779cef3b3c0884a27c31f40343b94b Mon Sep 17 00:00:00 2001 From: kngwyu Date: Sun, 14 Mar 2021 16:34:05 +0900 Subject: [PATCH 2/2] Refactor tests to use shorter macros --- src/lib.rs | 7 +- tests/test_buffer_protocol.rs | 5 +- tests/test_dunder.rs | 8 +-- tests/test_getter_setter.rs | 7 +- tests/test_mapping.rs | 60 ++++++++-------- tests/test_methods.rs | 105 ++++++++++++++------------- tests/test_module.rs | 65 ++++++++++------- tests/test_sequence.rs | 132 +++++++++++++++++++++------------- tests/test_various.rs | 5 +- 9 files changed, 216 insertions(+), 178 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1c14aa19..99879540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -331,10 +331,11 @@ macro_rules! wrap_pymodule { /// }); /// ``` /// -/// If you need to prepare the `locals` dict by yourself, you can pass it by `*locals`. +/// If you need to prepare the `locals` dict by yourself, you can pass it as `*locals`. /// /// ``` -/// # use pyo3::prelude::*; +/// use pyo3::prelude::*; +/// use pyo3::types::IntoPyDict; /// #[pyclass] /// struct MyClass {} /// #[pymethods] @@ -343,7 +344,7 @@ macro_rules! wrap_pymodule { /// fn new() -> Self { MyClass {} } /// } /// Python::with_gil(|py| { -/// let locals = [("C", py.get_type::())]; +/// let locals = [("C", py.get_type::())].into_py_dict(py); /// pyo3::py_run!(py, *locals, "c = C()"); /// }); /// ``` diff --git a/tests/test_buffer_protocol.rs b/tests/test_buffer_protocol.rs index c40e751b..e5831d1b 100644 --- a/tests/test_buffer_protocol.rs +++ b/tests/test_buffer_protocol.rs @@ -13,6 +13,8 @@ use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +mod common; + #[pyclass] struct TestBufferClass { vec: Vec, @@ -93,8 +95,7 @@ fn test_buffer() { ) .unwrap(); let env = [("ob", instance)].into_py_dict(py); - py.run("assert bytes(ob) == b' 23'", None, Some(env)) - .unwrap(); + py_assert!(py, *env, "bytes(ob) == b' 23'"); } assert!(drop_called.load(Ordering::Relaxed)); diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 5285950f..d4933205 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -4,7 +4,7 @@ use pyo3::class::{ }; use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PySlice, PyType}; +use pyo3::types::{PySlice, PyType}; use pyo3::{ffi, py_run, AsPyPointer, PyCell}; use std::convert::TryFrom; use std::{isize, iter}; @@ -446,11 +446,9 @@ fn test_cls_impl() { let py = gil.python(); let ob = Py::new(py, Test {}).unwrap(); - let d = [("ob", ob)].into_py_dict(py); - py.run("assert ob[1] == 'int'", None, Some(d)).unwrap(); - py.run("assert ob[100:200:1] == 'slice'", None, Some(d)) - .unwrap(); + py_assert!(py, ob, "ob[1] == 'int'"); + py_assert!(py, ob, "ob[100:200:1] == 'slice'"); } #[pyclass(dict, subclass)] diff --git a/tests/test_getter_setter.rs b/tests/test_getter_setter.rs index e3814ac7..4fe15ba7 100644 --- a/tests/test_getter_setter.rs +++ b/tests/test_getter_setter.rs @@ -60,12 +60,7 @@ fn class_with_properties() { py_run!(py, inst, "assert inst.data_list == [42]"); let d = [("C", py.get_type::())].into_py_dict(py); - py.run( - "assert C.DATA.__doc__ == 'a getter for data'", - None, - Some(d), - ) - .unwrap(); + py_assert!(py, *d, "C.DATA.__doc__ == 'a getter for data'"); } #[pyclass] diff --git a/tests/test_mapping.rs b/tests/test_mapping.rs index 2e5f42df..34e726dc 100644 --- a/tests/test_mapping.rs +++ b/tests/test_mapping.rs @@ -2,10 +2,13 @@ use std::collections::HashMap; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; +use pyo3::py_run; use pyo3::types::IntoPyDict; use pyo3::types::PyList; use pyo3::PyMappingProtocol; +mod common; + #[pyclass] struct Mapping { index: HashMap, @@ -66,35 +69,36 @@ impl PyMappingProtocol for Mapping { } } +/// Return a dict with `m = Mapping(['1', '2', '3'])`. +fn map_dict(py: Python) -> &pyo3::types::PyDict { + let d = [("Mapping", py.get_type::())].into_py_dict(py); + py_run!(py, *d, "m = Mapping(['1', '2', '3'])"); + d +} + #[test] fn test_getitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("Mapping", py.get_type::())].into_py_dict(py); + let d = map_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - - run("m = Mapping(['1', '2', '3']); assert m['1'] == 0"); - run("m = Mapping(['1', '2', '3']); assert m['2'] == 1"); - run("m = Mapping(['1', '2', '3']); assert m['3'] == 2"); - err("m = Mapping(['1', '2', '3']); print(m['4'])"); + py_assert!(py, *d, "m['1'] == 0"); + py_assert!(py, *d, "m['2'] == 1"); + py_assert!(py, *d, "m['3'] == 2"); + py_expect_exception!(py, *d, "print(m['4'])", PyKeyError); } #[test] fn test_setitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("Mapping", py.get_type::())].into_py_dict(py); + let d = map_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - - run("m = Mapping(['1', '2', '3']); m['1'] = 4; assert m['1'] == 4"); - run("m = Mapping(['1', '2', '3']); m['0'] = 0; assert m['0'] == 0"); - run("m = Mapping(['1', '2', '3']); len(m) == 4"); - err("m = Mapping(['1', '2', '3']); m[0] = 'hello'"); - err("m = Mapping(['1', '2', '3']); m[0] = -1"); + py_run!(py, *d, "m['1'] = 4; assert m['1'] == 4"); + py_run!(py, *d, "m['0'] = 0; assert m['0'] == 0"); + py_assert!(py, *d, "len(m) == 4"); + py_expect_exception!(py, *d, "m[0] = 'hello'", PyTypeError); + py_expect_exception!(py, *d, "m[0] = -1", PyTypeError); } #[test] @@ -102,16 +106,14 @@ fn test_delitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("Mapping", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - - run( - "m = Mapping(['1', '2', '3']); del m['1']; assert len(m) == 2; \ - assert m['2'] == 1; assert m['3'] == 2", + let d = map_dict(py); + py_run!( + py, + *d, + "del m['1']; assert len(m) == 2 and m['2'] == 1 and m['3'] == 2" ); - err("m = Mapping(['1', '2', '3']); del m[-1]"); - err("m = Mapping(['1', '2', '3']); del m['4']"); + py_expect_exception!(py, *d, "del m[-1]", PyTypeError); + py_expect_exception!(py, *d, "del m['4']", PyKeyError); } #[test] @@ -119,8 +121,6 @@ fn test_reversed() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("Mapping", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - - run("m = Mapping(['1', '2']); assert set(reversed(m)) == {'1', '2'}"); + let d = map_dict(py); + py_assert!(py, *d, "set(reversed(m)) == {'1', '2', '3'}"); } diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 6df0e6d3..f395f228 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -56,10 +56,8 @@ fn instance_method_with_args() { let obj = PyCell::new(py, InstanceMethodWithArgs { member: 7 }).unwrap(); let obj_ref = obj.borrow(); assert_eq!(obj_ref.method(6), 42); - let d = [("obj", obj)].into_py_dict(py); - py.run("assert obj.method(3) == 21", None, Some(d)).unwrap(); - py.run("assert obj.method(multiplier=6) == 42", None, Some(d)) - .unwrap(); + py_assert!(py, obj, "obj.method(3) == 21"); + py_assert!(py, obj, "obj.method(multiplier=6) == 42"); } #[pyclass] @@ -85,15 +83,10 @@ fn class_method() { let py = gil.python(); let d = [("C", py.get_type::())].into_py_dict(py); - let run = |code| { - py.run(code, None, Some(d)) - .map_err(|e| e.print(py)) - .unwrap() - }; - run("assert C.method() == 'ClassMethod.method()!'"); - run("assert C().method() == 'ClassMethod.method()!'"); - run("assert C.method.__doc__ == 'Test class method.'"); - run("assert C().method.__doc__ == 'Test class method.'"); + py_assert!(py, *d, "C.method() == 'ClassMethod.method()!'"); + py_assert!(py, *d, "C().method() == 'ClassMethod.method()!'"); + py_assert!(py, *d, "C.method.__doc__ == 'Test class method.'"); + py_assert!(py, *d, "C().method.__doc__ == 'Test class method.'"); } #[pyclass] @@ -113,12 +106,11 @@ fn class_method_with_args() { let py = gil.python(); let d = [("C", py.get_type::())].into_py_dict(py); - py.run( - "assert C.method('abc') == 'ClassMethodWithArgs.method(abc)'", - None, - Some(d), - ) - .unwrap(); + py_assert!( + py, + *d, + "C.method('abc') == 'ClassMethodWithArgs.method(abc)'" + ); } #[pyclass] @@ -146,15 +138,10 @@ fn static_method() { assert_eq!(StaticMethod::method(py), "StaticMethod.method()!"); let d = [("C", py.get_type::())].into_py_dict(py); - let run = |code| { - py.run(code, None, Some(d)) - .map_err(|e| e.print(py)) - .unwrap() - }; - run("assert C.method() == 'StaticMethod.method()!'"); - run("assert C().method() == 'StaticMethod.method()!'"); - run("assert C.method.__doc__ == 'Test static method.'"); - run("assert C().method.__doc__ == 'Test static method.'"); + py_assert!(py, *d, "C.method() == 'StaticMethod.method()!'"); + py_assert!(py, *d, "C().method() == 'StaticMethod.method()!'"); + py_assert!(py, *d, "C.method.__doc__ == 'Test static method.'"); + py_assert!(py, *d, "C().method.__doc__ == 'Test static method.'"); } #[pyclass] @@ -176,8 +163,7 @@ fn static_method_with_args() { assert_eq!(StaticMethodWithArgs::method(py, 1234), "0x4d2"); let d = [("C", py.get_type::())].into_py_dict(py); - py.run("assert C.method(1337) == '0x539'", None, Some(d)) - .unwrap(); + py_assert!(py, *d, "C.method(1337) == '0x539'"); } #[pyclass] @@ -449,15 +435,17 @@ fn meth_doc() { let gil = Python::acquire_gil(); let py = gil.python(); let d = [("C", py.get_type::())].into_py_dict(py); - let run = |code| { - py.run(code, None, Some(d)) - .map_err(|e| e.print(py)) - .unwrap() - }; - - run("assert C.__doc__ == 'A class with \"documentation\".'"); - run("assert C.method.__doc__ == 'A method with \"documentation\" as well.'"); - run("assert C.x.__doc__ == '`int`: a very \"important\" member of \\'this\\' instance.'"); + py_assert!(py, *d, "C.__doc__ == 'A class with \"documentation\".'"); + py_assert!( + py, + *d, + "C.method.__doc__ == 'A method with \"documentation\" as well.'" + ); + py_assert!( + py, + *d, + "C.x.__doc__ == '`int`: a very \"important\" member of \\'this\\' instance.'" + ); } #[pyclass] @@ -530,20 +518,31 @@ fn method_with_pyclassarg() { let py = gil.python(); let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); - let objs = [("obj1", obj1), ("obj2", obj2)].into_py_dict(py); - let run = |code| { - py.run(code, None, Some(objs)) - .map_err(|e| e.print(py)) - .unwrap() - }; - run("obj = obj1.add(obj2); assert obj.value == 20"); - run("obj = obj1.add_pyref(obj2); assert obj.value == 20"); - run("obj = obj1.optional_add(); assert obj.value == 20"); - run("obj = obj1.optional_add(obj2); assert obj.value == 20"); - run("obj1.inplace_add(obj2); assert obj.value == 20"); - run("obj1.inplace_add_pyref(obj2); assert obj2.value == 30"); - run("obj1.optional_inplace_add(); assert obj2.value == 30"); - run("obj1.optional_inplace_add(obj2); assert obj2.value == 40"); + let d = [("obj1", obj1), ("obj2", obj2)].into_py_dict(py); + py_run!(py, *d, "obj = obj1.add(obj2); assert obj.value == 20"); + py_run!(py, *d, "obj = obj1.add_pyref(obj2); assert obj.value == 20"); + py_run!(py, *d, "obj = obj1.optional_add(); assert obj.value == 20"); + py_run!( + py, + *d, + "obj = obj1.optional_add(obj2); assert obj.value == 20" + ); + py_run!(py, *d, "obj1.inplace_add(obj2); assert obj.value == 20"); + py_run!( + py, + *d, + "obj1.inplace_add_pyref(obj2); assert obj2.value == 30" + ); + py_run!( + py, + *d, + "obj1.optional_inplace_add(); assert obj2.value == 30" + ); + py_run!( + py, + *d, + "obj1.optional_inplace_add(obj2); assert obj2.value == 40" + ); } #[pyclass] diff --git a/tests/test_module.rs b/tests/test_module.rs index 68e01af2..20bf9c12 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -1,7 +1,7 @@ use pyo3::prelude::*; +use pyo3::py_run; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; - mod common; #[pyclass] @@ -83,25 +83,43 @@ fn test_module_with_functions() { )] .into_py_dict(py); - let run = |code| { - py.run(code, None, Some(d)) - .map_err(|e| e.print(py)) - .unwrap() - }; - - run("assert module_with_functions.__doc__ == 'This module is implemented in Rust.'"); - run("assert module_with_functions.sum_as_string(1, 2) == '3'"); - run("assert module_with_functions.no_parameters() == 42"); - run("assert module_with_functions.foo == 'bar'"); - run("assert module_with_functions.AnonClass != None"); - run("assert module_with_functions.LocatedClass != None"); - run("assert module_with_functions.LocatedClass.__module__ == 'module'"); - run("assert module_with_functions.double(3) == 6"); - run("assert module_with_functions.double.__doc__ == 'Doubles the given value'"); - run("assert module_with_functions.also_double(3) == 6"); - run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'"); - run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2"); - run("assert module_with_functions.with_module() == 'module_with_functions'"); + py_assert!( + py, + *d, + "module_with_functions.__doc__ == 'This module is implemented in Rust.'" + ); + py_assert!(py, *d, "module_with_functions.sum_as_string(1, 2) == '3'"); + py_assert!(py, *d, "module_with_functions.no_parameters() == 42"); + py_assert!(py, *d, "module_with_functions.foo == 'bar'"); + py_assert!(py, *d, "module_with_functions.AnonClass != None"); + py_assert!(py, *d, "module_with_functions.LocatedClass != None"); + py_assert!( + py, + *d, + "module_with_functions.LocatedClass.__module__ == 'module'" + ); + py_assert!(py, *d, "module_with_functions.double(3) == 6"); + py_assert!( + py, + *d, + "module_with_functions.double.__doc__ == 'Doubles the given value'" + ); + py_assert!(py, *d, "module_with_functions.also_double(3) == 6"); + py_assert!( + py, + *d, + "module_with_functions.also_double.__doc__ == 'Doubles the given value'" + ); + py_assert!( + py, + *d, + "module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2" + ); + py_assert!( + py, + *d, + "module_with_functions.with_module() == 'module_with_functions'" + ); } #[pymodule(other_name)] @@ -119,12 +137,7 @@ fn test_module_renaming() { let d = [("different_name", wrap_pymodule!(other_name)(py))].into_py_dict(py); - py.run( - "assert different_name.__name__ == 'other_name'", - None, - Some(d), - ) - .unwrap(); + py_run!(py, *d, "assert different_name.__name__ == 'other_name'"); } #[test] diff --git a/tests/test_sequence.rs b/tests/test_sequence.rs index a361c55c..0cb467ee 100644 --- a/tests/test_sequence.rs +++ b/tests/test_sequence.rs @@ -83,33 +83,35 @@ impl PySequenceProtocol for ByteSequence { } } +/// Return a dict with `s = ByteSequence([1, 2, 3])`. +fn seq_dict(py: Python) -> &pyo3::types::PyDict { + let d = [("ByteSequence", py.get_type::())].into_py_dict(py); + // Though we can construct `s` in Rust, let's test `__new__` works. + py_run!(py, *d, "s = ByteSequence([1, 2, 3])"); + d +} + #[test] fn test_getitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); + let d = seq_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - - run("s = ByteSequence([1, 2, 3]); assert s[0] == 1"); - run("s = ByteSequence([1, 2, 3]); assert s[1] == 2"); - run("s = ByteSequence([1, 2, 3]); assert s[2] == 3"); - err("s = ByteSequence([1, 2, 3]); print(s[-4])"); - err("s = ByteSequence([1, 2, 3]); print(s[4])"); + py_assert!(py, *d, "s[0] == 1"); + py_assert!(py, *d, "s[1] == 2"); + py_assert!(py, *d, "s[2] == 3"); + py_expect_exception!(py, *d, "print(s[-4])", PyIndexError); + py_expect_exception!(py, *d, "print(s[4])", PyIndexError); } #[test] fn test_setitem() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); + let d = seq_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - - run("s = ByteSequence([1, 2, 3]); s[0] = 4; assert list(s) == [4, 2, 3]"); - err("s = ByteSequence([1, 2, 3]); s[0] = 'hello'"); + py_run!(py, *d, "s[0] = 4; assert list(s) == [4, 2, 3]"); + py_expect_exception!(py, *d, "s[0] = 'hello'", PyTypeError); } #[test] @@ -118,15 +120,39 @@ fn test_delitem() { let py = gil.python(); let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - run("s = ByteSequence([1, 2, 3]); del s[0]; assert list(s) == [2, 3]"); - run("s = ByteSequence([1, 2, 3]); del s[1]; assert list(s) == [1, 3]"); - run("s = ByteSequence([1, 2, 3]); del s[-1]; assert list(s) == [1, 2]"); - run("s = ByteSequence([1, 2, 3]); del s[-2]; assert list(s) == [1, 3]"); - err("s = ByteSequence([1, 2, 3]); del s[-4]; print(list(s))"); - err("s = ByteSequence([1, 2, 3]); del s[4]"); + py_run!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[0]; assert list(s) == [2, 3]" + ); + py_run!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[1]; assert list(s) == [1, 3]" + ); + py_run!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[-1]; assert list(s) == [1, 2]" + ); + py_run!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[-2]; assert list(s) == [1, 3]" + ); + py_expect_exception!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[-4]; print(list(s))", + PyIndexError + ); + py_expect_exception!( + py, + *d, + "s = ByteSequence([1, 2, 3]); del s[4]", + PyIndexError + ); } #[test] @@ -134,14 +160,13 @@ fn test_contains() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); + let d = seq_dict(py); - run("s = ByteSequence([1, 2, 3]); assert 1 in s"); - run("s = ByteSequence([1, 2, 3]); assert 2 in s"); - run("s = ByteSequence([1, 2, 3]); assert 3 in s"); - run("s = ByteSequence([1, 2, 3]); assert 4 not in s"); - run("s = ByteSequence([1, 2, 3]); assert 'hello' not in s"); + py_assert!(py, *d, "1 in s"); + py_assert!(py, *d, "2 in s"); + py_assert!(py, *d, "3 in s"); + py_assert!(py, *d, "4 not in s"); + py_assert!(py, *d, "'hello' not in s"); } #[test] @@ -149,12 +174,19 @@ fn test_concat() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); + let d = seq_dict(py); - run("s1 = ByteSequence([1, 2]); s2 = ByteSequence([3, 4]); assert list(s1+s2) == [1, 2, 3, 4]"); - err("s1 = ByteSequence([1, 2]); s2 = 'hello'; s1 + s2"); + py_run!( + py, + *d, + "s1 = ByteSequence([1, 2]); s2 = ByteSequence([3, 4]); assert list(s1 + s2) == [1, 2, 3, 4]" + ); + py_expect_exception!( + py, + *d, + "s1 = ByteSequence([1, 2]); s2 = 'hello'; s1 + s2", + PyTypeError + ); } #[test] @@ -162,12 +194,14 @@ fn test_inplace_concat() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); + let d = seq_dict(py); - run("s = ByteSequence([1, 2]); s += ByteSequence([3, 4]); assert list(s) == [1, 2, 3, 4]"); - err("s = ByteSequence([1, 2]); s += 'hello'"); + py_run!( + py, + *d, + "s += ByteSequence([4, 5]); assert list(s) == [1, 2, 3, 4, 5]" + ); + py_expect_exception!(py, *d, "s += 'hello'", PyTypeError); } #[test] @@ -175,12 +209,10 @@ fn test_repeat() { let gil = Python::acquire_gil(); let py = gil.python(); - let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); + let d = seq_dict(py); - run("s1 = ByteSequence([1, 2, 3]); s2 = s1*2; assert list(s2) == [1, 2, 3, 1, 2, 3]"); - err("s1 = ByteSequence([1, 2, 3]); s2 = s1*-1; assert list(s2) == [1, 2, 3, 1, 2, 3]"); + py_run!(py, *d, "s2 = s * 2; assert list(s2) == [1, 2, 3, 1, 2, 3]"); + py_expect_exception!(py, *d, "s2 = s * -1", PyValueError); } #[test] @@ -189,11 +221,13 @@ fn test_inplace_repeat() { let py = gil.python(); let d = [("ByteSequence", py.get_type::())].into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); - let err = |code| py.run(code, None, Some(d)).unwrap_err(); - run("s = ByteSequence([1, 2]); s *= 3; assert list(s) == [1, 2, 1, 2, 1, 2]"); - err("s = ByteSequence([1, 2); s *= -1"); + py_run!( + py, + *d, + "s = ByteSequence([1, 2]); s *= 3; assert list(s) == [1, 2, 1, 2, 1, 2]" + ); + py_expect_exception!(py, *d, "s = ByteSequence([1, 2]); s *= -1", PyValueError); } // Check that #[pyo3(get, set)] works correctly for Vec diff --git a/tests/test_various.rs b/tests/test_various.rs index 3b9722b8..04c8dc82 100644 --- a/tests/test_various.rs +++ b/tests/test_various.rs @@ -1,5 +1,4 @@ use pyo3::prelude::*; -use pyo3::types::IntoPyDict; use pyo3::types::{PyDict, PyTuple}; use pyo3::{py_run, wrap_pyfunction, PyCell}; @@ -29,9 +28,7 @@ fn mut_ref_arg() { let inst1 = Py::new(py, MutRefArg { n: 0 }).unwrap(); let inst2 = Py::new(py, MutRefArg { n: 0 }).unwrap(); - let d = [("inst1", &inst1), ("inst2", &inst2)].into_py_dict(py); - - py.run("inst1.set_other(inst2)", None, Some(d)).unwrap(); + py_run!(py, inst1 inst2, "inst1.set_other(inst2)"); let inst2 = inst2.as_ref(py).borrow(); assert_eq!(inst2.n, 100); }