diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala index 72a04f02c1ecb..42783babd124c 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala @@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit |""".stripMargin) { _ => } assert(result(0).getLong(0) == 227302L) } + + test("test 'GLUTEN-5016'") { + withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) { + val sql = + """ + |SELECT + | sum(l_quantity) AS sum_qty + |FROM + | lineitem + |WHERE + | l_shipdate <= date'1998-09-02' + |""".stripMargin + runSql(sql, noFallBack = true) { _ => } + } + } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala index 4fbadb0b50ac7..d21956bfcc979 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala @@ -49,9 +49,12 @@ object MiscColumnarRules { // Aggregation transformation. private case class AggregationTransformRule() extends Rule[SparkPlan] with LogLevelUtil { - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + override def apply(plan: SparkPlan): SparkPlan = plan match { + case plan if TransformHints.isNotTransformable(plan) => + plan case agg: HashAggregateExec => genHashAggregateExec(agg) + case other => other } /** @@ -105,13 +108,144 @@ object MiscColumnarRules { } } + // Exchange transformation. + private case class ExchangeTransformRule() extends Rule[SparkPlan] with LogLevelUtil { + override def apply(plan: SparkPlan): SparkPlan = plan match { + case plan if TransformHints.isNotTransformable(plan) => + plan + case plan: ShuffleExchangeExec => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + val child = plan.child + if ( + (child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) && + BackendsApiManager.getSettings.supportColumnarShuffleExec() + ) { + BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child) + } else { + plan.withNewChildren(Seq(child)) + } + case plan: BroadcastExchangeExec => + val child = plan.child + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastExchangeExec(plan.mode, child) + case other => other + } + } + + // Join transformation. + private case class JoinTransformRule() extends Rule[SparkPlan] with LogLevelUtil { + + /** + * Get the build side supported by the execution of vanilla Spark. + * + * @param plan + * : shuffled hash join plan + * @return + * the supported build side + */ + private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = { + plan.joinType match { + case LeftOuter | LeftSemi => BuildRight + case RightOuter => BuildLeft + case _ => plan.buildSide + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + if (TransformHints.isNotTransformable(plan)) { + logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.") + plan match { + case shj: ShuffledHashJoinExec => + if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) { + // Because we manually removed the build side limitation for LeftOuter, LeftSemi and + // RightOuter, need to change the build side back if this join fallback into vanilla + // Spark for execution. + return ShuffledHashJoinExec( + shj.leftKeys, + shj.rightKeys, + shj.joinType, + getSparkSupportedBuildSide(shj), + shj.condition, + shj.left, + shj.right, + shj.isSkewJoin + ) + } else { + return shj + } + case p => + return p + } + } + plan match { + case plan: ShuffledHashJoinExec => + val left = plan.left + val right = plan.right + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + BackendsApiManager.getSparkPlanExecApiInstance + .genShuffledHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: SortMergeJoinExec => + val left = plan.left + val right = plan.right + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + SortMergeJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: BroadcastHashJoinExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + isNullAwareAntiJoin = plan.isNullAwareAntiJoin) + case plan: CartesianProductExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genCartesianProductExecTransformer(left, right, plan.condition) + case plan: BroadcastNestedLoopJoinExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + left, + right, + plan.buildSide, + plan.joinType, + plan.condition) + case other => other + } + } + + } + // Filter transformation. private case class FilterTransformRule() extends Rule[SparkPlan] with LogLevelUtil { private val replace = new ReplaceSingleNode() - override def apply(plan: SparkPlan): SparkPlan = plan.transformDown { + override def apply(plan: SparkPlan): SparkPlan = plan match { case filter: FilterExec => genFilterExec(filter) + case other => other } /** @@ -155,39 +289,18 @@ object MiscColumnarRules { private case class RegularTransformRule() extends Rule[SparkPlan] with LogLevelUtil { private val replace = new ReplaceSingleNode() - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case plan => replace.replaceWithTransformerPlan(plan) - } + override def apply(plan: SparkPlan): SparkPlan = replace.replaceWithTransformerPlan(plan) } // Utility to replace single node within transformed Gluten node. // Children will be preserved as they are as children of the output node. class ReplaceSingleNode() extends LogLevelUtil with Logging { - private val columnarConf: GlutenConfig = GlutenConfig.getConf def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = { val plan = p if (TransformHints.isNotTransformable(plan)) { logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.") plan match { - case shj: ShuffledHashJoinExec => - if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) { - // Because we manually removed the build side limitation for LeftOuter, LeftSemi and - // RightOuter, need to change the build side back if this join fallback into vanilla - // Spark for execution. - return ShuffledHashJoinExec( - shj.leftKeys, - shj.rightKeys, - shj.joinType, - getSparkSupportedBuildSide(shj), - shj.condition, - shj.left, - shj.right, - shj.isSkewJoin - ) - } else { - return shj - } case plan: BatchScanExec => return applyScanNotTransformable(plan) case plan: FileSourceScanExec => @@ -283,75 +396,6 @@ object MiscColumnarRules { plan.projectList, child, offset) - case plan: ShuffleExchangeExec => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - val child = plan.child - if ( - (child.supportsColumnar || columnarConf.enablePreferColumnar) && - BackendsApiManager.getSettings.supportColumnarShuffleExec() - ) { - BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child) - } else { - plan.withNewChildren(Seq(child)) - } - case plan: ShuffledHashJoinExec => - val left = plan.left - val right = plan.right - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - BackendsApiManager.getSparkPlanExecApiInstance - .genShuffledHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: SortMergeJoinExec => - val left = plan.left - val right = plan.right - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - SortMergeJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: BroadcastExchangeExec => - val child = plan.child - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarBroadcastExchangeExec(plan.mode, child) - case plan: BroadcastHashJoinExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genBroadcastHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - isNullAwareAntiJoin = plan.isNullAwareAntiJoin) - case plan: CartesianProductExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genCartesianProductExecTransformer(left, right, plan.condition) - case plan: BroadcastNestedLoopJoinExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genBroadcastNestedLoopJoinExecTransformer( - left, - right, - plan.buildSide, - plan.joinType, - plan.condition) case plan: WindowExec => WindowExecTransformer( plan.windowExpression, @@ -389,22 +433,6 @@ object MiscColumnarRules { } } - /** - * Get the build side supported by the execution of vanilla Spark. - * - * @param plan - * : shuffled hash join plan - * @return - * the supported build side - */ - private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = { - plan.joinType match { - case LeftOuter | LeftSemi => BuildRight - case RightOuter => BuildLeft - case _ => plan.buildSide - } - } - private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match { case plan: FileSourceScanExec => val newPartitionFilters = @@ -489,18 +517,23 @@ object MiscColumnarRules { case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil { import TransformPreOverrides._ - private val subRules = List( - FilterTransformRule(), + private val topdownRules = List( + FilterTransformRule() + ) + private val bottomupRules = List( RegularTransformRule(), - AggregationTransformRule() + AggregationTransformRule(), + ExchangeTransformRule(), + JoinTransformRule() ) @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() def apply(plan: SparkPlan): SparkPlan = { - val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p)) - planChangeLogger.logRule(ruleName, plan, newPlan) - newPlan + val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => rule(p) }) + val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => rule(p) }) + planChangeLogger.logRule(ruleName, plan, plan1) + plan1 } }