aboutsummaryrefslogtreecommitdiff
path: root/lib/new_sets.bzl
blob: cd90a30e3126a6a0c5701770b931563c927a982c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright 2018 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Skylib module containing common hash-set algorithms.

  An empty set can be created using: `sets.make()`, or it can be created with some starting values
  if you pass it an sequence: `sets.make([1, 2, 3])`. This returns a struct containing all of the
  values as keys in a dictionary - this means that all passed in values must be hashable.  The
  values in the set can be retrieved using `sets.to_list(my_set)`.

  An arbitrary object can be tested whether it is a set generated by `sets.make()` or not with the
  `types.is_set()` method in types.bzl.
"""

load(":dicts.bzl", "dicts")

def _make(elements = None):
    """Creates a new set.

    All elements must be hashable.

    Args:
      elements: Optional sequence to construct the set out of.

    Returns:
      A set containing the passed in values.
    """

    # If you change the structure of a set, you need to also update the _is_set method
    # in types.bzl.
    elements = elements if elements else []
    return struct(_values = {e: None for e in elements})

def _copy(s):
    """Creates a new set from another set.

    Args:
      s: A set, as returned by `sets.make()`.

    Returns:
      A new set containing the same elements as `s`.
    """
    return struct(_values = dict(s._values))

def _to_list(s):
    """Creates a list from the values in the set.

    Args:
      s: A set, as returned by `sets.make()`.

    Returns:
      A list of values inserted into the set.
    """
    return s._values.keys()

def _insert(s, e):
    """Inserts an element into the set.

    Element must be hashable.  This mutates the original set.

    Args:
      s: A set, as returned by `sets.make()`.
      e: The element to be inserted.

    Returns:
       The set `s` with `e` included.
    """
    s._values[e] = None
    return s

def _remove(s, e):
    """Removes an element from the set.

    Element must be hashable.  This mutates the original set.

    Args:
      s: A set, as returned by `sets.make()`.
      e: The element to be removed.

    Returns:
       The set `s` with `e` removed.
    """
    s._values.pop(e)
    return s

def _contains(a, e):
    """Checks for the existence of an element in a set.

    Args:
      a: A set, as returned by `sets.make()`.
      e: The element to look for.

    Returns:
      True if the element exists in the set, False if the element does not.
    """
    return e in a._values

def _get_shorter_and_longer(a, b):
    """Returns two sets in the order of shortest and longest.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      `a`, `b` if `a` is shorter than `b` - or `b`, `a` if `b` is shorter than `a`.
    """
    if _length(a) < _length(b):
        return a, b
    return b, a

def _is_equal(a, b):
    """Returns whether two sets are equal.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      True if `a` is equal to `b`, False otherwise.
    """
    return a._values == b._values

def _is_subset(a, b):
    """Returns whether `a` is a subset of `b`.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      True if `a` is a subset of `b`, False otherwise.
    """
    for e in a._values.keys():
        if e not in b._values:
            return False
    return True

def _disjoint(a, b):
    """Returns whether two sets are disjoint.

    Two sets are disjoint if they have no elements in common.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      True if `a` and `b` are disjoint, False otherwise.
    """
    shorter, longer = _get_shorter_and_longer(a, b)
    for e in shorter._values.keys():
        if e in longer._values:
            return False
    return True

def _intersection(a, b):
    """Returns the intersection of two sets.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      A set containing the elements that are in both `a` and `b`.
    """
    shorter, longer = _get_shorter_and_longer(a, b)
    return struct(_values = {e: None for e in shorter._values.keys() if e in longer._values})

def _union(*args):
    """Returns the union of several sets.

    Args:
      *args: An arbitrary number of sets.

    Returns:
      The set union of all sets in `*args`.
    """
    return struct(_values = dicts.add(*[s._values for s in args]))

def _difference(a, b):
    """Returns the elements in `a` that are not in `b`.

    Args:
      a: A set, as returned by `sets.make()`.
      b: A set, as returned by `sets.make()`.

    Returns:
      A set containing the elements that are in `a` but not in `b`.
    """
    return struct(_values = {e: None for e in a._values.keys() if e not in b._values})

def _length(s):
    """Returns the number of elements in a set.

    Args:
      s: A set, as returned by `sets.make()`.

    Returns:
      An integer representing the number of elements in the set.
    """
    return len(s._values)

def _repr(s):
    """Returns a string value representing the set.

    Args:
      s: A set, as returned by `sets.make()`.

    Returns:
      A string representing the set.
    """
    return repr(s._values.keys())

sets = struct(
    make = _make,
    copy = _copy,
    to_list = _to_list,
    insert = _insert,
    contains = _contains,
    is_equal = _is_equal,
    is_subset = _is_subset,
    disjoint = _disjoint,
    intersection = _intersection,
    union = _union,
    difference = _difference,
    length = _length,
    remove = _remove,
    repr = _repr,
    str = _repr,
    # is_set is declared in types.bzl
)