diff --git a/lib/sets.bzl b/lib/sets.bzl index ff3c9cf..f62e03d 100644 --- a/lib/sets.bzl +++ b/lib/sets.bzl @@ -43,6 +43,20 @@ def _precondition_only_sets_or_lists(*args): fail("Expected arguments to be depset or list, but found type %s: %r" % (t, a)) +def _depset_to_list(val): + """Converts a depset to a list. + + If the given value is a depset, will return the list representation of + the depset. Otherwise, will return the value itself. + + Args: + val: The value to be optionally converted and returned. + """ + if type(val) == _depset_type: + return val.to_list() + else: + return val + def _is_equal(a, b): """Returns whether two sets are equal. @@ -54,7 +68,11 @@ def _is_equal(a, b): True if `a` is equal to `b`, False otherwise. """ _precondition_only_sets_or_lists(a, b) - return sorted(depset(a)) == sorted(depset(b)) + + # Convert both values to a depset then back to a list to remove duplicates. + a = _depset_to_list(depset(a)) + b = _depset_to_list(depset(b)) + return sorted(a) == sorted(b) def _is_subset(a, b): """Returns whether `a` is a subset of `b`. @@ -67,8 +85,8 @@ def _is_subset(a, b): True if `a` is a subset of `b`, False otherwise. """ _precondition_only_sets_or_lists(a, b) - for e in a: - if e not in b: + for e in _depset_to_list(a): + if e not in _depset_to_list(b): return False return True @@ -85,8 +103,8 @@ def _disjoint(a, b): True if `a` and `b` are disjoint, False otherwise. """ _precondition_only_sets_or_lists(a, b) - for e in a: - if e in b: + for e in _depset_to_list(a): + if e in _depset_to_list(b): return False return True @@ -101,7 +119,7 @@ def _intersection(a, b): A set containing the elements that are in both `a` and `b`. """ _precondition_only_sets_or_lists(a, b) - return depset([e for e in a if e in b]) + return depset([e for e in _depset_to_list(a) if e in _depset_to_list(b)]) def _union(*args): """Returns the union of several sets. @@ -127,7 +145,7 @@ def _difference(a, b): A set containing the elements that are in `a` but not in `b`. """ _precondition_only_sets_or_lists(a, b) - return depset([e for e in a if e not in b]) + return depset([e for e in _depset_to_list(a) if e not in _depset_to_list(b)]) sets = struct( difference = _difference,