aboutsummaryrefslogtreecommitdiff
path: root/java/com/google/setfilters/cuckoofilter/SemiSortedCuckooFilterTable.java
blob: f08b47c4c72c66bca0c0f57149d57b1d54e1f542 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.setfilters.cuckoofilter;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Comparator.comparingInt;

import com.google.common.collect.ImmutableMap;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Optional;
import java.util.Random;

/**
 * Implementation of the {@link CuckooFilterTable} using the semi-sorting bucket compression scheme
 * in the original paper by Fan et al (https://www.cs.cmu.edu/~dga/papers/cuckoo-conext2014.pdf) -
 * see section 5.2.
 *
 * <p>The main idea behind the compression algorithm is that the order of the fingerprints in each
 * bucket is irrelevant - that is, the fingerprints in each bucket forms a multiset. For fingerprint
 * length f and bucket capacity b, the possible number of multisets of b fingerprints of f bits each
 * is given by C(2^f + b - 1, b), where C denotes binomial coefficient. In particular, we can encode
 * each bucket with ceil(log2(C(2^f + b - 1, b))) bits. On the other hand, naively encoding the
 * fingerprints will take b * f bits. Thus, it is theoretically possible to save b * f -
 * ceil(log2(C(2^f + b - 1, b))) bits per bucket (note that this is not information theoretically
 * tight because the distribution of the multisets is not uniform).
 *
 * <p>For performance reason, this only supports a table with bucket capacity of size 4 and
 * fingerprint length >= 4 - in many cases this is not a limitation because, for many practical
 * applications, bucket capacity of size 4 yields the optimal cuckoo filter size and fingerprint
 * length < 4 will never achieve good enough false positive rate.
 *
 * <p>Compared to the {@link UncompressedCuckooFilterTable}, this implementation can save 1 bit per
 * element, at the cost of slower filter operations by a constant factor (asymptotically, it is the
 * same as the uncompressed one). Note that for bucket capacity of size 4, saving 1 bit per element
 * is "optimal" up to rounding down, as the function 4 * f - ceil(log2(C(2^f + 3, 4))) < 5 for
 * reasonable values of f. However, this also incurs an additional fixed space overhead, so for
 * smaller filter the extra saving of 1 bit per element may not be worth it.
 */
final class SemiSortedCuckooFilterTable implements CuckooFilterTable {
  // Implementation type of the table, to be encoded in the serialization.
  public static final int TABLE_TYPE = 1;

  // Table containing all sorted 4 bit partial fingerprints of length 4 (16 bits) by its index.
  private static final short[] SORTED_PARTIAL_FINGERPRINTS = computeSortedPartialFingerprints();
  // Inverse map of SORTED_PARTIAL_FINGERPRINTS.
  private static final ImmutableMap<Short, Short> SORTED_PARTIAL_FINGERPRINTS_INDEX =
      computeSortedPartialFingerprintsIndex(SORTED_PARTIAL_FINGERPRINTS);

  private final CuckooFilterConfig.Size size;
  private final Random random;
  private final CuckooFilterArray cuckooFilterArray;

  /**
   * Creates a new uncompressed cuckoo filter table of the given size.
   *
   * <p>Uses the given source of {@code random} to choose the replaced fingerprint in {@code
   * insertWithReplacement} method.
   */
  public SemiSortedCuckooFilterTable(CuckooFilterConfig.Size size, Random random) {
    this.size = size;
    checkArgument(
        size.bucketCapacity() == 4,
        "SemiSortedCuckooFilterTable only supports bucket capacity of 4.");
    checkArgument(
        size.fingerprintLength() >= 4,
        "SemiSortedCuckooFilterTable only supports fingerprint length >= 4.");
    this.random = random;
    // bucketCapacity == 4 and fingerprintLength <= 64, so we can assume that it will always fit
    // into a long.
    cuckooFilterArray =
        new CuckooFilterArray(
            (long) size.bucketCount() * size.bucketCapacity(), size.fingerprintLength() - 1);
  }

