bazel-skylib/lib/new_sets.bzl

237 lines
5.9 KiB
Python

# Copyright 2018 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Skylib module containing common hash-set algorithms.
An empty set can be created using: `sets.make()`, or it can be created with some starting values
if you pass it an sequence: `sets.make([1, 2, 3])`. This returns a struct containing all of the
values as keys in a dictionary - this means that all passed in values must be hashable. The
values in the set can be retrieved using `sets.to_list(my_set)`.
"""
load(":dicts.bzl", "dicts")
def _make(elements = None):
"""Creates a new set.
All elements must be hashable.
Args:
elements: Optional sequence to construct the set out of.
Returns:
A set containing the passed in values.
"""
elements = elements if elements else []
return struct(_values = {e: None for e in elements})
def _copy(s):
"""Creates a new set from another set.
Args:
s: A set, as returned by `sets.make()`.
Returns:
A new set containing the same elements as `s`.
"""
return struct(_values = dict(s._values))
def _to_list(s):
"""Creates a list from the values in the set.
Args:
s: A set, as returned by `sets.make()`.
Returns:
A list of values inserted into the set.
"""
return s._values.keys()
def _insert(s, e):
"""Inserts an element into the set.
Element must be hashable. This mutates the orginal set.
Args:
s: A set, as returned by `sets.make()`.
e: The element to be inserted.
Returns:
The set `s` with `e` included.
"""
s._values[e] = None
return s
def _remove(s, e):
"""Removes an element from the set.
Element must be hashable. This mutates the orginal set.
Args:
s: A set, as returned by `sets.make()`.
e: The element to be removed.
Returns:
The set `s` with `e` removed.
"""
s._values.pop(e)
return s
def _contains(a, e):
"""Checks for the existence of an element in a set.
Args:
a: A set, as returned by `sets.make()`.
e: The element to look for.
Returns:
True if the element exists in the set, False if the element does not.
"""
return e in a._values
def _get_shorter_and_longer(a, b):
"""Returns two sets in the order of shortest and longest.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
`a`, `b` if `a` is shorter than `b` - or `b`, `a` if `b` is shorter than `a`.
"""
if _length(a) < _length(b):
return a, b
return b, a
def _is_equal(a, b):
"""Returns whether two sets are equal.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
True if `a` is equal to `b`, False otherwise.
"""
return a._values == b._values
def _is_subset(a, b):
"""Returns whether `a` is a subset of `b`.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
True if `a` is a subset of `b`, False otherwise.
"""
for e in a._values.keys():
if e not in b._values:
return False
return True
def _disjoint(a, b):
"""Returns whether two sets are disjoint.
Two sets are disjoint if they have no elements in common.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
True if `a` and `b` are disjoint, False otherwise.
"""
shorter, longer = _get_shorter_and_longer(a, b)
for e in shorter._values.keys():
if e in longer._values:
return False
return True
def _intersection(a, b):
"""Returns the intersection of two sets.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
A set containing the elements that are in both `a` and `b`.
"""
shorter, longer = _get_shorter_and_longer(a, b)
return struct(_values = {e: None for e in shorter._values.keys() if e in longer._values})
def _union(*args):
"""Returns the union of several sets.
Args:
*args: An arbitrary number of sets or lists.
Returns:
The set union of all sets or lists in `*args`.
"""
return struct(_values = dicts.add(*[s._values for s in args]))
def _difference(a, b):
"""Returns the elements in `a` that are not in `b`.
Args:
a: A set, as returned by `sets.make()`.
b: A set, as returned by `sets.make()`.
Returns:
A set containing the elements that are in `a` but not in `b`.
"""
return struct(_values = {e: None for e in a._values.keys() if e not in b._values})
def _length(s):
"""Returns the number of elements in a set.
Args:
s: A set, as returned by `sets.make()`.
Returns:
An integer representing the number of elements in the set.
"""
return len(s._values)
def _repr(s):
"""Returns a string value representing the set.
Args:
s: A set, as returned by `sets.make()`.
Returns:
A string representing the set.
"""
return repr(s._values.keys())
sets = struct(
make = _make,
copy = _copy,
to_list = _to_list,
insert = _insert,
contains = _contains,
is_equal = _is_equal,
is_subset = _is_subset,
disjoint = _disjoint,
intersection = _intersection,
union = _union,
difference = _difference,
length = _length,
remove = _remove,
repr = _repr,
str = _repr,
)