Implement flag_group in the new rule-based toolchain.

BEGIN_PUBLIC
Implement flag_group in the new rule-based toolchain.
END_PUBLIC

PiperOrigin-RevId: 622107179
Change-Id: I9e1971e279f313ce85537c899bcf80860616f8b7
This commit is contained in:
Googler 2024-04-05 01:47:05 -07:00 committed by Copybara-Service
parent 54677903cf
commit 6e490f79ee
6 changed files with 458 additions and 35 deletions

View File

@ -13,43 +13,50 @@
# limitations under the License.
"""All providers for rule-based bazel toolchain config."""
load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
load("//cc/toolchains/impl:args_utils.bzl", "validate_nested_args")
load(
"//cc/toolchains/impl:collect.bzl",
"collect_action_types",
"collect_files",
"collect_provider",
)
load(
"//cc/toolchains/impl:nested_args.bzl",
"NESTED_ARGS_ATTRS",
"args_wrapper_macro",
"nested_args_provider_from_ctx",
)
load(
":cc_toolchain_info.bzl",
"ActionTypeSetInfo",
"ArgsInfo",
"ArgsListInfo",
"BuiltinVariablesInfo",
"FeatureConstraintInfo",
"NestedArgsInfo",
)
visibility("public")
def _cc_args_impl(ctx):
if not ctx.attr.args and not ctx.attr.env:
fail("cc_args requires at least one of args and env")
actions = collect_action_types(ctx.attr.actions)
files = collect_files(ctx.attr.data)
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
if not ctx.attr.args and not ctx.attr.nested and not ctx.attr.env:
fail("cc_args requires at least one of args, nested, and env")
nested = None
if ctx.attr.args:
# TODO: This is temporary until cc_nested_args is implemented.
nested = NestedArgsInfo(
if ctx.attr.args or ctx.attr.nested:
nested = nested_args_provider_from_ctx(ctx)
validate_nested_args(
variables = ctx.attr._variables[BuiltinVariablesInfo].variables,
nested_args = nested,
actions = actions.to_list(),
label = ctx.label,
nested = tuple(),
iterate_over = None,
files = files,
requires_types = {},
legacy_flag_group = flag_group(flags = ctx.attr.args),
)
files = nested.files
else:
files = collect_files(ctx.attr.data)
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
args = ArgsInfo(
label = ctx.label,
@ -72,7 +79,7 @@ def _cc_args_impl(ctx):
),
]
cc_args = rule(
_cc_args = rule(
implementation = _cc_args_impl,
attrs = {
"actions": attr.label_list(
@ -82,21 +89,6 @@ cc_args = rule(
See @rules_cc//cc/toolchains/actions:all for valid options.
""",
),
"args": attr.string_list(
doc = """Arguments that should be added to the command-line.
These are evaluated in order, with earlier args appearing earlier in the
invocation of the underlying tool.
""",
),
"data": attr.label_list(
allow_files = True,
doc = """Files required to add this argument to the command-line.
For example, a flag that sets the header directory might add the headers in that
directory as additional files.
""",
),
"env": attr.string_dict(
doc = "Environment variables to be added to the command-line.",
@ -108,7 +100,10 @@ directory as additional files.
If omitted, this flag set will be enabled unconditionally.
""",
),
},
"_variables": attr.label(
default = "//cc/toolchains/variables:variables",
),
} | NESTED_ARGS_ATTRS,
provides = [ArgsInfo],
doc = """Declares a list of arguments bound to a set of actions.
@ -121,3 +116,5 @@ Examples:
)
""",
)
cc_args = lambda **kwargs: args_wrapper_macro(rule = _cc_args, **kwargs)

View File

@ -11,7 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""."""
"""Helper functions for working with args."""
load(":variables.bzl", "get_type")
visibility([
"//cc/toolchains",
"//tests/rule_based_toolchain/...",
])
def get_action_type(args_list, action_type):
"""Returns the corresponding entry in ArgsListInfo.by_action.
@ -28,3 +35,87 @@ def get_action_type(args_list, action_type):
return args
return struct(action = action_type, args = tuple(), files = depset([]))
def validate_nested_args(*, nested_args, variables, actions, label, fail = fail):
"""Validates the typing for an nested_args invocation.
Args:
nested_args: (NestedArgsInfo) The nested_args to validate
variables: (Dict[str, VariableInfo]) A mapping from variable name to
the metadata (variable type and valid actions).
actions: (List[ActionTypeInfo]) The actions we require these variables
to be valid for.
label: (Label) The label of the rule we're currently validating.
Used for error messages.
fail: The fail function. Use for testing only.
"""
stack = [(nested_args, {})]
for _ in range(9999999):
if not stack:
break
nested_args, overrides = stack.pop()
if nested_args.iterate_over != None or nested_args.unwrap_options:
# Make sure we don't keep using the same object.
overrides = dict(**overrides)
if nested_args.iterate_over != None:
type = get_type(
name = nested_args.iterate_over,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] == "list":
# Rewrite the type of the thing we iterate over from a List[T]
# to a T.
overrides[nested_args.iterate_over] = type["elements"]
elif type["name"] == "option" and type["elements"]["name"] == "list":
# Rewrite Option[List[T]] to T.
overrides[nested_args.iterate_over] = type["elements"]["elements"]
else:
fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"]))
# 1) Validate variables marked with after_option_unwrap = False.
# 2) Unwrap Option[T] to T as required.
# 3) Validate variables marked with after_option_unwrap = True.
for after_option_unwrap in [False, True]:
for var_name, requirements in nested_args.requires_types.items():
for requirement in requirements:
if requirement.after_option_unwrap == after_option_unwrap:
type = get_type(
name = var_name,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] not in requirement.valid_types:
fail("{msg}, but {var_name} has type {type}".format(
var_name = var_name,
msg = requirement.msg,
type = type["repr"],
))
# Only unwrap the options after the first iteration of this loop.
if not after_option_unwrap:
for var in nested_args.unwrap_options:
type = get_type(
name = var,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] == "option":
overrides[var] = type["elements"]
for child in nested_args.nested:
stack.append((child, overrides))

View File

@ -13,8 +13,10 @@
# limitations under the License.
"""Helper functions for working with args."""
load("@bazel_skylib//lib:structs.bzl", "structs")
load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo")
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo")
load(":collect.bzl", "collect_files", "collect_provider")
visibility([
"//cc/toolchains",
@ -48,6 +50,126 @@ cc_args(
iterate_over = "//toolchains/variables:foo_list",
"""
# @unsorted-dict-items.
NESTED_ARGS_ATTRS = {
"args": attr.string_list(
doc = """json-encoded arguments to be added to the command-line.
Usage:
cc_args(
...,
args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")]
)
This is equivalent to flag_group(flags = ["--foo", "%{foo}"])
Mutually exclusive with nested.
""",
),
"nested": attr.label_list(
providers = [NestedArgsInfo],
doc = """nested_args that should be added on the command-line.
Mutually exclusive with args.""",
),
"data": attr.label_list(
allow_files = True,
doc = """Files required to add this argument to the command-line.
For example, a flag that sets the header directory might add the headers in that
directory as additional files.
""",
),
"variables": attr.label_list(
providers = [VariableInfo],
doc = "Variables to be used in substitutions",
),
"iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"),
"requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"),
"requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"),
"requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"),
"requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"),
"requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"),
"requires_equal_value": attr.string(),
}
def args_wrapper_macro(*, name, rule, args = [], **kwargs):
"""Invokes a rule by converting args to attributes.
Args:
name: (str) The name of the target.
rule: (rule) The rule to invoke. Either cc_args or cc_nested_args.
args: (List[str|Formatted]) A list of either strings, or function calls
from format.bzl. For example:
["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")]
**kwargs: kwargs to pass through into the rule invocation.
"""
out_args = []
vars = []
if type(args) != "list":
fail("Args must be a list in %s" % native.package_relative_label(name))
for arg in args:
if type(arg) == "string":
out_args.append(raw_string(arg))
elif getattr(arg, "format_type") == "format_arg":
arg = structs.to_dict(arg)
if arg["value"] == None:
out_args.append(arg)
else:
var = arg.pop("value")
# Swap the variable from a label to an index. This allows us to
# actually get the providers in a rule.
out_args.append(struct(value = len(vars), **arg))
vars.append(var)
else:
fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg))
rule(
name = name,
args = [json.encode(arg) for arg in out_args],
variables = vars,
**kwargs
)
def _var(target):
if target == None:
return None
return target[VariableInfo].name
# TODO: Consider replacing this with a subrule in the future. However, maybe not
# for a long time, since it'll break compatibility with all bazel versions < 7.
def nested_args_provider_from_ctx(ctx):
"""Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS.
Args:
ctx: The rule context
Returns:
NestedArgsInfo
"""
variables = collect_provider(ctx.attr.variables, VariableInfo)
args = []
for arg in ctx.attr.args:
arg = json.decode(arg)
if "value" in arg:
if arg["value"] != None:
arg["value"] = variables[arg["value"]]
args.append(struct(**arg))
return nested_args_provider(
label = ctx.label,
args = args,
nested = collect_provider(ctx.attr.nested, NestedArgsInfo),
files = collect_files(ctx.attr.data),
iterate_over = _var(ctx.attr.iterate_over),
requires_not_none = _var(ctx.attr.requires_not_none),
requires_none = _var(ctx.attr.requires_none),
requires_true = _var(ctx.attr.requires_true),
requires_false = _var(ctx.attr.requires_false),
requires_equal = _var(ctx.attr.requires_equal),
requires_equal_value = ctx.attr.requires_equal_value,
)
def raw_string(s):
"""Constructs metadata for creating a raw string.

View File

@ -0,0 +1,45 @@
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All providers for rule-based bazel toolchain config."""
load(
"//cc/toolchains/impl:nested_args.bzl",
"NESTED_ARGS_ATTRS",
"args_wrapper_macro",
"nested_args_provider_from_ctx",
)
load(
":cc_toolchain_info.bzl",
"NestedArgsInfo",
)
visibility("public")
_cc_nested_args = rule(
implementation = lambda ctx: [nested_args_provider_from_ctx(ctx)],
attrs = NESTED_ARGS_ATTRS,
provides = [NestedArgsInfo],
doc = """Declares a list of arguments bound to a set of actions.
Roughly equivalent to ctx.actions.args()
Examples:
cc_nested_args(
name = "warnings_as_errors",
args = ["-Werror"],
)
""",
)
cc_nested_args = lambda **kwargs: args_wrapper_macro(rule = _cc_nested_args, **kwargs)

View File

@ -1,3 +1,5 @@
load("//cc/toolchains:format.bzl", "format_arg")
load("//cc/toolchains:nested_args.bzl", "cc_nested_args")
load("//cc/toolchains/impl:variables.bzl", "cc_builtin_variables", "cc_variable", "types")
load("//tests/rule_based_toolchain:analysis_test_suite.bzl", "analysis_test_suite")
load(":variables_test.bzl", "TARGETS", "TESTS")
@ -7,6 +9,11 @@ cc_variable(
type = types.string,
)
cc_variable(
name = "optional_list",
type = types.option(types.list(types.string)),
)
cc_variable(
name = "str_list",
type = types.list(types.string),
@ -28,15 +35,104 @@ cc_variable(
cc_variable(
name = "struct_list",
actions = ["//tests/rule_based_toolchain/actions:c_compile"],
type = types.list(types.struct(
nested_str = types.string,
nested_str_list = types.list(types.string),
)),
)
cc_variable(
name = "struct_list.nested_str_list",
type = types.unknown,
)
# Dots in the name confuse the test rules.
# It would end up generating targets.struct_list.nested_str_list.
alias(
name = "nested_str_list",
actual = ":struct_list.nested_str_list",
)
cc_nested_args(
name = "simple_str",
args = [format_arg("%s", ":str")],
)
cc_nested_args(
name = "list_not_allowed",
args = [format_arg("%s", ":str_list")],
)
cc_nested_args(
name = "iterate_over_list",
args = [format_arg("%s")],
iterate_over = ":str_list",
)
cc_nested_args(
name = "iterate_over_non_list",
args = ["--foo"],
iterate_over = ":str",
)
cc_nested_args(
name = "str_not_a_bool",
args = ["--foo"],
requires_true = ":str",
)
cc_nested_args(
name = "str_equal",
args = ["--foo"],
requires_equal = ":str",
requires_equal_value = "bar",
)
cc_nested_args(
name = "inner_iter",
args = [format_arg("%s")],
iterate_over = ":struct_list.nested_str_list",
)
cc_nested_args(
name = "outer_iter",
iterate_over = ":struct_list",
nested = [":inner_iter"],
)
cc_nested_args(
name = "bad_inner_iter",
args = [format_arg("%s", ":struct_list.nested_str_list")],
)
cc_nested_args(
name = "bad_outer_iter",
iterate_over = ":struct_list",
nested = [":bad_inner_iter"],
)
cc_nested_args(
name = "bad_nested_optional",
args = [format_arg("%s", ":str_option")],
)
cc_nested_args(
name = "good_nested_optional",
args = [format_arg("%s", ":str_option")],
requires_not_none = ":str_option",
)
cc_nested_args(
name = "optional_list_iter",
args = ["--foo"],
iterate_over = ":optional_list",
)
cc_builtin_variables(
name = "variables",
srcs = [
":optional_list",
":str",
":str_list",
":str_option",

View File

@ -13,13 +13,20 @@
# limitations under the License.
"""Tests for variables rule."""
load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "VariableInfo")
load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "NestedArgsInfo", "VariableInfo")
load("//cc/toolchains/impl:args_utils.bzl", _validate_nested_args = "validate_nested_args")
load(
"//cc/toolchains/impl:nested_args.bzl",
"FORMAT_ARGS_ERR",
"REQUIRES_TRUE_ERR",
)
load("//cc/toolchains/impl:variables.bzl", "types", _get_type = "get_type")
load("//tests/rule_based_toolchain:subjects.bzl", "result_fn_wrapper", "subjects")
visibility("private")
get_type = result_fn_wrapper(_get_type)
validate_nested_args = result_fn_wrapper(_validate_nested_args)
_ARGS_LABEL = Label("//:args")
_NESTED_LABEL = Label("//:nested_vars")
@ -56,6 +63,7 @@ def _get_types_test(env, targets):
expect_type("unknown").err().contains(
"""The variable unknown does not exist. Did you mean one of the following?
optional_list
str
str_list
""",
@ -110,11 +118,74 @@ nested_str_list: List[string]""")
},
).ok().equals(types.string)
def _variable_validation_test(env, targets):
c_compile = targets.c_compile[ActionTypeInfo]
cpp_compile = targets.cpp_compile[ActionTypeInfo]
variables = targets.variables[BuiltinVariablesInfo].variables
def _expect_validated(target, expr = None, actions = []):
return env.expect.that_value(
validate_nested_args(
nested_args = target[NestedArgsInfo],
variables = variables,
actions = actions,
label = _ARGS_LABEL,
),
expr = expr,
# Type is Result[None]
factory = subjects.result(subjects.unknown),
)
_expect_validated(targets.simple_str, expr = "simple_str").ok()
_expect_validated(targets.list_not_allowed).err().equals(
FORMAT_ARGS_ERR + ", but str_list has type List[string]",
)
_expect_validated(targets.iterate_over_list, expr = "iterate_over_list").ok()
_expect_validated(targets.iterate_over_non_list, expr = "iterate_over_non_list").err().equals(
"Attempting to iterate over str, but it was not a list - it was a string",
)
_expect_validated(targets.str_not_a_bool, expr = "str_not_a_bool").err().equals(
REQUIRES_TRUE_ERR + ", but str has type string",
)
_expect_validated(targets.str_equal, expr = "str_equal").ok()
_expect_validated(targets.inner_iter, expr = "inner_iter_standalone").err().equals(
'Attempted to access "struct_list.nested_str_list", but "struct_list" was not a struct - it had type List[struct(nested_str=string, nested_str_list=List[string])]. Maybe you meant to use iterate_over.',
)
_expect_validated(targets.outer_iter, actions = [c_compile], expr = "outer_iter_valid_action").ok()
_expect_validated(targets.outer_iter, actions = [c_compile, cpp_compile], expr = "outer_iter_missing_action").err().equals(
"The variable %s is inaccessible from the action %s. This is required because it is referenced in %s, which is included by %s, which references that action" % (targets.struct_list.label, cpp_compile.label, targets.outer_iter.label, _ARGS_LABEL),
)
_expect_validated(targets.bad_outer_iter, expr = "bad_outer_iter").err().equals(
FORMAT_ARGS_ERR + ", but struct_list.nested_str_list has type List[string]",
)
_expect_validated(targets.optional_list_iter, expr = "optional_list_iter").ok()
_expect_validated(targets.bad_nested_optional, expr = "bad_nested_optional").err().equals(
FORMAT_ARGS_ERR + ", but str_option has type Option[string]",
)
_expect_validated(targets.good_nested_optional, expr = "good_nested_optional").ok()
TARGETS = [
"//tests/rule_based_toolchain/actions:c_compile",
"//tests/rule_based_toolchain/actions:cpp_compile",
":bad_nested_optional",
":bad_outer_iter",
":good_nested_optional",
":inner_iter",
":iterate_over_list",
":iterate_over_non_list",
":list_not_allowed",
":nested_str_list",
":optional_list_iter",
":outer_iter",
":simple_str",
":str",
":str_equal",
":str_list",
":str_not_a_bool",
":str_option",
":struct",
":struct_list",
@ -125,4 +196,5 @@ TARGETS = [
TESTS = {
"types_represent_correctly_test": _types_represent_correctly_test,
"get_types_test": _get_types_test,
"variable_validation_test": _variable_validation_test,
}