Merge pull request #3461 from davidhewitt/some-wraps

better `Some`-wrapping for default arguments
This commit is contained in:
David Hewitt 2023-09-21 21:12:43 +00:00 committed by GitHub
commit aeb7a958dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 93 additions and 38 deletions

View File

@ -0,0 +1 @@
Some-wrapping of `Option<T>` default arguments will no longer re-wrap `Some(T)` or expressions evaluating to `None`.

View File

@ -1,6 +1,7 @@
use crate::{ use crate::{
method::{FnArg, FnSpec}, method::{FnArg, FnSpec},
pyfunction::FunctionSignature, pyfunction::FunctionSignature,
quotes::some_wrap,
}; };
use proc_macro2::{Span, TokenStream}; use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned}; use quote::{quote, quote_spanned};
@ -192,12 +193,7 @@ fn impl_arg_param(
// Option<T> arguments have special treatment: the default should be specified _without_ the // Option<T> arguments have special treatment: the default should be specified _without_ the
// Some() wrapper. Maybe this should be changed in future?! // Some() wrapper. Maybe this should be changed in future?!
if arg.optional.is_some() { if arg.optional.is_some() {
default = Some(match &default { default = Some(default.map_or_else(|| quote!(::std::option::Option::None), some_wrap));
Some(expression) if expression.to_string() != "None" => {
quote!(::std::option::Option::Some(#expression))
}
_ => quote!(::std::option::Option::None),
})
} }
let tokens = if let Some(expr_path) = arg.attrs.from_py_with.as_ref().map(|attr| &attr.value) { let tokens = if let Some(expr_path) = arg.attrs.from_py_with.as_ref().map(|attr| &attr.value) {

View File

@ -1,10 +1,16 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::quote; use quote::quote;
pub(crate) fn some_wrap(obj: TokenStream) -> TokenStream {
quote! {
_pyo3::impl_::wrap::SomeWrap::wrap(#obj)
}
}
pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream { pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream {
quote! { quote! {
_pyo3::impl_::pymethods::OkWrap::wrap(#obj, py) _pyo3::impl_::wrap::OkWrap::wrap(#obj, py)
.map_err(::core::convert::Into::into) .map_err(::core::convert::Into::<_pyo3::PyErr>::into)
} }
} }

View File

@ -19,3 +19,4 @@ pub mod pymethods;
pub mod pymodule; pub mod pymodule;
#[doc(hidden)] #[doc(hidden)]
pub mod trampoline; pub mod trampoline;
pub mod wrap;

View File

@ -1,10 +1,7 @@
use crate::gil::LockGIL; use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap; use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string; use crate::internal_tricks::extract_c_string;
use crate::{ use crate::{ffi, PyAny, PyCell, PyClass, PyObject, PyResult, PyTraverseError, PyVisit, Python};
ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow; use std::borrow::Cow;
use std::ffi::CStr; use std::ffi::CStr;
use std::fmt; use std::fmt;
@ -295,32 +292,6 @@ pub(crate) struct PyMethodDefDestructor {
doc: Cow<'static, CStr>, doc: Cow<'static, CStr>,
} }
// The macros need to Ok-wrap the output of user defined functions; i.e. if they're not a result, make them into one.
pub trait OkWrap<T> {
type Error;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
}
impl<T> OkWrap<T> for T
where
T: IntoPy<PyObject>,
{
type Error = PyErr;
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
Ok(self.into_py(py))
}
}
impl<T, E> OkWrap<T> for Result<T, E>
where
T: IntoPy<PyObject>,
{
type Error = E;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
self.map(|o| o.into_py(py))
}
}
pub(crate) fn get_name(name: &'static str) -> PyResult<Cow<'static, CStr>> { pub(crate) fn get_name(name: &'static str) -> PyResult<Cow<'static, CStr>> {
extract_c_string(name, "function name cannot contain NUL byte.") extract_c_string(name, "function name cannot contain NUL byte.")
} }

60
src/impl_/wrap.rs Normal file
View File

@ -0,0 +1,60 @@
use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
/// Used to wrap values in `Option<T>` for default arguments.
pub trait SomeWrap<T> {
fn wrap(self) -> T;
}
impl<T> SomeWrap<Option<T>> for T {
fn wrap(self) -> Option<T> {
Some(self)
}
}
impl<T> SomeWrap<Option<T>> for Option<T> {
fn wrap(self) -> Self {
self
}
}
/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`.
pub trait OkWrap<T> {
type Error;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
}
// The T: IntoPy<PyObject> bound here is necessary to prevent the
// implementation for Result<T, E> from conflicting
impl<T> OkWrap<T> for T
where
T: IntoPy<PyObject>,
{
type Error = PyErr;
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
Ok(self.into_py(py))
}
}
impl<T, E> OkWrap<T> for Result<T, E>
where
T: IntoPy<PyObject>,
{
type Error = E;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
self.map(|o| o.into_py(py))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wrap_option() {
let a: Option<u8> = SomeWrap::wrap(42);
assert_eq!(a, Some(42));
let b: Option<u8> = SomeWrap::wrap(None);
assert_eq!(b, None);
}
}

View File

@ -508,3 +508,23 @@ fn test_return_value_borrows_from_arguments() {
py_assert!(py, function key value, "function(key, value) == { \"key\": 42 }"); py_assert!(py, function key value, "function(key, value) == { \"key\": 42 }");
}); });
} }
#[test]
fn test_some_wrap_arguments() {
// https://github.com/PyO3/pyo3/issues/3460
const NONE: Option<u8> = None;
#[pyfunction(signature = (a = 1, b = Some(2), c = None, d = NONE))]
fn some_wrap_arguments(
a: Option<u8>,
b: Option<u8>,
c: Option<u8>,
d: Option<u8>,
) -> [Option<u8>; 4] {
[a, b, c, d]
}
Python::with_gil(|py| {
let function = wrap_pyfunction!(some_wrap_arguments, py).unwrap();
py_assert!(py, function, "function() == [1, 2, None, None]");
})
}