Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-5016][CH] Fix exchange fallback in simple aggregation sql if spark.gluten.sql.columnar.preferColumnar=false #5042

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit
|""".stripMargin) { _ => }
assert(result(0).getLong(0) == 227302L)
}

test("test 'GLUTEN-5016'") {
withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) {
val sql =
"""
|SELECT
| sum(l_quantity) AS sum_qty
|FROM
| lineitem
|WHERE
| l_shipdate <= date'1998-09-02'
|""".stripMargin
runSql(sql, noFallBack = true) { _ => }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ object MiscColumnarRules {

// Aggregation transformation.
private case class AggregationTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case agg: HashAggregateExec =>
genHashAggregateExec(agg)
case other => other
}

/**
Expand Down Expand Up @@ -105,13 +108,144 @@ object MiscColumnarRules {
}
}

// Exchange transformation.
private case class ExchangeTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case plan: ShuffleExchangeExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val child = plan.child
if (
(child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) &&
BackendsApiManager.getSettings.supportColumnarShuffleExec()
) {
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
} else {
plan.withNewChildren(Seq(child))
}
case plan: BroadcastExchangeExec =>
val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarBroadcastExchangeExec(plan.mode, child)
case other => other
}
}

// Join transformation.
private case class JoinTransformRule() extends Rule[SparkPlan] with LogLevelUtil {

/**
* Get the build side supported by the execution of vanilla Spark.
*
* @param plan
* : shuffled hash join plan
* @return
* the supported build side
*/
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
// Because we manually removed the build side limitation for LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join fallback into vanilla
// Spark for execution.
return ShuffledHashJoinExec(
shj.leftKeys,
shj.rightKeys,
shj.joinType,
getSparkSupportedBuildSide(shj),
shj.condition,
shj.left,
shj.right,
shj.isSkewJoin
)
} else {
return shj
}
case p =>
return p
}
}
plan match {
case plan: ShuffledHashJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
SortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case other => other
}
}

}

// Filter transformation.
private case class FilterTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
private val replace = new ReplaceSingleNode()

override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case filter: FilterExec =>
genFilterExec(filter)
case other => other
}

/**
Expand Down Expand Up @@ -155,39 +289,18 @@ object MiscColumnarRules {
private case class RegularTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
private val replace = new ReplaceSingleNode()

override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case plan => replace.replaceWithTransformerPlan(plan)
}
override def apply(plan: SparkPlan): SparkPlan = replace.replaceWithTransformerPlan(plan)
}

// Utility to replace single node within transformed Gluten node.
// Children will be preserved as they are as children of the output node.
class ReplaceSingleNode() extends LogLevelUtil with Logging {
private val columnarConf: GlutenConfig = GlutenConfig.getConf

def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
val plan = p
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
// Because we manually removed the build side limitation for LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join fallback into vanilla
// Spark for execution.
return ShuffledHashJoinExec(
shj.leftKeys,
shj.rightKeys,
shj.joinType,
getSparkSupportedBuildSide(shj),
shj.condition,
shj.left,
shj.right,
shj.isSkewJoin
)
} else {
return shj
}
case plan: BatchScanExec =>
return applyScanNotTransformable(plan)
case plan: FileSourceScanExec =>
Expand Down Expand Up @@ -283,75 +396,6 @@ object MiscColumnarRules {
plan.projectList,
child,
offset)
case plan: ShuffleExchangeExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val child = plan.child
if (
(child.supportsColumnar || columnarConf.enablePreferColumnar) &&
BackendsApiManager.getSettings.supportColumnarShuffleExec()
) {
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
} else {
plan.withNewChildren(Seq(child))
}
case plan: ShuffledHashJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
SortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastExchangeExec =>
val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarBroadcastExchangeExec(plan.mode, child)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case plan: WindowExec =>
WindowExecTransformer(
plan.windowExpression,
Expand Down Expand Up @@ -389,22 +433,6 @@ object MiscColumnarRules {
}
}

/**
* Get the build side supported by the execution of vanilla Spark.
*
* @param plan
* : shuffled hash join plan
* @return
* the supported build side
*/
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}

private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
case plan: FileSourceScanExec =>
val newPartitionFilters =
Expand Down Expand Up @@ -489,18 +517,23 @@ object MiscColumnarRules {
case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil {
import TransformPreOverrides._

private val subRules = List(
FilterTransformRule(),
private val topdownRules = List(
FilterTransformRule()
)
private val bottomupRules = List(
RegularTransformRule(),
AggregationTransformRule()
AggregationTransformRule(),
ExchangeTransformRule(),
JoinTransformRule()
)

@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()

def apply(plan: SparkPlan): SparkPlan = {
val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p))
planChangeLogger.logRule(ruleName, plan, newPlan)
newPlan
val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => rule(p) })
val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => rule(p) })
planChangeLogger.logRule(ruleName, plan, plan1)
plan1
}
}

Expand Down
Loading