diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 626f023a5b99..ec0bd020a1ed 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -39,6 +39,10 @@ org.apache.spark spark-tags_${scala.binary.version} + + org.roaringbitmap + RoaringBitmap + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java deleted file mode 100644 index 480a0a79db32..000000000000 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.sketch; - -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Arrays; - -final class BitArray { - private final long[] data; - private long bitCount; - - static int numWords(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive, but got " + numBits); - } - long numWords = (long) Math.ceil(numBits / 64.0); - if (numWords > Integer.MAX_VALUE) { - throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); - } - return (int) numWords; - } - - BitArray(long numBits) { - this(new long[numWords(numBits)]); - } - - private BitArray(long[] data) { - this.data = data; - long bitCount = 0; - for (long word : data) { - bitCount += Long.bitCount(word); - } - this.bitCount = bitCount; - } - - /** Returns true if the bit changed value. */ - boolean set(long index) { - if (!get(index)) { - data[(int) (index >>> 6)] |= (1L << index); - bitCount++; - return true; - } - return false; - } - - boolean get(long index) { - return (data[(int) (index >>> 6)] & (1L << index)) != 0; - } - - /** Number of bits */ - long bitSize() { - return (long) data.length * Long.SIZE; - } - - /** Number of set bits (1s) */ - long cardinality() { - return bitCount; - } - - /** Combines the two BitArrays using bitwise OR. */ - void putAll(BitArray array) { - assert data.length == array.data.length : "BitArrays must be of equal length when merging"; - long bitCount = 0; - for (int i = 0; i < data.length; i++) { - data[i] |= array.data[i]; - bitCount += Long.bitCount(data[i]); - } - this.bitCount = bitCount; - } - - void writeTo(DataOutputStream out) throws IOException { - out.writeInt(data.length); - for (long datum : data) { - out.writeLong(datum); - } - } - - static BitArray readFrom(DataInputStream in) throws IOException { - int numWords = in.readInt(); - long[] data = new long[numWords]; - for (int i = 0; i < numWords; i++) { - data[i] = in.readLong(); - } - return new BitArray(data); - } - - @Override - public boolean equals(Object other) { - if (this == other) return true; - if (other == null || !(other instanceof BitArray)) return false; - BitArray that = (BitArray) other; - return Arrays.equals(data, that.data); - } - - @Override - public int hashCode() { - return Arrays.hashCode(data); - } -} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c0b425e72959..f7e9143064ac 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -51,7 +51,7 @@ public enum Version { *
  • The words/longs (numWords * 64 bit)
  • * */ - V1(1); + V1(2); private final int versionNumber; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 92c28bcb56a5..fde2d14d20cb 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -23,13 +23,13 @@ class BloomFilterImpl extends BloomFilter implements Serializable { private int numHashFunctions; - private BitArray bits; + private RoaringBitmapArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { - this(new BitArray(numBits), numHashFunctions); + this(new RoaringBitmapArray(numBits), numHashFunctions); } - private BloomFilterImpl(BitArray bits, int numHashFunctions) { + private BloomFilterImpl(RoaringBitmapArray bits, int numHashFunctions) { this.bits = bits; this.numHashFunctions = numHashFunctions; } @@ -48,7 +48,7 @@ public boolean equals(Object other) { BloomFilterImpl that = (BloomFilterImpl) other; - return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + return (this.numHashFunctions == that.numHashFunctions) && this.bits.equals(that.bits); } @Override @@ -84,18 +84,19 @@ public boolean putString(String item) { @Override public boolean putBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64 + long[] hashes = new long[2]; + Murmur3_128.hashBytes(item, 0, hashes); + long h1 = hashes[0]; + long h2 = hashes[1]; long bitSize = bits.bitSize(); boolean bitsChanged = false; + long combinedHash = h1; for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - bitsChanged |= bits.set(combinedHash % bitSize); + // Make combinedHash positive and indexable + bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize); + combinedHash += h2; } return bitsChanged; } @@ -107,61 +108,59 @@ public boolean mightContainString(String item) { @Override public boolean mightContainBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64 + long[] hashes = new long[2]; + Murmur3_128.hashBytes(item, 0, hashes); + + long h1 = hashes[0]; + long h2 = hashes[1]; long bitSize = bits.bitSize(); + long combinedHash = h1; for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - if (!bits.get(combinedHash % bitSize)) { + // Make combinedHash positive and indexable + if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) { return false; } + combinedHash += h2; } return true; } @Override public boolean putLong(long item) { - // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n - // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy, it hash the input long element with - // every i to produce n hash values. - // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); + // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64 + long[] hashes = new long[2]; + Murmur3_128.hashLong(item, 0, hashes); + long h1 = hashes[0]; + long h2 = hashes[1]; long bitSize = bits.bitSize(); boolean bitsChanged = false; + long combinedHash = h1; for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - bitsChanged |= bits.set(combinedHash % bitSize); + // Make combinedHash positive and indexable + bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize); + combinedHash += h2; } return bitsChanged; } @Override public boolean mightContainLong(long item) { - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); + // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64 + long[] hashes = new long[2]; + Murmur3_128.hashLong(item, 0, hashes); + long h1 = hashes[0]; + long h2 = hashes[1]; long bitSize = bits.bitSize(); + long combinedHash = h1; for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - if (!bits.get(combinedHash % bitSize)) { + if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) { return false; } + combinedHash += h2; } return true; } @@ -238,7 +237,7 @@ private void readFrom0(InputStream in) throws IOException { } this.numHashFunctions = dis.readInt(); - this.bits = BitArray.readFrom(dis); + this.bits = RoaringBitmapArray.readFrom(dis); } public static BloomFilterImpl readFrom(InputStream in) throws IOException { diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java new file mode 100644 index 000000000000..9041e203966d --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * 128-bit Murmur3 hasher. + * Best performance is on x86_64 platform + * Based on implementation smhasher + * and SOLR implementation. + */ +final class Murmur3_128 { + + private static final long C1 = 0x87c37b91114253d5L; + private static final long C2 = 0x4cf5ad432745937fL; + + static void hashBytes(byte[] data, long seed, long[] hashes) { + hash(data, Platform.BYTE_ARRAY_OFFSET, data.length, seed, hashes); + } + + static void hashLong(long data, long seed, long[] hashes) { + hash(new long[]{data}, Platform.LONG_ARRAY_OFFSET, 8, seed, hashes); // 8 - long`s size in bytes + } + + @SuppressWarnings("fallthrough") + private static void hash(Object key, int offset, int length, long seed, long[] result) { + long h1 = seed & 0x00000000FFFFFFFFL; + long h2 = seed & 0x00000000FFFFFFFFL; + + int roundedEnd = offset + (length & 0xFFFFFFF0); // round down to 16 byte block + for (int i = offset; i < roundedEnd; i += 16) { + long k1 = getLongLittleEndian(key, i); + long k2 = getLongLittleEndian(key, i + 8); + + h1 ^= mixK1(k1); + + h1 = Long.rotateLeft(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + h2 ^= mixK2(k2); + + h2 = Long.rotateLeft(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + long k1 = 0; + long k2 = 0; + + switch (length & 15) { + case 15: + k2 = (Platform.getByte(key, roundedEnd + 14) & 0xFFL) << 48; // fall through + case 14: + k2 |= (Platform.getByte(key, roundedEnd + 13) & 0xFFL) << 40; // fall through + case 13: + k2 |= (Platform.getByte(key, roundedEnd + 12) & 0xFFL) << 32; // fall through + case 12: + k2 |= (Platform.getByte(key, roundedEnd + 11) & 0xFFL) << 24; // fall through + case 11: + k2 |= (Platform.getByte(key, roundedEnd + 10) & 0xFFL) << 16; // fall through + case 10: + k2 |= (Platform.getByte(key, roundedEnd + 9) & 0xFFL) << 8; // fall through + case 9: + k2 |= (Platform.getByte(key, roundedEnd + 8) & 0xFFL); // fall through + h2 ^= mixK2(k2); + case 8: + k1 = ((long) Platform.getByte(key, roundedEnd + 7)) << 56; // fall through + case 7: + k1 |= (Platform.getByte(key, roundedEnd + 6) & 0xFFL) << 48; // fall through + case 6: + k1 |= (Platform.getByte(key, roundedEnd + 5) & 0xFFL) << 40; // fall through + case 5: + k1 |= (Platform.getByte(key, roundedEnd + 4) & 0xFFL) << 32; // fall through + case 4: + k1 |= (Platform.getByte(key, roundedEnd + 3) & 0xFFL) << 24; // fall through + case 3: + k1 |= (Platform.getByte(key, roundedEnd + 2) & 0xFFL) << 16; // fall through + case 2: + k1 |= (Platform.getByte(key, roundedEnd + 1) & 0xFFL) << 8; // fall through + case 1: + k1 |= (Platform.getByte(key, roundedEnd) & 0xFFL); + h1 ^= mixK1(k1); + } + h1 ^= length; + h2 ^= length; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + result[0] = h1; + result[1] = h2; + } + + /** + * Gets a long from a byte buffer in little endian byte order. + */ + private static long getLongLittleEndian(Object key, int offset) { + return (Platform.getByte(key, offset) & 0xFFL) + | ((Platform.getByte(key, offset + 1) & 0xFFL) << 8) + | ((Platform.getByte(key, offset + 2) & 0xFFL) << 16) + | ((Platform.getByte(key, offset + 3) & 0xFFL) << 24) + | ((Platform.getByte(key, offset + 4) & 0xFFL) << 32) + | ((Platform.getByte(key, offset + 5) & 0xFFL) << 40) + | ((Platform.getByte(key, offset + 6) & 0xFFL) << 48) + | (((long) Platform.getByte(key, offset + 7)) << 56); + } + + private static long fmix64(long k) { + k ^= k >>> 33; + k *= 0xff51afd7ed558ccdL; + k ^= k >>> 33; + k *= 0xc4ceb9fe1a85ec53L; + k ^= k >>> 33; + return k; + } + + private static long mixK1(long k1) { + k1 *= C1; + k1 = Long.rotateLeft(k1, 31); + k1 *= C2; + return k1; + } + + private static long mixK2(long k2) { + k2 *= C2; + k2 = Long.rotateLeft(k2, 33); + k2 *= C1; + return k2; + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java new file mode 100644 index 000000000000..fa973ac48f21 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import org.roaringbitmap.RoaringBitmap; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; + +/** + * This class represents bit vector with + * {@link #set(long index) set} and {@link #get(long index) get} methods. + * It is memory efficient and faster in operations {@link #set(long index) set} + * and {@link #get(long index) get} than {@link java.util.BitSet} since we are using + * {@link org.roaringbitmap.RoaringBitmap} + * GitHub repository. + * Unfortunately, current version of {@link org.roaringbitmap.RoaringBitmap} supports + * only {@code int} indexes and limited to {@code Integer.MAX_VALUE} in size + * {@see https://github.com/RoaringBitmap/RoaringBitmap/issues/109}. + * To support {@code Long.MAX_VALUE} size we have to maintain + * array of {@link org.roaringbitmap.RoaringBitmap}. + */ +class RoaringBitmapArray { + + private final RoaringBitmap[] data; + private final long numBits; // size of bit vector + + private static int numOfBuckets(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } + return (int) Math.ceil(numBits / (double) Integer.MAX_VALUE); + } + + private static RoaringBitmap[] initialVector(int numOfBuckets) { + RoaringBitmap[] vector = new RoaringBitmap[numOfBuckets]; + for (int i = 0; i < numOfBuckets; i++) { + vector[i] = new RoaringBitmap(); + } + return vector; + } + + RoaringBitmapArray(long numBits) { + this(initialVector(numOfBuckets(numBits)), numBits); + } + + private RoaringBitmapArray(RoaringBitmap[] data, long numBits) { + this.data = data; + this.numBits = numBits; + } + + /** + * Returns true if the bit changed value. + */ + boolean set(long index) { + int bucketNum = (int) (index / Integer.MAX_VALUE); + int bitIdx = (int) (index % Integer.MAX_VALUE); + if (!data[bucketNum].contains(bitIdx)) { + data[bucketNum].add(bitIdx); + return true; + } + return false; + } + + boolean get(long index) { + int bucketNum = (int) (index / Integer.MAX_VALUE); + int bitIdx = (int) (index % Integer.MAX_VALUE); + return data[bucketNum].contains(bitIdx); + } + + /** + * Size of bit vector + */ + long bitSize() { + return numBits; + } + + /** + * Number of set bits (1s) + */ + long cardinality() { + long bitCount = 0; + for (RoaringBitmap bucket : data) { + bitCount += bucket.getCardinality(); + } + return bitCount; + } + + /** + * Combines the two RoaringBitmapArray using bitwise OR. + */ + void putAll(RoaringBitmapArray bitmap) { + assert data.length == bitmap.data.length : "RoaringBitmapArray`s must be of equal length when merging"; + for (int i = 0; i < data.length; i++) { + data[i].or(bitmap.data[i]); + } + } + + /** + * Serilize bit vector. + * The actual serialized size will be approximately the same as in memory. + * + * @param out - where to save bit vector {@link java.io.DataOutputStream} + * @throws IOException + */ + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + out.writeLong(numBits); + for (RoaringBitmap datum : data) { + datum.runOptimize(); + datum.serialize(out); + } + } + + static RoaringBitmapArray readFrom(DataInputStream in) throws IOException { + int numOfBuckets = in.readInt(); + long numBits = in.readLong(); + RoaringBitmap[] data = new RoaringBitmap[numOfBuckets]; + for (int i = 0; i < numOfBuckets; i++) { + data[i] = new RoaringBitmap(); + data[i].deserialize(in); + } + return new RoaringBitmapArray(data, numBits); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof RoaringBitmapArray)) { + return false; + } + RoaringBitmapArray that = (RoaringBitmapArray) other; + return (this.numBits == that.numBits) && Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } + +} \ No newline at end of file diff --git a/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java b/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java new file mode 100644 index 000000000000..e3b6c491d33a --- /dev/null +++ b/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Testing Murmur3 128 bit hasher. + * All tests are taken from [[com.google.common.hash.Murmur3Hash128Test]] + */ +public class Murmur3_128Suite { + + @Test + public void testKnownValues() throws Exception { + assertHash(0, 0x629942693e10f867L, 0x92db0b82baeb5347L, "hell"); + assertHash(1, 0xa78ddff5adae8d10L, 0x128900ef20900135L, "hello"); + assertHash(2, 0x8a486b23f422e826L, 0xf962a2c58947765fL, "hello "); + assertHash(3, 0x2ea59f466f6bed8cL, 0xc610990acc428a17L, "hello w"); + assertHash(4, 0x79f6305a386c572cL, 0x46305aed3483b94eL, "hello wo"); + assertHash(5, 0xc2219d213ec1f1b5L, 0xa1d8e2e0a52785bdL, "hello wor"); + assertHash(0, 0xe34bbc7bbc071b6cL, 0x7a433ca9c49a9347L, + "The quick brown fox jumps over the lazy dog"); + assertHash(0, 0x658ca970ff85269aL, 0x43fee3eaa68e5c3eL, + "The quick brown fox jumps over the lazy cog"); + } + + private static void assertHash(int seed, long expected1, long expected2, String stringInput) throws Exception { + long[] hash128bit = new long[2]; + byte[] data = stringInput.getBytes("UTF-8"); + Murmur3_128.hashBytes(data, seed, hash128bit); + Assert.assertEquals(expected1, hash128bit[0]); + Assert.assertEquals(expected2, hash128bit[1]); + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index a0408d2da4df..cee8370256e9 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -104,9 +104,7 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite testMergeInPlace[T](typeName, numItems)(itemGen) } - testItemType[Byte]("Byte", 160) { _.nextInt().toByte } - - testItemType[Short]("Short", 1000) { _.nextInt().toShort } + testItemType[Short]("Short", 1000) { _.nextInt(Short.MaxValue + 1).toShort } testItemType[Int]("Int", 100000) { _.nextInt() } @@ -131,4 +129,59 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite filter1.mergeInPlace(filter2) } } + + // Separate test for Byte type, since we can enumerate all possible values + test(s"accuracy - Byte") { + // Byte is from -128 to 127 inclusive + val allBytes = (-128 to 127).map(_.toByte) + val fpp = 0.05 + val filter = BloomFilter.create(129, fpp) + + val isEven: Byte => Boolean = _ % 2 == 0 + val even = allBytes.filter(isEven) + val odd = allBytes.filterNot(isEven) + // insert first `numInsertion` items. + even.foreach(filter.put) + + // false negative is not allowed. + assert(even.forall(filter.mightContain)) + + // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp` + // should not be significantly higher than the one we passed in to create this bloom filter. + assert(filter.expectedFpp() - fpp < EPSILON) + + val errorCount = odd.count(filter.mightContain) + // Also check the actual fpp is not significantly higher than we expected. + val actualFpp = errorCount.toDouble / odd.length + assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) + } + + // Separate test for Byte type, since we can enumerate all possible values + test(s"mergeInPlace - Byte") { + val allBytes = (-128 to 127).map(_.toByte) + val fpp = 0.05 + + val isEven: Byte => Boolean = _ % 2 == 0 + val even = allBytes.filter(isEven) + val odd = allBytes.filterNot(isEven) + + val filter1 = BloomFilter.create(256) + even.foreach(filter1.put) + + val filter2 = BloomFilter.create(256) + odd.foreach(filter2.put) + + filter1.mergeInPlace(filter2) + + // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the + // `expectedFpp` should not be significantly higher than the default one. + assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON) + + even.foreach(i => assert(filter1.mightContain(i))) + odd.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) + } } diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala similarity index 73% rename from common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala rename to common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala index ff728f0ebcb8..16ded6af3a87 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala @@ -21,34 +21,17 @@ import scala.util.Random import org.scalatest.FunSuite // scalastyle:ignore funsuite -class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite +class RoaringBitmapArraySuite extends FunSuite { // scalastyle:ignore funsuite - test("error case when create BitArray") { - intercept[IllegalArgumentException](new BitArray(0)) - intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1)) - } - - test("bitSize") { - assert(new BitArray(64).bitSize() == 64) - // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits. - assert(new BitArray(65).bitSize() == 128) - assert(new BitArray(127).bitSize() == 128) - assert(new BitArray(128).bitSize() == 128) - } - - test("set") { - val bitArray = new BitArray(64) - assert(bitArray.set(1)) - // Only returns true if the bit changed. - assert(!bitArray.set(1)) - assert(bitArray.set(2)) + test("error case when create RoaringBitmapArray") { + intercept[IllegalArgumentException](new RoaringBitmapArray(0)) } test("normal operation") { // use a fixed seed to make the test predictable. val r = new Random(37) - val bitArray = new BitArray(320) + val bitArray = new RoaringBitmapArray(320) val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct indexes.foreach(bitArray.set) @@ -56,12 +39,20 @@ class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite assert(bitArray.cardinality() == indexes.length) } + test("set") { + val bitArray = new RoaringBitmapArray(64) + assert(bitArray.set(1)) + // Only returns true if the bit changed. + assert(!bitArray.set(1)) + assert(bitArray.set(2)) + } + test("merge") { // use a fixed seed to make the test predictable. val r = new Random(37) - val bitArray1 = new BitArray(64 * 6) - val bitArray2 = new BitArray(64 * 6) + val bitArray1 = new RoaringBitmapArray(64 * 6) + val bitArray2 = new RoaringBitmapArray(64 * 6) val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct