Merge pull request #3330 from davidhewitt/get-item-with-error

move PyDict::get_item_with_error to PyDict::get_item
This commit is contained in:
David Hewitt 2023-09-10 20:40:20 +00:00 committed by GitHub
commit 0ab00c7442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 427 additions and 77 deletions

View File

@ -9,6 +9,56 @@ For a detailed list of all changes, see the [CHANGELOG](changelog.md).
PyO3 0.20 has increased minimum Rust version to 1.56. This enables use of newer language features and simplifies maintenance of the project.
### `PyDict::get_item` now returns a `Result`
`PyDict::get_item` in PyO3 0.19 and older was implemented using a Python API which would suppress all exceptions and return `None` in those cases. This included errors in `__hash__` and `__eq__` implementations of the key being looked up.
Newer recommendations by the Python core developers advise against using these APIs which suppress exceptions, instead allowing exceptions to bubble upwards. `PyDict::get_item_with_error` already implemented this recommended behavior, so that API has been renamed to `PyDict::get_item`.
Before:
```rust,ignore
use pyo3::prelude::*;
use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyDict, IntoPyDict};
# fn main() {
# let _ =
Python::with_gil(|py| {
let dict: &PyDict = [("a", 1)].into_py_dict(py);
// `a` is in the dictionary, with value 1
assert!(dict.get_item("a").map_or(Ok(false), |x| x.eq(1))?);
// `b` is not in the dictionary
assert!(dict.get_item("b").is_none());
// `dict` is not hashable, so this fails with a `TypeError`
assert!(dict.get_item_with_error(dict).unwrap_err().is_instance_of::<PyTypeError>(py));
});
# }
```
After:
```rust
use pyo3::prelude::*;
use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyDict, IntoPyDict};
# fn main() {
# let _ =
Python::with_gil(|py| -> PyResult<()> {
let dict: &PyDict = [("a", 1)].into_py_dict(py);
// `a` is in the dictionary, with value 1
assert!(dict.get_item("a")?.map_or(Ok(false), |x| x.eq(1))?);
// `b` is not in the dictionary
assert!(dict.get_item("b")?.is_none());
// `dict` is not hashable, so this fails with a `TypeError`
assert!(dict.get_item(dict).unwrap_err().is_instance_of::<PyTypeError>(py));
Ok(())
});
# }
```
### Required arguments are no longer accepted after optional arguments
[Trailing `Option<T>` arguments](./function/signature.md#trailing-optional-arguments) have an automatic default of `None`. To avoid unwanted changes when modifying function signatures, in PyO3 0.18 it was deprecated to have a required argument after an `Option<T>` argument without using `#[pyo3(signature = (...))]` to specify the intended defaults. In PyO3 0.20, this becomes a hard error.

View File

@ -0,0 +1 @@
Change `PyDict::get_item` to no longer suppress arbitrary exceptions (the return type is now `PyResult<Option<&PyAny>>` instead of `Option<&PyAny>`), and deprecate `PyDict::get_item_with_error`.

View File

@ -16,7 +16,7 @@ py_dec = decimal.Decimal("0.0")
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
b.iter(|| {
let _: Decimal = black_box(py_dec).extract().unwrap();

View File

@ -32,7 +32,12 @@ fn dict_get_item(b: &mut Bencher<'_>) {
let mut sum = 0;
b.iter(|| {
for i in 0..LEN {
sum += dict.get_item(i).unwrap().extract::<usize>().unwrap();
sum += dict
.get_item(i)
.unwrap()
.unwrap()
.extract::<usize>()
.unwrap();
}
});
});

View File

@ -390,7 +390,7 @@ mod tests {
Some(locals),
)
.unwrap();
let result: PyResult<FixedOffset> = locals.get_item("zi").unwrap().extract();
let result: PyResult<FixedOffset> = locals.get_item("zi").unwrap().unwrap().extract();
assert!(result.is_err());
let res = result.err().unwrap();
// Also check the error message is what we expect

View File

@ -112,7 +112,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -126,7 +134,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -139,7 +155,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}

View File

@ -148,7 +148,15 @@ mod test_indexmap {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(
map,
py_map.extract::<indexmap::IndexMap::<i32, i32>>().unwrap()
@ -166,7 +174,15 @@ mod test_indexmap {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -179,7 +195,15 @@ mod test_indexmap {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}

View File

@ -130,7 +130,7 @@ mod test_rust_decimal {
)
.unwrap();
// Checks if Python Decimal -> Rust Decimal conversion is correct
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let py_result: Decimal = FromPyObject::extract(py_dec).unwrap();
assert_eq!(rs_orig, py_result);
})
@ -192,7 +192,7 @@ mod test_rust_decimal {
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})
@ -208,7 +208,7 @@ mod test_rust_decimal {
Some(locals),
)
.unwrap();
let py_dec = locals.get_item("py_dec").unwrap();
let py_dec = locals.get_item("py_dec").unwrap().unwrap();
let roundtripped: Result<Decimal, PyErr> = FromPyObject::extract(py_dec);
assert!(roundtripped.is_err());
})

