Skip to content

Commit

Permalink
early exits
Browse files Browse the repository at this point in the history
  • Loading branch information
Donghan Zhang committed Feb 21, 2024
1 parent 9a82c32 commit e2873ce
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 29 deletions.
31 changes: 19 additions & 12 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ object Driver {
descr =
"Start date to compute join backfill, this start date will override start partition in conf.")
val selectedJoinParts: ScallopOption[List[String]] =
opt[List[String]](required = false,
descr = "A list of join parts that require backfilling.")
opt[List[String]](required = false, descr = "A list of join parts that require backfilling.")
lazy val joinConf: api.Join = parseConf[api.Join](confPath())
override def subcommandName() = s"join_${joinConf.metaData.name}"
}
Expand All @@ -250,18 +249,26 @@ object Driver {
)
val df = join.computeJoin(args.stepDays.toOption, args.startPartitionOverride.toOption)

if (args.shouldExport()) {
args.exportTableToLocal(args.joinConf.metaData.outputTable, tableUtils)
}
df match {
case None => {
logger.info("Selected join parts are populated. No final join is required.")
return
}
case Some(df) => {
if (args.shouldExport()) {
args.exportTableToLocal(args.joinConf.metaData.outputTable, tableUtils)
}

if (args.shouldPerformValidate()) {
val keys = CompareJob.getJoinKeys(args.joinConf, tableUtils)
args.validateResult(df, keys, tableUtils)
}
if (args.shouldPerformValidate()) {
val keys = CompareJob.getJoinKeys(args.joinConf, tableUtils)
args.validateResult(df, keys, tableUtils)
}

df.show(numRows = 3, truncate = 0, vertical = true)
logger.info(
s"\nShowing three rows of output above.\nQuery table `${args.joinConf.metaData.outputTable}` for more.\n")
df.show(numRows = 3, truncate = 0, vertical = true)
logger.info(
s"\nShowing three rows of output above.\nQuery table `${args.joinConf.metaData.outputTable}` for more.\n")
}
}
}
}

