From 8a60540e2580788147a5037d5425914291a2090c Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 19 Jul 2023 21:24:24 +0100 Subject: [PATCH] amend code for `PyDict::get_item` change --- pyo3-benches/benches/bench_decimal.rs | 2 +- pyo3-benches/benches/bench_dict.rs | 7 +- src/conversions/chrono.rs | 2 +- src/conversions/hashbrown.rs | 30 +++- src/conversions/indexmap.rs | 30 +++- src/conversions/rust_decimal.rs | 6 +- src/conversions/std/map.rs | 40 ++++- src/marker.rs | 6 +- src/sync.rs | 9 +- src/types/dict.rs | 239 ++++++++++++++++++++++---- src/types/traceback.rs | 4 +- tests/test_various.rs | 1 + 12 files changed, 325 insertions(+), 51 deletions(-) diff --git a/pyo3-benches/benches/bench_decimal.rs b/pyo3-benches/benches/bench_decimal.rs index 7a370ac3..c412c4b1 100644 --- a/pyo3-benches/benches/bench_decimal.rs +++ b/pyo3-benches/benches/bench_decimal.rs @@ -16,7 +16,7 @@ py_dec = decimal.Decimal("0.0") Some(locals), ) .unwrap(); - let py_dec = locals.get_item("py_dec").unwrap(); + let py_dec = locals.get_item("py_dec").unwrap().unwrap(); b.iter(|| { let _: Decimal = black_box(py_dec).extract().unwrap(); diff --git a/pyo3-benches/benches/bench_dict.rs b/pyo3-benches/benches/bench_dict.rs index 64398a65..62fd8820 100644 --- a/pyo3-benches/benches/bench_dict.rs +++ b/pyo3-benches/benches/bench_dict.rs @@ -32,7 +32,12 @@ fn dict_get_item(b: &mut Bencher<'_>) { let mut sum = 0; b.iter(|| { for i in 0..LEN { - sum += dict.get_item(i).unwrap().extract::().unwrap(); + sum += dict + .get_item(i) + .unwrap() + .unwrap() + .extract::() + .unwrap(); } }); }); diff --git a/src/conversions/chrono.rs b/src/conversions/chrono.rs index 041fa8d8..1554898e 100644 --- a/src/conversions/chrono.rs +++ b/src/conversions/chrono.rs @@ -390,7 +390,7 @@ mod tests { Some(locals), ) .unwrap(); - let result: PyResult = locals.get_item("zi").unwrap().extract(); + let result: PyResult = locals.get_item("zi").unwrap().unwrap().extract(); assert!(result.is_err()); let res = result.err().unwrap(); // Also check the error message is what we expect diff --git a/src/conversions/hashbrown.rs b/src/conversions/hashbrown.rs index d80e93d5..6e20db39 100644 --- a/src/conversions/hashbrown.rs +++ b/src/conversions/hashbrown.rs @@ -112,7 +112,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); assert_eq!(map, py_map.extract().unwrap()); }); } @@ -126,7 +134,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); }); } @@ -139,7 +155,15 @@ mod tests { let py_map = map.into_py_dict(py); assert_eq!(py_map.len(), 1); - assert_eq!(py_map.get_item(1).unwrap().extract::().unwrap(), 1); + assert_eq!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); }); } diff --git a/src/conversions/indexmap.rs b/src/conversions/indexmap.rs index 27324cbb..7c7303e6 100644 --- a/src/conversions/indexmap.rs +++ b/src/conversions/indexmap.rs @@ -148,7 +148,15 @@ mod test_indexmap { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); assert_eq!( map, py_map.extract::>().unwrap() @@ -166,7 +174,15 @@ mod test_indexmap { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); }); } @@ -179,7 +195,15 @@ mod test_indexmap { let py_map = map.into_py_dict(py); assert_eq!(py_map.len(), 1); - assert_eq!(py_map.get_item(1).unwrap().extract::().unwrap(), 1); + assert_eq!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); }); } diff --git a/src/conversions/rust_decimal.rs b/src/conversions/rust_decimal.rs index 9ee9c0ef..cdc8fd05 100644 --- a/src/conversions/rust_decimal.rs +++ b/src/conversions/rust_decimal.rs @@ -130,7 +130,7 @@ mod test_rust_decimal { ) .unwrap(); // Checks if Python Decimal -> Rust Decimal conversion is correct - let py_dec = locals.get_item("py_dec").unwrap(); + let py_dec = locals.get_item("py_dec").unwrap().unwrap(); let py_result: Decimal = FromPyObject::extract(py_dec).unwrap(); assert_eq!(rs_orig, py_result); }) @@ -192,7 +192,7 @@ mod test_rust_decimal { Some(locals), ) .unwrap(); - let py_dec = locals.get_item("py_dec").unwrap(); + let py_dec = locals.get_item("py_dec").unwrap().unwrap(); let roundtripped: Result = FromPyObject::extract(py_dec); assert!(roundtripped.is_err()); }) @@ -208,7 +208,7 @@ mod test_rust_decimal { Some(locals), ) .unwrap(); - let py_dec = locals.get_item("py_dec").unwrap(); + let py_dec = locals.get_item("py_dec").unwrap().unwrap(); let roundtripped: Result = FromPyObject::extract(py_dec); assert!(roundtripped.is_err()); }) diff --git a/src/conversions/std/map.rs b/src/conversions/std/map.rs index f79b415b..a53b0ce9 100644 --- a/src/conversions/std/map.rs +++ b/src/conversions/std/map.rs @@ -122,7 +122,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); assert_eq!(map, py_map.extract().unwrap()); }); } @@ -137,7 +145,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); assert_eq!(map, py_map.extract().unwrap()); }); } @@ -152,7 +168,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); }); } @@ -166,7 +190,15 @@ mod tests { let py_map: &PyDict = m.downcast(py).unwrap(); assert!(py_map.len() == 1); - assert!(py_map.get_item(1).unwrap().extract::().unwrap() == 1); + assert!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap() + == 1 + ); }); } } diff --git a/src/marker.rs b/src/marker.rs index 2d130ae4..55b48119 100644 --- a/src/marker.rs +++ b/src/marker.rs @@ -591,7 +591,7 @@ impl<'py> Python<'py> { /// Some(locals), /// ) /// .unwrap(); - /// let ret = locals.get_item("ret").unwrap(); + /// let ret = locals.get_item("ret").unwrap().unwrap(); /// let b64: &PyBytes = ret.downcast().unwrap(); /// assert_eq!(b64.as_bytes(), b"SGVsbG8gUnVzdCE="); /// }); @@ -1201,8 +1201,8 @@ mod tests { let namespace = PyDict::new(py); py.run("class Foo: pass", Some(namespace), Some(namespace)) .unwrap(); - assert!(namespace.get_item("Foo").is_some()); - assert!(namespace.get_item("__builtins__").is_some()); + assert!(matches!(namespace.get_item("Foo"), Ok(Some(..)))); + assert!(matches!(namespace.get_item("__builtins__"), Ok(Some(..)))); }) } } diff --git a/src/sync.rs b/src/sync.rs index 3cb4206d..50bb80da 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -264,7 +264,14 @@ mod tests { let dict = PyDict::new(py); dict.set_item(foo1, 42_usize).unwrap(); assert!(dict.contains(foo2).unwrap()); - assert_eq!(dict.get_item(foo3).unwrap().extract::().unwrap(), 42); + assert_eq!( + dict.get_item(foo3) + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 42 + ); }); } diff --git a/src/types/dict.rs b/src/types/dict.rs index 8664b450..79120299 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -447,8 +447,15 @@ mod tests { fn test_new() { Python::with_gil(|py| { let dict = [(7, 32)].into_py_dict(py); - assert_eq!(32, dict.get_item(7i32).unwrap().extract::().unwrap()); - assert!(dict.get_item(8i32).is_none()); + assert_eq!( + 32, + dict.get_item(7i32) + .unwrap() + .unwrap() + .extract::() + .unwrap() + ); + assert!(dict.get_item(8i32).unwrap().is_none()); let map: HashMap = [(7, 32)].iter().cloned().collect(); assert_eq!(map, dict.extract().unwrap()); let map: BTreeMap = [(7, 32)].iter().cloned().collect(); @@ -462,8 +469,22 @@ mod tests { Python::with_gil(|py| { let items = PyList::new(py, &vec![("a", 1), ("b", 2)]); let dict = PyDict::from_sequence(py, items.to_object(py)).unwrap(); - assert_eq!(1, dict.get_item("a").unwrap().extract::().unwrap()); - assert_eq!(2, dict.get_item("b").unwrap().extract::().unwrap()); + assert_eq!( + 1, + dict.get_item("a") + .unwrap() + .unwrap() + .extract::() + .unwrap() + ); + assert_eq!( + 2, + dict.get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap() + ); let map: HashMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect(); assert_eq!(map, dict.extract().unwrap()); let map: BTreeMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect(); @@ -486,8 +507,16 @@ mod tests { let dict = [(7, 32)].into_py_dict(py); let ndict = dict.copy().unwrap(); - assert_eq!(32, ndict.get_item(7i32).unwrap().extract::().unwrap()); - assert!(ndict.get_item(8i32).is_none()); + assert_eq!( + 32, + ndict + .get_item(7i32) + .unwrap() + .unwrap() + .extract::() + .unwrap() + ); + assert!(ndict.get_item(8i32).unwrap().is_none()); }); } @@ -524,12 +553,20 @@ mod tests { v.insert(7, 32); let ob = v.to_object(py); let dict: &PyDict = ob.downcast(py).unwrap(); - assert_eq!(32, dict.get_item(7i32).unwrap().extract::().unwrap()); - assert!(dict.get_item(8i32).is_none()); + assert_eq!( + 32, + dict.get_item(7i32) + .unwrap() + .unwrap() + .extract::() + .unwrap() + ); + assert!(dict.get_item(8i32).unwrap().is_none()); }); } #[test] + #[allow(deprecated)] #[cfg(not(PyPy))] fn test_get_item_with_error() { Python::with_gil(|py| { @@ -564,11 +601,19 @@ mod tests { assert!(dict.set_item(8i32, 123i32).is_ok()); // insert assert_eq!( 42i32, - dict.get_item(7i32).unwrap().extract::().unwrap() + dict.get_item(7i32) + .unwrap() + .unwrap() + .extract::() + .unwrap() ); assert_eq!( 123i32, - dict.get_item(8i32).unwrap().extract::().unwrap() + dict.get_item(8i32) + .unwrap() + .unwrap() + .extract::() + .unwrap() ); }); } @@ -612,7 +657,7 @@ mod tests { let dict: &PyDict = ob.downcast(py).unwrap(); assert!(dict.del_item(7i32).is_ok()); assert_eq!(0, dict.len()); - assert!(dict.get_item(7i32).is_none()); + assert!(dict.get_item(7i32).unwrap().is_none()); }); } @@ -829,7 +874,15 @@ mod tests { let py_map = map.into_py_dict(py); assert_eq!(py_map.len(), 1); - assert_eq!(py_map.get_item(1).unwrap().extract::().unwrap(), 1); + assert_eq!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); }); } @@ -842,7 +895,15 @@ mod tests { let py_map = map.into_py_dict(py); assert_eq!(py_map.len(), 1); - assert_eq!(py_map.get_item(1).unwrap().extract::().unwrap(), 1); + assert_eq!( + py_map + .get_item(1) + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); }); } @@ -853,7 +914,15 @@ mod tests { let py_map = vec.into_py_dict(py); assert_eq!(py_map.len(), 3); - assert_eq!(py_map.get_item("b").unwrap().extract::().unwrap(), 2); + assert_eq!( + py_map + .get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 2 + ); }); } @@ -864,7 +933,15 @@ mod tests { let py_map = arr.into_py_dict(py); assert_eq!(py_map.len(), 3); - assert_eq!(py_map.get_item("b").unwrap().extract::().unwrap(), 2); + assert_eq!( + py_map + .get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 2 + ); }); } @@ -935,15 +1012,67 @@ mod tests { let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py); dict.update(other.as_mapping()).unwrap(); assert_eq!(dict.len(), 4); - assert_eq!(dict.get_item("a").unwrap().extract::().unwrap(), 1); - assert_eq!(dict.get_item("b").unwrap().extract::().unwrap(), 4); - assert_eq!(dict.get_item("c").unwrap().extract::().unwrap(), 5); - assert_eq!(dict.get_item("d").unwrap().extract::().unwrap(), 6); + assert_eq!( + dict.get_item("a") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); + assert_eq!( + dict.get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 4 + ); + assert_eq!( + dict.get_item("c") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 5 + ); + assert_eq!( + dict.get_item("d") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 6 + ); assert_eq!(other.len(), 3); - assert_eq!(other.get_item("b").unwrap().extract::().unwrap(), 4); - assert_eq!(other.get_item("c").unwrap().extract::().unwrap(), 5); - assert_eq!(other.get_item("d").unwrap().extract::().unwrap(), 6); + assert_eq!( + other + .get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 4 + ); + assert_eq!( + other + .get_item("c") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 5 + ); + assert_eq!( + other + .get_item("d") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 6 + ); }) } @@ -954,15 +1083,67 @@ mod tests { let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py); dict.update_if_missing(other.as_mapping()).unwrap(); assert_eq!(dict.len(), 4); - assert_eq!(dict.get_item("a").unwrap().extract::().unwrap(), 1); - assert_eq!(dict.get_item("b").unwrap().extract::().unwrap(), 2); - assert_eq!(dict.get_item("c").unwrap().extract::().unwrap(), 3); - assert_eq!(dict.get_item("d").unwrap().extract::().unwrap(), 6); + assert_eq!( + dict.get_item("a") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 1 + ); + assert_eq!( + dict.get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 2 + ); + assert_eq!( + dict.get_item("c") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 3 + ); + assert_eq!( + dict.get_item("d") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 6 + ); assert_eq!(other.len(), 3); - assert_eq!(other.get_item("b").unwrap().extract::().unwrap(), 4); - assert_eq!(other.get_item("c").unwrap().extract::().unwrap(), 5); - assert_eq!(other.get_item("d").unwrap().extract::().unwrap(), 6); + assert_eq!( + other + .get_item("b") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 4 + ); + assert_eq!( + other + .get_item("c") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 5 + ); + assert_eq!( + other + .get_item("d") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + 6 + ); }) } } diff --git a/src/types/traceback.rs b/src/types/traceback.rs index 67e0e564..b9909435 100644 --- a/src/types/traceback.rs +++ b/src/types/traceback.rs @@ -97,7 +97,7 @@ except Exception as e: Some(locals), ) .unwrap(); - let err = PyErr::from_value(locals.get_item("err").unwrap()); + let err = PyErr::from_value(locals.get_item("err").unwrap().unwrap()); let traceback = err.value(py).getattr("__traceback__").unwrap(); assert!(err.traceback(py).unwrap().is(traceback)); }) @@ -117,7 +117,7 @@ def f(): Some(locals), ) .unwrap(); - let f = locals.get_item("f").unwrap(); + let f = locals.get_item("f").unwrap().unwrap(); let err = f.call0().unwrap_err(); let traceback = err.traceback(py).unwrap(); let err_object = err.clone_ref(py).into_py(py).into_ref(py); diff --git a/tests/test_various.rs b/tests/test_various.rs index 706c9b8d..976e61fc 100644 --- a/tests/test_various.rs +++ b/tests/test_various.rs @@ -137,6 +137,7 @@ fn add_module(py: Python<'_>, module: &PyModule) -> PyResult<()> { .dict() .get_item("modules") .unwrap() + .unwrap() .downcast::()? .set_item(module.name()?, module) }