Add `nb::is_flag()` annotation to Counter::Flags (#1870)

This saves us the definition of `__or__`, because we can just use the
one from `enum.IntFlag`.
This commit is contained in:
Nicholas Junge 2024-10-28 19:18:40 +01:00 committed by GitHub
parent 4e3f2d8b67
commit d99cdd7356
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 12 deletions

View File

@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) {
using benchmark::Counter; using benchmark::Counter;
nb::class_<Counter> py_counter(m, "Counter"); nb::class_<Counter> py_counter(m, "Counter");
nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic()) nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic(), nb::is_flag())
.value("kDefaults", Counter::Flags::kDefaults) .value("kDefaults", Counter::Flags::kDefaults)
.value("kIsRate", Counter::Flags::kIsRate) .value("kIsRate", Counter::Flags::kIsRate)
.value("kAvgThreads", Counter::Flags::kAvgThreads) .value("kAvgThreads", Counter::Flags::kAvgThreads)
@ -129,10 +129,7 @@ NB_MODULE(_benchmark, m) {
.value("kAvgIterations", Counter::Flags::kAvgIterations) .value("kAvgIterations", Counter::Flags::kAvgIterations)
.value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate) .value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate)
.value("kInvert", Counter::Flags::kInvert) .value("kInvert", Counter::Flags::kInvert)
.export_values() .export_values();
.def("__or__", [](Counter::Flags a, Counter::Flags b) {
return static_cast<int>(a) | static_cast<int>(b);
});
nb::enum_<Counter::OneK>(py_counter, "OneK") nb::enum_<Counter::OneK>(py_counter, "OneK")
.value("kIs1000", Counter::OneK::kIs1000) .value("kIs1000", Counter::OneK::kIs1000)
@ -140,11 +137,7 @@ NB_MODULE(_benchmark, m) {
.export_values(); .export_values();
py_counter py_counter
.def( .def(nb::init<double, Counter::Flags, Counter::OneK>(),
"__init__",
[](Counter* c, double value, int flags, Counter::OneK oneK) {
new (c) Counter(value, static_cast<Counter::Flags>(flags), oneK);
},
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults, nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000) nb::arg("k") = Counter::kIs1000)
.def("__init__", .def("__init__",