  /** Creates {@link SemiSortedCuckooFilterTable} from {@link SerializedCuckooFilterTable}. */
  public SemiSortedCuckooFilterTable(CuckooFilterConfig.Size size, byte[] bitArray, Random random) {
    this.size = size;
    this.random = random;
    cuckooFilterArray =
        new CuckooFilterArray(
            (long) size.bucketCount() * size.bucketCapacity(),
            size.fingerprintLength() - 1,
            bitArray);
  }

  @Override
  public Optional<Long> insertWithReplacement(int bucketIndex, long fingerprint) {
    long[] fingerprints = decodeBucket(bucketIndex);
    for (int i = 0; i < size.bucketCapacity(); i++) {
      if (fingerprints[i] == EMPTY_SLOT) {
        fingerprints[i] = fingerprint;
        encodeAndPut(bucketIndex, fingerprints);
        return Optional.empty();
      }
    }

    int replacedSlotIndex = random.nextInt(size.bucketCapacity());
    long replacedFingerprint = fingerprints[replacedSlotIndex];
    fingerprints[replacedSlotIndex] = fingerprint;
    encodeAndPut(bucketIndex, fingerprints);
    return Optional.of(replacedFingerprint);
  }

  @Override
  public boolean contains(int bucketIndex, long fingerprint) {
    long[] fingerprints = decodeBucket(bucketIndex);
    for (long fingerprintInBucket : fingerprints) {
      if (fingerprintInBucket == fingerprint) {
        return true;
      }
    }
    return false;
  }

  @Override
  public boolean delete(int bucketIndex, long fingerprint) {
    long[] fingerprints = decodeBucket(bucketIndex);
    for (int i = 0; i < fingerprints.length; i++) {
      if (fingerprints[i] == fingerprint) {
        fingerprints[i] = EMPTY_SLOT;
        encodeAndPut(bucketIndex, fingerprints);
        return true;
      }
    }
    return false;
  }

  @Override
  public boolean isFull(int bucketIndex) {
    return !contains(bucketIndex, CuckooFilterTable.EMPTY_SLOT);
  }

  @Override
  public CuckooFilterConfig.Size size() {
    return size;
  }

  @Override
  public SerializedCuckooFilterTable serialize() {
    byte[] serializedArray = cuckooFilterArray.toByteArray();

    // The first 16 bytes specifies the implementation type and the size of the table (defined by
    // tuple (type, bucketCount,
    // bucketCapacity, fingerprintLength)).
    // Rest is the bit array.
    ByteBuffer encoded = ByteBuffer.allocate(16 + serializedArray.length);
    return SerializedCuckooFilterTable.createFromByteArray(
        encoded
            .putInt(TABLE_TYPE)
            .putInt(size.bucketCount())
            .putInt(size.bucketCapacity())
            .putInt(size.fingerprintLength())
            .put(serializedArray)
            .array());
  }

  private long toArrayIndex(int bucketIndex, int slotIndex) {
    return (long) bucketIndex * size.bucketCapacity() + slotIndex;
  }

  // TODO: Check if encoding/decoding needs to be optimized.

  // Decodes fingerprints at bucketIndex.
  private long[] decodeBucket(int bucketIndex) {
    int encodedSortedPartialFingerintsIndex = 0;
    long[] fingerprintPrefixes = new long[size.bucketCapacity()];
    for (int i = 0; i < size.bucketCapacity(); i++) {
      long arrayIndex = toArrayIndex(bucketIndex, i);
      long n = cuckooFilterArray.getAsLong(arrayIndex);
      encodedSortedPartialFingerintsIndex <<= 3;
      encodedSortedPartialFingerintsIndex |= (int) (n & 0x7);
      fingerprintPrefixes[i] = n >>> 3;
    }

    int encodedSortedPartialFingerprints =
        SORTED_PARTIAL_FINGERPRINTS[encodedSortedPartialFingerintsIndex];
    long[] fingerprints = new long[size.bucketCapacity()];
    for (int i = size.bucketCapacity() - 1; i >= 0; i--) {
      fingerprints[i] = (fingerprintPrefixes[i] << 4) | (encodedSortedPartialFingerprints & 0xF);
      encodedSortedPartialFingerprints >>>= 4;
    }
    return fingerprints;
  }

