diff --git a/tests/rule_based_toolchain/generics.bzl b/tests/rule_based_toolchain/generics.bzl new file mode 100644 index 0000000..2244b0b --- /dev/null +++ b/tests/rule_based_toolchain/generics.bzl @@ -0,0 +1,122 @@ +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of a result type for use with rules_testing.""" + +load("@rules_testing//lib:truth.bzl", "subjects") + +visibility("//tests/rule_based_toolchain/...") + +def result_fn_wrapper(fn): + """Wraps a function that may fail in a type similar to rust's Result type. + + An example usage is the following: + # Implementation file + def get_only(value, fail=fail): + if len(value) == 1: + return value[0] + elif not value: + fail("Unexpectedly empty") + else: + fail("%r had length %d, expected 1" % (value, len(value)) + + # Test file + load("...", _fn=fn) + + fn = result_fn_wrapper(_fn) + int_result = result_subject(subjects.int) + + def my_test(env, _): + env.expect.that_value(fn([]), factory=int_result) + .err().equals("Unexpectedly empty") + env.expect.that_value(fn([1]), factory=int_result) + .ok().equals(1) + env.expect.that_value(fn([1, 2]), factory=int_result) + .err().contains("had length 2, expected 1") + + Args: + fn: A function that takes in a parameter fail and calls it on failure. + + Returns: + On success: struct(ok = , err = None) + On failure: struct(ok = None, err = + """ + + def new_fn(*args, **kwargs): + # Use a mutable type so that the fail_wrapper can modify this. + failures = [] + + def fail_wrapper(msg): + failures.append(msg) + + result = fn(fail = fail_wrapper, *args, **kwargs) + if failures: + return struct(ok = None, err = failures[0]) + else: + return struct(ok = result, err = None) + + return new_fn + +def result_subject(factory): + """A subject factory for Result. + + Args: + factory: A subject factory for T + Returns: + A subject factory for Result + """ + + def new_factory(value, *, meta): + def ok(): + if value.err != None: + meta.add_failure("Wanted a value, but got an error", value.err) + return factory(value.ok, meta = meta.derive("ok()")) + + def err(): + if value.err == None: + meta.add_failure("Wanted an error, but got a value", value.ok) + return subjects.str(value.err, meta = meta.derive("err()")) + + return struct(ok = ok, err = err) + + return new_factory + +def optional_subject(factory): + """A subject factory for Optional. + + Args: + factory: A subject factory for T + Returns: + A subject factory for Optional + """ + + def new_factory(value, *, meta): + def some(): + if value == None: + meta.add_failure("Wanted a value, but got None", None) + return factory(value, meta = meta) + + def is_none(): + if value != None: + meta.add_failure("Wanted None, but got a value", value) + + return struct(some = some, is_none = is_none) + + return new_factory + +# Curry subjects.struct so the type is actually generic. +struct_subject = lambda **attrs: lambda value, *, meta: subjects.struct( + value, + meta = meta, + attrs = attrs, +) diff --git a/tests/rule_based_toolchain/subjects.bzl b/tests/rule_based_toolchain/subjects.bzl index 53daeb5..87f6cb4 100644 --- a/tests/rule_based_toolchain/subjects.bzl +++ b/tests/rule_based_toolchain/subjects.bzl @@ -13,7 +13,8 @@ # limitations under the License. """Test subjects for cc_toolchain_info providers.""" -load("@rules_testing//lib:truth.bzl", "subjects") +load("@bazel_skylib//lib:structs.bzl", "structs") +load("@rules_testing//lib:truth.bzl", _subjects = "subjects") load( "//cc/toolchains:cc_toolchain_info.bzl", "ActionConfigInfo", @@ -29,15 +30,16 @@ load( "ToolInfo", ) load(":generate_factory.bzl", "ProviderDepset", "ProviderSequence", "generate_factory") +load(":generics.bzl", "optional_subject", "result_subject", "struct_subject", _result_fn_wrapper = "result_fn_wrapper") -visibility("private") +visibility("//tests/rule_based_toolchain/...") # buildifier: disable=name-conventions _ActionTypeFactory = generate_factory( ActionTypeInfo, "ActionTypeInfo", dict( - name = subjects.str, + name = _subjects.str, ), ) @@ -54,17 +56,17 @@ _ActionTypeSetFactory = generate_factory( _MutuallyExclusiveCategoryFactory = generate_factory( MutuallyExclusiveCategoryInfo, "MutuallyExclusiveCategoryInfo", - dict(name = subjects.str), + dict(name = _subjects.str), ) _FEATURE_FLAGS = dict( - name = subjects.str, - enabled = subjects.bool, + name = _subjects.str, + enabled = _subjects.bool, flag_sets = None, implies = None, requires_any_of = None, provides = ProviderSequence(_MutuallyExclusiveCategoryFactory), - known = subjects.bool, + known = _subjects.bool, overrides = None, ) @@ -98,8 +100,8 @@ _AddArgsFactory = generate_factory( AddArgsInfo, "AddArgsInfo", dict( - args = subjects.collection, - files = subjects.depset_file, + args = _subjects.collection, + files = _subjects.depset_file, ), ) @@ -110,8 +112,8 @@ _ArgsFactory = generate_factory( dict( actions = ProviderDepset(_ActionTypeFactory), args = ProviderSequence(_AddArgsFactory), - env = subjects.dict, - files = subjects.depset_file, + env = _subjects.dict, + files = _subjects.depset_file, requires_any_of = ProviderSequence(_FeatureConstraintFactory), ), ) @@ -132,9 +134,10 @@ _ToolFactory = generate_factory( ToolInfo, "ToolInfo", dict( - exe = subjects.file, - runifles = subjects.depset_file, + exe = _subjects.file, + runfiles = _subjects.depset_file, requires_any_of = ProviderSequence(_FeatureConstraintFactory), + execution_requirements = _subjects.collection, ), ) @@ -144,11 +147,11 @@ _ActionConfigFactory = generate_factory( "ActionConfigInfo", dict( action = _ActionTypeFactory, - enabled = subjects.bool, + enabled = _subjects.bool, tools = ProviderSequence(_ToolFactory), flag_sets = ProviderSequence(_ArgsFactory), implies = ProviderDepset(_FeatureFactory), - files = subjects.depset_file, + files = _subjects.depset_file, ), ) @@ -185,3 +188,13 @@ FACTORIES = [ _ToolFactory, _ActionConfigSetFactory, ] + +result_fn_wrapper = _result_fn_wrapper + +subjects = struct( + **(structs.to_dict(_subjects) | dict( + result = result_subject, + optional = optional_subject, + struct = struct_subject, + ) | {factory.name: factory.factory for factory in FACTORIES}) +)