View File

@ -122,7 +122,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -137,7 +145,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
assert_eq!(map, py_map.extract().unwrap());
});
}
@ -152,7 +168,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
@ -166,7 +190,15 @@ mod tests {
let py_map: &PyDict = m.downcast(py).unwrap();
assert!(py_map.len() == 1);
assert!(py_map.get_item(1).unwrap().extract::<i32>().unwrap() == 1);
assert!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
== 1
);
});
}
}

View File

@ -591,7 +591,7 @@ impl<'py> Python<'py> {
/// Some(locals),
/// )
/// .unwrap();
/// let ret = locals.get_item("ret").unwrap();
/// let ret = locals.get_item("ret").unwrap().unwrap();
/// let b64: &PyBytes = ret.downcast().unwrap();
/// assert_eq!(b64.as_bytes(), b"SGVsbG8gUnVzdCE=");
/// });
@ -1201,8 +1201,8 @@ mod tests {
let namespace = PyDict::new(py);
py.run("class Foo: pass", Some(namespace), Some(namespace))
.unwrap();
assert!(namespace.get_item("Foo").is_some());
assert!(namespace.get_item("__builtins__").is_some());
assert!(matches!(namespace.get_item("Foo"), Ok(Some(..))));
assert!(matches!(namespace.get_item("__builtins__"), Ok(Some(..))));
})
}
}

View File

@ -264,7 +264,14 @@ mod tests {
let dict = PyDict::new(py);
dict.set_item(foo1, 42_usize).unwrap();
assert!(dict.contains(foo2).unwrap());
assert_eq!(dict.get_item(foo3).unwrap().extract::<usize>().unwrap(), 42);
assert_eq!(
dict.get_item(foo3)
.unwrap()
.unwrap()
.extract::<usize>()
.unwrap(),
42
);
});
}

View File

