Speficy METH_COEXIST for some number methods except

This commit is contained in:
kngwyu 2020-03-29 22:40:20 +09:00
parent a76bd7c4e3
commit 25eda36353
6 changed files with 181 additions and 142 deletions

View File

@ -22,9 +22,29 @@ impl Proto {
}
}
// TODO(kngwyu): Currently only __radd__-like methods use METH_COEXIST to prevent
// __add__-like methods from overriding them.
pub struct PyMethod {
pub name: &'static str,
pub proto: &'static str,
pub can_coexist: bool,
}
impl PyMethod {
const fn coexist(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: true,
}
}
const fn new(name: &'static str, proto: &'static str) -> Self {
PyMethod {
name,
proto,
can_coexist: false,
}
}
}
pub const OBJECT: Proto = Proto {
@ -88,18 +108,9 @@ pub const OBJECT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__format__",
proto: "pyo3::class::basic::FormatProtocolImpl",
},
PyMethod {
name: "__bytes__",
proto: "pyo3::class::basic::BytesProtocolImpl",
},
PyMethod {
name: "__unicode__",
proto: "pyo3::class::basic::UnicodeProtocolImpl",
},
PyMethod::new("__format__", "pyo3::class::basic::FormatProtocolImpl"),
PyMethod::new("__bytes__", "pyo3::class::basic::BytesProtocolImpl"),
PyMethod::new("__unicode__", "pyo3::class::basic::UnicodeProtocolImpl"),
],
};
@ -135,14 +146,14 @@ pub const ASYNC: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__aenter__",
proto: "pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
},
PyMethod {
name: "__aexit__",
proto: "pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
},
PyMethod::new(
"__aenter__",
"pyo3::class::pyasync::PyAsyncAenterProtocolImpl",
),
PyMethod::new(
"__aexit__",
"pyo3::class::pyasync::PyAsyncAexitProtocolImpl",
),
],
};
@ -180,14 +191,14 @@ pub const CONTEXT: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__enter__",
proto: "pyo3::class::context::PyContextEnterProtocolImpl",
},
PyMethod {
name: "__exit__",
proto: "pyo3::class::context::PyContextExitProtocolImpl",
},
PyMethod::new(
"__enter__",
"pyo3::class::context::PyContextEnterProtocolImpl",
),
PyMethod::new(
"__exit__",
"pyo3::class::context::PyContextExitProtocolImpl",
),
],
};
@ -237,14 +248,11 @@ pub const DESCR: Proto = Proto {
},
],
py_methods: &[
PyMethod {
name: "__del__",
proto: "pyo3::class::context::PyDescrDelProtocolImpl",
},
PyMethod {
name: "__set_name__",
proto: "pyo3::class::context::PyDescrNameProtocolImpl",
},
PyMethod::new("__del__", "pyo3::class::context::PyDescrDelProtocolImpl"),
PyMethod::new(
"__set_name__",
"pyo3::class::context::PyDescrNameProtocolImpl",
),
],
};
@ -298,10 +306,10 @@ pub const MAPPING: Proto = Proto {
proto: "pyo3::class::mapping::PyMappingReversedProtocol",
},
],
py_methods: &[PyMethod {
name: "__reversed__",
proto: "pyo3::class::mapping::PyMappingReversedProtocolImpl",
}],
py_methods: &[PyMethod::new(
"__reversed__",
"pyo3::class::mapping::PyMappingReversedProtocolImpl",
)],
};
pub const SEQ: Proto = Proto {
@ -666,81 +674,58 @@ pub const NUM: Proto = Proto {
pyres: true,
proto: "pyo3::class::number::PyNumberFloatProtocol",
},
MethodProto::Unary {
name: "__round__",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
MethodProto::Unary {
name: "__index__",
pyres: true,
proto: "pyo3::class::number::PyNumberIndexProtocol",
},
MethodProto::Binary {
name: "__round__",
arg: "NDigits",
pyres: true,
proto: "pyo3::class::number::PyNumberRoundProtocol",
},
],
py_methods: &[
PyMethod {
name: "__radd__",
proto: "pyo3::class::number::PyNumberRAddProtocolImpl",
},
PyMethod {
name: "__rsub__",
proto: "pyo3::class::number::PyNumberRSubProtocolImpl",
},
PyMethod {
name: "__rmul__",
proto: "pyo3::class::number::PyNumberRMulProtocolImpl",
},
PyMethod {
name: "__rmatmul__",
proto: "pyo3::class::number::PyNumberRMatmulProtocolImpl",
},
PyMethod {
name: "__rtruediv__",
proto: "pyo3::class::number::PyNumberRTruedivProtocolImpl",
},
PyMethod {
name: "__rfloordiv__",
proto: "pyo3::class::number::PyNumberRFloordivProtocolImpl",
},
PyMethod {
name: "__rmod__",
proto: "pyo3::class::number::PyNumberRModProtocolImpl",
},
PyMethod {
name: "__rdivmod__",
proto: "pyo3::class::number::PyNumberRDivmodProtocolImpl",
},
PyMethod {
name: "__rpow__",
proto: "pyo3::class::number::PyNumberRPowProtocolImpl",
},
PyMethod {
name: "__rlshift__",
proto: "pyo3::class::number::PyNumberRLShiftProtocolImpl",
},
PyMethod {
name: "__rrshift__",
proto: "pyo3::class::number::PyNumberRRShiftProtocolImpl",
},
PyMethod {
name: "__rand__",
proto: "pyo3::class::number::PyNumberRAndProtocolImpl",
},
PyMethod {
name: "__rxor__",
proto: "pyo3::class::number::PyNumberRXorProtocolImpl",
},
PyMethod {
name: "__ror__",
proto: "pyo3::class::number::PyNumberROrProtocolImpl",
},
PyMethod {
name: "__complex__",
proto: "pyo3::class::number::PyNumberComplexProtocolImpl",
},
PyMethod {
name: "__round__",
proto: "pyo3::class::number::PyNumberRoundProtocolImpl",
},
PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"),
PyMethod::coexist("__rsub__", "pyo3::class::number::PyNumberRSubProtocolImpl"),
PyMethod::coexist("__rmul__", "pyo3::class::number::PyNumberRMulProtocolImpl"),
PyMethod::coexist(
"__rmatmul__",
"pyo3::class::number::PyNumberRMatmulProtocolImpl",
),
PyMethod::coexist(
"__rtruediv__",
"pyo3::class::number::PyNumberRTruedivProtocolImpl",
),
PyMethod::coexist(
"__rfloordiv__",
"pyo3::class::number::PyNumberRFloordivProtocolImpl",
),
PyMethod::coexist("__rmod__", "pyo3::class::number::PyNumberRModProtocolImpl"),
PyMethod::coexist(
"__rdivmod__",
"pyo3::class::number::PyNumberRDivmodProtocolImpl",
),
PyMethod::coexist("__rpow__", "pyo3::class::number::PyNumberRPowProtocolImpl"),
PyMethod::coexist(
"__rlshift__",
"pyo3::class::number::PyNumberRLShiftProtocolImpl",
),
PyMethod::coexist(
"__rrshift__",
"pyo3::class::number::PyNumberRRShiftProtocolImpl",
),
PyMethod::coexist("__rand__", "pyo3::class::number::PyNumberRAndProtocolImpl"),
PyMethod::coexist("__rxor__", "pyo3::class::number::PyNumberRXorProtocolImpl"),
PyMethod::coexist("__ror__", "pyo3::class::number::PyNumberROrProtocolImpl"),
PyMethod::new(
"__complex__",
"pyo3::class::number::PyNumberComplexProtocolImpl",
),
PyMethod::new(
"__round__",
"pyo3::class::number::PyNumberRoundProtocolImpl",
),
],
};

