diff --git a/newsfragments/3461.fixed.md b/newsfragments/3461.fixed.md new file mode 100644 index 00000000..9b59bcd6 --- /dev/null +++ b/newsfragments/3461.fixed.md @@ -0,0 +1 @@ +Some-wrapping of `Option` default arguments will no longer re-wrap `Some(T)` or expressions evaluating to `None`. diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index 71bced06..e511ca75 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -1,6 +1,7 @@ use crate::{ method::{FnArg, FnSpec}, pyfunction::FunctionSignature, + quotes::some_wrap, }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; @@ -192,12 +193,7 @@ fn impl_arg_param( // Option arguments have special treatment: the default should be specified _without_ the // Some() wrapper. Maybe this should be changed in future?! if arg.optional.is_some() { - default = Some(match &default { - Some(expression) if expression.to_string() != "None" => { - quote!(::std::option::Option::Some(#expression)) - } - _ => quote!(::std::option::Option::None), - }) + default = Some(default.map_or_else(|| quote!(::std::option::Option::None), some_wrap)); } let tokens = if let Some(expr_path) = arg.attrs.from_py_with.as_ref().map(|attr| &attr.value) { diff --git a/pyo3-macros-backend/src/quotes.rs b/pyo3-macros-backend/src/quotes.rs index 3404c271..966564b1 100644 --- a/pyo3-macros-backend/src/quotes.rs +++ b/pyo3-macros-backend/src/quotes.rs @@ -1,10 +1,16 @@ use proc_macro2::TokenStream; use quote::quote; +pub(crate) fn some_wrap(obj: TokenStream) -> TokenStream { + quote! { + _pyo3::impl_::wrap::SomeWrap::wrap(#obj) + } +} + pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream { quote! { - _pyo3::impl_::pymethods::OkWrap::wrap(#obj, py) - .map_err(::core::convert::Into::into) + _pyo3::impl_::wrap::OkWrap::wrap(#obj, py) + .map_err(::core::convert::Into::<_pyo3::PyErr>::into) } } diff --git a/src/impl_.rs b/src/impl_.rs index e0c06c5a..118d62d9 100644 --- a/src/impl_.rs +++ b/src/impl_.rs @@ -19,3 +19,4 @@ pub mod pymethods; pub mod pymodule; #[doc(hidden)] pub mod trampoline; +pub mod wrap; diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index 98089d20..ff2857ec 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -1,10 +1,7 @@ use crate::gil::LockGIL; use crate::impl_::panic::PanicTrap; use crate::internal_tricks::extract_c_string; -use crate::{ - ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, - Python, -}; +use crate::{ffi, PyAny, PyCell, PyClass, PyObject, PyResult, PyTraverseError, PyVisit, Python}; use std::borrow::Cow; use std::ffi::CStr; use std::fmt; @@ -295,32 +292,6 @@ pub(crate) struct PyMethodDefDestructor { doc: Cow<'static, CStr>, } -// The macros need to Ok-wrap the output of user defined functions; i.e. if they're not a result, make them into one. -pub trait OkWrap { - type Error; - fn wrap(self, py: Python<'_>) -> Result, Self::Error>; -} - -impl OkWrap for T -where - T: IntoPy, -{ - type Error = PyErr; - fn wrap(self, py: Python<'_>) -> PyResult> { - Ok(self.into_py(py)) - } -} - -impl OkWrap for Result -where - T: IntoPy, -{ - type Error = E; - fn wrap(self, py: Python<'_>) -> Result, Self::Error> { - self.map(|o| o.into_py(py)) - } -} - pub(crate) fn get_name(name: &'static str) -> PyResult> { extract_c_string(name, "function name cannot contain NUL byte.") } diff --git a/src/impl_/wrap.rs b/src/impl_/wrap.rs new file mode 100644 index 00000000..a73e3597 --- /dev/null +++ b/src/impl_/wrap.rs @@ -0,0 +1,60 @@ +use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python}; + +/// Used to wrap values in `Option` for default arguments. +pub trait SomeWrap { + fn wrap(self) -> T; +} + +impl SomeWrap> for T { + fn wrap(self) -> Option { + Some(self) + } +} + +impl SomeWrap> for Option { + fn wrap(self) -> Self { + self + } +} + +/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`. +pub trait OkWrap { + type Error; + fn wrap(self, py: Python<'_>) -> Result, Self::Error>; +} + +// The T: IntoPy bound here is necessary to prevent the +// implementation for Result from conflicting +impl OkWrap for T +where + T: IntoPy, +{ + type Error = PyErr; + fn wrap(self, py: Python<'_>) -> PyResult> { + Ok(self.into_py(py)) + } +} + +impl OkWrap for Result +where + T: IntoPy, +{ + type Error = E; + fn wrap(self, py: Python<'_>) -> Result, Self::Error> { + self.map(|o| o.into_py(py)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wrap_option() { + let a: Option = SomeWrap::wrap(42); + assert_eq!(a, Some(42)); + + let b: Option = SomeWrap::wrap(None); + assert_eq!(b, None); + } +} diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 481ae0e8..f39c03c7 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -508,3 +508,23 @@ fn test_return_value_borrows_from_arguments() { py_assert!(py, function key value, "function(key, value) == { \"key\": 42 }"); }); } + +#[test] +fn test_some_wrap_arguments() { + // https://github.com/PyO3/pyo3/issues/3460 + const NONE: Option = None; + #[pyfunction(signature = (a = 1, b = Some(2), c = None, d = NONE))] + fn some_wrap_arguments( + a: Option, + b: Option, + c: Option, + d: Option, + ) -> [Option; 4] { + [a, b, c, d] + } + + Python::with_gil(|py| { + let function = wrap_pyfunction!(some_wrap_arguments, py).unwrap(); + py_assert!(py, function, "function() == [1, 2, None, None]"); + }) +}