amend code for `PyDict::get_item` change

This commit is contained in:
David Hewitt 2023-07-19 21:24:24 +01:00
parent 16728c4da2
commit 8a60540e25
12 changed files with 325 additions and 51 deletions

View File

@ -16,7 +16,7 @@ py_dec = decimal.Decimal("0.0")
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
b.iter(|| {
let _: Decimal = black_box(py_dec).extract().unwrap();

View File

@ -32,7 +32,12 @@ fn dict_get_item(b: &mut Bencher<'_>) {
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += dict.get_item(i).unwrap().extract::<usize>().unwrap();
sum += dict
.get_item(i)
.unwrap()
.unwrap()
.extract::<usize>()
.unwrap();
}
});
});

View File

@ -390,7 +390,7 @@ mod tests {
Some(locals),
)
.unwrap();
let result: PyResult<FixedOffset> = locals.get_item("zi").unwrap().extract();
let result: PyResult<FixedOffset> = locals.get_item("zi").unwrap().unwrap().extract();
assert!(result.is_err());
let res = result.err().unwrap();
// Also check the error message is what we expect

View File

@ -112,7 +112,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -126,7 +134,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -139,7 +155,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}

View File

@ -148,7 +148,15 @@ mod test_indexmap {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(
map,
py_map.extract::<indexmap::IndexMap::<i32, i32>>().unwrap()
@ -166,7 +174,15 @@ mod test_indexmap {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -179,7 +195,15 @@ mod test_indexmap {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}

View File

@ -130,7 +130,7 @@ mod test_rust_decimal {
)
.unwrap();
// Checks if Python Decimal -> Rust Decimal conversion is correct
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let py_result: Decimal = FromPyObject::extract(py_dec).unwrap();
assert_eq!(rs_orig, py_result);
})
@ -192,7 +192,7 @@ mod test_rust_decimal {
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
@ -208,7 +208,7 @@ mod test_rust_decimal {
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})

View File

@ -122,7 +122,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -137,7 +145,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -152,7 +168,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -166,7 +190,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
}

View File

@ -591,7 +591,7 @@ impl<'py> Python<'py> {
/// Some(locals),
/// )
/// .unwrap();
/// let ret = locals.get_item("ret").unwrap();
/// let ret = locals.get_item("ret").unwrap().unwrap();
/// let b64: &PyBytes = ret.downcast().unwrap();
/// assert_eq!(b64.as_bytes(), b"SGVsbG8gUnVzdCE=");
/// });
@ -1201,8 +1201,8 @@ mod tests {
let namespace = PyDict::new(py);
py.run("class Foo: pass", Some(namespace), Some(namespace))
.unwrap();
assert!(namespace.get_item("Foo").is_some());
assert!(namespace.get_item("__builtins__").is_some());
assert!(matches!(namespace.get_item("Foo"), Ok(Some(..))));
assert!(matches!(namespace.get_item("__builtins__"), Ok(Some(..))));
})
}
}

View File

@ -264,7 +264,14 @@ mod tests {
let dict = PyDict::new(py);
dict.set_item(foo1, 42_usize).unwrap();
assert!(dict.contains(foo2).unwrap());
assert_eq!(dict.get_item(foo3).unwrap().extract::<usize>().unwrap(), 42);
assert_eq!(
dict.get_item(foo3)
.unwrap()
.unwrap()
.extract::<usize>()
.unwrap(),
42
);
});
}

View File

