From 6e490f79ee52b97ab5da9315e3ff9bf3e80aa5f2 Mon Sep 17 00:00:00 2001 From: Googler Date: Fri, 5 Apr 2024 01:47:05 -0700 Subject: [PATCH] Implement flag_group in the new rule-based toolchain. BEGIN_PUBLIC Implement flag_group in the new rule-based toolchain. END_PUBLIC PiperOrigin-RevId: 622107179 Change-Id: I9e1971e279f313ce85537c899bcf80860616f8b7 --- cc/toolchains/args.bzl | 61 ++++----- cc/toolchains/impl/args_utils.bzl | 93 ++++++++++++- cc/toolchains/impl/nested_args.bzl | 124 +++++++++++++++++- cc/toolchains/nested_args.bzl | 45 +++++++ tests/rule_based_toolchain/variables/BUILD | 96 ++++++++++++++ .../variables/variables_test.bzl | 74 ++++++++++- 6 files changed, 458 insertions(+), 35 deletions(-) create mode 100644 cc/toolchains/nested_args.bzl diff --git a/cc/toolchains/args.bzl b/cc/toolchains/args.bzl index 29e3a1b..1df3333 100644 --- a/cc/toolchains/args.bzl +++ b/cc/toolchains/args.bzl @@ -13,43 +13,50 @@ # limitations under the License. """All providers for rule-based bazel toolchain config.""" -load("//cc:cc_toolchain_config_lib.bzl", "flag_group") +load("//cc/toolchains/impl:args_utils.bzl", "validate_nested_args") load( "//cc/toolchains/impl:collect.bzl", "collect_action_types", "collect_files", "collect_provider", ) +load( + "//cc/toolchains/impl:nested_args.bzl", + "NESTED_ARGS_ATTRS", + "args_wrapper_macro", + "nested_args_provider_from_ctx", +) load( ":cc_toolchain_info.bzl", "ActionTypeSetInfo", "ArgsInfo", "ArgsListInfo", + "BuiltinVariablesInfo", "FeatureConstraintInfo", - "NestedArgsInfo", ) visibility("public") def _cc_args_impl(ctx): - if not ctx.attr.args and not ctx.attr.env: - fail("cc_args requires at least one of args and env") - actions = collect_action_types(ctx.attr.actions) - files = collect_files(ctx.attr.data) - requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo) + + if not ctx.attr.args and not ctx.attr.nested and not ctx.attr.env: + fail("cc_args requires at least one of args, nested, and env") nested = None - if ctx.attr.args: - # TODO: This is temporary until cc_nested_args is implemented. - nested = NestedArgsInfo( + if ctx.attr.args or ctx.attr.nested: + nested = nested_args_provider_from_ctx(ctx) + validate_nested_args( + variables = ctx.attr._variables[BuiltinVariablesInfo].variables, + nested_args = nested, + actions = actions.to_list(), label = ctx.label, - nested = tuple(), - iterate_over = None, - files = files, - requires_types = {}, - legacy_flag_group = flag_group(flags = ctx.attr.args), ) + files = nested.files + else: + files = collect_files(ctx.attr.data) + + requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo) args = ArgsInfo( label = ctx.label, @@ -72,7 +79,7 @@ def _cc_args_impl(ctx): ), ] -cc_args = rule( +_cc_args = rule( implementation = _cc_args_impl, attrs = { "actions": attr.label_list( @@ -82,21 +89,6 @@ cc_args = rule( See @rules_cc//cc/toolchains/actions:all for valid options. """, - ), - "args": attr.string_list( - doc = """Arguments that should be added to the command-line. - -These are evaluated in order, with earlier args appearing earlier in the -invocation of the underlying tool. -""", - ), - "data": attr.label_list( - allow_files = True, - doc = """Files required to add this argument to the command-line. - -For example, a flag that sets the header directory might add the headers in that -directory as additional files. - """, ), "env": attr.string_dict( doc = "Environment variables to be added to the command-line.", @@ -108,7 +100,10 @@ directory as additional files. If omitted, this flag set will be enabled unconditionally. """, ), - }, + "_variables": attr.label( + default = "//cc/toolchains/variables:variables", + ), + } | NESTED_ARGS_ATTRS, provides = [ArgsInfo], doc = """Declares a list of arguments bound to a set of actions. @@ -121,3 +116,5 @@ Examples: ) """, ) + +cc_args = lambda **kwargs: args_wrapper_macro(rule = _cc_args, **kwargs) diff --git a/cc/toolchains/impl/args_utils.bzl b/cc/toolchains/impl/args_utils.bzl index 2ace6aa..55b4841 100644 --- a/cc/toolchains/impl/args_utils.bzl +++ b/cc/toolchains/impl/args_utils.bzl @@ -11,7 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""".""" +"""Helper functions for working with args.""" + +load(":variables.bzl", "get_type") + +visibility([ + "//cc/toolchains", + "//tests/rule_based_toolchain/...", +]) def get_action_type(args_list, action_type): """Returns the corresponding entry in ArgsListInfo.by_action. @@ -28,3 +35,87 @@ def get_action_type(args_list, action_type): return args return struct(action = action_type, args = tuple(), files = depset([])) + +def validate_nested_args(*, nested_args, variables, actions, label, fail = fail): + """Validates the typing for an nested_args invocation. + + Args: + nested_args: (NestedArgsInfo) The nested_args to validate + variables: (Dict[str, VariableInfo]) A mapping from variable name to + the metadata (variable type and valid actions). + actions: (List[ActionTypeInfo]) The actions we require these variables + to be valid for. + label: (Label) The label of the rule we're currently validating. + Used for error messages. + fail: The fail function. Use for testing only. + """ + stack = [(nested_args, {})] + + for _ in range(9999999): + if not stack: + break + nested_args, overrides = stack.pop() + if nested_args.iterate_over != None or nested_args.unwrap_options: + # Make sure we don't keep using the same object. + overrides = dict(**overrides) + + if nested_args.iterate_over != None: + type = get_type( + name = nested_args.iterate_over, + variables = variables, + overrides = overrides, + actions = actions, + args_label = label, + nested_label = nested_args.label, + fail = fail, + ) + if type["name"] == "list": + # Rewrite the type of the thing we iterate over from a List[T] + # to a T. + overrides[nested_args.iterate_over] = type["elements"] + elif type["name"] == "option" and type["elements"]["name"] == "list": + # Rewrite Option[List[T]] to T. + overrides[nested_args.iterate_over] = type["elements"]["elements"] + else: + fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"])) + + # 1) Validate variables marked with after_option_unwrap = False. + # 2) Unwrap Option[T] to T as required. + # 3) Validate variables marked with after_option_unwrap = True. + for after_option_unwrap in [False, True]: + for var_name, requirements in nested_args.requires_types.items(): + for requirement in requirements: + if requirement.after_option_unwrap == after_option_unwrap: + type = get_type( + name = var_name, + variables = variables, + overrides = overrides, + actions = actions, + args_label = label, + nested_label = nested_args.label, + fail = fail, + ) + if type["name"] not in requirement.valid_types: + fail("{msg}, but {var_name} has type {type}".format( + var_name = var_name, + msg = requirement.msg, + type = type["repr"], + )) + + # Only unwrap the options after the first iteration of this loop. + if not after_option_unwrap: + for var in nested_args.unwrap_options: + type = get_type( + name = var, + variables = variables, + overrides = overrides, + actions = actions, + args_label = label, + nested_label = nested_args.label, + fail = fail, + ) + if type["name"] == "option": + overrides[var] = type["elements"] + + for child in nested_args.nested: + stack.append((child, overrides)) diff --git a/cc/toolchains/impl/nested_args.bzl b/cc/toolchains/impl/nested_args.bzl index dda7498..ed83cf1 100644 --- a/cc/toolchains/impl/nested_args.bzl +++ b/cc/toolchains/impl/nested_args.bzl @@ -13,8 +13,10 @@ # limitations under the License. """Helper functions for working with args.""" +load("@bazel_skylib//lib:structs.bzl", "structs") load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value") -load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo") +load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo") +load(":collect.bzl", "collect_files", "collect_provider") visibility([ "//cc/toolchains", @@ -48,6 +50,126 @@ cc_args( iterate_over = "//toolchains/variables:foo_list", """ +# @unsorted-dict-items. +NESTED_ARGS_ATTRS = { + "args": attr.string_list( + doc = """json-encoded arguments to be added to the command-line. + +Usage: +cc_args( + ..., + args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")] +) + +This is equivalent to flag_group(flags = ["--foo", "%{foo}"]) + +Mutually exclusive with nested. +""", + ), + "nested": attr.label_list( + providers = [NestedArgsInfo], + doc = """nested_args that should be added on the command-line. + +Mutually exclusive with args.""", + ), + "data": attr.label_list( + allow_files = True, + doc = """Files required to add this argument to the command-line. + +For example, a flag that sets the header directory might add the headers in that +directory as additional files. +""", + ), + "variables": attr.label_list( + providers = [VariableInfo], + doc = "Variables to be used in substitutions", + ), + "iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"), + "requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"), + "requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"), + "requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"), + "requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"), + "requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"), + "requires_equal_value": attr.string(), +} + +def args_wrapper_macro(*, name, rule, args = [], **kwargs): + """Invokes a rule by converting args to attributes. + + Args: + name: (str) The name of the target. + rule: (rule) The rule to invoke. Either cc_args or cc_nested_args. + args: (List[str|Formatted]) A list of either strings, or function calls + from format.bzl. For example: + ["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")] + **kwargs: kwargs to pass through into the rule invocation. + """ + out_args = [] + vars = [] + if type(args) != "list": + fail("Args must be a list in %s" % native.package_relative_label(name)) + for arg in args: + if type(arg) == "string": + out_args.append(raw_string(arg)) + elif getattr(arg, "format_type") == "format_arg": + arg = structs.to_dict(arg) + if arg["value"] == None: + out_args.append(arg) + else: + var = arg.pop("value") + + # Swap the variable from a label to an index. This allows us to + # actually get the providers in a rule. + out_args.append(struct(value = len(vars), **arg)) + vars.append(var) + else: + fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg)) + + rule( + name = name, + args = [json.encode(arg) for arg in out_args], + variables = vars, + **kwargs + ) + +def _var(target): + if target == None: + return None + return target[VariableInfo].name + +# TODO: Consider replacing this with a subrule in the future. However, maybe not +# for a long time, since it'll break compatibility with all bazel versions < 7. +def nested_args_provider_from_ctx(ctx): + """Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS. + + Args: + ctx: The rule context + Returns: + NestedArgsInfo + """ + variables = collect_provider(ctx.attr.variables, VariableInfo) + args = [] + for arg in ctx.attr.args: + arg = json.decode(arg) + if "value" in arg: + if arg["value"] != None: + arg["value"] = variables[arg["value"]] + args.append(struct(**arg)) + + return nested_args_provider( + label = ctx.label, + args = args, + nested = collect_provider(ctx.attr.nested, NestedArgsInfo), + files = collect_files(ctx.attr.data), + iterate_over = _var(ctx.attr.iterate_over), + requires_not_none = _var(ctx.attr.requires_not_none), + requires_none = _var(ctx.attr.requires_none), + requires_true = _var(ctx.attr.requires_true), + requires_false = _var(ctx.attr.requires_false), + requires_equal = _var(ctx.attr.requires_equal), + requires_equal_value = ctx.attr.requires_equal_value, + ) + def raw_string(s): """Constructs metadata for creating a raw string. diff --git a/cc/toolchains/nested_args.bzl b/cc/toolchains/nested_args.bzl new file mode 100644 index 0000000..e4e3d53 --- /dev/null +++ b/cc/toolchains/nested_args.bzl @@ -0,0 +1,45 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""All providers for rule-based bazel toolchain config.""" + +load( + "//cc/toolchains/impl:nested_args.bzl", + "NESTED_ARGS_ATTRS", + "args_wrapper_macro", + "nested_args_provider_from_ctx", +) +load( + ":cc_toolchain_info.bzl", + "NestedArgsInfo", +) + +visibility("public") + +_cc_nested_args = rule( + implementation = lambda ctx: [nested_args_provider_from_ctx(ctx)], + attrs = NESTED_ARGS_ATTRS, + provides = [NestedArgsInfo], + doc = """Declares a list of arguments bound to a set of actions. + +Roughly equivalent to ctx.actions.args() + +Examples: + cc_nested_args( + name = "warnings_as_errors", + args = ["-Werror"], + ) +""", +) + +cc_nested_args = lambda **kwargs: args_wrapper_macro(rule = _cc_nested_args, **kwargs) diff --git a/tests/rule_based_toolchain/variables/BUILD b/tests/rule_based_toolchain/variables/BUILD index 80928c7..5f7a5a6 100644 --- a/tests/rule_based_toolchain/variables/BUILD +++ b/tests/rule_based_toolchain/variables/BUILD @@ -1,3 +1,5 @@ +load("//cc/toolchains:format.bzl", "format_arg") +load("//cc/toolchains:nested_args.bzl", "cc_nested_args") load("//cc/toolchains/impl:variables.bzl", "cc_builtin_variables", "cc_variable", "types") load("//tests/rule_based_toolchain:analysis_test_suite.bzl", "analysis_test_suite") load(":variables_test.bzl", "TARGETS", "TESTS") @@ -7,6 +9,11 @@ cc_variable( type = types.string, ) +cc_variable( + name = "optional_list", + type = types.option(types.list(types.string)), +) + cc_variable( name = "str_list", type = types.list(types.string), @@ -28,15 +35,104 @@ cc_variable( cc_variable( name = "struct_list", + actions = ["//tests/rule_based_toolchain/actions:c_compile"], type = types.list(types.struct( nested_str = types.string, nested_str_list = types.list(types.string), )), ) +cc_variable( + name = "struct_list.nested_str_list", + type = types.unknown, +) + +# Dots in the name confuse the test rules. +# It would end up generating targets.struct_list.nested_str_list. +alias( + name = "nested_str_list", + actual = ":struct_list.nested_str_list", +) + +cc_nested_args( + name = "simple_str", + args = [format_arg("%s", ":str")], +) + +cc_nested_args( + name = "list_not_allowed", + args = [format_arg("%s", ":str_list")], +) + +cc_nested_args( + name = "iterate_over_list", + args = [format_arg("%s")], + iterate_over = ":str_list", +) + +cc_nested_args( + name = "iterate_over_non_list", + args = ["--foo"], + iterate_over = ":str", +) + +cc_nested_args( + name = "str_not_a_bool", + args = ["--foo"], + requires_true = ":str", +) + +cc_nested_args( + name = "str_equal", + args = ["--foo"], + requires_equal = ":str", + requires_equal_value = "bar", +) + +cc_nested_args( + name = "inner_iter", + args = [format_arg("%s")], + iterate_over = ":struct_list.nested_str_list", +) + +cc_nested_args( + name = "outer_iter", + iterate_over = ":struct_list", + nested = [":inner_iter"], +) + +cc_nested_args( + name = "bad_inner_iter", + args = [format_arg("%s", ":struct_list.nested_str_list")], +) + +cc_nested_args( + name = "bad_outer_iter", + iterate_over = ":struct_list", + nested = [":bad_inner_iter"], +) + +cc_nested_args( + name = "bad_nested_optional", + args = [format_arg("%s", ":str_option")], +) + +cc_nested_args( + name = "good_nested_optional", + args = [format_arg("%s", ":str_option")], + requires_not_none = ":str_option", +) + +cc_nested_args( + name = "optional_list_iter", + args = ["--foo"], + iterate_over = ":optional_list", +) + cc_builtin_variables( name = "variables", srcs = [ + ":optional_list", ":str", ":str_list", ":str_option", diff --git a/tests/rule_based_toolchain/variables/variables_test.bzl b/tests/rule_based_toolchain/variables/variables_test.bzl index a3cf843..98a64fd 100644 --- a/tests/rule_based_toolchain/variables/variables_test.bzl +++ b/tests/rule_based_toolchain/variables/variables_test.bzl @@ -13,13 +13,20 @@ # limitations under the License. """Tests for variables rule.""" -load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "VariableInfo") +load("//cc/toolchains:cc_toolchain_info.bzl", "ActionTypeInfo", "BuiltinVariablesInfo", "NestedArgsInfo", "VariableInfo") +load("//cc/toolchains/impl:args_utils.bzl", _validate_nested_args = "validate_nested_args") +load( + "//cc/toolchains/impl:nested_args.bzl", + "FORMAT_ARGS_ERR", + "REQUIRES_TRUE_ERR", +) load("//cc/toolchains/impl:variables.bzl", "types", _get_type = "get_type") load("//tests/rule_based_toolchain:subjects.bzl", "result_fn_wrapper", "subjects") visibility("private") get_type = result_fn_wrapper(_get_type) +validate_nested_args = result_fn_wrapper(_validate_nested_args) _ARGS_LABEL = Label("//:args") _NESTED_LABEL = Label("//:nested_vars") @@ -56,6 +63,7 @@ def _get_types_test(env, targets): expect_type("unknown").err().contains( """The variable unknown does not exist. Did you mean one of the following? +optional_list str str_list """, @@ -110,11 +118,74 @@ nested_str_list: List[string]""") }, ).ok().equals(types.string) +def _variable_validation_test(env, targets): + c_compile = targets.c_compile[ActionTypeInfo] + cpp_compile = targets.cpp_compile[ActionTypeInfo] + variables = targets.variables[BuiltinVariablesInfo].variables + + def _expect_validated(target, expr = None, actions = []): + return env.expect.that_value( + validate_nested_args( + nested_args = target[NestedArgsInfo], + variables = variables, + actions = actions, + label = _ARGS_LABEL, + ), + expr = expr, + # Type is Result[None] + factory = subjects.result(subjects.unknown), + ) + + _expect_validated(targets.simple_str, expr = "simple_str").ok() + _expect_validated(targets.list_not_allowed).err().equals( + FORMAT_ARGS_ERR + ", but str_list has type List[string]", + ) + _expect_validated(targets.iterate_over_list, expr = "iterate_over_list").ok() + _expect_validated(targets.iterate_over_non_list, expr = "iterate_over_non_list").err().equals( + "Attempting to iterate over str, but it was not a list - it was a string", + ) + _expect_validated(targets.str_not_a_bool, expr = "str_not_a_bool").err().equals( + REQUIRES_TRUE_ERR + ", but str has type string", + ) + _expect_validated(targets.str_equal, expr = "str_equal").ok() + _expect_validated(targets.inner_iter, expr = "inner_iter_standalone").err().equals( + 'Attempted to access "struct_list.nested_str_list", but "struct_list" was not a struct - it had type List[struct(nested_str=string, nested_str_list=List[string])]. Maybe you meant to use iterate_over.', + ) + + _expect_validated(targets.outer_iter, actions = [c_compile], expr = "outer_iter_valid_action").ok() + _expect_validated(targets.outer_iter, actions = [c_compile, cpp_compile], expr = "outer_iter_missing_action").err().equals( + "The variable %s is inaccessible from the action %s. This is required because it is referenced in %s, which is included by %s, which references that action" % (targets.struct_list.label, cpp_compile.label, targets.outer_iter.label, _ARGS_LABEL), + ) + + _expect_validated(targets.bad_outer_iter, expr = "bad_outer_iter").err().equals( + FORMAT_ARGS_ERR + ", but struct_list.nested_str_list has type List[string]", + ) + + _expect_validated(targets.optional_list_iter, expr = "optional_list_iter").ok() + + _expect_validated(targets.bad_nested_optional, expr = "bad_nested_optional").err().equals( + FORMAT_ARGS_ERR + ", but str_option has type Option[string]", + ) + _expect_validated(targets.good_nested_optional, expr = "good_nested_optional").ok() + TARGETS = [ "//tests/rule_based_toolchain/actions:c_compile", "//tests/rule_based_toolchain/actions:cpp_compile", + ":bad_nested_optional", + ":bad_outer_iter", + ":good_nested_optional", + ":inner_iter", + ":iterate_over_list", + ":iterate_over_non_list", + ":list_not_allowed", + ":nested_str_list", + ":optional_list_iter", + ":outer_iter", + ":simple_str", ":str", + ":str_equal", ":str_list", + ":str_not_a_bool", ":str_option", ":struct", ":struct_list", @@ -125,4 +196,5 @@ TARGETS = [ TESTS = { "types_represent_correctly_test": _types_represent_correctly_test, "get_types_test": _get_types_test, + "variable_validation_test": _variable_validation_test, }