rules_cc/tests/rule_based_toolchain/generate_factory.bzl

131 lines
4.6 KiB
Python

# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates provider factories."""
load("@bazel_skylib//lib:structs.bzl", "structs")
load("@rules_testing//lib:truth.bzl", "subjects")
visibility("private")
def generate_factory(type, name, attrs):
"""Generates a factory for a custom struct.
There are three reasons we need to do so:
1. It's very difficult to read providers printed by these types.
eg. If you have a 10 layer deep diamond dependency graph, and try to
print the top value, the bottom value will be printed 2^10 times.
2. Collections of subjects are not well supported by rules_testing
eg. `FeatureInfo(flag_sets = [FlagSetInfo(...)])`
(You can do it, but the inner values are just regular bazel structs and
you can't do fluent assertions on them).
3. Recursive types are not supported at all
eg. `FeatureInfo(implies = depset([FeatureInfo(...)]))`
To solve this, we create a factory that:
* Validates that the types of the children are correct.
* Inlines providers to their labels when unambiguous.
For example, given:
```
foo = FeatureInfo(name = "foo", label = Label("//:foo"))
bar = FeatureInfo(..., implies = depset([foo]))
```
It would convert itself a subject for the following struct:
`FeatureInfo(..., implies = depset([Label("//:foo")]))`
Args:
type: (type) The type to create a factory for (eg. FooInfo)
name: (str) The name of the type (eg. "FooInfo")
attrs: (dict[str, Factory]) The attributes associated with this type.
Returns:
A struct `FooFactory` suitable for use with
* `analysis_test(provider_subject_factories=[FooFactory])`
* `generate_factory(..., attrs=dict(foo = FooFactory))`
* `ProviderSequence(FooFactory)`
* `DepsetSequence(FooFactory)`
"""
attrs["label"] = subjects.label
want_keys = sorted(attrs.keys())
def validate(*, value, meta):
if value == None:
meta.add_failure("Wanted a %s but got" % name, value)
got_keys = sorted(structs.to_dict(value).keys())
subjects.collection(got_keys, meta = meta.derive(details = [
"Value was not a %s - it has a different set of fields" % name,
])).contains_exactly(want_keys).in_order()
def type_factory(value, *, meta):
validate(value = value, meta = meta)
transformed_value = {}
transformed_factories = {}
for field, factory in attrs.items():
field_value = getattr(value, field)
# If it's a type generated by generate_factory, inline it.
if hasattr(factory, "factory"):
factory.validate(value = field_value, meta = meta.derive(field))
transformed_value[field] = field_value.label
transformed_factories[field] = subjects.label
else:
transformed_value[field] = field_value
transformed_factories[field] = factory
return subjects.struct(
struct(**transformed_value),
meta = meta,
attrs = transformed_factories,
)
return struct(
type = type,
name = name,
factory = type_factory,
validate = validate,
)
def _provider_collection(element_factory, fn):
def factory(value, *, meta):
value = fn(value)
# Validate that it really is the correct type
for i in range(len(value)):
element_factory.validate(
value = value[i],
meta = meta.derive("offset({})".format(i)),
)
# Inline the providers to just labels.
return subjects.collection([v.label for v in value], meta = meta)
return factory
# This acts like a class, so we name it like one.
# buildifier: disable=name-conventions
ProviderSequence = lambda element_factory: _provider_collection(
element_factory,
fn = lambda x: list(x),
)
# buildifier: disable=name-conventions
ProviderDepset = lambda element_factory: _provider_collection(
element_factory,
fn = lambda x: x.to_list(),
)