@ -447,8 +447,15 @@ mod tests {
fn test_new() {
Python::with_gil(|py| {
let dict = [(7, 32)].into_py_dict(py);
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(dict.get_item(8i32).is_none());
assert_eq!(
32,
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(dict.get_item(8i32).unwrap().is_none());
let map: HashMap<i32, i32> = [(7, 32)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<i32, i32> = [(7, 32)].iter().cloned().collect();
@ -462,8 +469,22 @@ mod tests {
Python::with_gil(|py| {
let items = PyList::new(py, &vec![("a", 1), ("b", 2)]);
let dict = PyDict::from_sequence(py, items.to_object(py)).unwrap();
assert_eq!(1, dict.get_item("a").unwrap().extract::<i32>().unwrap());
assert_eq!(2, dict.get_item("b").unwrap().extract::<i32>().unwrap());
assert_eq!(
1,
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert_eq!(
2,
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
let map: HashMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect();
@ -486,8 +507,16 @@ mod tests {
let dict = [(7, 32)].into_py_dict(py);
let ndict = dict.copy().unwrap();
assert_eq!(32, ndict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(ndict.get_item(8i32).is_none());
assert_eq!(
32,
ndict
.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(ndict.get_item(8i32).unwrap().is_none());
});
}
@ -524,12 +553,20 @@ mod tests {
v.insert(7, 32);
let ob = v.to_object(py);
let dict: &PyDict = ob.downcast(py).unwrap();
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(dict.get_item(8i32).is_none());
assert_eq!(
32,
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(dict.get_item(8i32).unwrap().is_none());
});
}
#[test]
#[allow(deprecated)]
#[cfg(not(PyPy))]
fn test_get_item_with_error() {
Python::with_gil(|py| {
@ -564,11 +601,19 @@ mod tests {
assert!(dict.set_item(8i32, 123i32).is_ok()); // insert
assert_eq!(
42i32,
dict.get_item(7i32).unwrap().extract::<i32>().unwrap()
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert_eq!(
123i32,
dict.get_item(8i32).unwrap().extract::<i32>().unwrap()
dict.get_item(8i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
});
}
@ -612,7 +657,7 @@ mod tests {
let dict: &PyDict = ob.downcast(py).unwrap();
assert!(dict.del_item(7i32).is_ok());
assert_eq!(0, dict.len());
assert!(dict.get_item(7i32).is_none());
assert!(dict.get_item(7i32).unwrap().is_none());
});
}
@ -829,7 +874,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}
@ -842,7 +895,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}
@ -853,7 +914,15 @@ mod tests {
let py_map = vec.into_py_dict(py);
assert_eq!(py_map.len(), 3);
assert_eq!(py_map.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(
py_map
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
});
}
@ -864,7 +933,15 @@ mod tests {
let py_map = arr.into_py_dict(py);
assert_eq!(py_map.len(), 3);
assert_eq!(py_map.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(
py_map
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
});
}
@ -935,15 +1012,67 @@ mod tests {
let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py);
dict.update(other.as_mapping()).unwrap();
assert_eq!(dict.len(), 4);
assert_eq!(dict.get_item("a").unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(dict.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(dict.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(dict.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
assert_eq!(
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
dict.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
dict.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
assert_eq!(other.len(), 3);
assert_eq!(other.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(other.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(other.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
other
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
other
.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
other
.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
})
}
@ -954,15 +1083,67 @@ mod tests {
let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py);
dict.update_if_missing(other.as_mapping()).unwrap();
assert_eq!(dict.len(), 4);
assert_eq!(dict.get_item("a").unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(dict.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(dict.get_item("c").unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(dict.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
assert_eq!(
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
assert_eq!(
dict.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
3
);
assert_eq!(
dict.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
assert_eq!(other.len(), 3);
assert_eq!(other.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(other.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(other.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
other
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
other
.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
other
.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
})
}
}

View File

@ -97,7 +97,7 @@ except Exception as e:
Some(locals),
)
.unwrap();
let err = PyErr::from_value(locals.get_item("err").unwrap());
let err = PyErr::from_value(locals.get_item("err").unwrap().unwrap());
let traceback = err.value(py).getattr("__traceback__").unwrap();
assert!(err.traceback(py).unwrap().is(traceback));
})
@ -117,7 +117,7 @@ def f():
Some(locals),
)
.unwrap();
let f = locals.get_item("f").unwrap();
let f = locals.get_item("f").unwrap().unwrap();
let err = f.call0().unwrap_err();
let traceback = err.traceback(py).unwrap();
let err_object = err.clone_ref(py).into_py(py).into_ref(py);

View File

@ -137,6 +137,7 @@ fn add_module(py: Python<'_>, module: &PyModule) -> PyResult<()> {
.dict()
.get_item("modules")
.unwrap()
.unwrap()
.downcast::<PyDict>()?
.set_item(module.name()?, module)
}