Update benchmark Python bindings for nanobind 2.0, and update to nanobind 2.0. (#1817)

Incorporates the nanobind_bazel change from https://github.com/google/benchmark/pull/1795.

nanobind 2.0 reworked the nanobind::enum_ class so it uses a real Python enum or intenum rather than its previous hand-rolled implementation.
https://nanobind.readthedocs.io/en/latest/changelog.html#version-2-0-0-may-23-2024

As a consequence of that change, nanobind now checks when casting an integer to a enum value that the integer corresponds to a valid enum. Counter::Flags is a bitmask, and many combinations are not valid enum members.

This change:
a) sets nb::is_arithmetic(), which means Counter::Flags becomes an IntEnum that can be freely cast to an integer.
b) defines the | operator for flags to return an integer, not an enum, avoiding the error.
c) changes Counter's constructor to accept an int, not a Counter::Flags enum. Since Counter::Flags is an IntEnum now, it can be freely coerced to an int.

If https://github.com/wjakob/nanobind/pull/599 is merged into nanobind, then we can perhaps use a flag enum here instead.
This commit is contained in:
Peter Hawkins 2024-07-18 11:54:02 -04:00 committed by GitHub
parent a6ad7fbbdc
commit 64b5d8cd11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 7 deletions

View File

@ -38,4 +38,4 @@ use_repo(pip, "tools_pip_deps")
# -- bazel_dep definitions -- # # -- bazel_dep definitions -- #
bazel_dep(name = "nanobind_bazel", version = "1.0.0", dev_dependency = True) bazel_dep(name = "nanobind_bazel", version = "2.0.0", dev_dependency = True)

View File

@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) {
using benchmark::Counter; using benchmark::Counter;
nb::class_<Counter> py_counter(m, "Counter"); nb::class_<Counter> py_counter(m, "Counter");
nb::enum_<Counter::Flags>(py_counter, "Flags") nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic())
.value("kDefaults", Counter::Flags::kDefaults) .value("kDefaults", Counter::Flags::kDefaults)
.value("kIsRate", Counter::Flags::kIsRate) .value("kIsRate", Counter::Flags::kIsRate)
.value("kAvgThreads", Counter::Flags::kAvgThreads) .value("kAvgThreads", Counter::Flags::kAvgThreads)
@ -130,7 +130,9 @@ NB_MODULE(_benchmark, m) {
.value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate) .value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate)
.value("kInvert", Counter::Flags::kInvert) .value("kInvert", Counter::Flags::kInvert)
.export_values() .export_values()
.def(nb::self | nb::self); .def("__or__", [](Counter::Flags a, Counter::Flags b) {
return static_cast<int>(a) | static_cast<int>(b);
});
nb::enum_<Counter::OneK>(py_counter, "OneK") nb::enum_<Counter::OneK>(py_counter, "OneK")
.value("kIs1000", Counter::OneK::kIs1000) .value("kIs1000", Counter::OneK::kIs1000)
@ -138,10 +140,15 @@ NB_MODULE(_benchmark, m) {
.export_values(); .export_values();
py_counter py_counter
.def(nb::init<double, Counter::Flags, Counter::OneK>(), .def(
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults, "__init__",
nb::arg("k") = Counter::kIs1000) [](Counter* c, double value, int flags, Counter::OneK oneK) {
.def("__init__", ([](Counter *c, double value) { new (c) Counter(value); })) new (c) Counter(value, static_cast<Counter::Flags>(flags), oneK);
},
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000)
.def("__init__",
([](Counter* c, double value) { new (c) Counter(value); }))
.def_rw("value", &Counter::value) .def_rw("value", &Counter::value)
.def_rw("flags", &Counter::flags) .def_rw("flags", &Counter::flags)
.def_rw("oneK", &Counter::oneK) .def_rw("oneK", &Counter::oneK)