  /**
   * Encode fingerprints and put them to bucketIndex.
   *
   * <p>Encoding works as follows.
   *
   * <p>Suppose each fingerprint is logically f bits. First, sort the fingerprints by the least
   * significant 4 bits. Let's call the most significant f - 4 bits of the fingerprints as the
   * fingerprint prefixes. The least significant 4 bits of the fingerprints will be the partial
   * fingerprints, which will be encoded according to the SORTED_PARTIAL_FINGEPRRINTS_INDEX map as a
   * 12 bit value. Partition the encoded 12 bit value into four 3 bit chunks. Group each of the f -
   * 4 bit prefixes with each 3 bit chunk (f - 1 bits total) and insert it as a cuckoo filter array
   * element.
   */
  private void encodeAndPut(int bucketIndex, long[] fingerprints) {
    long[] fingerprintPrefixes = new long[size.bucketCapacity()];
    int[] partialFingerprints = new int[size.bucketCapacity()];
    for (int i = 0; i < size.bucketCapacity(); i++) {
      fingerprintPrefixes[i] = fingerprints[i] >>> 4;
      partialFingerprints[i] = (int) (fingerprints[i] & 0xF);
    }
    Integer[] indices = {0, 1, 2, 3};
    Arrays.sort(indices, comparingInt((Integer i) -> partialFingerprints[i]));
    short encodedSortedPartialFingerprints =
        (short)
            ((partialFingerprints[indices[0]] << 12)
                | (partialFingerprints[indices[1]] << 8)
                | (partialFingerprints[indices[2]] << 4)
                | partialFingerprints[indices[3]]);
    int encodedSortedPartialFingerprintsIndex =
        SORTED_PARTIAL_FINGERPRINTS_INDEX.get(encodedSortedPartialFingerprints);
    for (int i = size.bucketCapacity() - 1; i >= 0; i--) {
      long arrayIndex = toArrayIndex(bucketIndex, i);
      cuckooFilterArray.set(
          arrayIndex,
          (fingerprintPrefixes[indices[i]] << 3) | (encodedSortedPartialFingerprintsIndex & 0x7));
      encodedSortedPartialFingerprintsIndex >>>= 3;
    }
  }

  private static short[] computeSortedPartialFingerprints() {
    // (2^4 + 3 choose 4) = 3876 counts the number of multisets of size 4, with each element in
    // [0, 16).
    short[] sortedPartialFingerprints = new short[3876];

    final short fingerprintUpperBound = 16;

    int i = 0;
    for (short a = 0; a < fingerprintUpperBound; a++) {
      for (short b = a; b < fingerprintUpperBound; b++) {
        for (short c = b; c < fingerprintUpperBound; c++) {
          for (short d = c; d < fingerprintUpperBound; d++) {
            sortedPartialFingerprints[i] = (short) ((a << 12) | (b << 8) | (c << 4) | d);
            i++;
          }
        }
      }
    }
    return sortedPartialFingerprints;
  }

  private static ImmutableMap<Short, Short> computeSortedPartialFingerprintsIndex(
      short[] sortedPartialFingerprints) {
    ImmutableMap.Builder<Short, Short> map = ImmutableMap.builder();
    for (short i = 0; i < sortedPartialFingerprints.length; i++) {
      map.put(sortedPartialFingerprints[i], i);
    }
    return map.buildOrThrow();
  }
}