From 8d0240a3b49fdf76b8fe3bc3a2fac43d0d402566 Mon Sep 17 00:00:00 2001 From: Sophie Wang Date: Mon, 2 Oct 2023 15:44:59 -0700 Subject: [PATCH] Add BloomFilter conf --- .../scala/ai/chronon/spark/Extensions.scala | 2 +- .../scala/ai/chronon/spark/JoinBase.scala | 39 +++++++-------- .../scala/ai/chronon/spark/JoinUtils.scala | 48 ++++++++++++++++++- .../chronon/spark/SparkSessionBuilder.scala | 3 +- .../scala/ai/chronon/spark/TableUtils.scala | 3 ++ .../ai/chronon/spark/test/JoinTest.scala | 9 ++-- 6 files changed, 78 insertions(+), 26 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index 3a25039778..751379997e 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -132,7 +132,7 @@ object Extensions { totalCount: Long, tableName: String, partitionRange: PartitionRange, - fpp: Double = 0.03): BloomFilter = { + fpp: Double = 0.01): BloomFilter = { val approxCount = df.filter(df.col(col).isNotNull).select(approx_count_distinct(col)).collect()(0).getLong(0) if (approxCount == 0) { diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 3e0fd0fac9..7894f595b8 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -173,33 +173,34 @@ abstract class JoinBase(joinConf: api.Join, return None } + //todo validate how threshold should be used. left count? println(s"\nBackfill is required for ${joinPart.groupBy.metaData.name} for $rowCount rows on range $unfilledRange") - - val leftBlooms = joinConf.leftKeyCols.toSeq.map { key => - key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange) - }.toMap + val rightBloomMap = JoinUtils.genBloomFilterIfNeeded(leftDf, joinPart, joinConf, rowCount, unfilledRange, tableUtils) +// val leftBlooms = joinConf.leftKeyCols.toSeq.map { key => +// key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange) +// }.toMap +// +// val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap +// val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty +// println(s""" +// |JoinPart Info: +// | part name : ${joinPart.groupBy.metaData.name}, +// | left type : ${joinConf.left.dataModel}, +// | right type: ${joinPart.groupBy.dataModel}, +// | accuracy : ${joinPart.groupBy.inferredAccuracy}, +// | part unfilled range: $unfilledRange, +// | left row count: $rowCount +// | bloom sizes: $bloomSizes +// | groupBy: ${joinPart.groupBy.toString} +// |""".stripMargin) val rightSkewFilter = joinConf.partSkewFilter(joinPart) - val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap - val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty - println(s""" - |JoinPart Info: - | part name : ${joinPart.groupBy.metaData.name}, - | left type : ${joinConf.left.dataModel}, - | right type: ${joinPart.groupBy.dataModel}, - | accuracy : ${joinPart.groupBy.inferredAccuracy}, - | part unfilled range: $unfilledRange, - | left row count: $rowCount - | bloom sizes: $bloomSizes - | groupBy: ${joinPart.groupBy.toString} - |""".stripMargin) - def genGroupBy(partitionRange: PartitionRange) = GroupBy.from(joinPart.groupBy, partitionRange, tableUtils, computeDependency = true, - Option(rightBloomMap), + rightBloomMap, rightSkewFilter, mutationScan = mutationScan, showDf = showDf) diff --git a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala index 29d7ae77fa..b97fcd60b1 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinUtils.scala @@ -1,14 +1,18 @@ package ai.chronon.spark -import ai.chronon.api.Constants +import ai.chronon.api.{Constants, JoinPart} import ai.chronon.api.DataModel.Events import ai.chronon.api.Extensions._ +import ai.chronon.api.Extensions.JoinOps import ai.chronon.spark.Extensions._ import com.google.gson.Gson import org.apache.spark.sql.DataFrame import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{coalesce, col, udf} +import org.apache.spark.util.sketch.BloomFilter +import scala.collection.JavaConverters._ +import scala.collection.Seq import scala.util.ScalaJavaConversions.MapOps object JoinUtils { @@ -262,6 +266,48 @@ object JoinUtils { labelMap.groupBy(_._2).map { case (v, kvs) => (v, tableUtils.chunk(kvs.keySet.toSet)) } } + /** + * Generate bloomfilter for joinPart if needed + * @param df + * @param filter + * @return + */ + + def genBloomFilterIfNeeded(leftDf: DataFrame, + joinPart: ai.chronon.api.JoinPart, + joinConf: ai.chronon.api.Join, + rowCount: Long, + unfilledRange: PartitionRange, + tableUtils: TableUtils): Option[Map[String, BloomFilter]] = { + println( + s"\nRow count to be filled for ${joinPart.groupBy.metaData.name}. BloomFilter Threshold: ${tableUtils.bloomFilterThreshold}") + + // apply bloom filter when row count is below threshold + if (rowCount > tableUtils.bloomFilterThreshold) { + println("Row count is above threshold. Skip gen bloom filter.") + Option.empty + } else { + val leftBlooms = joinConf.leftKeyCols.toSeq.map { key => + key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange) + }.toMap + + val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap + val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty + println(s""" + |JoinPart Info: + | part name : ${joinPart.groupBy.metaData.name}, + | left type : ${joinConf.left.dataModel}, + | right type: ${joinPart.groupBy.dataModel}, + | accuracy : ${joinPart.groupBy.inferredAccuracy}, + | part unfilled range: $unfilledRange, + | left row count: $rowCount + | bloom sizes: $bloomSizes + | groupBy: ${joinPart.groupBy.toString} + |""".stripMargin) + Some(rightBloomMap) + } + } + def filterColumns(df: DataFrame, filter: Seq[String]): DataFrame = { val columnsToDrop = df.columns .filterNot(col => filter.contains(col)) diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index 946d60d2a3..2ac7a1beb2 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -10,7 +10,7 @@ import scala.util.Properties object SparkSessionBuilder { - val DefaultWarehouseDir = new File("/tmp/chronon/spark-warehouse") + val DefaultWarehouseDir = new File("./spark-warehouse") def expandUser(path: String): String = path.replaceFirst("~", System.getProperty("user.home")) // we would want to share locally generated warehouse during CI testing @@ -40,6 +40,7 @@ object SparkSessionBuilder { .config("spark.sql.catalogImplementation", "hive") .config("spark.hadoop.hive.exec.max.dynamic.partitions", 30000) .config("spark.sql.legacy.timeParserPolicy", "LEGACY") + .config("spark.chronon.backfill.bloomfilter.threshold","100") additionalConfig.foreach { configMap => configMap.foreach { config => baseBuilder = baseBuilder.config(config._1, config._2) } diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index a388aae1e9..232e59f75b 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -24,6 +24,9 @@ case class TableUtils(sparkSession: SparkSession) { sparkSession.conf.get("spark.chronon.partition.format", "yyyy-MM-dd") val partitionSpec: PartitionSpec = PartitionSpec(partitionFormat, WindowUtils.Day.millis) val backfillValidationEnforced = sparkSession.conf.get("spark.chronon.backfill.validation.enabled", "true").toBoolean + // Threshold to control whether or not to use bloomfilter on join backfill. If the row approximate count is under this threshold, we will use bloomfilter. + // We are choosing approximate count so that optimal number of bits is at-least 1G for default fpp of 0.01 + val bloomFilterThreshold = sparkSession.conf.get("spark.chronon.backfill.bloomfilter.threshold", "800000000").toLong sparkSession.sparkContext.setLogLevel("ERROR") // converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d") diff --git a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala index 723eac282d..c5ad6f7f03 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala @@ -20,7 +20,10 @@ import scala.util.ScalaJavaConversions.ListOps class JoinTest { - val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true) + val spark: SparkSession = SparkSessionBuilder.build( + "JoinTest", + local = true, + additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100"))) private val tableUtils = TableUtils(spark) private val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) @@ -368,9 +371,7 @@ class JoinTest { DataFrameGen.entities(spark, weightSchema, 1000, partitions = 400).save(weightTable) val weightSource = Builders.Source.entities( - query = Builders.Query(selects = Builders.Selects("weight"), - startPartition = yearAgo, - endPartition = today), + query = Builders.Query(selects = Builders.Selects("weight"), startPartition = yearAgo, endPartition = today), snapshotTable = weightTable )