Expand Down
16 changes: 10 additions & 6 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Join(joinConf: api.Join,
mutationScan: Boolean = true,
showDf: Boolean = false,
selectedJoinParts: Option[List[String]] = None)
extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, mutationScan, showDf) {
extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, mutationScan, showDf, selectedJoinParts) {

private val bootstrapTable = joinConf.metaData.bootstrapTable

Expand Down Expand Up @@ -157,8 +157,7 @@ class Join(joinConf: api.Join,

val coveringSetsPerJoinPart: Seq[(JoinPartMetadata, Seq[CoveringSet])] = bootstrapInfo.joinParts
.filter(part => selectedJoinParts.isEmpty || selectedJoinParts.get.contains(part.joinPart.fullPrefix))
.map {
joinPartMetadata =>
.map { joinPartMetadata =>
val coveringSets = distinctBootstrapSets.map {
case (hashes, rowCount) =>
val schema = hashes.toSet.flatMap(bootstrapInfo.hashToSchema.apply)
Expand All @@ -172,7 +171,7 @@ class Join(joinConf: api.Join,
CoveringSet(hashes, rowCount, isCovering)
}
(joinPartMetadata, coveringSets)
}
}

logger.info(
s"\n======= CoveringSet for JoinPart ${joinConf.metaData.name} for PartitionRange(${leftRange.start}, ${leftRange.end}) =======\n")
Expand All @@ -188,7 +187,9 @@ class Join(joinConf: api.Join,
coveringSetsPerJoinPart
}

override def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame = {
override def computeRange(leftDf: DataFrame,
leftRange: PartitionRange,
bootstrapInfo: BootstrapInfo): Option[DataFrame] = {
val leftTaggedDf = if (leftDf.schema.names.contains(Constants.TimeColumn)) {
leftDf.withTimeBasedColumn(Constants.TimePartitionColumn)
} else {
Expand Down Expand Up @@ -262,6 +263,9 @@ class Join(joinConf: api.Join,
}
val rightResults = Await.result(Future.sequence(rightResultsFuture), Duration.Inf).flatten

// early exit if selectedJoinParts is defined. Otherwise, we combine all join parts
if (selectedJoinParts.isDefined) return None

// combine bootstrap table and join part tables
// sequentially join bootstrap table and each join part table. some column may exist both on left and right because
// a bootstrap source can cover a partial date range. we combine the columns using coalesce-rule
Expand Down Expand Up @@ -290,7 +294,7 @@ class Join(joinConf: api.Join,
bootstrapInfo,
leftDf.columns)
finalDf.explain()
finalDf
Some(finalDf)
}

def applyDerivation(baseDf: DataFrame, bootstrapInfo: BootstrapInfo, leftColumns: Seq[String]): DataFrame = {
Expand Down
36 changes: 25 additions & 11 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ abstract class JoinBase(joinConf: api.Join,
tableUtils: TableUtils,
skipFirstHole: Boolean,
mutationScan: Boolean = true,
showDf: Boolean = false) {
showDf: Boolean = false,
selectedJoinParts: Option[Seq[String]] = None) {
@transient lazy val logger = LoggerFactory.getLogger(getClass)
assert(Option(joinConf.metaData.outputNamespace).nonEmpty, s"output namespace could not be empty or null")
val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConf)
Expand Down Expand Up @@ -286,9 +287,9 @@ abstract class JoinBase(joinConf: api.Join,
Some(rightDfWithDerivations)
}

def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): DataFrame
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, bootstrapInfo: BootstrapInfo): Option[DataFrame]

def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): DataFrame = {
def computeJoin(stepDays: Option[Int] = None, overrideStartPartition: Option[String] = None): Option[DataFrame] = {

assert(Option(joinConf.metaData.team).nonEmpty,
s"join.metaData.team needs to be set for join ${joinConf.metaData.name}")
Expand Down Expand Up @@ -337,7 +338,7 @@ abstract class JoinBase(joinConf: api.Join,
def finalResult: DataFrame = tableUtils.sql(rangeToFill.genScanQuery(null, outputTable))
if (unfilledRanges.isEmpty) {
logger.info(s"\nThere is no data to compute based on end partition of ${rangeToFill.end}.\n\n Exiting..")
return finalResult
return Some(finalResult)
}

stepDays.foreach(metrics.gauge("step_days", _))
Expand All @@ -358,14 +359,27 @@ abstract class JoinBase(joinConf: api.Join,
leftDf(joinConf, range, tableUtils).map { leftDfInRange =>
if (showDf) leftDfInRange.prettyPrint()
// set autoExpand = true to ensure backward compatibility due to column ordering changes
computeRange(leftDfInRange, range, bootstrapInfo).save(outputTable, tableProps, autoExpand = true)
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length)
logger.info(s"Wrote to table $outputTable, into partitions: ${range.toString} $progress in $elapsedMins mins")
val finalDf = computeRange(leftDfInRange, range, bootstrapInfo)
if (selectedJoinParts.isDefined) {
assert(finalDf.isEmpty, "finalDf should be empty")
logger.info(s"Skipping final join for range: ${range.toString} $progress")
} else {
finalDf.save(outputTable, tableProps, autoExpand = true)
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
metrics.gauge(Metrics.Name.PartitionCount, range.partitions.length)
logger.info(
s"Wrote to table $outputTable, into partitions: ${range.toString} $progress in $elapsedMins mins")
}
}
}
logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges")
finalResult
if (selectedJoinParts.isDefined) {
logger.info(s"Completed join parts: ${selectedJoinParts.get.mkString(", ")}")
logger.info(s"Skipping final join...")
None
} else {
logger.info(s"Wrote to table $outputTable, into partitions: $unfilledRanges")
Some(finalResult)
}
}
}

0 comments on commit e2873ce

Please sign in to comment.