Make it enable to take &PyClass as arguments as pyfunctions/methods
This commit is contained in:
parent
6307c25b81
commit
e63e0cbf5a
|
@ -109,12 +109,11 @@ impl<'a> FnSpec<'a> {
|
|||
|
||||
let py = crate::utils::if_type_is_python(ty);
|
||||
|
||||
let opt = check_arg_ty_and_optional(name, ty);
|
||||
let opt = check_ty_optional(ty);
|
||||
arguments.push(FnArg {
|
||||
name: ident,
|
||||
by_ref,
|
||||
mutability,
|
||||
// mode: mode,
|
||||
ty,
|
||||
optional: opt,
|
||||
py,
|
||||
|
@ -305,55 +304,18 @@ pub fn is_ref(name: &syn::Ident, ty: &syn::Type) -> bool {
|
|||
false
|
||||
}
|
||||
|
||||
pub fn check_arg_ty_and_optional<'a>(
|
||||
name: &'a syn::Ident,
|
||||
ty: &'a syn::Type,
|
||||
) -> Option<&'a syn::Type> {
|
||||
match ty {
|
||||
syn::Type::Path(syn::TypePath { ref path, .. }) => {
|
||||
//if let Some(ref qs) = qs {
|
||||
// panic!("explicit Self type in a 'qualified path' is not supported: {:?} - {:?}",
|
||||
// name, qs);
|
||||
//}
|
||||
|
||||
if let Some(segment) = path.segments.last() {
|
||||
match segment.ident.to_string().as_str() {
|
||||
"Option" => match segment.arguments {
|
||||
syn::PathArguments::AngleBracketed(ref params) => {
|
||||
if params.args.len() != 1 {
|
||||
panic!("argument type is not supported by python method: {:?} ({:?}) {:?}",
|
||||
name,
|
||||
ty,
|
||||
path);
|
||||
}
|
||||
|
||||
match ¶ms.args[0] {
|
||||
syn::GenericArgument::Type(ref ty) => Some(ty),
|
||||
_ => panic!("argument type is not supported by python method: {:?} ({:?}) {:?}",
|
||||
name,
|
||||
ty,
|
||||
path),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!(
|
||||
"argument type is not supported by python method: {:?} ({:?}) {:?}",
|
||||
name, ty, path
|
||||
);
|
||||
}
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
None
|
||||
//panic!("argument type is not supported by python method: {:?} ({:?})",
|
||||
//name,
|
||||
//ty);
|
||||
}
|
||||
pub(crate) fn check_ty_optional<'a>(ty: &'a syn::Type) -> Option<&'a syn::Type> {
|
||||
let path = match ty {
|
||||
syn::Type::Path(syn::TypePath { ref path, .. }) => path,
|
||||
_ => return None,
|
||||
};
|
||||
let seg = path.segments.last().filter(|s| s.ident == "Option")?;
|
||||
match seg.arguments {
|
||||
syn::PathArguments::AngleBracketed(ref params) => match params.args.first() {
|
||||
Some(syn::GenericArgument::Type(ref ty)) => Some(ty),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> syn::Result<m
|
|||
};
|
||||
|
||||
let py = crate::utils::if_type_is_python(&cap.ty);
|
||||
let opt = method::check_arg_ty_and_optional(&name, &cap.ty);
|
||||
let opt = method::check_ty_optional(&cap.ty);
|
||||
Ok(method::FnArg {
|
||||
name: ident,
|
||||
mutability,
|
||||
|
|
|
@ -412,6 +412,15 @@ fn impl_class(
|
|||
type BaseNativeType = #base_nativetype;
|
||||
}
|
||||
|
||||
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
|
||||
{
|
||||
type Target = pyo3::PyRef<'a, #cls>;
|
||||
}
|
||||
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
|
||||
{
|
||||
type Target = pyo3::PyRefMut<'a, #cls>;
|
||||
}
|
||||
|
||||
#into_pyobject
|
||||
|
||||
#inventory_impl
|
||||
|
|
|
@ -387,10 +387,7 @@ pub(crate) fn impl_wrap_setter(
|
|||
};
|
||||
match _result {
|
||||
Ok(_) => 0,
|
||||
Err(e) => {
|
||||
e.restore(_py);
|
||||
-1
|
||||
}
|
||||
Err(e) => e.restore_and_minus1(_py),
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -523,45 +520,65 @@ fn impl_arg_param(
|
|||
}
|
||||
let arg_value = quote!(output[#option_pos]);
|
||||
*option_pos += 1;
|
||||
if arg.optional.is_some() {
|
||||
let default = if let Some(d) = spec.default_value(name) {
|
||||
if d.to_string() == "None" {
|
||||
quote! { None }
|
||||
} else {
|
||||
quote! { Some(#d) }
|
||||
}
|
||||
|
||||
return if let Some(ty) = arg.optional.as_ref() {
|
||||
let default = if let Some(d) = spec.default_value(name).filter(|d| d.to_string() != "None")
|
||||
{
|
||||
quote! { Some(#d) }
|
||||
} else {
|
||||
quote! { None }
|
||||
};
|
||||
quote! {
|
||||
let #arg_name = match #arg_value.as_ref() {
|
||||
Some(_obj) => {
|
||||
if _obj.is_none() {
|
||||
#default
|
||||
} else {
|
||||
Some(_obj.extract()?)
|
||||
}
|
||||
},
|
||||
None => #default
|
||||
if let syn::Type::Reference(tref) = ty {
|
||||
let (tref, mut_) = tref_preprocess(tref);
|
||||
let as_deref = if mut_.is_some() {
|
||||
quote! { as_deref_mut }
|
||||
} else {
|
||||
quote! { as_deref }
|
||||
};
|
||||
// Get Option<&T> from Option<PyRef<T>>
|
||||
quote! {
|
||||
let #mut_ _tmp = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
|
||||
Some(_obj) => {
|
||||
Some(_obj.extract::<<#tref as pyo3::derive_utils::ExtractExt>::Target>()?)
|
||||
},
|
||||
None => #default,
|
||||
};
|
||||
let #arg_name = _tmp.#as_deref();
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
|
||||
Some(_obj) => Some(_obj.extract()?),
|
||||
None => #default,
|
||||
};
|
||||
}
|
||||
}
|
||||
} else if let Some(default) = spec.default_value(name) {
|
||||
quote! {
|
||||
let #arg_name = match #arg_value.as_ref() {
|
||||
Some(_obj) => {
|
||||
if _obj.is_none() {
|
||||
#default
|
||||
} else {
|
||||
_obj.extract()?
|
||||
}
|
||||
},
|
||||
None => #default
|
||||
let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
|
||||
Some(_obj) => _obj.extract()?,
|
||||
None => #default,
|
||||
};
|
||||
}
|
||||
} else if let syn::Type::Reference(tref) = arg.ty {
|
||||
let (tref, mut_) = tref_preprocess(tref);
|
||||
// Get &T from PyRef<T>
|
||||
quote! {
|
||||
let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target
|
||||
= #arg_value.unwrap().extract()?;
|
||||
let #arg_name = &#mut_ *_tmp;
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let #arg_name = #arg_value.unwrap().extract()?;
|
||||
}
|
||||
};
|
||||
|
||||
fn tref_preprocess(tref: &syn::TypeReference) -> (syn::TypeReference, Option<syn::token::Mut>) {
|
||||
let mut tref = tref.to_owned();
|
||||
tref.lifetime = None;
|
||||
let mut_ = tref.mutability;
|
||||
(tref, mut_)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -198,6 +198,7 @@ impl<T: PyClass, I: Into<PyClassInitializer<T>>> IntoPyNewResult<T, I> for PyRes
|
|||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub trait GetPropertyValue {
|
||||
fn get_property_value(&self, py: Python) -> PyObject;
|
||||
}
|
||||
|
@ -218,6 +219,7 @@ impl GetPropertyValue for PyObject {
|
|||
}
|
||||
|
||||
/// Utilities for basetype
|
||||
#[doc(hidden)]
|
||||
pub trait PyBaseTypeUtils {
|
||||
type Dict;
|
||||
type WeakRef;
|
||||
|
@ -231,3 +233,16 @@ impl<T: PyClass> PyBaseTypeUtils for T {
|
|||
type LayoutAsBase = crate::pycell::PyCellInner<T>;
|
||||
type BaseNativeType = T::BaseNativeType;
|
||||
}
|
||||
|
||||
/// Utility trait to enable &PyClass as a pymethod/function argument
|
||||
#[doc(hidden)]
|
||||
pub trait ExtractExt<'a> {
|
||||
type Target: crate::FromPyObject<'a>;
|
||||
}
|
||||
|
||||
impl<'a, T> ExtractExt<'a> for T
|
||||
where
|
||||
T: crate::FromPyObject<'a>,
|
||||
{
|
||||
type Target = T;
|
||||
}
|
||||
|
|
|
@ -405,6 +405,62 @@ fn method_with_lifetime() {
|
|||
);
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct MethodWithPyClassArg {
|
||||
#[pyo3(get)]
|
||||
value: i64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MethodWithPyClassArg {
|
||||
fn add(&self, other: &MethodWithPyClassArg) -> MethodWithPyClassArg {
|
||||
MethodWithPyClassArg {
|
||||
value: self.value + other.value,
|
||||
}
|
||||
}
|
||||
fn add_pyref(&self, other: PyRef<MethodWithPyClassArg>) -> MethodWithPyClassArg {
|
||||
MethodWithPyClassArg {
|
||||
value: self.value + other.value,
|
||||
}
|
||||
}
|
||||
fn inplace_add(&self, other: &mut MethodWithPyClassArg) {
|
||||
other.value += self.value;
|
||||
}
|
||||
fn optional_add(&self, other: Option<&MethodWithPyClassArg>) -> MethodWithPyClassArg {
|
||||
MethodWithPyClassArg {
|
||||
value: self.value + other.map(|o| o.value).unwrap_or(10),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn method_with_pyclassarg() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
|
||||
let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
|
||||
py_run!(
|
||||
py,
|
||||
obj1 obj2,
|
||||
"obj = obj1.add(obj2); assert obj.value == 20"
|
||||
);
|
||||
py_run!(
|
||||
py,
|
||||
obj1 obj2,
|
||||
"obj = obj1.add_pyref(obj2); assert obj.value == 20"
|
||||
);
|
||||
py_run!(
|
||||
py,
|
||||
obj1 obj2,
|
||||
"obj = obj1.optional_add(); assert obj.value == 20"
|
||||
);
|
||||
py_run!(
|
||||
py,
|
||||
obj1 obj2,
|
||||
"obj1.inplace_add(obj2); assert obj2.value == 20"
|
||||
);
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[cfg(unix)]
|
||||
struct CfgStruct {}
|
||||
|
|
|
@ -7,6 +7,19 @@ mod common;
|
|||
#[pyclass]
|
||||
struct AnonClass {}
|
||||
|
||||
#[pyclass]
|
||||
struct ValueClass {
|
||||
value: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl ValueClass {
|
||||
#[new]
|
||||
fn new(value: usize) -> ValueClass {
|
||||
ValueClass { value }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(module = "module")]
|
||||
struct LocatedClass {}
|
||||
|
||||
|
@ -36,7 +49,13 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
|
|||
Ok(42)
|
||||
}
|
||||
|
||||
#[pyfn(m, "double_value")]
|
||||
fn double_value(v: &ValueClass) -> usize {
|
||||
v.value * 2
|
||||
}
|
||||
|
||||
m.add_class::<AnonClass>().unwrap();
|
||||
m.add_class::<ValueClass>().unwrap();
|
||||
m.add_class::<LocatedClass>().unwrap();
|
||||
|
||||
m.add("foo", "bar").unwrap();
|
||||
|
@ -60,7 +79,11 @@ fn test_module_with_functions() {
|
|||
)]
|
||||
.into_py_dict(py);
|
||||
|
||||
let run = |code| py.run(code, None, Some(d)).unwrap();
|
||||
let run = |code| {
|
||||
py.run(code, None, Some(d))
|
||||
.map_err(|e| e.print(py))
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
run("assert module_with_functions.__doc__ == 'This module is implemented in Rust.'");
|
||||
run("assert module_with_functions.sum_as_string(1, 2) == '3'");
|
||||
|
@ -73,6 +96,7 @@ fn test_module_with_functions() {
|
|||
run("assert module_with_functions.double.__doc__ == 'Doubles the given value'");
|
||||
run("assert module_with_functions.also_double(3) == 6");
|
||||
run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'");
|
||||
run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2");
|
||||
}
|
||||
|
||||
#[pymodule(other_name)]
|
||||
|
|
Loading…
Reference in New Issue