From f32277163a96e846316a73f62cad2ee097226a80 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Mon, 8 Jun 2020 12:04:01 +0900 Subject: [PATCH] Move nb_bool under PyObjectProtocol again --- pyo3-derive-backend/src/defs.rs | 12 ++++++------ src/class/basic.rs | 13 +++++++++++++ src/class/number.rs | 22 +++++----------------- src/pyclass.rs | 10 +++++++++- tests/test_dunder.rs | 4 ---- 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index 19b6ab42..e1990a15 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -123,6 +123,11 @@ pub const OBJECT: Proto = Proto { pyres: true, proto: "pyo3::class::basic::PyObjectRichcmpProtocol", }, + MethodProto::Unary { + name: "__bool__", + pyres: false, + proto: "pyo3::class::basic::PyObjectBoolProtocol", + }, ], py_methods: &[ PyMethod::new("__format__", "pyo3::class::basic::FormatProtocolImpl"), @@ -142,6 +147,7 @@ pub const OBJECT: Proto = Proto { }, SlotSetter::new(&["__setattr__"], "set_setattr"), SlotSetter::new(&["__delattr__"], "set_delattr"), + SlotSetter::new(&["__bool__"], "set_bool"), ], }; @@ -784,11 +790,6 @@ pub const NUM: Proto = Proto { pyres: true, proto: "pyo3::class::number::PyNumberRoundProtocol", }, - MethodProto::Unary { - name: "__bool__", - pyres: false, - proto: "pyo3::class::number::PyNumberBoolProtocol", - }, ], py_methods: &[ PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"), @@ -867,7 +868,6 @@ pub const NUM: Proto = Proto { SlotSetter::new(&["__neg__"], "set_neg"), SlotSetter::new(&["__pos__"], "set_pos"), SlotSetter::new(&["__abs__"], "set_abs"), - SlotSetter::new(&["__bool__"], "set_bool"), SlotSetter::new(&["__invert__"], "set_invert"), SlotSetter::new(&["__rdivmod__"], "set_rdivmod"), SlotSetter { diff --git a/src/class/basic.rs b/src/class/basic.rs index c9827a77..38daf9cf 100644 --- a/src/class/basic.rs +++ b/src/class/basic.rs @@ -90,6 +90,12 @@ pub trait PyObjectProtocol<'p>: PyClass { { unimplemented!() } + fn __bool__(&'p self) -> Self::Result + where + Self: PyObjectBoolProtocol<'p>, + { + unimplemented!() + } } pub trait PyObjectGetAttrProtocol<'p>: PyObjectProtocol<'p> { @@ -144,6 +150,7 @@ pub struct PyObjectMethods { pub tp_getattro: Option, pub tp_richcompare: Option, pub tp_setattro: Option, + pub nb_bool: Option, } impl PyObjectMethods { @@ -215,6 +222,12 @@ impl PyObjectMethods { __delattr__ ) } + pub fn set_bool(&mut self) + where + T: for<'p> PyObjectBoolProtocol<'p>, + { + self.nb_bool = py_unary_func!(PyObjectBoolProtocol, T::__bool__, c_int); + } } fn tp_getattro() -> Option diff --git a/src/class/number.rs b/src/class/number.rs index 007f38ab..4441d15b 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -5,7 +5,6 @@ use crate::err::PyResult; use crate::{ffi, FromPyObject, IntoPy, PyClass, PyObject}; -use std::os::raw::c_int; /// Number interface #[allow(unused_variables)] @@ -314,12 +313,6 @@ pub trait PyNumberProtocol<'p>: PyClass { { unimplemented!() } - fn __bool__(&'p self) -> Self::Result - where - Self: PyNumberBoolProtocol<'p>, - { - unimplemented!() - } } pub trait PyNumberAddProtocol<'p>: PyNumberProtocol<'p> { @@ -622,11 +615,12 @@ pub trait PyNumberIndexProtocol<'p>: PyNumberProtocol<'p> { type Result: Into>; } -pub trait PyNumberBoolProtocol<'p>: PyNumberProtocol<'p> { - type Result: Into>; -} - impl ffi::PyNumberMethods { + pub(crate) fn from_nb_bool(nb_bool: ffi::inquiry) -> *mut Self { + let mut nm = ffi::PyNumberMethods_INIT; + nm.nb_bool = Some(nb_bool); + Box::into_raw(Box::new(nm)) + } pub fn set_add(&mut self) where T: for<'p> PyNumberAddProtocol<'p>, @@ -711,12 +705,6 @@ impl ffi::PyNumberMethods { { self.nb_absolute = py_unary_func!(PyNumberAbsProtocol, T::__abs__); } - pub fn set_bool(&mut self) - where - T: for<'p> PyNumberBoolProtocol<'p>, - { - self.nb_bool = py_unary_func!(PyNumberBoolProtocol, T::__bool__, c_int); - } pub fn set_invert(&mut self) where T: for<'p> PyNumberInvertProtocol<'p>, diff --git a/src/pyclass.rs b/src/pyclass.rs index bbec2cea..0f41076b 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -152,13 +152,21 @@ where unsafe { iter.as_ref() }.update_typeobj(type_object); } + // nb_bool is a part of PyObjectProtocol, but should be placed under tp_as_number + let mut nb_bool = None; // basic methods if let Some(basic) = T::basic_methods() { unsafe { basic.as_ref() }.update_typeobj(type_object); + nb_bool = unsafe { basic.as_ref() }.nb_bool; } // number methods - type_object.tp_as_number = T::number_methods().map_or_else(ptr::null_mut, |p| p.as_ptr()); + type_object.tp_as_number = T::number_methods() + .map(|mut p| { + unsafe { p.as_mut() }.nb_bool = nb_bool; + p.as_ptr() + }) + .unwrap_or_else(|| nb_bool.map_or_else(ptr::null_mut, ffi::PyNumberMethods::from_nb_bool)); // mapping methods type_object.tp_as_mapping = T::mapping_methods().map_or_else(ptr::null_mut, |p| p.as_ptr()); // sequence methods diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 8ecc0101..421873f0 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -121,10 +121,6 @@ impl PyObjectProtocol for Comparisons { fn __hash__(&self) -> PyResult { Ok(self.val as isize) } -} - -#[pyproto] -impl pyo3::class::PyNumberProtocol for Comparisons { fn __bool__(&self) -> PyResult { Ok(self.val != 0) }