pymethods: support most numerical methods

This commit is contained in:
David Hewitt 2021-09-18 00:31:17 +01:00
parent 75c0116f6a
commit 43eb762346
5 changed files with 379 additions and 91 deletions

View file

@ -93,10 +93,10 @@ pub enum FnType {
}
impl FnType {
pub fn self_conversion(&self, cls: Option<&syn::Type>) -> TokenStream {
pub fn self_conversion(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream {
match self {
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => {
st.receiver(cls.expect("no class given for Fn with a \"self\" receiver"))
st.receiver(cls.expect("no class given for Fn with a \"self\" receiver"), error_mode)
}
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
@ -128,26 +128,44 @@ pub enum SelfType {
TryFromPyCell(Span),
}
pub enum ExtractErrorMode {
NotImplemented,
Raise,
}
impl SelfType {
pub fn receiver(&self, cls: &syn::Type) -> TokenStream {
pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream {
let cell = match error_mode {
ExtractErrorMode::Raise => {
quote! { _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>()? }
},
ExtractErrorMode::NotImplemented => {
quote! {
match _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>() {
::std::result::Result::Ok(cell) => cell,
::std::result::Result::Err(_) => return ::pyo3::callback::convert(_py, _py.NotImplemented()),
}
}
},
};
match self {
SelfType::Receiver { mutable: false } => {
quote! {
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
let _ref = _cell.try_borrow()?;
let _slf = &_ref;
}
}
SelfType::Receiver { mutable: true } => {
quote! {
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
let mut _ref = _cell.try_borrow_mut()?;
let _slf = &mut _ref;
}
}
SelfType::TryFromPyCell(span) => {
quote_spanned! { *span =>
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
#[allow(clippy::useless_conversion)] // In case _slf is PyCell<Self>
let _slf = std::convert::TryFrom::try_from(_cell)?;
}
@ -442,7 +460,7 @@ impl<'a> FnSpec<'a> {
cls: Option<&syn::Type>,
) -> Result<TokenStream> {
let deprecations = &self.deprecations;
let self_conversion = self.tp.self_conversion(cls);
let self_conversion = self.tp.self_conversion(cls, ExtractErrorMode::Raise);
let self_arg = self.tp.self_arg();
let arg_names = (0..self.args.len())
.map(|pos| syn::Ident::new(&format!("arg{}", pos), Span::call_site()))

View file

@ -602,6 +602,39 @@ fn impl_class(
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_setitem_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_add_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_sub_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mul_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_mod_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_divmod_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_lshift_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_rshift_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_and_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_or_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_xor_slot!(#cls) {
generated_slots.push(setdescr);
}
if let ::std::option::Option::Some(setdescr) = ::pyo3::generate_pyclass_matmul_slot!(#cls) {
generated_slots.push(setdescr);
}
visitor(&generated_slots);
}

View file

@ -3,6 +3,7 @@
use std::borrow::Cow;
use crate::attributes::NameAttribute;
use crate::method::ExtractErrorMode;
use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc};
use crate::{deprecations::Deprecations, utils};
use crate::{
@ -31,12 +32,15 @@ pub fn gen_py_method(
ensure_function_options_valid(&options)?;
let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?;
if let Some(slot_def) = pyproto(&spec.python_name.to_string()) {
let method_name = spec.python_name.to_string();
if let Some(slot_def) = pyproto(&method_name) {
let slot = slot_def.generate_type_slot(cls, &spec)?;
return Ok(GeneratedPyMethod::Proto(slot));
}
if let Some(proto) = pyproto_fragment(cls, &spec)? {
if let Some(slot_fragment_def) = pyproto_fragment(&method_name) {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, &spec)?;
return Ok(GeneratedPyMethod::TraitImpl(proto));
}
@ -212,8 +216,8 @@ pub fn impl_py_setter_def(cls: &syn::Type, property_type: PropertyType) -> Resul
};
let slf = match property_type {
PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls),
PropertyType::Function { self_type, .. } => self_type.receiver(cls),
PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise),
PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise),
};
Ok(quote! {
::pyo3::class::PyMethodDefType::Setter({
@ -288,8 +292,8 @@ pub fn impl_py_getter_def(cls: &syn::Type, property_type: PropertyType) -> Resul
};
let slf = match property_type {
PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls),
PropertyType::Function { self_type, .. } => self_type.receiver(cls),
PropertyType::Descriptor { .. } => SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise),
PropertyType::Function { self_type, .. } => self_type.receiver(cls, ExtractErrorMode::Raise),
};
Ok(quote! {
::pyo3::class::PyMethodDefType::Getter({
@ -515,6 +519,7 @@ enum Ty {
Int,
PyHashT,
PySsizeT,
Void,
}
impl Ty {
@ -525,6 +530,7 @@ impl Ty {
Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int },
Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t },
Ty::PySsizeT => quote! { ::pyo3::ffi::Py_ssize_t },
Ty::Void => quote! { () },
}
}
@ -577,7 +583,7 @@ impl Ty {
let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident)
.ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?;
},
Ty::Int | Ty::PyHashT | Ty::PySsizeT => todo!(),
Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(),
}
}
}
@ -705,7 +711,7 @@ impl SlotDef {
let py = syn::Ident::new("_py", Span::call_site());
let method_arguments = generate_method_arguments(arguments);
let ret_ty = ret_ty.ffi_type();
let body = generate_method_body(cls, spec, &py, arguments, return_mode.as_ref())?;
let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::Raise, return_mode.as_ref())?;
Ok(quote!({
unsafe extern "C" fn __wrap(_raw_slf: *mut ::pyo3::ffi::PyObject, #(#method_arguments),*) -> #ret_ty {
let _slf = _raw_slf;
@ -737,9 +743,10 @@ fn generate_method_body(
spec: &FnSpec,
py: &syn::Ident,
arguments: &[Ty],
extract_error_mode: ExtractErrorMode,
return_mode: Option<&ReturnMode>,
) -> Result<TokenStream> {
let self_conversion = spec.tp.self_conversion(Some(cls));
let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode);
let rust_name = spec.name;
let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?;
let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) };
@ -755,78 +762,89 @@ fn generate_method_body(
})
}
fn generate_pyproto_fragment(
cls: &syn::Type,
spec: &FnSpec,
fragment: &str,
arguments: &[Ty],
) -> Result<TokenStream> {
let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment);
let implemented = format_ident!("{}implemented", fragment);
let method = syn::Ident::new(fragment, Span::call_site());
let py = syn::Ident::new("_py", Span::call_site());
let method_arguments = generate_method_arguments(arguments);
let body = generate_method_body(cls, spec, &py, arguments, None)?;
Ok(quote! {
impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn #implemented(self) -> bool { true }
#[inline]
unsafe fn #method(
self,
_raw_slf: *mut ::pyo3::ffi::PyObject,
#(#method_arguments),*
) -> ::pyo3::PyResult<()> {
let _slf = _raw_slf;
let #py = ::pyo3::Python::assume_gil_acquired();
#body
}
}
})
struct SlotFragmentDef {
fragment: &'static str,
arguments: &'static [Ty],
ret_ty: Ty,
}
fn pyproto_fragment(cls: &syn::Type, spec: &FnSpec) -> Result<Option<TokenStream>> {
match spec.python_name.to_string().as_str() {
"__setattr__" => Some(generate_pyproto_fragment(
cls,
spec,
"__setattr__",
&[Ty::Object, Ty::NonNullObject],
)),
"__delattr__" => Some(generate_pyproto_fragment(
cls,
spec,
"__delattr__",
&[Ty::Object],
)),
"__set__" => Some(generate_pyproto_fragment(
cls,
spec,
"__set__",
&[Ty::Object, Ty::NonNullObject],
)),
"__delete__" => Some(generate_pyproto_fragment(
cls,
spec,
"__delete__",
&[Ty::Object],
)),
"__setitem__" => Some(generate_pyproto_fragment(
cls,
spec,
"__setitem__",
&[Ty::Object, Ty::NonNullObject],
)),
"__delitem__" => Some(generate_pyproto_fragment(
cls,
spec,
"__delitem__",
&[Ty::Object],
)),
impl SlotFragmentDef {
const fn new(fragment: &'static str, arguments: &'static [Ty]) -> Self {
SlotFragmentDef {
fragment,
arguments,
ret_ty: Ty::Void,
}
}
const fn ret_ty(mut self, ret_ty: Ty) -> Self {
self.ret_ty = ret_ty;
self
}
fn generate_pyproto_fragment(&self, cls: &syn::Type, spec: &FnSpec) -> Result<TokenStream> {
let SlotFragmentDef {
fragment,
arguments,
ret_ty,
} = self;
let fragment_trait = format_ident!("PyClass{}SlotFragment", fragment);
let implemented = format_ident!("{}implemented", fragment);
let method = syn::Ident::new(fragment, Span::call_site());
let py = syn::Ident::new("_py", Span::call_site());
let method_arguments = generate_method_arguments(arguments);
let body = generate_method_body(cls, spec, &py, arguments, ExtractErrorMode::NotImplemented, None)?;
let ret_ty = ret_ty.ffi_type();
Ok(quote! {
impl ::pyo3::class::impl_::#fragment_trait<#cls> for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
#[inline]
fn #implemented(self) -> bool { true }
#[inline]
unsafe fn #method(
self,
#py: ::pyo3::Python,
_raw_slf: *mut ::pyo3::ffi::PyObject,
#(#method_arguments),*
) -> ::pyo3::PyResult<#ret_ty> {
let _slf = _raw_slf;
#body
}
}
})
}
}
const __SETATTR__: SlotFragmentDef =
SlotFragmentDef::new("__setattr__", &[Ty::Object, Ty::NonNullObject]);
const __DELATTR__: SlotFragmentDef =
SlotFragmentDef::new("__delattr__", &[Ty::Object]);
const __SET__: SlotFragmentDef =
SlotFragmentDef::new("__set__", &[Ty::Object, Ty::NonNullObject]);
const __DELETE__: SlotFragmentDef =
SlotFragmentDef::new("__delete__", &[Ty::Object]);
const __SETITEM__: SlotFragmentDef =
SlotFragmentDef::new("__setitem__", &[Ty::Object, Ty::NonNullObject]);
const __DELITEM__: SlotFragmentDef =
SlotFragmentDef::new("__delitem__", &[Ty::Object]);
const __ADD__: SlotFragmentDef =
SlotFragmentDef::new("__add__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object);
const __RADD__: SlotFragmentDef =
SlotFragmentDef::new("__radd__", &[Ty::ObjectOrNotImplemented]).ret_ty(Ty::Object);
fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
match method_name {
"__setattr__" => Some(&__SETATTR__),
"__delattr__" => Some(&__DELATTR__),
"__set__" => Some(&__SET__),
"__delete__" => Some(&__DELETE__),
"__setitem__" => Some(&__SETITEM__),
"__delitem__" => Some(&__DELITEM__),
"__add__" => Some(&__ADD__),
"__radd__" => Some(&__RADD__),
_ => None,
}
.transpose()
}
fn extract_proto_arguments(

View file

@ -151,6 +151,7 @@ macro_rules! define_pyclass_setattr_slot {
#[inline]
unsafe fn $set(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_attr: *mut ffi::PyObject,
_value: NonNull<ffi::PyObject>,
@ -167,6 +168,7 @@ macro_rules! define_pyclass_setattr_slot {
#[inline]
unsafe fn $del(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_attr: *mut ffi::PyObject,
) -> PyResult<()> {
@ -187,15 +189,14 @@ macro_rules! define_pyclass_setattr_slot {
attr: *mut $crate::ffi::PyObject,
value: *mut $crate::ffi::PyObject,
) -> ::std::os::raw::c_int {
use $crate::callback::IntoPyCallbackOutput;
$crate::callback::handle_panic(|py| {
let collector = PyClassImplCollector::<$cls>::new();
$crate::callback::convert(py, {
if let Some(value) = ::std::ptr::NonNull::new(value) {
collector.$set(_slf, attr, value)
} else {
collector.$del(_slf, attr)
}
})
if let Some(value) = ::std::ptr::NonNull::new(value) {
collector.$set(py, _slf, attr, value).convert(py)
} else {
collector.$del(py, _slf, attr).convert(py)
}
})
}
Some($crate::ffi::PyType_Slot {
@ -252,6 +253,223 @@ define_pyclass_setattr_slot! {
objobjargproc,
}
/// Macro which expands to three items
/// - Trait for a lhs dunder e.g. __add__
/// - Trait for the corresponding rhs e.g. __radd__
/// - A macro which will use dtolnay specialisation to generate the shared slot for the two dunders
macro_rules! define_pyclass_binary_operator_slot {
(
$lhs_trait:ident,
$rhs_trait:ident,
$lhs_implemented:ident,
$rhs_implemented:ident,
$lhs:ident,
$rhs:ident,
$generate_macro:ident,
$slot:ident,
$func_ty:ident,
) => {
slot_fragment_trait! {
$lhs_trait,
$lhs_implemented,
/// # Safety: _slf and _attr must be valid non-null Python objects
#[inline]
unsafe fn $lhs(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
ffi::Py_INCREF(ffi::Py_NotImplemented());
Ok(ffi::Py_NotImplemented())
}
}
slot_fragment_trait! {
$rhs_trait,
$rhs_implemented,
/// # Safety: _slf and _attr must be valid non-null Python objects
#[inline]
unsafe fn $rhs(
self,
_py: Python,
_slf: *mut ffi::PyObject,
_other: *mut ffi::PyObject,
) -> PyResult<*mut ffi::PyObject> {
ffi::Py_INCREF(ffi::Py_NotImplemented());
Ok(ffi::Py_NotImplemented())
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! $generate_macro {
($cls:ty) => {{
use ::std::option::Option::*;
use $crate::class::impl_::*;
let collector = PyClassImplCollector::<$cls>::new();
if collector.$lhs_implemented() || collector.$rhs_implemented() {
unsafe extern "C" fn __wrap(
_slf: *mut $crate::ffi::PyObject,
_other: *mut $crate::ffi::PyObject,
) -> *mut $crate::ffi::PyObject {
$crate::callback::handle_panic(|py| {
let collector = PyClassImplCollector::<$cls>::new();
let lhs_result = collector.$lhs(py, _slf, _other)?;
if lhs_result == $crate::ffi::Py_NotImplemented() {
$crate::ffi::Py_DECREF(lhs_result);
collector.$rhs(py, _other, _slf)
} else {
::std::result::Result::Ok(lhs_result)
}
})
}
Some($crate::ffi::PyType_Slot {
slot: $crate::ffi::$slot,
pfunc: __wrap as $crate::ffi::$func_ty as _,
})
} else {
None
}
}};
}
};
}
define_pyclass_binary_operator_slot! {
PyClass__add__SlotFragment,
PyClass__radd__SlotFragment,
__add__implemented,
__radd__implemented,
__add__,
__radd__,
generate_pyclass_add_slot,
Py_nb_add,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__sub__SlotFragment,
PyClass__rsub__SlotFragment,
__sub__implemented,
__rsub__implemented,
__sub__,
__rsub__,
generate_pyclass_sub_slot,
Py_nb_subtract,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__mul__SlotFragment,
PyClass__rmul__SlotFragment,
__mul__implemented,
__rmul__implemented,
__mul__,
__rmul__,
generate_pyclass_mul_slot,
Py_nb_multiply,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__mod__SlotFragment,
PyClass__rmod__SlotFragment,
__mod__implemented,
__rmod__implemented,
__mod__,
__rmod__,
generate_pyclass_mod_slot,
Py_nb_remainder,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__divmod__SlotFragment,
PyClass__rdivmod__SlotFragment,
__divmod__implemented,
__rdivmod__implemented,
__divmod__,
__rdivmod__,
generate_pyclass_divmod_slot,
Py_nb_divmod,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__lshift__SlotFragment,
PyClass__rlshift__SlotFragment,
__lshift__implemented,
__rlshift__implemented,
__lshift__,
__rlshift__,
generate_pyclass_lshift_slot,
Py_nb_lshift,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__rshift__SlotFragment,
PyClass__rrshift__SlotFragment,
__rshift__implemented,
__rrshift__implemented,
__rshift__,
__rrshift__,
generate_pyclass_rshift_slot,
Py_nb_rshift,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__and__SlotFragment,
PyClass__rand__SlotFragment,
__and__implemented,
__rand__implemented,
__and__,
__rand__,
generate_pyclass_and_slot,
Py_nb_and,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__or__SlotFragment,
PyClass__ror__SlotFragment,
__or__implemented,
__ror__implemented,
__or__,
__ror__,
generate_pyclass_or_slot,
Py_nb_or,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__xor__SlotFragment,
PyClass__rxor__SlotFragment,
__xor__implemented,
__rxor__implemented,
__xor__,
__rxor__,
generate_pyclass_xor_slot,
Py_nb_xor,
binaryfunc,
}
define_pyclass_binary_operator_slot! {
PyClass__matmul__SlotFragment,
PyClass__rmatmul__SlotFragment,
__matmul__implemented,
__rmatmul__implemented,
__matmul__,
__rmatmul__,
generate_pyclass_matmul_slot,
Py_nb_matrix_multiply,
binaryfunc,
}
pub trait PyClassAllocImpl<T> {
fn alloc_impl(self) -> Option<ffi::allocfunc>;
}

View file

@ -551,4 +551,5 @@ assert c.counter.count == 3
// TODO: test __delete__
// TODO: test __anext__, __aiter__
// TODO: test __index__, __int__, __float__, __invert__
// TODO: __floordiv__, __truediv__
// TODO: better argument casting errors