diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 5c01841e5015a..e05f1ac3a50ae 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -201,7 +201,7 @@ private static int optimalNumOfHashFunctions(long n, long m) { * @param n expected insertions (must be positive) * @param p false positive rate (must be 0 < p < 1) */ - private static long optimalNumOfBits(long n, double p) { + static long optimalNumOfBits(long n, double p) { return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 0d4372b8738ee..f12e7d8fe9425 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} +import java.io.ByteArrayInputStream import scala.collection.JavaConverters._ @@ -25,7 +26,7 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. @@ -584,6 +585,86 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } CountMinSketch.readFrom(ds.head()) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) + } + + private def buildBloomFilter( + col: Column, + expectedNumItems: Long, + numBits: Long, + fpp: Double): BloomFilter = { + + val agg = if (!fpp.isNaN) { + Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(fpp)) + } else { + Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) + } + + val ds = sparkSession.newDataset(BinaryEncoder) { builder => + builder.getProjectBuilder + .setInput(root) + .addExpressions(agg.expr) + } + BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) + } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 62ff21332f105..b2686c9dcecd4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -176,4 +176,81 @@ class ClientDataFrameStatSuite extends RemoteSparkSession { assert(sketch.relativeError() === 0.001) assert(sketch.confidence() === 0.99 +- 5e-3) } + + test("Bloom filter -- Long Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toLong) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Int Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Short Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toShort) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Byte Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toByte) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- String Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toString) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + private def checkBloomFilter(data: Seq[Any], df: DataFrame) = { + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(data.forall(filter1.mightContain)) + val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter2.bitSize() == 64 * 5) + assert(data.forall(filter2.mightContain)) + } + + test("Bloom filter -- Wrong dataType Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toDouble) + val message = intercept[SparkException] { + data.toDF("id").stat.bloomFilter("id", 1000, 0.03) + }.getMessage + assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE")) + } + + test("Bloom filter test invalid inputs") { + val df = spark.range(1000).toDF("id") + val message1 = intercept[SparkException] { + df.stat.bloomFilter("id", -1000, 100) + }.getMessage + assert(message1.contains("Expected insertions must be positive")) + + val message2 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -100) + }.getMessage + assert(message2.contains("Number of bits must be positive")) + + val message3 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -1.0) + }.getMessage + assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f21257..96831e25b6440 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -163,9 +163,6 @@ object CheckConnectJvmClientCompatibility { // DataFrameNaFunctions ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"), - // DataFrameStatFunctions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"), - // Dataset ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.Dataset$" // private[sql] diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7136476b515f9..437bba05f51f0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical @@ -79,6 +80,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId import org.apache.spark.util.Utils +import org.apache.spark.util.sketch.BloomFilterHelper final case class InvalidCommandInput( private val message: String = "", @@ -1727,6 +1729,50 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val ignoreNulls = extractBoolean(children(3), "ignoreNulls") Some(Lead(children.head, children(1), children(2), ignoreNulls)) + case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + // [col, expectedNumItems: Long, numBits: Long] or + // [col, expectedNumItems: Long, fpp: Double] + val children = fun.getArgumentsList.asScala.map(transformExpression) + + // Check expectedNumItems > 0L + val expectedNumItemsExpr = children(1) + val expectedNumItems = expectedNumItemsExpr match { + case Literal(l: Long, LongType) => l + case _ => + throw InvalidPlanInput("Expected insertions must be long literal.") + } + if (expectedNumItems <= 0L) { + throw InvalidPlanInput("Expected insertions must be positive.") + } + + val numberBitsOrFpp = children(2) + + val numBitsExpr = numberBitsOrFpp match { + case Literal(numBits: Long, LongType) => + // Check numBits > 0L + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive.") + } + numberBitsOrFpp + case DoubleLiteral(fpp) => + // Check fpp not NaN and in (0.0, 1.0). + if (fpp.isNaN || fpp <= 0d || fpp >= 1d) { + throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") + } + // Calculate numBits through expectedNumItems and fpp, + // refer to `BloomFilter.optimalNumOfBits(long, double)`. + val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp) + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive") + } + Literal(numBits, LongType) + case _ => + throw InvalidPlanInput("The 3rd parameter must be double or long literal.") + } + Some( + new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) + .toAggregateExpression()) + case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => val children = fun.getArgumentsList.asScala.map(transformExpression) val timeCol = children.head diff --git a/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala new file mode 100644 index 0000000000000..bbb0ee3c2f1a4 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +/** + * `BloomFilterHelper` is used to bridge helper methods in BloomFilter` + */ +private[spark] object BloomFilterHelper { + def optimalNumOfBits(expectedNumItems: Long, fpp: Double): Long = + BloomFilter.optimalNumOfBits(expectedNumItems, fpp) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 980785e764cdb..2325b2b055f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.BloomFilter /** @@ -78,7 +79,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType, LongType, LongType) => + case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -150,6 +151,15 @@ case class BloomFilterAggregate( Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + // Mark as lazy so that `updater` is not evaluated during tree transformation. + private lazy val updater: BloomFilterUpdater = first.dataType match { + case LongType => LongUpdater + case IntegerType => IntUpdater + case ShortType => ShortUpdater + case ByteType => ByteUpdater + case StringType => BinaryUpdater + } + override def first: Expression = child override def second: Expression = estimatedNumItemsExpression @@ -174,7 +184,7 @@ case class BloomFilterAggregate( if (value == null) { return buffer } - buffer.putLong(value.asInstanceOf[Long]) + updater.update(buffer, value) buffer } @@ -224,3 +234,32 @@ object BloomFilterAggregate { bloomFilter } } + +private trait BloomFilterUpdater { + def update(bf: BloomFilter, v: Any): Boolean +} + +private object LongUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Long]) +} + +private object IntUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Int]) +} + +private object ShortUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Short]) +} + +private object ByteUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Byte]) +} + +private object BinaryUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putBinary(v.asInstanceOf[UTF8String].getBytes) +}