@ -132,35 +132,47 @@ impl PyDict {
/// Gets an item from the dictionary.
///
/// Returns `None` if the item is not present, or if an error occurs.
/// Returns `Ok(None)` if the item is not present. To get a `KeyError` for
/// non-existing keys, use [`PyAny::get_item`].
///
/// To get a `KeyError` for non-existing keys, use `PyAny::get_item`.
pub fn get_item<K>(&self, key: K) -> Option<&PyAny>
where
K: ToPyObject,
{
fn inner(dict: &PyDict, key: PyObject) -> Option<&PyAny> {
let py = dict.py();
// PyDict_GetItem returns a borrowed ptr, must make it owned for safety (see #890).
// PyObject::from_borrowed_ptr_or_opt will take ownership in this way.
unsafe {
PyObject::from_borrowed_ptr_or_opt(
py,
ffi::PyDict_GetItem(dict.as_ptr(), key.as_ptr()),
)
}
.map(|pyobject| pyobject.into_ref(py))
}
inner(self, key.to_object(self.py()))
}
/// Gets an item from the dictionary,
/// Returns `Err(PyErr)` if Python magic methods `__hash__` or `__eq__` used in dictionary
/// lookup raise an exception, for example if the key `K` is not hashable. Usually it is
/// best to bubble this error up to the caller using the `?` operator.
///
/// returns `Ok(None)` if item is not present, or `Err(PyErr)` if an error occurs.
/// # Examples
///
/// To get a `KeyError` for non-existing keys, use `PyAny::get_item_with_error`.
pub fn get_item_with_error<K>(&self, key: K) -> PyResult<Option<&PyAny>>
/// The following example calls `get_item` for the dictionary `{"a": 1}` with various
/// keys.
/// - `get_item("a")` returns `Ok(Some(...))`, with the `PyAny` being a reference to the Python
/// int `1`.
/// - `get_item("b")` returns `Ok(None)`, because "b" is not in the dictionary.
/// - `get_item(dict)` returns an `Err(PyErr)`. The error will be a `TypeError` because a dict is not
/// hashable.
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::types::{PyDict, IntoPyDict};
/// use pyo3::exceptions::{PyTypeError, PyKeyError};
///
/// # fn main() {
/// # let _ =
/// Python::with_gil(|py| -> PyResult<()> {
/// let dict: &PyDict = [("a", 1)].into_py_dict(py);
/// // `a` is in the dictionary, with value 1
/// assert!(dict.get_item("a")?.map_or(Ok(false), |x| x.eq(1))?);
/// // `b` is not in the dictionary
/// assert!(dict.get_item("b")?.is_none());
/// // `dict` is not hashable, so this returns an error
/// assert!(dict.get_item(dict).unwrap_err().is_instance_of::<PyTypeError>(py));
///
/// // `PyAny::get_item("b")` will raise a `KeyError` instead of returning `None`
/// let any: &PyAny = dict.as_ref();
/// assert!(any.get_item("b").unwrap_err().is_instance_of::<PyKeyError>(py));
/// Ok(())
/// });
/// # }
/// ```
pub fn get_item<K>(&self, key: K) -> PyResult<Option<&PyAny>>
where
K: ToPyObject,
{
@ -182,6 +194,19 @@ impl PyDict {
inner(self, key.to_object(self.py()))
}
/// Deprecated version of `get_item`.
#[deprecated(
since = "0.20.0",
note = "this is now equivalent to `PyDict::get_item`"
)]
#[inline]
pub fn get_item_with_error<K>(&self, key: K) -> PyResult<Option<&PyAny>>
where
K: ToPyObject,
{
self.get_item(key)
}
/// Sets an item value.
///
/// This is equivalent to the Python statement `self[key] = value`.
@ -459,8 +484,15 @@ mod tests {
fn test_new() {
Python::with_gil(|py| {
let dict = [(7, 32)].into_py_dict(py);
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(dict.get_item(8i32).is_none());
assert_eq!(
32,
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(dict.get_item(8i32).unwrap().is_none());
let map: HashMap<i32, i32> = [(7, 32)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<i32, i32> = [(7, 32)].iter().cloned().collect();
@ -474,8 +506,22 @@ mod tests {
Python::with_gil(|py| {
let items = PyList::new(py, &vec![("a", 1), ("b", 2)]);
let dict = PyDict::from_sequence(py, items.to_object(py)).unwrap();
assert_eq!(1, dict.get_item("a").unwrap().extract::<i32>().unwrap());
assert_eq!(2, dict.get_item("b").unwrap().extract::<i32>().unwrap());
assert_eq!(
1,
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert_eq!(
2,
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
let map: HashMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect();
assert_eq!(map, dict.extract().unwrap());
let map: BTreeMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect();
@ -498,8 +544,16 @@ mod tests {
let dict = [(7, 32)].into_py_dict(py);
let ndict = dict.copy().unwrap();
assert_eq!(32, ndict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(ndict.get_item(8i32).is_none());
assert_eq!(
32,
ndict
.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(ndict.get_item(8i32).unwrap().is_none());
});
}
@ -536,12 +590,20 @@ mod tests {
v.insert(7, 32);
let ob = v.to_object(py);
let dict: &PyDict = ob.downcast(py).unwrap();
assert_eq!(32, dict.get_item(7i32).unwrap().extract::<i32>().unwrap());
assert!(dict.get_item(8i32).is_none());
assert_eq!(
32,
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert!(dict.get_item(8i32).unwrap().is_none());
});
}
#[test]
#[allow(deprecated)]
#[cfg(not(PyPy))]
fn test_get_item_with_error() {
Python::with_gil(|py| {
@ -576,11 +638,19 @@ mod tests {
assert!(dict.set_item(8i32, 123i32).is_ok()); // insert
assert_eq!(
42i32,
dict.get_item(7i32).unwrap().extract::<i32>().unwrap()
dict.get_item(7i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
assert_eq!(
123i32,
dict.get_item(8i32).unwrap().extract::<i32>().unwrap()
dict.get_item(8i32)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap()
);
});
}
@ -624,7 +694,7 @@ mod tests {
let dict: &PyDict = ob.downcast(py).unwrap();
assert!(dict.del_item(7i32).is_ok());
assert_eq!(0, dict.len());
assert!(dict.get_item(7i32).is_none());
assert!(dict.get_item(7i32).unwrap().is_none());
});
}
@ -841,7 +911,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}
@ -854,7 +932,15 @@ mod tests {
let py_map = map.into_py_dict(py);
assert_eq!(py_map.len(), 1);
assert_eq!(py_map.get_item(1).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(
py_map
.get_item(1)
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
});
}
@ -865,7 +951,15 @@ mod tests {
let py_map = vec.into_py_dict(py);
assert_eq!(py_map.len(), 3);
assert_eq!(py_map.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(
py_map
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
});
}
@ -876,7 +970,15 @@ mod tests {
let py_map = arr.into_py_dict(py);
assert_eq!(py_map.len(), 3);
assert_eq!(py_map.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(
py_map
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
});
}
@ -947,15 +1049,67 @@ mod tests {
let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py);
dict.update(other.as_mapping()).unwrap();
assert_eq!(dict.len(), 4);
assert_eq!(dict.get_item("a").unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(dict.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(dict.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(dict.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
assert_eq!(
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
dict.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
dict.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
assert_eq!(other.len(), 3);
assert_eq!(other.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(other.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(other.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
other
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
other
.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
other
.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
})
}
@ -966,15 +1120,67 @@ mod tests {
let other = [("b", 4), ("c", 5), ("d", 6)].into_py_dict(py);
dict.update_if_missing(other.as_mapping()).unwrap();
assert_eq!(dict.len(), 4);
assert_eq!(dict.get_item("a").unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(dict.get_item("b").unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(dict.get_item("c").unwrap().extract::<i32>().unwrap(), 3);
assert_eq!(dict.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
dict.get_item("a")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
1
);
assert_eq!(
dict.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
2
);
assert_eq!(
dict.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
3
);
assert_eq!(
dict.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
assert_eq!(other.len(), 3);
assert_eq!(other.get_item("b").unwrap().extract::<i32>().unwrap(), 4);
assert_eq!(other.get_item("c").unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(other.get_item("d").unwrap().extract::<i32>().unwrap(), 6);
assert_eq!(
other
.get_item("b")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
4
);
assert_eq!(
other
.get_item("c")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
5
);
assert_eq!(
other
.get_item("d")
.unwrap()
.unwrap()
.extract::<i32>()
.unwrap(),
6
);
})
}
}

View File

@ -97,7 +97,7 @@ except Exception as e:
Some(locals),
)
.unwrap();
let err = PyErr::from_value(locals.get_item("err").unwrap());
let err = PyErr::from_value(locals.get_item("err").unwrap().unwrap());
let traceback = err.value(py).getattr("__traceback__").unwrap();
assert!(err.traceback(py).unwrap().is(traceback));
})
@ -117,7 +117,7 @@ def f():
Some(locals),
)
.unwrap();
let f = locals.get_item("f").unwrap();
let f = locals.get_item("f").unwrap().unwrap();
let err = f.call0().unwrap_err();
let traceback = err.traceback(py).unwrap();
let err_object = err.clone_ref(py).into_py(py).into_ref(py);

View File

@ -137,6 +137,7 @@ fn add_module(py: Python<'_>, module: &PyModule) -> PyResult<()> {
.dict()
.get_item("modules")
.unwrap()
.unwrap()
.downcast::<PyDict>()?
.set_item(module.name()?, module)
}