From c7a53611e032e5f19a576b30f54b5c4c5c5dc7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Niederb=C3=BChl?= Date: Sat, 12 Oct 2019 01:59:24 +0200 Subject: [PATCH] Enable to give None as default value for an argument --- pyo3-derive-backend/src/pymethod.rs | 7 +++++-- tests/test_methods.rs | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 0245ce4f..6ed158fd 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -523,11 +523,14 @@ fn impl_arg_param( } } else if arg.optional.is_some() { let default = if let Some(d) = spec.default_value(name) { - quote! { Some(#d) } + if d.to_string() == "None" { + quote! { None } + } else { + quote! { Some(#d) } + } } else { quote! { None } }; - quote! { let #arg_name = match #arg_value.as_ref() { Some(_obj) => { diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 5906ba64..495c2e3f 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -214,6 +214,21 @@ impl MethArgs { fn get_optional(&self, test: Option) -> PyResult { Ok(test.unwrap_or(10)) } + fn get_optional2(&self, test: Option) -> PyResult> { + Ok(test) + } + #[args(test = "None")] + fn get_optional3(&self, test: Option) -> PyResult> { + Ok(test) + } + fn get_optional_positional( + &self, + _t1: Option, + t2: Option, + _t3: Option, + ) -> PyResult> { + Ok(t2) + } #[args(test = "10")] fn get_default(&self, test: i32) -> PyResult { @@ -264,6 +279,15 @@ fn meth_args() { py_run!(py, inst, "assert inst.get_optional() == 10"); py_run!(py, inst, "assert inst.get_optional(100) == 100"); + py_run!(py, inst, "assert inst.get_optional2() == None"); + py_run!(py, inst, "assert inst.get_optional(100) == 100"); + py_run!(py, inst, "assert inst.get_optional3() == None"); + py_run!( + py, + inst, + "assert inst.get_optional_positional(1, 2, 3) == 2" + ); + py_run!(py, inst, "assert inst.get_optional_positional(1) == None"); py_run!(py, inst, "assert inst.get_default() == 10"); py_run!(py, inst, "assert inst.get_default(100) == 100"); py_run!(py, inst, "assert inst.get_kwarg() == 10");