Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
af12b55
support
LuciferYang Mar 9, 2023
6f2e736
format
LuciferYang Mar 9, 2023
92da589
remove Alias
LuciferYang Mar 9, 2023
71fba34
support Int/Byte/Short
LuciferYang Mar 10, 2023
ab78bae
support string
LuciferYang Mar 10, 2023
ebcec0b
support string
LuciferYang Mar 10, 2023
6ee5528
change to check numBits
LuciferYang Mar 10, 2023
dfffd53
tmp fix
LuciferYang Mar 10, 2023
8d0e517
refactor
LuciferYang Mar 10, 2023
d193de0
move check to server side
LuciferYang Mar 10, 2023
943fcdb
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 13, 2023
dc62996
try remove lazy
LuciferYang Mar 14, 2023
df56a05
make lazy
LuciferYang Mar 14, 2023
d225405
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 15, 2023
f588a6d
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 16, 2023
e7bb98a
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 16, 2023
7a02ad1
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 21, 2023
f6474c8
add Serializable
LuciferYang Mar 21, 2023
6c756a6
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 22, 2023
824fdaa
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 23, 2023
ad1d205
change since from 3.4.0 to 3.5.0
LuciferYang Mar 23, 2023
9b05105
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 27, 2023
b22a615
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 28, 2023
64bb472
Merge branch 'apache:master' into SPARK-42664
LuciferYang Mar 31, 2023
dbada4c
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 3, 2023
886809b
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 4, 2023
176ce09
merge case
LuciferYang Apr 7, 2023
ea2234c
refactor to remove pass dataType
LuciferYang Apr 7, 2023
5bd6169
chanage to always pass 3 parameters
LuciferYang Apr 7, 2023
43d6305
add NaN check in Planner
LuciferYang Apr 7, 2023
dbbb61a
bridge optimalNumOfBits
LuciferYang Apr 7, 2023
9520355
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 8, 2023
3a24371
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 10, 2023
c9c9f79
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 17, 2023
a244420
Merge branch 'apache:master' into SPARK-42664
LuciferYang Apr 19, 2023
d004993
Merge branch 'upmaster' into SPARK-42664
LuciferYang Apr 24, 2023
62fdf91
Merge branch 'apache:master' into SPARK-42664
LuciferYang May 8, 2023
60217da
Merge branch 'apache:master' into SPARK-42664
LuciferYang May 11, 2023
79d7d89
Merge branch 'apache:master' into SPARK-42664
LuciferYang May 15, 2023
61bea56
Merge branch 'apache:master' into SPARK-42664
LuciferYang May 16, 2023
80d6e6d
Merge branch 'apache:master' into SPARK-42664
LuciferYang May 31, 2023
caab9bf
Merge branch 'apache:master' into SPARK-42664
LuciferYang Jun 1, 2023
e3aeb86
Merge branch 'upmaster' into SPARK-42664
LuciferYang Jun 5, 2023
1fdc5c0
Merge branch 'apache:master' into SPARK-42664
LuciferYang Jun 16, 2023
8861190
Merge branch 'apache:master' into SPARK-42664
LuciferYang Jun 28, 2023
efcc1b5
Merge branch 'upmaster' into SPARK-42664
LuciferYang Jul 13, 2023
5134264
Merge branch 'upmaster' into SPARK-42664
LuciferYang Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ private static int optimalNumOfHashFunctions(long n, long m) {
* @param n expected insertions (must be positive)
* @param p false positive rate (must be 0 < p < 1)
*/
private static long optimalNumOfBits(long n, double p) {
static long optimalNumOfBits(long n, double p) {
return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
package org.apache.spark.sql

import java.{lang => jl, util => ju}
import java.io.ByteArrayInputStream

import scala.collection.JavaConverters._

import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.functions.lit
import org.apache.spark.util.sketch.CountMinSketch
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
* Statistic functions for `DataFrame`s.
Expand Down Expand Up @@ -584,6 +585,86 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo
}
CountMinSketch.readFrom(ds.head())
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param fpp
* expected false positive probability of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(col, expectedNumItems, -1L, fpp)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param colName
* name of the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
}

/**
* Builds a Bloom filter over a specified column.
*
* @param col
* the column over which the filter is built
* @param expectedNumItems
* expected number of items which will be put into the filter.
* @param numBits
* expected number of bits of the filter.
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
}

private def buildBloomFilter(
col: Column,
expectedNumItems: Long,
numBits: Long,
fpp: Double): BloomFilter = {

val agg = if (!fpp.isNaN) {
Copy link
Contributor Author

@LuciferYang LuciferYang Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before chanage to always pass 3 parameters, always pass all 4 parameters.

Now change to pass (col, expectedNumItems, fpp) if !fpp. isNaN , otherwise pass (col, expectedNumItems, numBits).

Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(fpp))
} else {
Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits))
}

val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
builder.getProjectBuilder
.setInput(root)
.addExpressions(agg.expr)
}
BloomFilter.readFrom(new ByteArrayInputStream(ds.head()))
}
}

private object DataFrameStatFunctions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,81 @@ class ClientDataFrameStatSuite extends RemoteSparkSession {
assert(sketch.relativeError() === 0.001)
assert(sketch.confidence() === 0.99 +- 5e-3)
}

test("Bloom filter -- Long Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toLong)
val df = data.toDF("id")
checkBloomFilter(data, df)
}

test("Bloom filter -- Int Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000)
val df = data.toDF("id")
checkBloomFilter(data, df)
}

test("Bloom filter -- Short Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toShort)
val df = data.toDF("id")
checkBloomFilter(data, df)
}

test("Bloom filter -- Byte Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toByte)
val df = data.toDF("id")
checkBloomFilter(data, df)
}

test("Bloom filter -- String Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toString)
val df = data.toDF("id")
checkBloomFilter(data, df)
}

private def checkBloomFilter(data: Seq[Any], df: DataFrame) = {
val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
assert(filter1.expectedFpp() - 0.03 < 1e-3)
assert(data.forall(filter1.mightContain))
val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5)
assert(filter2.bitSize() == 64 * 5)
assert(data.forall(filter2.mightContain))
}

test("Bloom filter -- Wrong dataType Column") {
val session = spark
import session.implicits._
val data = Range(0, 1000).map(_.toDouble)
val message = intercept[SparkException] {
data.toDF("id").stat.bloomFilter("id", 1000, 0.03)
}.getMessage
assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE"))
}

test("Bloom filter test invalid inputs") {
val df = spark.range(1000).toDF("id")
val message1 = intercept[SparkException] {
df.stat.bloomFilter("id", -1000, 100)
}.getMessage
assert(message1.contains("Expected insertions must be positive"))

val message2 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -100)
}.getMessage
assert(message2.contains("Number of bits must be positive"))

val message3 = intercept[SparkException] {
df.stat.bloomFilter("id", 1000, -1.0)
}.getMessage
assert(message3.contains("False positive probability must be within range (0.0, 1.0)"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility {
// DataFrameNaFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),

// DataFrameStatFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),

// Dataset
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.Dataset$" // private[sql]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -79,6 +80,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.CacheId
import org.apache.spark.util.Utils
import org.apache.spark.util.sketch.BloomFilterHelper

final case class InvalidCommandInput(
private val message: String = "",
Expand Down Expand Up @@ -1727,6 +1729,50 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lead(children.head, children(1), children(2), ignoreNulls))

case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long] or
// [col, expectedNumItems: Long, fpp: Double]
val children = fun.getArgumentsList.asScala.map(transformExpression)

// Check expectedNumItems > 0L
val expectedNumItemsExpr = children(1)
val expectedNumItems = expectedNumItemsExpr match {
case Literal(l: Long, LongType) => l
case _ =>
throw InvalidPlanInput("Expected insertions must be long literal.")
}
if (expectedNumItems <= 0L) {
throw InvalidPlanInput("Expected insertions must be positive.")
}

val numberBitsOrFpp = children(2)

val numBitsExpr = numberBitsOrFpp match {
case Literal(numBits: Long, LongType) =>
// Check numBits > 0L
if (numBits <= 0L) {
throw InvalidPlanInput("Number of bits must be positive.")
}
numberBitsOrFpp
case DoubleLiteral(fpp) =>
// Check fpp not NaN and in (0.0, 1.0).
if (fpp.isNaN || fpp <= 0d || fpp >= 1d) {
throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)")
}
// Calculate numBits through expectedNumItems and fpp,
// refer to `BloomFilter.optimalNumOfBits(long, double)`.
val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp)
if (numBits <= 0L) {
throw InvalidPlanInput("Number of bits must be positive")
}
Literal(numBits, LongType)
case _ =>
throw InvalidPlanInput("The 3rd parameter must be double or long literal.")
}
Some(
new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr)
.toAggregateExpression())

case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val timeCol = children.head
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.sketch

/**
* `BloomFilterHelper` is used to bridge helper methods in BloomFilter`
*/
private[spark] object BloomFilterHelper {
def optimalNumOfBits(expectedNumItems: Long, fpp: Double): Long =
BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.BloomFilter

/**
Expand Down Expand Up @@ -78,7 +79,7 @@ case class BloomFilterAggregate(
"exprName" -> "estimatedNumItems or numBits"
)
)
case (LongType, LongType, LongType) =>
case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) =>
if (!estimatedNumItemsExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Expand Down Expand Up @@ -150,6 +151,15 @@ case class BloomFilterAggregate(
Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))

// Mark as lazy so that `updater` is not evaluated during tree transformation.
private lazy val updater: BloomFilterUpdater = first.dataType match {
case LongType => LongUpdater
case IntegerType => IntUpdater
case ShortType => ShortUpdater
case ByteType => ByteUpdater
case StringType => BinaryUpdater
}

override def first: Expression = child

override def second: Expression = estimatedNumItemsExpression
Expand All @@ -174,7 +184,7 @@ case class BloomFilterAggregate(
if (value == null) {
return buffer
}
buffer.putLong(value.asInstanceOf[Long])
updater.update(buffer, value)
buffer
}

Expand Down Expand Up @@ -224,3 +234,32 @@ object BloomFilterAggregate {
bloomFilter
}
}

private trait BloomFilterUpdater {
def update(bf: BloomFilter, v: Any): Boolean
}

private object LongUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Long])
}

private object IntUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Int])
}

private object ShortUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Short])
}

private object ByteUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putLong(v.asInstanceOf[Byte])
}

private object BinaryUpdater extends BloomFilterUpdater with Serializable {
override def update(bf: BloomFilter, v: Any): Boolean =
bf.putBinary(v.asInstanceOf[UTF8String].getBytes)
}