Skip to content

Commit

Permalink
Add BloomFilter conf
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophie Wang committed Oct 2, 2023
1 parent 3e854df commit 8d0240a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 26 deletions.
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object Extensions {
totalCount: Long,
tableName: String,
partitionRange: PartitionRange,
fpp: Double = 0.03): BloomFilter = {
fpp: Double = 0.01): BloomFilter = {
val approxCount =
df.filter(df.col(col).isNotNull).select(approx_count_distinct(col)).collect()(0).getLong(0)
if (approxCount == 0) {
Expand Down
39 changes: 20 additions & 19 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,33 +173,34 @@ abstract class JoinBase(joinConf: api.Join,
return None
}

//todo validate how threshold should be used. left count?
println(s"\nBackfill is required for ${joinPart.groupBy.metaData.name} for $rowCount rows on range $unfilledRange")

val leftBlooms = joinConf.leftKeyCols.toSeq.map { key =>
key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange)
}.toMap
val rightBloomMap = JoinUtils.genBloomFilterIfNeeded(leftDf, joinPart, joinConf, rowCount, unfilledRange, tableUtils)
// val leftBlooms = joinConf.leftKeyCols.toSeq.map { key =>
// key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange)
// }.toMap
//
// val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap
// val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty
// println(s"""
// |JoinPart Info:
// | part name : ${joinPart.groupBy.metaData.name},
// | left type : ${joinConf.left.dataModel},
// | right type: ${joinPart.groupBy.dataModel},
// | accuracy : ${joinPart.groupBy.inferredAccuracy},
// | part unfilled range: $unfilledRange,
// | left row count: $rowCount
// | bloom sizes: $bloomSizes
// | groupBy: ${joinPart.groupBy.toString}
// |""".stripMargin)

val rightSkewFilter = joinConf.partSkewFilter(joinPart)
val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap
val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty
println(s"""
|JoinPart Info:
| part name : ${joinPart.groupBy.metaData.name},
| left type : ${joinConf.left.dataModel},
| right type: ${joinPart.groupBy.dataModel},
| accuracy : ${joinPart.groupBy.inferredAccuracy},
| part unfilled range: $unfilledRange,
| left row count: $rowCount
| bloom sizes: $bloomSizes
| groupBy: ${joinPart.groupBy.toString}
|""".stripMargin)

def genGroupBy(partitionRange: PartitionRange) =
GroupBy.from(joinPart.groupBy,
partitionRange,
tableUtils,
computeDependency = true,
Option(rightBloomMap),
rightBloomMap,
rightSkewFilter,
mutationScan = mutationScan,
showDf = showDf)
Expand Down
48 changes: 47 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package ai.chronon.spark

import ai.chronon.api.Constants
import ai.chronon.api.{Constants, JoinPart}
import ai.chronon.api.DataModel.Events
import ai.chronon.api.Extensions._
import ai.chronon.api.Extensions.JoinOps
import ai.chronon.spark.Extensions._
import com.google.gson.Gson
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{coalesce, col, udf}
import org.apache.spark.util.sketch.BloomFilter

import scala.collection.JavaConverters._
import scala.collection.Seq
import scala.util.ScalaJavaConversions.MapOps

object JoinUtils {
Expand Down Expand Up @@ -262,6 +266,48 @@ object JoinUtils {
labelMap.groupBy(_._2).map { case (v, kvs) => (v, tableUtils.chunk(kvs.keySet.toSet)) }
}

/**
* Generate bloomfilter for joinPart if needed
* @param df
* @param filter
* @return
*/

def genBloomFilterIfNeeded(leftDf: DataFrame,
joinPart: ai.chronon.api.JoinPart,
joinConf: ai.chronon.api.Join,
rowCount: Long,
unfilledRange: PartitionRange,
tableUtils: TableUtils): Option[Map[String, BloomFilter]] = {
println(
s"\nRow count to be filled for ${joinPart.groupBy.metaData.name}. BloomFilter Threshold: ${tableUtils.bloomFilterThreshold}")

// apply bloom filter when row count is below threshold
if (rowCount > tableUtils.bloomFilterThreshold) {
println("Row count is above threshold. Skip gen bloom filter.")
Option.empty
} else {
val leftBlooms = joinConf.leftKeyCols.toSeq.map { key =>
key -> leftDf.generateBloomFilter(key, rowCount, joinConf.left.table, unfilledRange)
}.toMap

val rightBloomMap = joinPart.rightToLeft.mapValues(leftBlooms(_)).toMap
val bloomSizes = rightBloomMap.map { case (col, bloom) => s"$col -> ${bloom.bitSize()}" }.pretty
println(s"""
|JoinPart Info:
| part name : ${joinPart.groupBy.metaData.name},
| left type : ${joinConf.left.dataModel},
| right type: ${joinPart.groupBy.dataModel},
| accuracy : ${joinPart.groupBy.inferredAccuracy},
| part unfilled range: $unfilledRange,
| left row count: $rowCount
| bloom sizes: $bloomSizes
| groupBy: ${joinPart.groupBy.toString}
|""".stripMargin)
Some(rightBloomMap)
}
}

def filterColumns(df: DataFrame, filter: Seq[String]): DataFrame = {
val columnsToDrop = df.columns
.filterNot(col => filter.contains(col))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.util.Properties

object SparkSessionBuilder {

val DefaultWarehouseDir = new File("/tmp/chronon/spark-warehouse")
val DefaultWarehouseDir = new File("./spark-warehouse")

def expandUser(path: String): String = path.replaceFirst("~", System.getProperty("user.home"))
// we would want to share locally generated warehouse during CI testing
Expand Down Expand Up @@ -40,6 +40,7 @@ object SparkSessionBuilder {
.config("spark.sql.catalogImplementation", "hive")
.config("spark.hadoop.hive.exec.max.dynamic.partitions", 30000)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.chronon.backfill.bloomfilter.threshold","100")

additionalConfig.foreach { configMap =>
configMap.foreach { config => baseBuilder = baseBuilder.config(config._1, config._2) }
Expand Down
3 changes: 3 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ case class TableUtils(sparkSession: SparkSession) {
sparkSession.conf.get("spark.chronon.partition.format", "yyyy-MM-dd")
val partitionSpec: PartitionSpec = PartitionSpec(partitionFormat, WindowUtils.Day.millis)
val backfillValidationEnforced = sparkSession.conf.get("spark.chronon.backfill.validation.enabled", "true").toBoolean
// Threshold to control whether or not to use bloomfilter on join backfill. If the row approximate count is under this threshold, we will use bloomfilter.
// We are choosing approximate count so that optimal number of bits is at-least 1G for default fpp of 0.01
val bloomFilterThreshold = sparkSession.conf.get("spark.chronon.backfill.bloomfilter.threshold", "800000000").toLong

sparkSession.sparkContext.setLogLevel("ERROR")
// converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")
Expand Down
9 changes: 5 additions & 4 deletions spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import scala.util.ScalaJavaConversions.ListOps

class JoinTest {

val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
val spark: SparkSession = SparkSessionBuilder.build(
"JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
private val tableUtils = TableUtils(spark)

private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
Expand Down Expand Up @@ -368,9 +371,7 @@ class JoinTest {
DataFrameGen.entities(spark, weightSchema, 1000, partitions = 400).save(weightTable)

val weightSource = Builders.Source.entities(
query = Builders.Query(selects = Builders.Selects("weight"),
startPartition = yearAgo,
endPartition = today),
query = Builders.Query(selects = Builders.Selects("weight"), startPartition = yearAgo, endPartition = today),
snapshotTable = weightTable
)

Expand Down

0 comments on commit 8d0240a

Please sign in to comment.