diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 626f023a5b99..ec0bd020a1ed 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -39,6 +39,10 @@
org.apache.spark
spark-tags_${scala.binary.version}
+
+ org.roaringbitmap
+ RoaringBitmap
+
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
deleted file mode 100644
index 480a0a79db32..000000000000
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.sketch;
-
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.util.Arrays;
-
-final class BitArray {
- private final long[] data;
- private long bitCount;
-
- static int numWords(long numBits) {
- if (numBits <= 0) {
- throw new IllegalArgumentException("numBits must be positive, but got " + numBits);
- }
- long numWords = (long) Math.ceil(numBits / 64.0);
- if (numWords > Integer.MAX_VALUE) {
- throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits");
- }
- return (int) numWords;
- }
-
- BitArray(long numBits) {
- this(new long[numWords(numBits)]);
- }
-
- private BitArray(long[] data) {
- this.data = data;
- long bitCount = 0;
- for (long word : data) {
- bitCount += Long.bitCount(word);
- }
- this.bitCount = bitCount;
- }
-
- /** Returns true if the bit changed value. */
- boolean set(long index) {
- if (!get(index)) {
- data[(int) (index >>> 6)] |= (1L << index);
- bitCount++;
- return true;
- }
- return false;
- }
-
- boolean get(long index) {
- return (data[(int) (index >>> 6)] & (1L << index)) != 0;
- }
-
- /** Number of bits */
- long bitSize() {
- return (long) data.length * Long.SIZE;
- }
-
- /** Number of set bits (1s) */
- long cardinality() {
- return bitCount;
- }
-
- /** Combines the two BitArrays using bitwise OR. */
- void putAll(BitArray array) {
- assert data.length == array.data.length : "BitArrays must be of equal length when merging";
- long bitCount = 0;
- for (int i = 0; i < data.length; i++) {
- data[i] |= array.data[i];
- bitCount += Long.bitCount(data[i]);
- }
- this.bitCount = bitCount;
- }
-
- void writeTo(DataOutputStream out) throws IOException {
- out.writeInt(data.length);
- for (long datum : data) {
- out.writeLong(datum);
- }
- }
-
- static BitArray readFrom(DataInputStream in) throws IOException {
- int numWords = in.readInt();
- long[] data = new long[numWords];
- for (int i = 0; i < numWords; i++) {
- data[i] = in.readLong();
- }
- return new BitArray(data);
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) return true;
- if (other == null || !(other instanceof BitArray)) return false;
- BitArray that = (BitArray) other;
- return Arrays.equals(data, that.data);
- }
-
- @Override
- public int hashCode() {
- return Arrays.hashCode(data);
- }
-}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index c0b425e72959..f7e9143064ac 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -51,7 +51,7 @@ public enum Version {
* The words/longs (numWords * 64 bit)
*
*/
- V1(1);
+ V1(2);
private final int versionNumber;
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index 92c28bcb56a5..fde2d14d20cb 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -23,13 +23,13 @@ class BloomFilterImpl extends BloomFilter implements Serializable {
private int numHashFunctions;
- private BitArray bits;
+ private RoaringBitmapArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
- this(new BitArray(numBits), numHashFunctions);
+ this(new RoaringBitmapArray(numBits), numHashFunctions);
}
- private BloomFilterImpl(BitArray bits, int numHashFunctions) {
+ private BloomFilterImpl(RoaringBitmapArray bits, int numHashFunctions) {
this.bits = bits;
this.numHashFunctions = numHashFunctions;
}
@@ -48,7 +48,7 @@ public boolean equals(Object other) {
BloomFilterImpl that = (BloomFilterImpl) other;
- return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits);
+ return (this.numHashFunctions == that.numHashFunctions) && this.bits.equals(that.bits);
}
@Override
@@ -84,18 +84,19 @@ public boolean putString(String item) {
@Override
public boolean putBinary(byte[] item) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
+ // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64
+ long[] hashes = new long[2];
+ Murmur3_128.hashBytes(item, 0, hashes);
+ long h1 = hashes[0];
+ long h2 = hashes[1];
long bitSize = bits.bitSize();
boolean bitsChanged = false;
+ long combinedHash = h1;
for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- bitsChanged |= bits.set(combinedHash % bitSize);
+ // Make combinedHash positive and indexable
+ bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize);
+ combinedHash += h2;
}
return bitsChanged;
}
@@ -107,61 +108,59 @@ public boolean mightContainString(String item) {
@Override
public boolean mightContainBinary(byte[] item) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
+ // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64
+ long[] hashes = new long[2];
+ Murmur3_128.hashBytes(item, 0, hashes);
+
+ long h1 = hashes[0];
+ long h2 = hashes[1];
long bitSize = bits.bitSize();
+ long combinedHash = h1;
for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- if (!bits.get(combinedHash % bitSize)) {
+ // Make combinedHash positive and indexable
+ if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) {
return false;
}
+ combinedHash += h2;
}
return true;
}
@Override
public boolean putLong(long item) {
- // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
- // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
- // Note that `CountMinSketch` use a different strategy, it hash the input long element with
- // every i to produce n hash values.
- // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
- int h1 = Murmur3_x86_32.hashLong(item, 0);
- int h2 = Murmur3_x86_32.hashLong(item, h1);
+ // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64
+ long[] hashes = new long[2];
+ Murmur3_128.hashLong(item, 0, hashes);
+ long h1 = hashes[0];
+ long h2 = hashes[1];
long bitSize = bits.bitSize();
boolean bitsChanged = false;
+ long combinedHash = h1;
for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- bitsChanged |= bits.set(combinedHash % bitSize);
+ // Make combinedHash positive and indexable
+ bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize);
+ combinedHash += h2;
}
return bitsChanged;
}
@Override
public boolean mightContainLong(long item) {
- int h1 = Murmur3_x86_32.hashLong(item, 0);
- int h2 = Murmur3_x86_32.hashLong(item, h1);
+ // Strategy is taken from guava`s BloomFilterStrategies.MURMUR128_MITZ_64
+ long[] hashes = new long[2];
+ Murmur3_128.hashLong(item, 0, hashes);
+ long h1 = hashes[0];
+ long h2 = hashes[1];
long bitSize = bits.bitSize();
+ long combinedHash = h1;
for (int i = 1; i <= numHashFunctions; i++) {
- int combinedHash = h1 + (i * h2);
- // Flip all the bits if it's negative (guaranteed positive number)
- if (combinedHash < 0) {
- combinedHash = ~combinedHash;
- }
- if (!bits.get(combinedHash % bitSize)) {
+ if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) {
return false;
}
+ combinedHash += h2;
}
return true;
}
@@ -238,7 +237,7 @@ private void readFrom0(InputStream in) throws IOException {
}
this.numHashFunctions = dis.readInt();
- this.bits = BitArray.readFrom(dis);
+ this.bits = RoaringBitmapArray.readFrom(dis);
}
public static BloomFilterImpl readFrom(InputStream in) throws IOException {
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java
new file mode 100644
index 000000000000..9041e203966d
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_128.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+/**
+ * 128-bit Murmur3 hasher.
+ * Best performance is on x86_64 platform
+ * Based on implementation smhasher
+ * and SOLR implementation.
+ */
+final class Murmur3_128 {
+
+ private static final long C1 = 0x87c37b91114253d5L;
+ private static final long C2 = 0x4cf5ad432745937fL;
+
+ static void hashBytes(byte[] data, long seed, long[] hashes) {
+ hash(data, Platform.BYTE_ARRAY_OFFSET, data.length, seed, hashes);
+ }
+
+ static void hashLong(long data, long seed, long[] hashes) {
+ hash(new long[]{data}, Platform.LONG_ARRAY_OFFSET, 8, seed, hashes); // 8 - long`s size in bytes
+ }
+
+ @SuppressWarnings("fallthrough")
+ private static void hash(Object key, int offset, int length, long seed, long[] result) {
+ long h1 = seed & 0x00000000FFFFFFFFL;
+ long h2 = seed & 0x00000000FFFFFFFFL;
+
+ int roundedEnd = offset + (length & 0xFFFFFFF0); // round down to 16 byte block
+ for (int i = offset; i < roundedEnd; i += 16) {
+ long k1 = getLongLittleEndian(key, i);
+ long k2 = getLongLittleEndian(key, i + 8);
+
+ h1 ^= mixK1(k1);
+
+ h1 = Long.rotateLeft(h1, 27);
+ h1 += h2;
+ h1 = h1 * 5 + 0x52dce729;
+
+ h2 ^= mixK2(k2);
+
+ h2 = Long.rotateLeft(h2, 31);
+ h2 += h1;
+ h2 = h2 * 5 + 0x38495ab5;
+ }
+ long k1 = 0;
+ long k2 = 0;
+
+ switch (length & 15) {
+ case 15:
+ k2 = (Platform.getByte(key, roundedEnd + 14) & 0xFFL) << 48; // fall through
+ case 14:
+ k2 |= (Platform.getByte(key, roundedEnd + 13) & 0xFFL) << 40; // fall through
+ case 13:
+ k2 |= (Platform.getByte(key, roundedEnd + 12) & 0xFFL) << 32; // fall through
+ case 12:
+ k2 |= (Platform.getByte(key, roundedEnd + 11) & 0xFFL) << 24; // fall through
+ case 11:
+ k2 |= (Platform.getByte(key, roundedEnd + 10) & 0xFFL) << 16; // fall through
+ case 10:
+ k2 |= (Platform.getByte(key, roundedEnd + 9) & 0xFFL) << 8; // fall through
+ case 9:
+ k2 |= (Platform.getByte(key, roundedEnd + 8) & 0xFFL); // fall through
+ h2 ^= mixK2(k2);
+ case 8:
+ k1 = ((long) Platform.getByte(key, roundedEnd + 7)) << 56; // fall through
+ case 7:
+ k1 |= (Platform.getByte(key, roundedEnd + 6) & 0xFFL) << 48; // fall through
+ case 6:
+ k1 |= (Platform.getByte(key, roundedEnd + 5) & 0xFFL) << 40; // fall through
+ case 5:
+ k1 |= (Platform.getByte(key, roundedEnd + 4) & 0xFFL) << 32; // fall through
+ case 4:
+ k1 |= (Platform.getByte(key, roundedEnd + 3) & 0xFFL) << 24; // fall through
+ case 3:
+ k1 |= (Platform.getByte(key, roundedEnd + 2) & 0xFFL) << 16; // fall through
+ case 2:
+ k1 |= (Platform.getByte(key, roundedEnd + 1) & 0xFFL) << 8; // fall through
+ case 1:
+ k1 |= (Platform.getByte(key, roundedEnd) & 0xFFL);
+ h1 ^= mixK1(k1);
+ }
+ h1 ^= length;
+ h2 ^= length;
+
+ h1 += h2;
+ h2 += h1;
+
+ h1 = fmix64(h1);
+ h2 = fmix64(h2);
+
+ h1 += h2;
+ h2 += h1;
+
+ result[0] = h1;
+ result[1] = h2;
+ }
+
+ /**
+ * Gets a long from a byte buffer in little endian byte order.
+ */
+ private static long getLongLittleEndian(Object key, int offset) {
+ return (Platform.getByte(key, offset) & 0xFFL)
+ | ((Platform.getByte(key, offset + 1) & 0xFFL) << 8)
+ | ((Platform.getByte(key, offset + 2) & 0xFFL) << 16)
+ | ((Platform.getByte(key, offset + 3) & 0xFFL) << 24)
+ | ((Platform.getByte(key, offset + 4) & 0xFFL) << 32)
+ | ((Platform.getByte(key, offset + 5) & 0xFFL) << 40)
+ | ((Platform.getByte(key, offset + 6) & 0xFFL) << 48)
+ | (((long) Platform.getByte(key, offset + 7)) << 56);
+ }
+
+ private static long fmix64(long k) {
+ k ^= k >>> 33;
+ k *= 0xff51afd7ed558ccdL;
+ k ^= k >>> 33;
+ k *= 0xc4ceb9fe1a85ec53L;
+ k ^= k >>> 33;
+ return k;
+ }
+
+ private static long mixK1(long k1) {
+ k1 *= C1;
+ k1 = Long.rotateLeft(k1, 31);
+ k1 *= C2;
+ return k1;
+ }
+
+ private static long mixK2(long k2) {
+ k2 *= C2;
+ k2 = Long.rotateLeft(k2, 33);
+ k2 *= C1;
+ return k2;
+ }
+}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java
new file mode 100644
index 000000000000..fa973ac48f21
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/RoaringBitmapArray.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import org.roaringbitmap.RoaringBitmap;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * This class represents bit vector with
+ * {@link #set(long index) set} and {@link #get(long index) get} methods.
+ * It is memory efficient and faster in operations {@link #set(long index) set}
+ * and {@link #get(long index) get} than {@link java.util.BitSet} since we are using
+ * {@link org.roaringbitmap.RoaringBitmap}
+ * GitHub repository.
+ * Unfortunately, current version of {@link org.roaringbitmap.RoaringBitmap} supports
+ * only {@code int} indexes and limited to {@code Integer.MAX_VALUE} in size
+ * {@see https://github.com/RoaringBitmap/RoaringBitmap/issues/109}.
+ * To support {@code Long.MAX_VALUE} size we have to maintain
+ * array of {@link org.roaringbitmap.RoaringBitmap}.
+ */
+class RoaringBitmapArray {
+
+ private final RoaringBitmap[] data;
+ private final long numBits; // size of bit vector
+
+ private static int numOfBuckets(long numBits) {
+ if (numBits <= 0) {
+ throw new IllegalArgumentException("numBits must be positive, but got " + numBits);
+ }
+ return (int) Math.ceil(numBits / (double) Integer.MAX_VALUE);
+ }
+
+ private static RoaringBitmap[] initialVector(int numOfBuckets) {
+ RoaringBitmap[] vector = new RoaringBitmap[numOfBuckets];
+ for (int i = 0; i < numOfBuckets; i++) {
+ vector[i] = new RoaringBitmap();
+ }
+ return vector;
+ }
+
+ RoaringBitmapArray(long numBits) {
+ this(initialVector(numOfBuckets(numBits)), numBits);
+ }
+
+ private RoaringBitmapArray(RoaringBitmap[] data, long numBits) {
+ this.data = data;
+ this.numBits = numBits;
+ }
+
+ /**
+ * Returns true if the bit changed value.
+ */
+ boolean set(long index) {
+ int bucketNum = (int) (index / Integer.MAX_VALUE);
+ int bitIdx = (int) (index % Integer.MAX_VALUE);
+ if (!data[bucketNum].contains(bitIdx)) {
+ data[bucketNum].add(bitIdx);
+ return true;
+ }
+ return false;
+ }
+
+ boolean get(long index) {
+ int bucketNum = (int) (index / Integer.MAX_VALUE);
+ int bitIdx = (int) (index % Integer.MAX_VALUE);
+ return data[bucketNum].contains(bitIdx);
+ }
+
+ /**
+ * Size of bit vector
+ */
+ long bitSize() {
+ return numBits;
+ }
+
+ /**
+ * Number of set bits (1s)
+ */
+ long cardinality() {
+ long bitCount = 0;
+ for (RoaringBitmap bucket : data) {
+ bitCount += bucket.getCardinality();
+ }
+ return bitCount;
+ }
+
+ /**
+ * Combines the two RoaringBitmapArray using bitwise OR.
+ */
+ void putAll(RoaringBitmapArray bitmap) {
+ assert data.length == bitmap.data.length : "RoaringBitmapArray`s must be of equal length when merging";
+ for (int i = 0; i < data.length; i++) {
+ data[i].or(bitmap.data[i]);
+ }
+ }
+
+ /**
+ * Serilize bit vector.
+ * The actual serialized size will be approximately the same as in memory.
+ *
+ * @param out - where to save bit vector {@link java.io.DataOutputStream}
+ * @throws IOException
+ */
+ void writeTo(DataOutputStream out) throws IOException {
+ out.writeInt(data.length);
+ out.writeLong(numBits);
+ for (RoaringBitmap datum : data) {
+ datum.runOptimize();
+ datum.serialize(out);
+ }
+ }
+
+ static RoaringBitmapArray readFrom(DataInputStream in) throws IOException {
+ int numOfBuckets = in.readInt();
+ long numBits = in.readLong();
+ RoaringBitmap[] data = new RoaringBitmap[numOfBuckets];
+ for (int i = 0; i < numOfBuckets; i++) {
+ data[i] = new RoaringBitmap();
+ data[i].deserialize(in);
+ }
+ return new RoaringBitmapArray(data, numBits);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other == null || !(other instanceof RoaringBitmapArray)) {
+ return false;
+ }
+ RoaringBitmapArray that = (RoaringBitmapArray) other;
+ return (this.numBits == that.numBits) && Arrays.equals(data, that.data);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(data);
+ }
+
+}
\ No newline at end of file
diff --git a/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java b/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java
new file mode 100644
index 000000000000..e3b6c491d33a
--- /dev/null
+++ b/common/sketch/src/test/java/org/apache/spark/util/sketch/Murmur3_128Suite.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Testing Murmur3 128 bit hasher.
+ * All tests are taken from [[com.google.common.hash.Murmur3Hash128Test]]
+ */
+public class Murmur3_128Suite {
+
+ @Test
+ public void testKnownValues() throws Exception {
+ assertHash(0, 0x629942693e10f867L, 0x92db0b82baeb5347L, "hell");
+ assertHash(1, 0xa78ddff5adae8d10L, 0x128900ef20900135L, "hello");
+ assertHash(2, 0x8a486b23f422e826L, 0xf962a2c58947765fL, "hello ");
+ assertHash(3, 0x2ea59f466f6bed8cL, 0xc610990acc428a17L, "hello w");
+ assertHash(4, 0x79f6305a386c572cL, 0x46305aed3483b94eL, "hello wo");
+ assertHash(5, 0xc2219d213ec1f1b5L, 0xa1d8e2e0a52785bdL, "hello wor");
+ assertHash(0, 0xe34bbc7bbc071b6cL, 0x7a433ca9c49a9347L,
+ "The quick brown fox jumps over the lazy dog");
+ assertHash(0, 0x658ca970ff85269aL, 0x43fee3eaa68e5c3eL,
+ "The quick brown fox jumps over the lazy cog");
+ }
+
+ private static void assertHash(int seed, long expected1, long expected2, String stringInput) throws Exception {
+ long[] hash128bit = new long[2];
+ byte[] data = stringInput.getBytes("UTF-8");
+ Murmur3_128.hashBytes(data, seed, hash128bit);
+ Assert.assertEquals(expected1, hash128bit[0]);
+ Assert.assertEquals(expected2, hash128bit[1]);
+ }
+}
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
index a0408d2da4df..cee8370256e9 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
@@ -104,9 +104,7 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
testMergeInPlace[T](typeName, numItems)(itemGen)
}
- testItemType[Byte]("Byte", 160) { _.nextInt().toByte }
-
- testItemType[Short]("Short", 1000) { _.nextInt().toShort }
+ testItemType[Short]("Short", 1000) { _.nextInt(Short.MaxValue + 1).toShort }
testItemType[Int]("Int", 100000) { _.nextInt() }
@@ -131,4 +129,59 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
filter1.mergeInPlace(filter2)
}
}
+
+ // Separate test for Byte type, since we can enumerate all possible values
+ test(s"accuracy - Byte") {
+ // Byte is from -128 to 127 inclusive
+ val allBytes = (-128 to 127).map(_.toByte)
+ val fpp = 0.05
+ val filter = BloomFilter.create(129, fpp)
+
+ val isEven: Byte => Boolean = _ % 2 == 0
+ val even = allBytes.filter(isEven)
+ val odd = allBytes.filterNot(isEven)
+ // insert first `numInsertion` items.
+ even.foreach(filter.put)
+
+ // false negative is not allowed.
+ assert(even.forall(filter.mightContain))
+
+ // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp`
+ // should not be significantly higher than the one we passed in to create this bloom filter.
+ assert(filter.expectedFpp() - fpp < EPSILON)
+
+ val errorCount = odd.count(filter.mightContain)
+ // Also check the actual fpp is not significantly higher than we expected.
+ val actualFpp = errorCount.toDouble / odd.length
+ assert(actualFpp - fpp < EPSILON)
+
+ checkSerDe(filter)
+ }
+
+ // Separate test for Byte type, since we can enumerate all possible values
+ test(s"mergeInPlace - Byte") {
+ val allBytes = (-128 to 127).map(_.toByte)
+ val fpp = 0.05
+
+ val isEven: Byte => Boolean = _ % 2 == 0
+ val even = allBytes.filter(isEven)
+ val odd = allBytes.filterNot(isEven)
+
+ val filter1 = BloomFilter.create(256)
+ even.foreach(filter1.put)
+
+ val filter2 = BloomFilter.create(256)
+ odd.foreach(filter2.put)
+
+ filter1.mergeInPlace(filter2)
+
+ // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the
+ // `expectedFpp` should not be significantly higher than the default one.
+ assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON)
+
+ even.foreach(i => assert(filter1.mightContain(i)))
+ odd.foreach(i => assert(filter1.mightContain(i)))
+
+ checkSerDe(filter1)
+ }
}
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala
similarity index 73%
rename from common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala
rename to common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala
index ff728f0ebcb8..16ded6af3a87 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/RoaringBitmapArraySuite.scala
@@ -21,34 +21,17 @@ import scala.util.Random
import org.scalatest.FunSuite // scalastyle:ignore funsuite
-class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite
+class RoaringBitmapArraySuite extends FunSuite { // scalastyle:ignore funsuite
- test("error case when create BitArray") {
- intercept[IllegalArgumentException](new BitArray(0))
- intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1))
- }
-
- test("bitSize") {
- assert(new BitArray(64).bitSize() == 64)
- // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits.
- assert(new BitArray(65).bitSize() == 128)
- assert(new BitArray(127).bitSize() == 128)
- assert(new BitArray(128).bitSize() == 128)
- }
-
- test("set") {
- val bitArray = new BitArray(64)
- assert(bitArray.set(1))
- // Only returns true if the bit changed.
- assert(!bitArray.set(1))
- assert(bitArray.set(2))
+ test("error case when create RoaringBitmapArray") {
+ intercept[IllegalArgumentException](new RoaringBitmapArray(0))
}
test("normal operation") {
// use a fixed seed to make the test predictable.
val r = new Random(37)
- val bitArray = new BitArray(320)
+ val bitArray = new RoaringBitmapArray(320)
val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct
indexes.foreach(bitArray.set)
@@ -56,12 +39,20 @@ class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite
assert(bitArray.cardinality() == indexes.length)
}
+ test("set") {
+ val bitArray = new RoaringBitmapArray(64)
+ assert(bitArray.set(1))
+ // Only returns true if the bit changed.
+ assert(!bitArray.set(1))
+ assert(bitArray.set(2))
+ }
+
test("merge") {
// use a fixed seed to make the test predictable.
val r = new Random(37)
- val bitArray1 = new BitArray(64 * 6)
- val bitArray2 = new BitArray(64 * 6)
+ val bitArray1 = new RoaringBitmapArray(64 * 6)
+ val bitArray2 = new RoaringBitmapArray(64 * 6)
val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct
val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct