diff --git a/CHANGELOG.md b/CHANGELOG.md index afb180e7..17715a56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add an experimental `generate-abi3-import-lib` feature to auto-generate `python3.dll` import libraries for Windows. [#2282](https://github.com/PyO3/pyo3/pull/2282) - Add FFI definitions for `PyDateTime_BaseTime` and `PyDateTime_BaseDateTime`. [#2294](https://github.com/PyO3/pyo3/pull/2294) +- Added `PyTzInfoAccess` for safe access to time zone information. [#2263](https://github.com/PyO3/pyo3/pull/2263) ### Changed diff --git a/pytests/src/datetime.rs b/pytests/src/datetime.rs index f526ae0a..b21d3e69 100644 --- a/pytests/src/datetime.rs +++ b/pytests/src/datetime.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::{ PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTuple, - PyTzInfo, + PyTzInfo, PyTzInfoAccess, }; #[pyfunction] @@ -179,6 +179,16 @@ fn datetime_from_timestamp<'p>( PyDateTime::from_timestamp(py, ts, tz) } +#[pyfunction] +fn get_datetime_tzinfo(dt: &PyDateTime) -> Option<&PyTzInfo> { + dt.get_tzinfo() +} + +#[pyfunction] +fn get_time_tzinfo(dt: &PyTime) -> Option<&PyTzInfo> { + dt.get_tzinfo() +} + #[pyclass(extends=PyTzInfo)] pub struct TzClass {} @@ -214,6 +224,8 @@ pub fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(make_datetime, m)?)?; m.add_function(wrap_pyfunction!(get_datetime_tuple, m)?)?; m.add_function(wrap_pyfunction!(datetime_from_timestamp, m)?)?; + m.add_function(wrap_pyfunction!(get_datetime_tzinfo, m)?)?; + m.add_function(wrap_pyfunction!(get_time_tzinfo, m)?)?; // Functions not supported by PyPy #[cfg(not(PyPy))] diff --git a/pytests/tests/test_datetime.py b/pytests/tests/test_datetime.py index e70504e7..d4c1b60e 100644 --- a/pytests/tests/test_datetime.py +++ b/pytests/tests/test_datetime.py @@ -114,6 +114,7 @@ def test_time(args, kwargs): assert act == exp assert act.tzinfo is exp.tzinfo + assert rdt.get_time_tzinfo(act) == exp.tzinfo @given(t=st.times()) @@ -194,6 +195,7 @@ def test_datetime(args, kwargs): assert act == exp assert act.tzinfo is exp.tzinfo + assert rdt.get_datetime_tzinfo(act) == exp.tzinfo @given(dt=st.datetimes()) diff --git a/src/types/datetime.rs b/src/types/datetime.rs index c4053011..90782774 100644 --- a/src/types/datetime.rs +++ b/src/types/datetime.rs @@ -4,9 +4,8 @@ //! documentation](https://docs.python.org/3/library/datetime.html) use crate::err::PyResult; -use crate::ffi; use crate::ffi::{ - PyDateTime_CAPI, PyDateTime_FromTimestamp, PyDateTime_IMPORT, PyDate_FromTimestamp, + self, PyDateTime_CAPI, PyDateTime_FromTimestamp, PyDateTime_IMPORT, PyDate_FromTimestamp, }; #[cfg(not(PyPy))] use crate::ffi::{PyDateTime_DATE_GET_FOLD, PyDateTime_TIME_GET_FOLD}; @@ -22,6 +21,7 @@ use crate::ffi::{ PyDateTime_TIME_GET_HOUR, PyDateTime_TIME_GET_MICROSECOND, PyDateTime_TIME_GET_MINUTE, PyDateTime_TIME_GET_SECOND, }; +use crate::instance::PyNativeType; use crate::types::PyTuple; use crate::{AsPyPointer, PyAny, PyObject, Python, ToPyObject}; use std::os::raw::c_int; @@ -160,6 +160,16 @@ pub trait PyTimeAccess { fn get_fold(&self) -> bool; } +/// Trait for accessing the components of a struct containing a tzinfo. +pub trait PyTzInfoAccess { + /// Returns the tzinfo (which may be None). + /// + /// Implementations should conform to the upstream documentation: + /// + /// + fn get_tzinfo(&self) -> Option<&PyTzInfo>; +} + /// Bindings around `datetime.date` #[repr(transparent)] pub struct PyDate(PyAny); @@ -354,6 +364,19 @@ impl PyTimeAccess for PyDateTime { } } +impl PyTzInfoAccess for PyDateTime { + fn get_tzinfo(&self) -> Option<&PyTzInfo> { + let ptr = self.as_ptr() as *mut ffi::PyDateTime_DateTime; + unsafe { + if (*ptr).hastzinfo != 0 { + Some(self.py().from_borrowed_ptr((*ptr).tzinfo)) + } else { + None + } + } + } +} + /// Bindings for `datetime.time` #[repr(transparent)] pub struct PyTime(PyAny); @@ -439,6 +462,19 @@ impl PyTimeAccess for PyTime { } } +impl PyTzInfoAccess for PyTime { + fn get_tzinfo(&self) -> Option<&PyTzInfo> { + let ptr = self.as_ptr() as *mut ffi::PyDateTime_Time; + unsafe { + if (*ptr).hastzinfo != 0 { + Some(self.py().from_borrowed_ptr((*ptr).tzinfo)) + } else { + None + } + } + } +} + /// Bindings for `datetime.tzinfo` /// /// This is an abstract base class and should not be constructed directly. @@ -524,4 +560,33 @@ mod tests { assert!(b.unwrap().get_fold()); }); } + + #[cfg(not(PyPy))] + #[test] + fn test_get_tzinfo() { + crate::Python::with_gil(|py| { + use crate::conversion::ToPyObject; + use crate::types::{PyDateTime, PyTime, PyTzInfoAccess}; + + let datetime = py.import("datetime").map_err(|e| e.print(py)).unwrap(); + let timezone = datetime.getattr("timezone").unwrap(); + let utc = timezone.getattr("utc").unwrap().to_object(py); + + let dt = PyDateTime::new(py, 2018, 1, 1, 0, 0, 0, 0, Some(&utc)).unwrap(); + + assert!(dt.get_tzinfo().unwrap().eq(&utc).unwrap()); + + let dt = PyDateTime::new(py, 2018, 1, 1, 0, 0, 0, 0, None).unwrap(); + + assert!(dt.get_tzinfo().is_none()); + + let t = PyTime::new(py, 0, 0, 0, 0, Some(&utc)).unwrap(); + + assert!(t.get_tzinfo().unwrap().eq(&utc).unwrap()); + + let t = PyTime::new(py, 0, 0, 0, 0, None).unwrap(); + + assert!(t.get_tzinfo().is_none()); + }); + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 192fa80a..ad79d062 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -11,6 +11,7 @@ pub use self::complex::PyComplex; #[cfg(not(Py_LIMITED_API))] pub use self::datetime::{ PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTzInfo, + PyTzInfoAccess, }; pub use self::dict::{IntoPyDict, PyDict}; pub use self::floatob::PyFloat;