pymethods: support inplace numerical operations
This commit is contained in:
parent
c090b6581d
commit
92e2156161
|
@ -32,7 +32,7 @@ pub fn gen_py_method(
|
|||
let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?;
|
||||
|
||||
if let Some(slot_def) = pyproto(&spec.python_name.to_string()) {
|
||||
let slot = slot_def.generate_type_slot(cls, &spec);
|
||||
let slot = slot_def.generate_type_slot(cls, &spec)?;
|
||||
return Ok(GeneratedPyMethod::Proto(slot));
|
||||
}
|
||||
|
||||
|
@ -399,7 +399,6 @@ const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc")
|
|||
));
|
||||
const __RICHCMP__: SlotDef =
|
||||
SlotDef::new("Py_tp_richcompare", "richcmpfunc").arguments(&[Ty::Object, Ty::CompareOp]);
|
||||
const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int);
|
||||
const __GET__: SlotDef =
|
||||
SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]);
|
||||
const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc");
|
||||
|
@ -417,6 +416,55 @@ const __CONTAINS__: SlotDef = SlotDef::new("Py_sq_contains", "objobjproc")
|
|||
.ret_ty(Ty::Int);
|
||||
const __GETITEM__: SlotDef = SlotDef::new("Py_mp_subscript", "binaryfunc").arguments(&[Ty::Object]);
|
||||
|
||||
const __POS__: SlotDef = SlotDef::new("Py_nb_positive", "unaryfunc");
|
||||
const __NEG__: SlotDef = SlotDef::new("Py_nb_negative", "unaryfunc");
|
||||
const __ABS__: SlotDef = SlotDef::new("Py_nb_absolute", "unaryfunc");
|
||||
const __INVERT__: SlotDef = SlotDef::new("Py_nb_invert", "unaryfunc");
|
||||
const __INDEX__: SlotDef = SlotDef::new("Py_nb_index", "unaryfunc");
|
||||
const __INT__: SlotDef = SlotDef::new("Py_nb_int", "unaryfunc");
|
||||
const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc");
|
||||
const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int);
|
||||
|
||||
const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc")
|
||||
.arguments(&[Ty::ObjectOrNotImplemented])
|
||||
.return_self();
|
||||
|
||||
fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
|
||||
match method_name {
|
||||
"__getattr__" => Some(&__GETATTR__),
|
||||
|
@ -424,7 +472,6 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
|
|||
"__repr__" => Some(&__REPR__),
|
||||
"__hash__" => Some(&__HASH__),
|
||||
"__richcmp__" => Some(&__RICHCMP__),
|
||||
"__bool__" => Some(&__BOOL__),
|
||||
"__get__" => Some(&__GET__),
|
||||
"__iter__" => Some(&__ITER__),
|
||||
"__next__" => Some(&__NEXT__),
|
||||
|
@ -434,6 +481,27 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
|
|||
"__len__" => Some(&__LEN__),
|
||||
"__contains__" => Some(&__CONTAINS__),
|
||||
"__getitem__" => Some(&__GETITEM__),
|
||||
"__pos__" => Some(&__POS__),
|
||||
"__neg__" => Some(&__NEG__),
|
||||
"__abs__" => Some(&__ABS__),
|
||||
"__invert__" => Some(&__INVERT__),
|
||||
"__index__" => Some(&__INDEX__),
|
||||
"__int__" => Some(&__INT__),
|
||||
"__float__" => Some(&__FLOAT__),
|
||||
"__bool__" => Some(&__BOOL__),
|
||||
"__iadd__" => Some(&__IADD__),
|
||||
"__isub__" => Some(&__ISUB__),
|
||||
"__imul__" => Some(&__IMUL__),
|
||||
"__imatmul__" => Some(&__IMATMUL__),
|
||||
"__itruediv__" => Some(&__ITRUEDIV__),
|
||||
"__ifloordiv__" => Some(&__IFLOORDIV__),
|
||||
"__imod__" => Some(&__IMOD__),
|
||||
"__ipow__" => Some(&__IPOW__),
|
||||
"__ilshift__" => Some(&__ILSHIFT__),
|
||||
"__irshift__" => Some(&__IRSHIFT__),
|
||||
"__iand__" => Some(&__IAND__),
|
||||
"__ixor__" => Some(&__IXOR__),
|
||||
"__ior__" => Some(&__IOR__),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -441,6 +509,7 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
|
|||
#[derive(Clone, Copy)]
|
||||
enum Ty {
|
||||
Object,
|
||||
ObjectOrNotImplemented,
|
||||
NonNullObject,
|
||||
CompareOp,
|
||||
Int,
|
||||
|
@ -451,7 +520,7 @@ enum Ty {
|
|||
impl Ty {
|
||||
fn ffi_type(self) -> TokenStream {
|
||||
match self {
|
||||
Ty::Object => quote! { *mut ::pyo3::ffi::PyObject },
|
||||
Ty::Object | Ty::ObjectOrNotImplemented => quote! { *mut ::pyo3::ffi::PyObject },
|
||||
Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> },
|
||||
Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int },
|
||||
Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t },
|
||||
|
@ -474,6 +543,29 @@ impl Ty {
|
|||
#extract
|
||||
}
|
||||
}
|
||||
Ty::ObjectOrNotImplemented => {
|
||||
let extract = if let syn::Type::Reference(tref) = unwrap_ty_group(target) {
|
||||
let (tref, mut_) = preprocess_tref(tref, cls);
|
||||
quote! {
|
||||
let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = match #ident.extract() {
|
||||
Ok(#ident) => #ident,
|
||||
Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()),
|
||||
};
|
||||
let #ident = &#mut_ *#ident;
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let #ident = match #ident.extract() {
|
||||
Ok(#ident) => #ident,
|
||||
Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()),
|
||||
};
|
||||
}
|
||||
};
|
||||
quote! {
|
||||
let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident);
|
||||
#extract
|
||||
}
|
||||
}
|
||||
Ty::NonNullObject => {
|
||||
let extract = extract_from_any(cls, target, ident);
|
||||
quote! {
|
||||
|
@ -502,12 +594,13 @@ fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -
|
|||
let #ident = #ident.extract()?;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Replace `Self`, remove lifetime and get mutability from the type
|
||||
fn preprocess_tref(
|
||||
/// Replace `Self`, remove lifetime and get mutability from the type
|
||||
fn preprocess_tref(
|
||||
tref: &syn::TypeReference,
|
||||
self_: &syn::Type,
|
||||
) -> (syn::TypeReference, Option<syn::token::Mut>) {
|
||||
) -> (syn::TypeReference, Option<syn::token::Mut>) {
|
||||
let mut tref = tref.to_owned();
|
||||
if let syn::Type::Path(tpath) = self_ {
|
||||
replace_self(&mut tref, &tpath.path);
|
||||
|
@ -515,10 +608,10 @@ fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -
|
|||
tref.lifetime = None;
|
||||
let mut_ = tref.mutability;
|
||||
(tref, mut_)
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace `Self` with the exact type name since it is used out of the impl block
|
||||
fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) {
|
||||
/// Replace `Self` with the exact type name since it is used out of the impl block
|
||||
fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) {
|
||||
match &mut *tref.elem {
|
||||
syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path),
|
||||
syn::Type::Path(tpath) => {
|
||||
|
@ -530,6 +623,27 @@ fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -
|
|||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
enum ReturnMode {
|
||||
ReturnSelf,
|
||||
Conversion(TokenGenerator),
|
||||
}
|
||||
|
||||
impl ReturnMode {
|
||||
fn return_call_output(&self, py: &syn::Ident, call: TokenStream) -> TokenStream {
|
||||
match self {
|
||||
ReturnMode::Conversion(conversion) => quote! {
|
||||
let _result: PyResult<#conversion> = #call;
|
||||
::pyo3::callback::convert(#py, _result)
|
||||
},
|
||||
ReturnMode::ReturnSelf => quote! {
|
||||
let _result: PyResult<()> = #call;
|
||||
_result?;
|
||||
::pyo3::ffi::Py_XINCREF(_raw_slf);
|
||||
Ok(_raw_slf)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,7 +653,7 @@ struct SlotDef {
|
|||
arguments: &'static [Ty],
|
||||
ret_ty: Ty,
|
||||
before_call_method: Option<TokenGenerator>,
|
||||
return_conversion: Option<TokenGenerator>,
|
||||
return_mode: Option<ReturnMode>,
|
||||
}
|
||||
|
||||
impl SlotDef {
|
||||
|
@ -550,7 +664,7 @@ impl SlotDef {
|
|||
arguments: &[],
|
||||
ret_ty: Ty::Object,
|
||||
before_call_method: None,
|
||||
return_conversion: None,
|
||||
return_mode: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -570,25 +684,31 @@ impl SlotDef {
|
|||
}
|
||||
|
||||
const fn return_conversion(mut self, return_conversion: TokenGenerator) -> Self {
|
||||
self.return_conversion = Some(return_conversion);
|
||||
self.return_mode = Some(ReturnMode::Conversion(return_conversion));
|
||||
self
|
||||
}
|
||||
|
||||
fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> TokenStream {
|
||||
const fn return_self(mut self) -> Self {
|
||||
self.return_mode = Some(ReturnMode::ReturnSelf);
|
||||
self
|
||||
}
|
||||
|
||||
fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result<TokenStream> {
|
||||
let SlotDef {
|
||||
slot,
|
||||
func_ty,
|
||||
before_call_method,
|
||||
arguments,
|
||||
ret_ty,
|
||||
return_conversion,
|
||||
return_mode,
|
||||
} = self;
|
||||
let py = syn::Ident::new("_py", Span::call_site());
|
||||
let method_arguments = generate_method_arguments(arguments);
|
||||
let ret_ty = ret_ty.ffi_type();
|
||||
let body = generate_method_body(cls, spec, &py, arguments, return_conversion.as_ref());
|
||||
quote!({
|
||||
unsafe extern "C" fn __wrap(_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty {
|
||||
let body = generate_method_body(cls, spec, &py, arguments, return_mode.as_ref())?;
|
||||
Ok(quote!({
|
||||
unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty {
|
||||
let _slf = _raw_slf;
|
||||
#before_call_method
|
||||
::pyo3::callback::handle_panic(|#py| {
|
||||
#body
|
||||
|
@ -598,7 +718,7 @@ impl SlotDef {
|
|||
slot: ::pyo3::ffi::#slot,
|
||||
pfunc: __wrap as ::pyo3::ffi::#func_ty as _
|
||||
}
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -617,25 +737,22 @@ fn generate_method_body(
|
|||
spec: &FnSpec,
|
||||
py: &syn::Ident,
|
||||
arguments: &[Ty],
|
||||
return_conversion: Option<&TokenGenerator>,
|
||||
) -> TokenStream {
|
||||
return_mode: Option<&ReturnMode>,
|
||||
) -> Result<TokenStream> {
|
||||
let self_conversion = spec.tp.self_conversion(Some(cls));
|
||||
let rust_name = spec.name;
|
||||
let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments);
|
||||
let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?;
|
||||
let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) };
|
||||
let body = if let Some(return_conversion) = return_conversion {
|
||||
quote! {
|
||||
let _result: PyResult<#return_conversion> = #call;
|
||||
::pyo3::callback::convert(#py, _result)
|
||||
}
|
||||
let body = if let Some(return_mode) = return_mode {
|
||||
return_mode.return_call_output(py, call)
|
||||
} else {
|
||||
call
|
||||
};
|
||||
quote! {
|
||||
Ok(quote! {
|
||||
#self_conversion
|
||||
#conversions
|
||||
#body
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_pyproto_fragment(
|
||||
|
@ -643,14 +760,14 @@ fn generate_pyproto_fragment(
|
|||
spec: &FnSpec,
|
||||
fragment: &str,
|
||||
arguments: &[Ty],
|
||||
) -> TokenStream {
|
||||
) -> Result<TokenStream> {
|
||||
let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment);
|
||||
let implemented = format_ident!("{}implemented", fragment);
|
||||
let method = syn::Ident::new(fragment, Span::call_site());
|
||||
let py = syn::Ident::new("_py", Span::call_site());
|
||||
let method_arguments = generate_method_arguments(arguments);
|
||||
let body = generate_method_body(cls, spec, &py, arguments, None);
|
||||
quote! {
|
||||
let body = generate_method_body(cls, spec, &py, arguments, None)?;
|
||||
Ok(quote! {
|
||||
impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
|
||||
#[inline]
|
||||
fn #implemented(self) -> bool { true }
|
||||
|
@ -658,18 +775,19 @@ fn generate_pyproto_fragment(
|
|||
#[inline]
|
||||
unsafe fn #method(
|
||||
self,
|
||||
_slf: *mut ::pyo3::ffi::PyObject,
|
||||
_raw_slf: *mut ::pyo3::ffi::PyObject,
|
||||
#(#method_arguments),*
|
||||
) -> ::pyo3::PyResult<()> {
|
||||
let _slf = _raw_slf;
|
||||
let #py = ::pyo3::Python::assume_gil_acquired();
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result<Option<TokenStream>> {
|
||||
Ok(match spec.python_name.to_string().as_str() {
|
||||
match spec.python_name.to_string().as_str() {
|
||||
"__setattr__" => Some(generate_pyproto_fragment(
|
||||
cls,
|
||||
spec,
|
||||
|
@ -707,7 +825,8 @@ fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result<Option<TokenStream
|
|||
&[Ty::Object],
|
||||
)),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn extract_proto_arguments(
|
||||
|
@ -715,24 +834,28 @@ fn extract_proto_arguments(
|
|||
py: &syn::Ident,
|
||||
method_args: &[FnArg],
|
||||
proto_args: &[Ty],
|
||||
) -> (Vec<Ident>, TokenStream) {
|
||||
) -> Result<(Vec<Ident>, TokenStream)> {
|
||||
let mut arg_idents = Vec::with_capacity(method_args.len());
|
||||
let mut non_python_args = 0;
|
||||
|
||||
let args_conversion = method_args.iter().filter_map(|arg| {
|
||||
let mut args_conversions = Vec::with_capacity(proto_args.len());
|
||||
|
||||
for arg in method_args {
|
||||
if arg.py {
|
||||
arg_idents.push(py.clone());
|
||||
None
|
||||
} else {
|
||||
let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site());
|
||||
let conversions = proto_args[non_python_args].extract(cls, py, &ident, arg.ty);
|
||||
let conversions = proto_args.get(non_python_args)
|
||||
.ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))?
|
||||
.extract(cls, py, &ident, arg.ty);
|
||||
non_python_args += 1;
|
||||
args_conversions.push(conversions);
|
||||
arg_idents.push(ident);
|
||||
Some(conversions)
|
||||
}
|
||||
});
|
||||
let conversions = quote!(#(#args_conversion)*);
|
||||
(arg_idents, conversions)
|
||||
}
|
||||
|
||||
let conversions = quote!(#(#args_conversions)*);
|
||||
Ok((arg_idents, conversions))
|
||||
}
|
||||
|
||||
struct StaticIdent(&'static str);
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
#![allow(deprecated)] // for deprecated protocol methods
|
||||
|
||||
use pyo3::class::basic::CompareOp;
|
||||
use pyo3::class::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::py_run;
|
||||
|
||||
|
@ -16,17 +13,11 @@ impl UnaryArithmetic {
|
|||
fn new(value: f64) -> Self {
|
||||
UnaryArithmetic { inner: value }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for UnaryArithmetic {
|
||||
fn __repr__(&self) -> String {
|
||||
format!("UA({})", self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for UnaryArithmetic {
|
||||
fn __neg__(&self) -> Self {
|
||||
Self::new(-self.inner)
|
||||
}
|
||||
|
@ -57,30 +48,17 @@ fn unary_arithmetic() {
|
|||
py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'");
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct BinaryArithmetic {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for BinaryArithmetic {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"BA"
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct InPlaceOperations {
|
||||
value: u32,
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for InPlaceOperations {
|
||||
#[pymethods]
|
||||
impl InPlaceOperations {
|
||||
fn __repr__(&self) -> String {
|
||||
format!("IPO({:?})", self.value)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for InPlaceOperations {
|
||||
fn __iadd__(&mut self, other: u32) {
|
||||
self.value += other;
|
||||
}
|
||||
|
@ -142,42 +120,49 @@ fn inplace_operations() {
|
|||
);
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for BinaryArithmetic {
|
||||
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} + {:?}", lhs, rhs)
|
||||
#[pyclass]
|
||||
struct BinaryArithmetic {}
|
||||
|
||||
#[pymethods]
|
||||
impl BinaryArithmetic {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"BA"
|
||||
}
|
||||
|
||||
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} - {:?}", lhs, rhs)
|
||||
fn __add__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA + {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} * {:?}", lhs, rhs)
|
||||
fn __sub__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA - {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} << {:?}", lhs, rhs)
|
||||
fn __mul__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA * {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} >> {:?}", lhs, rhs)
|
||||
fn __lshift__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA << {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} & {:?}", lhs, rhs)
|
||||
fn __rshift__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA >> {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} ^ {:?}", lhs, rhs)
|
||||
fn __and__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA & {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} | {:?}", lhs, rhs)
|
||||
fn __xor__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA ^ {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> String {
|
||||
format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_)
|
||||
fn __or__(&self, rhs: &PyAny) -> String {
|
||||
format!("BA | {:?}", rhs)
|
||||
}
|
||||
|
||||
fn __pow__(&self, rhs: &PyAny, mod_: Option<u32>) -> String {
|
||||
format!("BA ** {:?} (mod: {:?})", rhs, mod_)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,8 +200,8 @@ fn binary_arithmetic() {
|
|||
#[pyclass]
|
||||
struct RhsArithmetic {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for RhsArithmetic {
|
||||
#[pymethods]
|
||||
impl RhsArithmetic {
|
||||
fn __radd__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} + RA", other)
|
||||
}
|
||||
|
@ -249,7 +234,7 @@ impl PyNumberProtocol for RhsArithmetic {
|
|||
format!("{:?} | RA", other)
|
||||
}
|
||||
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String {
|
||||
format!("{:?} ** RA", other)
|
||||
}
|
||||
}
|
||||
|
@ -289,8 +274,12 @@ impl std::fmt::Debug for LhsAndRhs {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for LhsAndRhs {
|
||||
#[pymethods]
|
||||
impl LhsAndRhs {
|
||||
// fn __repr__(&self) -> &'static str {
|
||||
// "BA"
|
||||
// }
|
||||
|
||||
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} + {:?}", lhs, rhs)
|
||||
}
|
||||
|
@ -363,7 +352,7 @@ impl PyNumberProtocol for LhsAndRhs {
|
|||
format!("{:?} | RA", other)
|
||||
}
|
||||
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&PyAny>) -> String {
|
||||
format!("{:?} ** RA", other)
|
||||
}
|
||||
|
||||
|
@ -372,13 +361,6 @@ impl PyNumberProtocol for LhsAndRhs {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for LhsAndRhs {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"BA"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lhs_fellback_to_rhs() {
|
||||
let gil = Python::acquire_gil();
|
||||
|
@ -412,8 +394,8 @@ fn lhs_fellback_to_rhs() {
|
|||
#[pyclass]
|
||||
struct RichComparisons {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for RichComparisons {
|
||||
#[pymethods]
|
||||
impl RichComparisons {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC"
|
||||
}
|
||||
|
@ -433,8 +415,8 @@ impl PyObjectProtocol for RichComparisons {
|
|||
#[pyclass]
|
||||
struct RichComparisons2 {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for RichComparisons2 {
|
||||
#[pymethods]
|
||||
impl RichComparisons2 {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC2"
|
||||
}
|
||||
|
@ -508,76 +490,73 @@ mod return_not_implemented {
|
|||
#[pyclass]
|
||||
struct RichComparisonToSelf {}
|
||||
|
||||
#[pyproto]
|
||||
impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf {
|
||||
#[pymethods]
|
||||
impl RichComparisonToSelf {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC_Self"
|
||||
}
|
||||
|
||||
fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject {
|
||||
fn __richcmp__(&self, other: PyRef<Self>, _op: CompareOp) -> PyObject {
|
||||
other.py().None()
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf {
|
||||
fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __add__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __sub__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __mul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __matmul__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __truediv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __floordiv__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __mod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option<u8>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __pow__<'p>(slf: PyRef<'p, Self>, _other: u8, _modulo: Option<u8>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __lshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __rshift__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __divmod__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __and__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __or__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
fn __xor__<'p>(slf: PyRef<'p, Self>, _other: PyRef<'p, Self>) -> PyRef<'p, Self> {
|
||||
slf
|
||||
}
|
||||
|
||||
// Inplace assignments
|
||||
fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __iadd__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __isub__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __imul__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __imatmul__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __itruediv__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __ifloordiv__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __imod__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __ipow__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __ilshift__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __irshift__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __iand__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __ior__(&mut self, _other: PyRef<Self>) {}
|
||||
fn __ixor__(&mut self, _other: PyRef<Self>) {}
|
||||
}
|
||||
|
||||
fn _test_binary_dunder(dunder: &str) {
|
||||
|
|
|
@ -0,0 +1,683 @@
|
|||
#![allow(deprecated)] // for deprecated protocol methods
|
||||
|
||||
use pyo3::class::basic::CompareOp;
|
||||
use pyo3::class::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::py_run;
|
||||
|
||||
mod common;
|
||||
|
||||
#[pyclass]
|
||||
struct UnaryArithmetic {
|
||||
inner: f64,
|
||||
}
|
||||
|
||||
impl UnaryArithmetic {
|
||||
fn new(value: f64) -> Self {
|
||||
UnaryArithmetic { inner: value }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for UnaryArithmetic {
|
||||
fn __repr__(&self) -> String {
|
||||
format!("UA({})", self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for UnaryArithmetic {
|
||||
fn __neg__(&self) -> Self {
|
||||
Self::new(-self.inner)
|
||||
}
|
||||
|
||||
fn __pos__(&self) -> Self {
|
||||
Self::new(self.inner)
|
||||
}
|
||||
|
||||
fn __abs__(&self) -> Self {
|
||||
Self::new(self.inner.abs())
|
||||
}
|
||||
|
||||
fn __round__(&self, _ndigits: Option<u32>) -> Self {
|
||||
Self::new(self.inner.round())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_arithmetic() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c = PyCell::new(py, UnaryArithmetic::new(2.7)).unwrap();
|
||||
py_run!(py, c, "assert repr(-c) == 'UA(-2.7)'");
|
||||
py_run!(py, c, "assert repr(+c) == 'UA(2.7)'");
|
||||
py_run!(py, c, "assert repr(abs(c)) == 'UA(2.7)'");
|
||||
py_run!(py, c, "assert repr(round(c)) == 'UA(3)'");
|
||||
py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'");
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct BinaryArithmetic {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for BinaryArithmetic {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"BA"
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct InPlaceOperations {
|
||||
value: u32,
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for InPlaceOperations {
|
||||
fn __repr__(&self) -> String {
|
||||
format!("IPO({:?})", self.value)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for InPlaceOperations {
|
||||
fn __iadd__(&mut self, other: u32) {
|
||||
self.value += other;
|
||||
}
|
||||
|
||||
fn __isub__(&mut self, other: u32) {
|
||||
self.value -= other;
|
||||
}
|
||||
|
||||
fn __imul__(&mut self, other: u32) {
|
||||
self.value *= other;
|
||||
}
|
||||
|
||||
fn __ilshift__(&mut self, other: u32) {
|
||||
self.value <<= other;
|
||||
}
|
||||
|
||||
fn __irshift__(&mut self, other: u32) {
|
||||
self.value >>= other;
|
||||
}
|
||||
|
||||
fn __iand__(&mut self, other: u32) {
|
||||
self.value &= other;
|
||||
}
|
||||
|
||||
fn __ixor__(&mut self, other: u32) {
|
||||
self.value ^= other;
|
||||
}
|
||||
|
||||
fn __ior__(&mut self, other: u32) {
|
||||
self.value |= other;
|
||||
}
|
||||
|
||||
fn __ipow__(&mut self, other: u32) {
|
||||
self.value = self.value.pow(other);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inplace_operations() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let init = |value, code| {
|
||||
let c = PyCell::new(py, InPlaceOperations { value }).unwrap();
|
||||
py_run!(py, c, code);
|
||||
};
|
||||
|
||||
init(0, "d = c; c += 1; assert repr(c) == repr(d) == 'IPO(1)'");
|
||||
init(10, "d = c; c -= 1; assert repr(c) == repr(d) == 'IPO(9)'");
|
||||
init(3, "d = c; c *= 3; assert repr(c) == repr(d) == 'IPO(9)'");
|
||||
init(3, "d = c; c <<= 2; assert repr(c) == repr(d) == 'IPO(12)'");
|
||||
init(12, "d = c; c >>= 2; assert repr(c) == repr(d) == 'IPO(3)'");
|
||||
init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'");
|
||||
init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'");
|
||||
init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'");
|
||||
init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'");
|
||||
init(
|
||||
3,
|
||||
"d = c; c.__ipow__(4); assert repr(c) == repr(d) == 'IPO(81)'",
|
||||
);
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for BinaryArithmetic {
|
||||
fn __add__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} + {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} - {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} * {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __lshift__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} << {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __rshift__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} >> {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __and__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} & {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __xor__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} ^ {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __or__(lhs: &PyAny, rhs: &PyAny) -> String {
|
||||
format!("{:?} | {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> String {
|
||||
format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_arithmetic() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c = PyCell::new(py, BinaryArithmetic {}).unwrap();
|
||||
py_run!(py, c, "assert c + c == 'BA + BA'");
|
||||
py_run!(py, c, "assert c.__add__(c) == 'BA + BA'");
|
||||
py_run!(py, c, "assert c + 1 == 'BA + 1'");
|
||||
py_run!(py, c, "assert 1 + c == '1 + BA'");
|
||||
py_run!(py, c, "assert c - 1 == 'BA - 1'");
|
||||
py_run!(py, c, "assert 1 - c == '1 - BA'");
|
||||
py_run!(py, c, "assert c * 1 == 'BA * 1'");
|
||||
py_run!(py, c, "assert 1 * c == '1 * BA'");
|
||||
|
||||
py_run!(py, c, "assert c << 1 == 'BA << 1'");
|
||||
py_run!(py, c, "assert 1 << c == '1 << BA'");
|
||||
py_run!(py, c, "assert c >> 1 == 'BA >> 1'");
|
||||
py_run!(py, c, "assert 1 >> c == '1 >> BA'");
|
||||
py_run!(py, c, "assert c & 1 == 'BA & 1'");
|
||||
py_run!(py, c, "assert 1 & c == '1 & BA'");
|
||||
py_run!(py, c, "assert c ^ 1 == 'BA ^ 1'");
|
||||
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
|
||||
py_run!(py, c, "assert c | 1 == 'BA | 1'");
|
||||
py_run!(py, c, "assert 1 | c == '1 | BA'");
|
||||
py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'");
|
||||
py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'");
|
||||
|
||||
py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct RhsArithmetic {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for RhsArithmetic {
|
||||
fn __radd__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} + RA", other)
|
||||
}
|
||||
|
||||
fn __rsub__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} - RA", other)
|
||||
}
|
||||
|
||||
fn __rmul__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} * RA", other)
|
||||
}
|
||||
|
||||
fn __rlshift__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} << RA", other)
|
||||
}
|
||||
|
||||
fn __rrshift__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} >> RA", other)
|
||||
}
|
||||
|
||||
fn __rand__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} & RA", other)
|
||||
}
|
||||
|
||||
fn __rxor__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} ^ RA", other)
|
||||
}
|
||||
|
||||
fn __ror__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} | RA", other)
|
||||
}
|
||||
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
|
||||
format!("{:?} ** RA", other)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rhs_arithmetic() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c = PyCell::new(py, RhsArithmetic {}).unwrap();
|
||||
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
|
||||
py_run!(py, c, "assert 1 + c == '1 + RA'");
|
||||
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
|
||||
py_run!(py, c, "assert 1 - c == '1 - RA'");
|
||||
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
|
||||
py_run!(py, c, "assert 1 * c == '1 * RA'");
|
||||
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
|
||||
py_run!(py, c, "assert 1 << c == '1 << RA'");
|
||||
py_run!(py, c, "assert c.__rrshift__(1) == '1 >> RA'");
|
||||
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
|
||||
py_run!(py, c, "assert c.__rand__(1) == '1 & RA'");
|
||||
py_run!(py, c, "assert 1 & c == '1 & RA'");
|
||||
py_run!(py, c, "assert c.__rxor__(1) == '1 ^ RA'");
|
||||
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
|
||||
py_run!(py, c, "assert c.__ror__(1) == '1 | RA'");
|
||||
py_run!(py, c, "assert 1 | c == '1 | RA'");
|
||||
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
|
||||
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct LhsAndRhs {}
|
||||
|
||||
impl std::fmt::Debug for LhsAndRhs {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "LR")
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyNumberProtocol for LhsAndRhs {
|
||||
fn __add__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} + {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __sub__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} - {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __mul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} * {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __lshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} << {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __rshift__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} >> {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __and__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} & {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __xor__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} ^ {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __or__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} | {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __pow__(lhs: PyRef<Self>, rhs: &PyAny, _mod: Option<usize>) -> String {
|
||||
format!("{:?} ** {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __matmul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
|
||||
format!("{:?} @ {:?}", lhs, rhs)
|
||||
}
|
||||
|
||||
fn __radd__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} + RA", other)
|
||||
}
|
||||
|
||||
fn __rsub__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} - RA", other)
|
||||
}
|
||||
|
||||
fn __rmul__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} * RA", other)
|
||||
}
|
||||
|
||||
fn __rlshift__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} << RA", other)
|
||||
}
|
||||
|
||||
fn __rrshift__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} >> RA", other)
|
||||
}
|
||||
|
||||
fn __rand__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} & RA", other)
|
||||
}
|
||||
|
||||
fn __rxor__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} ^ RA", other)
|
||||
}
|
||||
|
||||
fn __ror__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} | RA", other)
|
||||
}
|
||||
|
||||
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> String {
|
||||
format!("{:?} ** RA", other)
|
||||
}
|
||||
|
||||
fn __rmatmul__(&self, other: &PyAny) -> String {
|
||||
format!("{:?} @ RA", other)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for LhsAndRhs {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"BA"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lhs_fellback_to_rhs() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c = PyCell::new(py, LhsAndRhs {}).unwrap();
|
||||
// If the light hand value is `LhsAndRhs`, LHS is used.
|
||||
py_run!(py, c, "assert c + 1 == 'LR + 1'");
|
||||
py_run!(py, c, "assert c - 1 == 'LR - 1'");
|
||||
py_run!(py, c, "assert c * 1 == 'LR * 1'");
|
||||
py_run!(py, c, "assert c << 1 == 'LR << 1'");
|
||||
py_run!(py, c, "assert c >> 1 == 'LR >> 1'");
|
||||
py_run!(py, c, "assert c & 1 == 'LR & 1'");
|
||||
py_run!(py, c, "assert c ^ 1 == 'LR ^ 1'");
|
||||
py_run!(py, c, "assert c | 1 == 'LR | 1'");
|
||||
py_run!(py, c, "assert c ** 1 == 'LR ** 1'");
|
||||
py_run!(py, c, "assert c @ 1 == 'LR @ 1'");
|
||||
// Fellback to RHS because of type mismatching
|
||||
py_run!(py, c, "assert 1 + c == '1 + RA'");
|
||||
py_run!(py, c, "assert 1 - c == '1 - RA'");
|
||||
py_run!(py, c, "assert 1 * c == '1 * RA'");
|
||||
py_run!(py, c, "assert 1 << c == '1 << RA'");
|
||||
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
|
||||
py_run!(py, c, "assert 1 & c == '1 & RA'");
|
||||
py_run!(py, c, "assert 1 ^ c == '1 ^ RA'");
|
||||
py_run!(py, c, "assert 1 | c == '1 | RA'");
|
||||
py_run!(py, c, "assert 1 ** c == '1 ** RA'");
|
||||
py_run!(py, c, "assert 1 @ c == '1 @ RA'");
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct RichComparisons {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for RichComparisons {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC"
|
||||
}
|
||||
|
||||
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> String {
|
||||
match op {
|
||||
CompareOp::Lt => format!("{} < {:?}", self.__repr__(), other),
|
||||
CompareOp::Le => format!("{} <= {:?}", self.__repr__(), other),
|
||||
CompareOp::Eq => format!("{} == {:?}", self.__repr__(), other),
|
||||
CompareOp::Ne => format!("{} != {:?}", self.__repr__(), other),
|
||||
CompareOp::Gt => format!("{} > {:?}", self.__repr__(), other),
|
||||
CompareOp::Ge => format!("{} >= {:?}", self.__repr__(), other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct RichComparisons2 {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for RichComparisons2 {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC2"
|
||||
}
|
||||
|
||||
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyObject {
|
||||
match op {
|
||||
CompareOp::Eq => true.into_py(other.py()),
|
||||
CompareOp::Ne => false.into_py(other.py()),
|
||||
_ => other.py().NotImplemented(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rich_comparisons() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c = PyCell::new(py, RichComparisons {}).unwrap();
|
||||
py_run!(py, c, "assert (c < c) == 'RC < RC'");
|
||||
py_run!(py, c, "assert (c < 1) == 'RC < 1'");
|
||||
py_run!(py, c, "assert (1 < c) == 'RC > 1'");
|
||||
py_run!(py, c, "assert (c <= c) == 'RC <= RC'");
|
||||
py_run!(py, c, "assert (c <= 1) == 'RC <= 1'");
|
||||
py_run!(py, c, "assert (1 <= c) == 'RC >= 1'");
|
||||
py_run!(py, c, "assert (c == c) == 'RC == RC'");
|
||||
py_run!(py, c, "assert (c == 1) == 'RC == 1'");
|
||||
py_run!(py, c, "assert (1 == c) == 'RC == 1'");
|
||||
py_run!(py, c, "assert (c != c) == 'RC != RC'");
|
||||
py_run!(py, c, "assert (c != 1) == 'RC != 1'");
|
||||
py_run!(py, c, "assert (1 != c) == 'RC != 1'");
|
||||
py_run!(py, c, "assert (c > c) == 'RC > RC'");
|
||||
py_run!(py, c, "assert (c > 1) == 'RC > 1'");
|
||||
py_run!(py, c, "assert (1 > c) == 'RC < 1'");
|
||||
py_run!(py, c, "assert (c >= c) == 'RC >= RC'");
|
||||
py_run!(py, c, "assert (c >= 1) == 'RC >= 1'");
|
||||
py_run!(py, c, "assert (1 >= c) == 'RC <= 1'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rich_comparisons_python_3_type_error() {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let c2 = PyCell::new(py, RichComparisons2 {}).unwrap();
|
||||
py_expect_exception!(py, c2, "c2 < c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 < 1", PyTypeError);
|
||||
py_expect_exception!(py, c2, "1 < c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 <= c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 <= 1", PyTypeError);
|
||||
py_expect_exception!(py, c2, "1 <= c2", PyTypeError);
|
||||
py_run!(py, c2, "assert (c2 == c2) == True");
|
||||
py_run!(py, c2, "assert (c2 == 1) == True");
|
||||
py_run!(py, c2, "assert (1 == c2) == True");
|
||||
py_run!(py, c2, "assert (c2 != c2) == False");
|
||||
py_run!(py, c2, "assert (c2 != 1) == False");
|
||||
py_run!(py, c2, "assert (1 != c2) == False");
|
||||
py_expect_exception!(py, c2, "c2 > c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 > 1", PyTypeError);
|
||||
py_expect_exception!(py, c2, "1 > c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 >= c2", PyTypeError);
|
||||
py_expect_exception!(py, c2, "c2 >= 1", PyTypeError);
|
||||
py_expect_exception!(py, c2, "1 >= c2", PyTypeError);
|
||||
}
|
||||
|
||||
// Checks that binary operations for which the arguments don't match the
|
||||
// required type, return NotImplemented.
|
||||
mod return_not_implemented {
|
||||
use super::*;
|
||||
|
||||
#[pyclass]
|
||||
struct RichComparisonToSelf {}
|
||||
|
||||
#[pyproto]
|
||||
impl<'p> PyObjectProtocol<'p> for RichComparisonToSelf {
|
||||
fn __repr__(&self) -> &'static str {
|
||||
"RC_Self"
|
||||
}
|
||||
|
||||
fn __richcmp__(&self, other: PyRef<'p, Self>, _op: CompareOp) -> PyObject {
|
||||
other.py().None()
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl<'p> PyNumberProtocol<'p> for RichComparisonToSelf {
|
||||
fn __add__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __sub__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __mul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __matmul__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __truediv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __floordiv__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __mod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __pow__(lhs: &'p PyAny, _other: u8, _modulo: Option<u8>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __lshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __rshift__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __divmod__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __and__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __or__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
fn __xor__(lhs: &'p PyAny, _other: PyRef<'p, Self>) -> &'p PyAny {
|
||||
lhs
|
||||
}
|
||||
|
||||
// Inplace assignments
|
||||
fn __iadd__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __isub__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imul__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imatmul__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __itruediv__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ifloordiv__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __imod__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ipow__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ilshift__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __irshift__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __iand__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ior__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
fn __ixor__(&'p mut self, _other: PyRef<'p, Self>) {}
|
||||
}
|
||||
|
||||
fn _test_binary_dunder(dunder: &str) {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
|
||||
py_run!(
|
||||
py,
|
||||
c2,
|
||||
&format!(
|
||||
"class Other: pass\nassert c2.__{}__(Other()) is NotImplemented",
|
||||
dunder
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
fn _test_binary_operator(operator: &str, dunder: &str) {
|
||||
_test_binary_dunder(dunder);
|
||||
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let c2 = PyCell::new(py, RichComparisonToSelf {}).unwrap();
|
||||
py_expect_exception!(
|
||||
py,
|
||||
c2,
|
||||
&format!("class Other: pass\nc2 {} Other()", operator),
|
||||
PyTypeError
|
||||
);
|
||||
}
|
||||
|
||||
fn _test_inplace_binary_operator(operator: &str, dunder: &str) {
|
||||
_test_binary_operator(operator, dunder);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equality() {
|
||||
_test_binary_dunder("eq");
|
||||
_test_binary_dunder("ne");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ordering() {
|
||||
_test_binary_operator("<", "lt");
|
||||
_test_binary_operator("<=", "le");
|
||||
_test_binary_operator(">", "gt");
|
||||
_test_binary_operator(">=", "ge");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bitwise() {
|
||||
_test_binary_operator("&", "and");
|
||||
_test_binary_operator("|", "or");
|
||||
_test_binary_operator("^", "xor");
|
||||
_test_binary_operator("<<", "lshift");
|
||||
_test_binary_operator(">>", "rshift");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arith() {
|
||||
_test_binary_operator("+", "add");
|
||||
_test_binary_operator("-", "sub");
|
||||
_test_binary_operator("*", "mul");
|
||||
_test_binary_operator("@", "matmul");
|
||||
_test_binary_operator("/", "truediv");
|
||||
_test_binary_operator("//", "floordiv");
|
||||
_test_binary_operator("%", "mod");
|
||||
_test_binary_operator("**", "pow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn reverse_arith() {
|
||||
_test_binary_dunder("radd");
|
||||
_test_binary_dunder("rsub");
|
||||
_test_binary_dunder("rmul");
|
||||
_test_binary_dunder("rmatmul");
|
||||
_test_binary_dunder("rtruediv");
|
||||
_test_binary_dunder("rfloordiv");
|
||||
_test_binary_dunder("rmod");
|
||||
_test_binary_dunder("rpow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inplace_bitwise() {
|
||||
_test_inplace_binary_operator("&=", "iand");
|
||||
_test_inplace_binary_operator("|=", "ior");
|
||||
_test_inplace_binary_operator("^=", "ixor");
|
||||
_test_inplace_binary_operator("<<=", "ilshift");
|
||||
_test_inplace_binary_operator(">>=", "irshift");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inplace_arith() {
|
||||
_test_inplace_binary_operator("+=", "iadd");
|
||||
_test_inplace_binary_operator("-=", "isub");
|
||||
_test_inplace_binary_operator("*=", "imul");
|
||||
_test_inplace_binary_operator("@=", "imatmul");
|
||||
_test_inplace_binary_operator("/=", "itruediv");
|
||||
_test_inplace_binary_operator("//=", "ifloordiv");
|
||||
_test_inplace_binary_operator("%=", "imod");
|
||||
_test_inplace_binary_operator("**=", "ipow");
|
||||
}
|
||||
}
|
|
@ -550,4 +550,5 @@ assert c.counter.count == 3
|
|||
|
||||
// TODO: test __delete__
|
||||
// TODO: test __anext__, __aiter__
|
||||
// TODO: test __index__, __int__, __float__, __invert__
|
||||
// TODO: better argument casting errors
|
||||
|
|
Loading…
Reference in New Issue