Make it enable to take &PyClass as arguments as pyfunctions/methods

This commit is contained in:
kngwyu 2020-03-04 01:23:32 +09:00
parent 6307c25b81
commit e63e0cbf5a
8 changed files with 166 additions and 83 deletions

View File

@ -109,12 +109,11 @@ impl<'a> FnSpec<'a> {
let py = crate::utils::if_type_is_python(ty);
let opt = check_arg_ty_and_optional(name, ty);
let opt = check_ty_optional(ty);
arguments.push(FnArg {
name: ident,
by_ref,
mutability,
// mode: mode,
ty,
optional: opt,
py,
@ -305,55 +304,18 @@ pub fn is_ref(name: &syn::Ident, ty: &syn::Type) -> bool {
false
}
pub fn check_arg_ty_and_optional<'a>(
name: &'a syn::Ident,
ty: &'a syn::Type,
) -> Option<&'a syn::Type> {
match ty {
syn::Type::Path(syn::TypePath { ref path, .. }) => {
//if let Some(ref qs) = qs {
// panic!("explicit Self type in a 'qualified path' is not supported: {:?} - {:?}",
// name, qs);
//}
if let Some(segment) = path.segments.last() {
match segment.ident.to_string().as_str() {
"Option" => match segment.arguments {
syn::PathArguments::AngleBracketed(ref params) => {
if params.args.len() != 1 {
panic!("argument type is not supported by python method: {:?} ({:?}) {:?}",
name,
ty,
path);
}
match &params.args[0] {
syn::GenericArgument::Type(ref ty) => Some(ty),
_ => panic!("argument type is not supported by python method: {:?} ({:?}) {:?}",
name,
ty,
path),
}
}
_ => {
panic!(
"argument type is not supported by python method: {:?} ({:?}) {:?}",
name, ty, path
);
}
},
_ => None,
}
} else {
None
}
}
_ => {
None
//panic!("argument type is not supported by python method: {:?} ({:?})",
//name,
//ty);
}
pub(crate) fn check_ty_optional<'a>(ty: &'a syn::Type) -> Option<&'a syn::Type> {
let path = match ty {
syn::Type::Path(syn::TypePath { ref path, .. }) => path,
_ => return None,
};
let seg = path.segments.last().filter(|s| s.ident == "Option")?;
match seg.arguments {
syn::PathArguments::AngleBracketed(ref params) => match params.args.first() {
Some(syn::GenericArgument::Type(ref ty)) => Some(ty),
_ => None,
},
_ => None,
}
}

View File

@ -68,7 +68,7 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> syn::Result<m
};
let py = crate::utils::if_type_is_python(&cap.ty);
let opt = method::check_arg_ty_and_optional(&name, &cap.ty);
let opt = method::check_ty_optional(&cap.ty);
Ok(method::FnArg {
name: ident,
mutability,

View File

@ -412,6 +412,15 @@ fn impl_class(
type BaseNativeType = #base_nativetype;
}
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
{
type Target = pyo3::PyRef<'a, #cls>;
}
impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
{
type Target = pyo3::PyRefMut<'a, #cls>;
}
#into_pyobject
#inventory_impl

View File

@ -387,10 +387,7 @@ pub(crate) fn impl_wrap_setter(
};
match _result {
Ok(_) => 0,
Err(e) => {
e.restore(_py);
-1
}
Err(e) => e.restore_and_minus1(_py),
}
}
})
@ -523,45 +520,65 @@ fn impl_arg_param(
}
let arg_value = quote!(output[#option_pos]);
*option_pos += 1;
if arg.optional.is_some() {
let default = if let Some(d) = spec.default_value(name) {
if d.to_string() == "None" {
quote! { None }
} else {
quote! { Some(#d) }
}
return if let Some(ty) = arg.optional.as_ref() {
let default = if let Some(d) = spec.default_value(name).filter(|d| d.to_string() != "None")
{
quote! { Some(#d) }
} else {
quote! { None }
};
quote! {
let #arg_name = match #arg_value.as_ref() {
Some(_obj) => {
if _obj.is_none() {
#default
} else {
Some(_obj.extract()?)
}
},
None => #default
if let syn::Type::Reference(tref) = ty {
let (tref, mut_) = tref_preprocess(tref);
let as_deref = if mut_.is_some() {
quote! { as_deref_mut }
} else {
quote! { as_deref }
};
// Get Option<&T> from Option<PyRef<T>>
quote! {
let #mut_ _tmp = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
Some(_obj) => {
Some(_obj.extract::<<#tref as pyo3::derive_utils::ExtractExt>::Target>()?)
},
None => #default,
};
let #arg_name = _tmp.#as_deref();
}
} else {
quote! {
let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
Some(_obj) => Some(_obj.extract()?),
None => #default,
};
}
}
} else if let Some(default) = spec.default_value(name) {
quote! {
let #arg_name = match #arg_value.as_ref() {
Some(_obj) => {
if _obj.is_none() {
#default
} else {
_obj.extract()?
}
},
None => #default
let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) {
Some(_obj) => _obj.extract()?,
None => #default,
};
}
} else if let syn::Type::Reference(tref) = arg.ty {
let (tref, mut_) = tref_preprocess(tref);
// Get &T from PyRef<T>
quote! {
let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target
= #arg_value.unwrap().extract()?;
let #arg_name = &#mut_ *_tmp;
}
} else {
quote! {
let #arg_name = #arg_value.unwrap().extract()?;
}
};
fn tref_preprocess(tref: &syn::TypeReference) -> (syn::TypeReference, Option<syn::token::Mut>) {
let mut tref = tref.to_owned();
tref.lifetime = None;
let mut_ = tref.mutability;
(tref, mut_)
}
}

View File

@ -198,6 +198,7 @@ impl<T: PyClass, I: Into<PyClassInitializer<T>>> IntoPyNewResult<T, I> for PyRes
}
}
#[doc(hidden)]
pub trait GetPropertyValue {
fn get_property_value(&self, py: Python) -> PyObject;
}
@ -218,6 +219,7 @@ impl GetPropertyValue for PyObject {
}
/// Utilities for basetype
#[doc(hidden)]
pub trait PyBaseTypeUtils {
type Dict;
type WeakRef;
@ -231,3 +233,16 @@ impl<T: PyClass> PyBaseTypeUtils for T {
type LayoutAsBase = crate::pycell::PyCellInner<T>;
type BaseNativeType = T::BaseNativeType;
}
/// Utility trait to enable &PyClass as a pymethod/function argument
#[doc(hidden)]
pub trait ExtractExt<'a> {
type Target: crate::FromPyObject<'a>;
}
impl<'a, T> ExtractExt<'a> for T
where
T: crate::FromPyObject<'a>,
{
type Target = T;
}

0
tests/test_dunder.rs Executable file → Normal file
View File

View File

@ -405,6 +405,62 @@ fn method_with_lifetime() {
);
}
#[pyclass]
struct MethodWithPyClassArg {
#[pyo3(get)]
value: i64,
}
#[pymethods]
impl MethodWithPyClassArg {
fn add(&self, other: &MethodWithPyClassArg) -> MethodWithPyClassArg {
MethodWithPyClassArg {
value: self.value + other.value,
}
}
fn add_pyref(&self, other: PyRef<MethodWithPyClassArg>) -> MethodWithPyClassArg {
MethodWithPyClassArg {
value: self.value + other.value,
}
}
fn inplace_add(&self, other: &mut MethodWithPyClassArg) {
other.value += self.value;
}
fn optional_add(&self, other: Option<&MethodWithPyClassArg>) -> MethodWithPyClassArg {
MethodWithPyClassArg {
value: self.value + other.map(|o| o.value).unwrap_or(10),
}
}
}
#[test]
fn method_with_pyclassarg() {
let gil = Python::acquire_gil();
let py = gil.python();
let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
py_run!(
py,
obj1 obj2,
"obj = obj1.add(obj2); assert obj.value == 20"
);
py_run!(
py,
obj1 obj2,
"obj = obj1.add_pyref(obj2); assert obj.value == 20"
);
py_run!(
py,
obj1 obj2,
"obj = obj1.optional_add(); assert obj.value == 20"
);
py_run!(
py,
obj1 obj2,
"obj1.inplace_add(obj2); assert obj2.value == 20"
);
}
#[pyclass]
#[cfg(unix)]
struct CfgStruct {}

View File

@ -7,6 +7,19 @@ mod common;
#[pyclass]
struct AnonClass {}
#[pyclass]
struct ValueClass {
value: usize,
}
#[pymethods]
impl ValueClass {
#[new]
fn new(value: usize) -> ValueClass {
ValueClass { value }
}
}
#[pyclass(module = "module")]
struct LocatedClass {}
@ -36,7 +49,13 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
Ok(42)
}
#[pyfn(m, "double_value")]
fn double_value(v: &ValueClass) -> usize {
v.value * 2
}
m.add_class::<AnonClass>().unwrap();
m.add_class::<ValueClass>().unwrap();
m.add_class::<LocatedClass>().unwrap();
m.add("foo", "bar").unwrap();
@ -60,7 +79,11 @@ fn test_module_with_functions() {
)]
.into_py_dict(py);
let run = |code| py.run(code, None, Some(d)).unwrap();
let run = |code| {
py.run(code, None, Some(d))
.map_err(|e| e.print(py))
.unwrap()
};
run("assert module_with_functions.__doc__ == 'This module is implemented in Rust.'");
run("assert module_with_functions.sum_as_string(1, 2) == '3'");
@ -73,6 +96,7 @@ fn test_module_with_functions() {
run("assert module_with_functions.double.__doc__ == 'Doubles the given value'");
run("assert module_with_functions.also_double(3) == 6");
run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'");
run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2");
}
#[pymodule(other_name)]