View File

@ -75,7 +75,11 @@ fn impl_proto_impl(
Err(err) => return err.to_compile_error(),
};
let meth = pymethod::impl_proto_wrap(ty, &fn_spec);
let coexist = if m.can_coexist {
quote!(pyo3::ffi::METH_COEXIST)
} else {
quote!(0)
};
py_methods.push(quote! {
impl #proto for #ty
{
@ -86,7 +90,8 @@ fn impl_proto_impl(
Some(pyo3::class::PyMethodDef {
ml_name: stringify!(#name),
ml_meth: pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap),
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
// We need METH_COEXIST here to prevent __add__ from overriding __radd__
ml_flags: pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS | #coexist,
ml_doc: ""
})
}

View File

@ -323,7 +323,11 @@ macro_rules! py_ternary_self_func {
let arg2 = py.from_borrowed_ptr::<$crate::PyAny>(arg2);
let result = call_mut!(slf_cell, $f, arg1, arg2);
match result {
Ok(_) => slf,
Ok(_) => {
// Without this INCREF, SIGSEGV happens...
ffi::Py_INCREF(slf);
slf
}
Err(e) => e.restore_and_null(py),
}
}

View File

@ -60,7 +60,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Self::Modulo) -> Self::Result
fn __pow__(lhs: Self::Left, rhs: Self::Right, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberPowProtocol<'p>,
{
@ -145,7 +145,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __rpow__(&'p self, other: Self::Other, module: Self::Modulo) -> Self::Result
fn __rpow__(&'p self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberRPowProtocol<'p>,
{
@ -224,7 +224,7 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __ipow__(&'p mut self, other: Self::Other, modulo: Self::Modulo) -> Self::Result
fn __ipow__(&'p mut self, other: Self::Other, modulo: Option<Self::Modulo>) -> Self::Result
where
Self: PyNumberIPowProtocol<'p>,
{
@ -304,18 +304,18 @@ pub trait PyNumberProtocol<'p>: PyClass {
{
unimplemented!()
}
fn __round__(&'p self) -> Self::Result
where
Self: PyNumberRoundProtocol<'p>,
{
unimplemented!()
}
fn __index__(&'p self) -> Self::Result
where
Self: PyNumberIndexProtocol<'p>,
{
unimplemented!()
}
fn __round__(&'p self, ndigits: Option<Self::NDigits>) -> Self::Result
where
Self: PyNumberRoundProtocol<'p>,
{
unimplemented!()
}
}
pub trait PyNumberAddProtocol<'p>: PyNumberProtocol<'p> {
@ -610,6 +610,7 @@ pub trait PyNumberFloatProtocol<'p>: PyNumberProtocol<'p> {
pub trait PyNumberRoundProtocol<'p>: PyNumberProtocol<'p> {
type Success: IntoPy<PyObject>;
type NDigits: FromPyObject<'p>;
type Result: Into<PyResult<Self::Success>>;
}
@ -2137,7 +2138,7 @@ where
}
}
trait PyNumberComplexProtocolImpl {
pub trait PyNumberComplexProtocolImpl {
fn __complex__() -> Option<PyMethodDef>;
}
@ -2150,7 +2151,7 @@ where
}
}
trait PyNumberRoundProtocolImpl {
pub trait PyNumberRoundProtocolImpl {
fn __round__() -> Option<PyMethodDef>;
}

View File

@ -108,7 +108,6 @@ impl Drop for GILGuard {
unsafe {
let pool: &'static mut ReleasePool = &mut *POOL;
pool.drain(self.python(), self.owned, self.borrowed);
ffi::PyGILState_Release(self.gstate);
}
}

View File

@ -8,24 +8,39 @@ use pyo3::py_run;
mod common;
#[pyclass]
struct UnaryArithmetic {}
struct UnaryArithmetic {
inner: f64,
}
impl UnaryArithmetic {
fn new(value: f64) -> Self {
UnaryArithmetic { inner: value }
}
}
#[pyproto]
impl PyObjectProtocol for UnaryArithmetic {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("UA({})", self.inner))
}
}
#[pyproto]
impl PyNumberProtocol for UnaryArithmetic {
fn __neg__(&self) -> PyResult<&'static str> {
Ok("neg")
fn __neg__(&self) -> PyResult<Self> {
Ok(Self::new(-self.inner))
}
fn __pos__(&self) -> PyResult<&'static str> {
Ok("pos")
fn __pos__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner))
}
fn __abs__(&self) -> PyResult<&'static str> {
Ok("abs")
fn __abs__(&self) -> PyResult<Self> {
Ok(Self::new(self.inner.abs()))
}
fn __invert__(&self) -> PyResult<&'static str> {
Ok("invert")
fn __round__(&self, _ndigits: Option<u32>) -> PyResult<Self> {
Ok(Self::new(self.inner.round()))
}
}
@ -34,11 +49,12 @@ fn unary_arithmetic() {
let gil = Python::acquire_gil();
let py = gil.python();
let c = PyCell::new(py, UnaryArithmetic {}).unwrap();
py_run!(py, c, "assert -c == 'neg'");
py_run!(py, c, "assert +c == 'pos'");
py_run!(py, c, "assert abs(c) == 'abs'");
py_run!(py, c, "assert ~c == 'invert'");
let c = PyCell::new(py, UnaryArithmetic::new(2.718281)).unwrap();
py_run!(py, c, "assert repr(-c) == 'UA(-2.718281)'");
py_run!(py, c, "assert repr(+c) == 'UA(2.718281)'");
py_run!(py, c, "assert repr(abs(c)) == 'UA(2.718281)'");
py_run!(py, c, "assert repr(round(c)) == 'UA(3)'");
py_run!(py, c, "assert repr(round(c, 1)) == 'UA(3)'");
}
#[pyclass]
@ -104,13 +120,17 @@ impl PyNumberProtocol for InPlaceOperations {
self.value |= other;
Ok(())
}
fn __ipow__(&mut self, other: u32, _mod: Option<u32>) -> PyResult<()> {
self.value = self.value.pow(other);
Ok(())
}
}
#[test]
fn inplace_operations() {
let gil = Python::acquire_gil();
let py = gil.python();
let init = |value, code| {
let c = PyCell::new(py, InPlaceOperations { value }).unwrap();
py_run!(py, c, code);
@ -124,6 +144,11 @@ fn inplace_operations() {
init(12, "d = c; c &= 10; assert repr(c) == repr(d) == 'IPO(8)'");
init(12, "d = c; c |= 3; assert repr(c) == repr(d) == 'IPO(15)'");
init(12, "d = c; c ^= 5; assert repr(c) == repr(d) == 'IPO(9)'");
init(3, "d = c; c **= 4; assert repr(c) == repr(d) == 'IPO(81)'");
init(
3,
"d = c; c.__ipow__(4); assert repr(c) == repr(d) == 'IPO(81)'",
);
}
#[pyproto]
@ -159,6 +184,10 @@ impl PyNumberProtocol for BinaryArithmetic {
fn __or__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} | {:?}", lhs, rhs))
}
fn __pow__(lhs: &PyAny, rhs: &PyAny, mod_: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?} (mod: {:?})", lhs, rhs, mod_))
}
}
#[test]
@ -186,6 +215,10 @@ fn binary_arithmetic() {
py_run!(py, c, "assert 1 ^ c == '1 ^ BA'");
py_run!(py, c, "assert c | 1 == 'BA | 1'");
py_run!(py, c, "assert 1 | c == '1 | BA'");
py_run!(py, c, "assert c ** 1 == 'BA ** 1 (mod: None)'");
py_run!(py, c, "assert 1 ** c == '1 ** BA (mod: None)'");
py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");
}
#[pyclass]
@ -225,7 +258,7 @@ impl PyNumberProtocol for RhsArithmetic {
Ok(format!("{:?} | RA", other))
}
fn __rpow__(&self, other: &PyAny, _module: &PyAny) -> PyResult<String> {
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
}
}
@ -269,6 +302,10 @@ impl PyNumberProtocol for LhsAndRhsArithmetic {
Ok(format!("{:?} - RA", other))
}
fn __rpow__(&self, other: &PyAny, _mod: Option<&'p PyAny>) -> PyResult<String> {
Ok(format!("{:?} ** RA", other))
}
fn __add__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} + {:?}", lhs, rhs))
}
@ -276,6 +313,10 @@ impl PyNumberProtocol for LhsAndRhsArithmetic {
fn __sub__(lhs: &PyAny, rhs: &PyAny) -> PyResult<String> {
Ok(format!("{:?} - {:?}", lhs, rhs))
}
fn __pow__(lhs: &PyAny, rhs: &PyAny, _mod: Option<u32>) -> PyResult<String> {
Ok(format!("{:?} ** {:?}", lhs, rhs))
}
}
#[pyproto]
@ -291,10 +332,14 @@ fn lhs_override_rhs() {
let py = gil.python();
let c = PyCell::new(py, LhsAndRhsArithmetic {}).unwrap();
py_run!(py, c, "assert c.__radd__(1) == '1 + BA'");
// Not overrided
py_run!(py, c, "assert c.__radd__(1) == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert c.__rpow__(1) == '1 ** RA'");
// Overrided
py_run!(py, c, "assert 1 + c == '1 + BA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - BA'");
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert 1 ** c == '1 ** BA'");
}
#[pyclass]