Refactor AddArgsInfo into ExpandArgsInfo

BEGIN_PUBLIC
Refactor AddArgsInfo into ExpandArgsInfo

This allows us to create a similar mechanism to the current toolchain, while maintaining type safety.
END_PUBLIC

PiperOrigin-RevId: 615939056
Change-Id: I9b6763150194f8a76dfd8da730a3e2d45accbe20
This commit is contained in:
Googler 2024-03-14 16:26:51 -07:00 committed by Copybara-Service
parent 69c9748afb
commit bbb0615a87
6 changed files with 135 additions and 34 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""All providers for rule-based bazel toolchain config.""" """All providers for rule-based bazel toolchain config."""
load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
load( load(
"//cc/toolchains/impl:collect.bzl", "//cc/toolchains/impl:collect.bzl",
"collect_action_types", "collect_action_types",
@ -22,32 +23,41 @@ load(
load( load(
":cc_toolchain_info.bzl", ":cc_toolchain_info.bzl",
"ActionTypeSetInfo", "ActionTypeSetInfo",
"AddArgsInfo",
"ArgsInfo", "ArgsInfo",
"ArgsListInfo", "ArgsListInfo",
"ExpandArgsInfo",
"FeatureConstraintInfo", "FeatureConstraintInfo",
) )
visibility("public") visibility("public")
def _cc_args_impl(ctx): def _cc_args_impl(ctx):
add_args = [AddArgsInfo( if not ctx.attr.args and not ctx.attr.env:
label = ctx.label, fail("cc_args requires at least one of args and env")
args = tuple(ctx.attr.args),
files = depset([]),
)]
actions = collect_action_types(ctx.attr.actions) actions = collect_action_types(ctx.attr.actions)
files = collect_files(ctx.attr.data) files = collect_files(ctx.attr.data)
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo) requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
expand = None
if ctx.attr.args:
# TODO: This is temporary until cc_expand_args is implemented.
expand = ExpandArgsInfo(
label = ctx.label,
expand = tuple(),
iterate_over = None,
files = files,
requires_types = {},
legacy_flag_group = flag_group(flags = ctx.attr.args),
)
args = ArgsInfo( args = ArgsInfo(
label = ctx.label, label = ctx.label,
actions = actions, actions = actions,
requires_any_of = tuple(requires), requires_any_of = tuple(requires),
files = files, expand = expand,
args = add_args,
env = ctx.attr.env, env = ctx.attr.env,
files = files,
) )
return [ return [
args, args,
@ -74,7 +84,6 @@ See @rules_cc//cc/toolchains/actions:all for valid options.
""", """,
), ),
"args": attr.string_list( "args": attr.string_list(
mandatory = True,
doc = """Arguments that should be added to the command-line. doc = """Arguments that should be added to the command-line.
These are evaluated in order, with earlier args appearing earlier in the These are evaluated in order, with earlier args appearing earlier in the

View File

@ -45,13 +45,16 @@ ActionTypeSetInfo = provider(
}, },
) )
AddArgsInfo = provider( ExpandArgsInfo = provider(
doc = "A provider representation of Args.add/add_all/add_joined parameters", doc = "A provider representation of Args.add/add_all/add_joined parameters",
# @unsorted-dict-items # @unsorted-dict-items
fields = { fields = {
"label": "(Label) The label defining this provider. Place in error messages to simplify debugging", "label": "(Label) The label defining this provider. Place in error messages to simplify debugging",
"args": "(Sequence[str]) The command-line arguments to add", "expand": "(Sequence[ExpandArgsInfo]) The nested arg expansion. Mutually exclusive with args",
"iterate_over": "(Optional[str]) The variable to iterate over",
"files": "(depset[File]) The files required to use this variable", "files": "(depset[File]) The files required to use this variable",
"requires_types": "(dict[str, str]) A mapping from variables to their expected type name (not type). This means that we can require the generic type Option, rather than an Option[T]",
"legacy_flag_group": "(flag_group) The flag_group this corresponds to",
}, },
) )
@ -62,7 +65,7 @@ ArgsInfo = provider(
"label": "(Label) The label defining this provider. Place in error messages to simplify debugging", "label": "(Label) The label defining this provider. Place in error messages to simplify debugging",
"actions": "(depset[ActionTypeInfo]) The set of actions this is associated with", "actions": "(depset[ActionTypeInfo]) The set of actions this is associated with",
"requires_any_of": "(Sequence[FeatureConstraintInfo]) This will be enabled if any of the listed predicates are met. Equivalent to with_features", "requires_any_of": "(Sequence[FeatureConstraintInfo]) This will be enabled if any of the listed predicates are met. Equivalent to with_features",
"args": "(Sequence[AddArgsInfo]) The command-line arguments to add.", "expand": "(Optional[ExpandArgsInfo]) The args to expand. Equivalent to a flag group.",
"files": "(depset[File]) Files required for the args", "files": "(depset[File]) Files required for the args",
"env": "(dict[str, str]) Environment variables to apply", "env": "(dict[str, str]) Environment variables to apply",
}, },

View File

@ -20,7 +20,6 @@ load(
legacy_env_set = "env_set", legacy_env_set = "env_set",
legacy_feature = "feature", legacy_feature = "feature",
legacy_feature_set = "feature_set", legacy_feature_set = "feature_set",
legacy_flag_group = "flag_group",
legacy_flag_set = "flag_set", legacy_flag_set = "flag_set",
legacy_tool = "tool", legacy_tool = "tool",
legacy_with_feature_set = "with_feature_set", legacy_with_feature_set = "with_feature_set",
@ -50,10 +49,14 @@ def convert_feature_constraint(constraint):
not_features = sorted([ft.name for ft in constraint.none_of.to_list()]), not_features = sorted([ft.name for ft in constraint.none_of.to_list()]),
) )
def _convert_add_arg(add_arg): def convert_args(args):
return [legacy_flag_group(flags = list(add_arg.args))] """Converts an ArgsInfo to flag_sets and env_sets.
def _convert_args(args): Args:
args: (ArgsInfo) The args to convert
Returns:
struct(flag_sets = List[flag_set], env_sets = List[env_sets])
"""
actions = _convert_actions(args.actions) actions = _convert_actions(args.actions)
with_features = [ with_features = [
convert_feature_constraint(fc) convert_feature_constraint(fc)
@ -61,14 +64,11 @@ def _convert_args(args):
] ]
flag_sets = [] flag_sets = []
if args.args: if args.expand != None:
flag_groups = []
for add_args in args.args:
flag_groups.extend(_convert_add_arg(add_args))
flag_sets.append(legacy_flag_set( flag_sets.append(legacy_flag_set(
actions = actions, actions = actions,
with_features = with_features, with_features = with_features,
flag_groups = flag_groups, flag_groups = [args.expand.legacy_flag_group],
)) ))
env_sets = [] env_sets = []
@ -93,7 +93,7 @@ def _convert_args_sequence(args_sequence):
flag_sets = [] flag_sets = []
env_sets = [] env_sets = []
for args in args_sequence: for args in args_sequence:
legacy_args = _convert_args(args) legacy_args = convert_args(args)
flag_sets.extend(legacy_args.flag_sets) flag_sets.extend(legacy_args.flag_sets)
env_sets.extend(legacy_args.env_sets) env_sets.extend(legacy_args.env_sets)

View File

@ -18,6 +18,17 @@ util.helper_target(
env = {"BAR": "bar"}, env = {"BAR": "bar"},
) )
util.helper_target(
cc_args,
name = "env_only",
actions = ["//tests/rule_based_toolchain/actions:all_compile"],
data = [
"//tests/rule_based_toolchain/testdata:file1",
"//tests/rule_based_toolchain/testdata:multiple",
],
env = {"BAR": "bar"},
)
analysis_test_suite( analysis_test_suite(
name = "test_suite", name = "test_suite",
targets = TARGETS, targets = TARGETS,

View File

@ -13,12 +13,24 @@
# limitations under the License. # limitations under the License.
"""Tests for the cc_args rule.""" """Tests for the cc_args rule."""
load(
"//cc:cc_toolchain_config_lib.bzl",
"env_entry",
"env_set",
"flag_group",
"flag_set",
)
load( load(
"//cc/toolchains:cc_toolchain_info.bzl", "//cc/toolchains:cc_toolchain_info.bzl",
"ActionTypeInfo", "ActionTypeInfo",
"ArgsInfo", "ArgsInfo",
"ArgsListInfo", "ArgsListInfo",
) )
load(
"//cc/toolchains/impl:legacy_converter.bzl",
"convert_args",
)
load("//tests/rule_based_toolchain:subjects.bzl", "subjects")
visibility("private") visibility("private")
@ -28,13 +40,17 @@ _SIMPLE_FILES = [
"tests/rule_based_toolchain/testdata/multiple2", "tests/rule_based_toolchain/testdata/multiple2",
] ]
def _test_simple_args_impl(env, targets): _CONVERTED_ARGS = subjects.struct(
flag_sets = subjects.collection,
env_sets = subjects.collection,
)
def _simple_test(env, targets):
simple = env.expect.that_target(targets.simple).provider(ArgsInfo) simple = env.expect.that_target(targets.simple).provider(ArgsInfo)
simple.actions().contains_exactly([ simple.actions().contains_exactly([
targets.c_compile.label, targets.c_compile.label,
targets.cpp_compile.label, targets.cpp_compile.label,
]) ])
simple.args().contains_exactly([targets.simple.label])
simple.env().contains_exactly({"BAR": "bar"}) simple.env().contains_exactly({"BAR": "bar"})
simple.files().contains_exactly(_SIMPLE_FILES) simple.files().contains_exactly(_SIMPLE_FILES)
@ -44,12 +60,54 @@ def _test_simple_args_impl(env, targets):
c_compile.args().contains_exactly([targets.simple[ArgsInfo]]) c_compile.args().contains_exactly([targets.simple[ArgsInfo]])
c_compile.files().contains_exactly(_SIMPLE_FILES) c_compile.files().contains_exactly(_SIMPLE_FILES)
converted = env.expect.that_value(
convert_args(targets.simple[ArgsInfo]),
factory = _CONVERTED_ARGS,
)
converted.env_sets().contains_exactly([env_set(
actions = ["c_compile", "cpp_compile"],
env_entries = [env_entry(key = "BAR", value = "bar")],
)])
converted.flag_sets().contains_exactly([flag_set(
actions = ["c_compile", "cpp_compile"],
flag_groups = [flag_group(flags = ["--foo", "foo"])],
)])
def _env_only_test(env, targets):
env_only = env.expect.that_target(targets.env_only).provider(ArgsInfo)
env_only.actions().contains_exactly([
targets.c_compile.label,
targets.cpp_compile.label,
])
env_only.env().contains_exactly({"BAR": "bar"})
env_only.files().contains_exactly(_SIMPLE_FILES)
c_compile = env.expect.that_target(targets.simple).provider(ArgsListInfo).by_action().get(
targets.c_compile[ActionTypeInfo],
)
c_compile.files().contains_exactly(_SIMPLE_FILES)
converted = env.expect.that_value(
convert_args(targets.env_only[ArgsInfo]),
factory = _CONVERTED_ARGS,
)
converted.env_sets().contains_exactly([env_set(
actions = ["c_compile", "cpp_compile"],
env_entries = [env_entry(key = "BAR", value = "bar")],
)])
converted.flag_sets().contains_exactly([])
TARGETS = [ TARGETS = [
":simple", ":simple",
":env_only",
"//tests/rule_based_toolchain/actions:c_compile", "//tests/rule_based_toolchain/actions:c_compile",
"//tests/rule_based_toolchain/actions:cpp_compile", "//tests/rule_based_toolchain/actions:cpp_compile",
] ]
# @unsorted-dict-items
TESTS = { TESTS = {
"simple_test": _test_simple_args_impl, "simple_test": _simple_test,
"env_only_test_test": _env_only_test,
} }

View File

@ -21,9 +21,9 @@ load(
"ActionTypeConfigSetInfo", "ActionTypeConfigSetInfo",
"ActionTypeInfo", "ActionTypeInfo",
"ActionTypeSetInfo", "ActionTypeSetInfo",
"AddArgsInfo",
"ArgsInfo", "ArgsInfo",
"ArgsListInfo", "ArgsListInfo",
"ExpandArgsInfo",
"FeatureConstraintInfo", "FeatureConstraintInfo",
"FeatureInfo", "FeatureInfo",
"FeatureSetInfo", "FeatureSetInfo",
@ -40,6 +40,10 @@ visibility("//tests/rule_based_toolchain/...")
# This makes it rather awkward for copybara. # This makes it rather awkward for copybara.
runfiles_subject = lambda value, meta: _subjects.depset_file(value.files, meta = meta) runfiles_subject = lambda value, meta: _subjects.depset_file(value.files, meta = meta)
# The string type has .equals(), which is all we can really do for an unknown
# type.
unknown_subject = _subjects.str
# buildifier: disable=name-conventions # buildifier: disable=name-conventions
_ActionTypeFactory = generate_factory( _ActionTypeFactory = generate_factory(
ActionTypeInfo, ActionTypeInfo,
@ -102,13 +106,27 @@ _FeatureConstraintFactory = generate_factory(
), ),
) )
_EXPAND_ARGS_FLAGS = dict(
expand = None,
files = _subjects.depset_file,
iterate_over = optional_subject(_subjects.str),
legacy_flag_group = unknown_subject,
requires_types = _subjects.dict,
)
# buildifier: disable=name-conventions # buildifier: disable=name-conventions
_AddArgsFactory = generate_factory( _FakeExpandArgsFactory = generate_factory(
AddArgsInfo, ExpandArgsInfo,
"AddArgsInfo", "ExpandArgsInfo",
dict( _EXPAND_ARGS_FLAGS,
args = _subjects.collection, )
files = _subjects.depset_file,
# buildifier: disable=name-conventions
_ExpandArgsFactory = generate_factory(
ExpandArgsInfo,
"ExpandArgsInfo",
_EXPAND_ARGS_FLAGS | dict(
expand = ProviderSequence(_FakeExpandArgsFactory),
), ),
) )
@ -118,9 +136,10 @@ _ArgsFactory = generate_factory(
"ArgsInfo", "ArgsInfo",
dict( dict(
actions = ProviderDepset(_ActionTypeFactory), actions = ProviderDepset(_ActionTypeFactory),
args = ProviderSequence(_AddArgsFactory),
env = _subjects.dict, env = _subjects.dict,
files = _subjects.depset_file, files = _subjects.depset_file,
# Use .factory so it's not inlined.
expand = optional_subject(_ExpandArgsFactory.factory),
requires_any_of = ProviderSequence(_FeatureConstraintFactory), requires_any_of = ProviderSequence(_FeatureConstraintFactory),
), ),
) )
@ -201,7 +220,7 @@ _ToolchainConfigFactory = generate_factory(
FACTORIES = [ FACTORIES = [
_ActionTypeFactory, _ActionTypeFactory,
_ActionTypeSetFactory, _ActionTypeSetFactory,
_AddArgsFactory, _ExpandArgsFactory,
_ArgsFactory, _ArgsFactory,
_ArgsListFactory, _ArgsListFactory,
_MutuallyExclusiveCategoryFactory, _MutuallyExclusiveCategoryFactory,
@ -217,6 +236,7 @@ result_fn_wrapper = _result_fn_wrapper
subjects = struct( subjects = struct(
**(structs.to_dict(_subjects) | dict( **(structs.to_dict(_subjects) | dict(
unknown = unknown_subject,
result = result_subject, result = result_subject,
optional = optional_subject, optional = optional_subject,
struct = struct_subject, struct = struct_subject,