Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,45 @@ int getVersionNumber() {
public abstract long totalCount();

/**
* Adds 1 to {@code item}.
* Increments {@code item}'s count by one.
*/
public abstract void add(Object item);

/**
* Adds {@code count} to {@code item}.
* Increments {@code item}'s count by {@code count}.
*/
public abstract void add(Object item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addLong(long item);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add java doc.

also update the other java doc to say "Increment item's count by one." or "Increment item's count by count"


/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addLong(long item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addString(String item);

/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addString(String item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addBinary(byte[] item);

/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addBinary(byte[] item, long count);

/**
* Returns the estimated frequency of {@code item}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -146,27 +145,49 @@ public void add(Object item, long count) {
}
}

private void addString(String item, long count) {
@Override
public void addString(String item) {
addString(item, 1);
}

@Override
public void addString(String item, long count) {
addBinary(Utils.getBytesFromUTF8String(item), count);
}

@Override
public void addLong(long item) {
addLong(item, 1);
}

@Override
public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][buckets[i]] += count;
table[i][hash(item, i)] += count;
}

totalCount += count;
}

private void addLong(long item, long count) {
@Override
public void addBinary(byte[] item) {
addBinary(item, 1);
}

@Override
public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][hash(item, i)] += count;
table[i][buckets[i]] += count;
}

totalCount += count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types.{IntegralType, StringType}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
Expand Down Expand Up @@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
Expand Down Expand Up @@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(
colType == StringType || colType.isInstanceOf[IntegralType],
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
val updater: (CountMinSketch, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.rdd.aggregate(zero)(
(sketch: CountMinSketch, row: Row) => {
sketch.add(row.get(0))
singleCol.queryExecution.toRdd.aggregate(zero)(
(sketch: CountMinSketch, row: InternalRow) => {
updater(sketch, row)
sketch
},

(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
sketch1.mergeInPlace(sketch2)
}
(sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
)
}

Expand Down Expand Up @@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")

val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
(filter, row) =>
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
filter.putBinary(row.getUTF8String(0).getBytes)
filter
} else {
(filter, row) =>
// TODO: specialize it.
filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
filter
val updater: (BloomFilter, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
case ByteType => (filter, row) => filter.putLong(row.getByte(0))
case ShortType => (filter, row) => filter.putLong(row.getShort(0))
case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
case LongType => (filter, row) => filter.putLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Bloom filter only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
singleCol.queryExecution.toRdd.aggregate(zero)(
(filter: BloomFilter, row: InternalRow) => {
updater(filter, row)
filter
},
(filter1, filter2) => filter1.mergeInPlace(filter2)
)
}
}