From 5b1104131f7aba50153bb19470ad4b073b3241fd Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Sat, 10 Feb 2024 13:57:20 +0000 Subject: [PATCH] fix segmentation fault when `datetime` module is invalid --- newsfragments/3818.fixed.md | 1 + src/types/datetime.rs | 56 ++++++++++++++++++----------------- tests/test_datetime_import.rs | 26 ++++++++++++++++ 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 newsfragments/3818.fixed.md create mode 100644 tests/test_datetime_import.rs diff --git a/newsfragments/3818.fixed.md b/newsfragments/3818.fixed.md new file mode 100644 index 00000000..76fe01a5 --- /dev/null +++ b/newsfragments/3818.fixed.md @@ -0,0 +1 @@ +Fix segmentation fault using `datetime` types when an invalid `datetime` module is on sys.path. diff --git a/src/types/datetime.rs b/src/types/datetime.rs index 354414b8..088e37d9 100644 --- a/src/types/datetime.rs +++ b/src/types/datetime.rs @@ -23,21 +23,27 @@ use crate::ffi_ptr_ext::FfiPtrExt; use crate::instance::PyNativeType; use crate::types::any::PyAnyMethods; use crate::types::PyTuple; -use crate::{Bound, IntoPy, Py, PyAny, Python}; +use crate::{Bound, IntoPy, Py, PyAny, PyErr, Python}; use std::os::raw::c_int; #[cfg(feature = "chrono")] use std::ptr; -fn ensure_datetime_api(_py: Python<'_>) -> &'static PyDateTime_CAPI { - unsafe { - if pyo3_ffi::PyDateTimeAPI().is_null() { - PyDateTime_IMPORT() +fn ensure_datetime_api(py: Python<'_>) -> PyResult<&'static PyDateTime_CAPI> { + if let Some(api) = unsafe { pyo3_ffi::PyDateTimeAPI().as_ref() } { + Ok(api) + } else { + unsafe { + PyDateTime_IMPORT(); + pyo3_ffi::PyDateTimeAPI().as_ref() } - - &*pyo3_ffi::PyDateTimeAPI() + .ok_or_else(|| PyErr::fetch(py)) } } +fn expect_datetime_api(py: Python<'_>) -> &'static PyDateTime_CAPI { + ensure_datetime_api(py).expect("failed to import `datetime` C API") +} + // Type Check macros // // These are bindings around the C API typecheck macros, all of them return @@ -189,7 +195,7 @@ pub struct PyDate(PyAny); pyobject_native_type!( PyDate, crate::ffi::PyDateTime_Date, - |py| ensure_datetime_api(py).DateType, + |py| expect_datetime_api(py).DateType, #module=Some("datetime"), #checkfunction=PyDate_Check ); @@ -197,13 +203,9 @@ pyobject_native_type!( impl PyDate { /// Creates a new `datetime.date`. pub fn new(py: Python<'_>, year: i32, month: u8, day: u8) -> PyResult<&PyDate> { + let api = ensure_datetime_api(py)?; unsafe { - let ptr = (ensure_datetime_api(py).Date_FromDate)( - year, - c_int::from(month), - c_int::from(day), - ensure_datetime_api(py).DateType, - ); + let ptr = (api.Date_FromDate)(year, c_int::from(month), c_int::from(day), api.DateType); py.from_owned_ptr_or_err(ptr) } } @@ -215,7 +217,7 @@ impl PyDate { let time_tuple = PyTuple::new_bound(py, [timestamp]); // safety ensure that the API is loaded - let _api = ensure_datetime_api(py); + let _api = ensure_datetime_api(py)?; unsafe { let ptr = PyDate_FromTimestamp(time_tuple.as_ptr()); @@ -258,7 +260,7 @@ pub struct PyDateTime(PyAny); pyobject_native_type!( PyDateTime, crate::ffi::PyDateTime_DateTime, - |py| ensure_datetime_api(py).DateTimeType, + |py| expect_datetime_api(py).DateTimeType, #module=Some("datetime"), #checkfunction=PyDateTime_Check ); @@ -277,7 +279,7 @@ impl PyDateTime { microsecond: u32, tzinfo: Option<&PyTzInfo>, ) -> PyResult<&'p PyDateTime> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.DateTime_FromDateAndTime)( year, @@ -314,7 +316,7 @@ impl PyDateTime { tzinfo: Option<&PyTzInfo>, fold: bool, ) -> PyResult<&'p PyDateTime> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.DateTime_FromDateAndTimeAndFold)( year, @@ -343,7 +345,7 @@ impl PyDateTime { let args: Py = (timestamp, tzinfo).into_py(py); // safety ensure API is loaded - let _api = ensure_datetime_api(py); + let _api = ensure_datetime_api(py)?; unsafe { let ptr = PyDateTime_FromTimestamp(args.as_ptr()); @@ -455,7 +457,7 @@ pub struct PyTime(PyAny); pyobject_native_type!( PyTime, crate::ffi::PyDateTime_Time, - |py| ensure_datetime_api(py).TimeType, + |py| expect_datetime_api(py).TimeType, #module=Some("datetime"), #checkfunction=PyTime_Check ); @@ -470,7 +472,7 @@ impl PyTime { microsecond: u32, tzinfo: Option<&PyTzInfo>, ) -> PyResult<&'p PyTime> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.Time_FromTime)( c_int::from(hour), @@ -494,7 +496,7 @@ impl PyTime { tzinfo: Option<&PyTzInfo>, fold: bool, ) -> PyResult<&'p PyTime> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.Time_FromTimeAndFold)( c_int::from(hour), @@ -589,14 +591,14 @@ pub struct PyTzInfo(PyAny); pyobject_native_type!( PyTzInfo, crate::ffi::PyObject, - |py| ensure_datetime_api(py).TZInfoType, + |py| expect_datetime_api(py).TZInfoType, #module=Some("datetime"), #checkfunction=PyTZInfo_Check ); /// Equivalent to `datetime.timezone.utc` pub fn timezone_utc(py: Python<'_>) -> &PyTzInfo { - unsafe { &*(ensure_datetime_api(py).TimeZone_UTC as *const PyTzInfo) } + unsafe { &*(expect_datetime_api(py).TimeZone_UTC as *const PyTzInfo) } } /// Equivalent to `datetime.timezone` constructor @@ -604,7 +606,7 @@ pub fn timezone_utc(py: Python<'_>) -> &PyTzInfo { /// Only used internally #[cfg(feature = "chrono")] pub fn timezone_from_offset<'a>(py: Python<'a>, offset: &PyDelta) -> PyResult<&'a PyTzInfo> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.TimeZone_FromTimeZone)(offset.as_ptr(), ptr::null_mut()); py.from_owned_ptr_or_err(ptr) @@ -617,7 +619,7 @@ pub struct PyDelta(PyAny); pyobject_native_type!( PyDelta, crate::ffi::PyDateTime_Delta, - |py| ensure_datetime_api(py).DeltaType, + |py| expect_datetime_api(py).DeltaType, #module=Some("datetime"), #checkfunction=PyDelta_Check ); @@ -631,7 +633,7 @@ impl PyDelta { microseconds: i32, normalize: bool, ) -> PyResult<&PyDelta> { - let api = ensure_datetime_api(py); + let api = ensure_datetime_api(py)?; unsafe { let ptr = (api.Delta_FromDelta)( days as c_int, diff --git a/tests/test_datetime_import.rs b/tests/test_datetime_import.rs new file mode 100644 index 00000000..fee99b07 --- /dev/null +++ b/tests/test_datetime_import.rs @@ -0,0 +1,26 @@ +#![cfg(not(Py_LIMITED_API))] + +use pyo3::{types::PyDate, Python}; + +#[test] +#[should_panic(expected = "module 'datetime' has no attribute 'datetime_CAPI'")] +fn test_bad_datetime_module_panic() { + // Create an empty temporary directory + // with an empty "datetime" module which we'll put on the sys.path + let tmpdir = std::env::temp_dir(); + let tmpdir = tmpdir.join("pyo3_test_date_check"); + let _ = std::fs::remove_dir_all(&tmpdir); + std::fs::create_dir(&tmpdir).unwrap(); + std::fs::File::create(tmpdir.join("datetime.py")).unwrap(); + + Python::with_gil(|py: Python<'_>| { + let sys = py.import("sys").unwrap(); + sys.getattr("path") + .unwrap() + .call_method1("insert", (0, tmpdir)) + .unwrap(); + + // This should panic because the "datetime" module is empty + PyDate::new(py, 2018, 1, 1).unwrap(); + }); +}