Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Only trigger Comet Final aggregation on Comet partial aggregation #264

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -319,26 +320,42 @@ class CometSparkSessionExtensions
}

case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the aggregates
// have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
case None =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
op
} else {
val sparkFinalMode = {
!modes.isEmpty && modes.head == Final && findPartialAgg(child).isEmpty
}

if (sparkFinalMode) {
op
} else {
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the
// aggregates have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
case None =>
op
}
}
}

case op: ShuffledHashJoinExec
Expand Down Expand Up @@ -596,6 +613,20 @@ class CometSparkSessionExtensions
}
}
}

/**
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate
* with partial mode, it will return None.
*/
def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
plan.collectFirst {
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
case a: AQEShuffleReadExec => findPartialAgg(a.child)
case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
}.flatten
}
}

// This rule is responsible for eliminating redundant transitions between row-based and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

test("Only trigger Comet Final aggregation on Comet partial aggregation") {
withTempView("lowerCaseData") {
lowerCaseData.createOrReplaceTempView("lowerCaseData")
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
val df = sql("SELECT LAST(n) FROM lowerCaseData")
checkSparkAnswer(df)
}
}
}

test(
"Average expression in Comet Final should handle " +
"all null inputs from partial Spark aggregation") {
Expand Down
Loading