Skip to content

Commit

Permalink
[GLUTEN-7690][CORE][CH][VL] GlutenConfig should support runtime confi…
Browse files Browse the repository at this point in the history
…guration changes
  • Loading branch information
beliefer committed Nov 4, 2024
1 parent 48d312a commit 7fca07c
Show file tree
Hide file tree
Showing 27 changed files with 121 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,48 +66,48 @@ private object CHRuleApi {
// Gluten columnar: Transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(c => FallbackOnANSIMode.apply(c.spark))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.spark))
injector.injectTransform(_ => RewriteSubqueryBroadcast())
injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session))
injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))
injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.spark))
injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.spark))
injector.injectTransform(_ => intercept(RewriteSparkPlanRulesManager()))
injector.injectTransform(_ => intercept(AddFallbackTagRule()))
injector.injectTransform(_ => intercept(TransformPreOverrides()))
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => intercept(RewriteTransformer.apply(c.session)))
injector.injectTransform(c => intercept(RewriteTransformer.apply(c.spark)))
injector.injectTransform(_ => PushDownFilterToScan)
injector.injectTransform(_ => PushDownInputFileExpression.PostOffload)
injector.injectTransform(_ => EnsureLocalSortRequirements)
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session))
injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.session))
injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.spark))
injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.spark))
injector.injectTransform(
c =>
intercept(
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)(
c.session)))
c.spark)))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))

// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))

// Gluten columnar: Post rules.
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.injectPost(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => intercept(each(c.session))))
.foreach(each => injector.injectPost(c => intercept(each(c.spark))))
injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.injectTransform(
c =>
intercept(
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.session)))
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.spark)))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.spark))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.spark))
injector.injectFinal(_ => RemoveFallbackTagRule())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ import scala.collection.mutable
// --conf spark.sql.planChangeLog.batches=all
class CommonSubexpressionEliminateRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging {

private val glutenConf = new GlutenConfig(spark)

private var lastPlan: LogicalPlan = null

override def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan =
if (
plan.resolved && GlutenConfig.getConf.enableGluten
&& GlutenConfig.getConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan)
plan.resolved && glutenConf.enableGluten
&& glutenConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan)
) {
lastPlan = plan
visitPlan(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ import scala.util.control.Breaks.{break, breakable}
// see each other during transformation. In order to prevent BroadcastExec being transformed
// to columnar while BHJ fallbacks, BroadcastExec need to be tagged not transformable when applying
// queryStagePrepRules.
case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extends Rule[SparkPlan] {
case class FallbackBroadcastHashJoinPrepQueryStage(spark: SparkSession) extends Rule[SparkPlan] {

private val glutenConf: GlutenConfig = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
val glutenConf: GlutenConfig = GlutenConfig.getConf
plan.foreach {
case bhj: BroadcastHashJoinExec =>
val buildSidePlan = bhj.buildSide match {
Expand Down Expand Up @@ -144,15 +146,17 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend

// For similar purpose with FallbackBroadcastHashJoinPrepQueryStage, executed during applying
// columnar rules.
case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPlan] {
case class FallbackBroadcastHashJoin(spark: SparkSession) extends Rule[SparkPlan] {

private val glutenConf: GlutenConfig = new GlutenConfig(spark)

private val enableColumnarBroadcastJoin: Boolean =
GlutenConfig.getConf.enableColumnarBroadcastJoin &&
GlutenConfig.getConf.enableColumnarBroadcastExchange
glutenConf.enableColumnarBroadcastJoin &&
glutenConf.enableColumnarBroadcastExchange

private val enableColumnarBroadcastNestedLoopJoin: Boolean =
GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
GlutenConfig.getConf.enableColumnarBroadcastExchange
glutenConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
glutenConf.enableColumnarBroadcastExchange

override def apply(plan: SparkPlan): SparkPlan = {
plan.foreachUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
* Note: this rule must be applied before the `PullOutPreProject` rule, because the
* `PullOutPreProject` rule will modify the attributes in some cases.
*/
case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
case class MergeTwoPhasesHashBaseAggregate(spark: SparkSession)
extends Rule[SparkPlan]
with Logging {

val glutenConf: GlutenConfig = GlutenConfig.getConf
val scanOnly: Boolean = glutenConf.enableScanOnly
val enableColumnarHashAgg: Boolean = !scanOnly && glutenConf.enableColumnarHashAgg
val replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg
private val glutenConf: GlutenConfig = new GlutenConfig(spark)

private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = {
// TODO: now it can not support to merge agg which there are the filters in the aggregate exprs.
Expand All @@ -59,7 +56,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!enableColumnarHashAgg) {
if (glutenConf.enableScanOnly || !glutenConf.enableColumnarHashAgg) {
plan
} else {
plan.transformDown {
Expand Down Expand Up @@ -111,7 +108,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
_,
resultExpressions,
child: SortAggregateExec)
if replaceSortAggWithHashAgg && !isStreaming && isPartialAgg(child, sortAgg) =>
if glutenConf.forceToUseHashAgg && !isStreaming && isPartialAgg(child, sortAgg) =>
// convert to complete mode aggregate expressions
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
sortAgg.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ class RewriteDateTimestampComparisonRule(spark: SparkSession)
"yyyy"
)

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (
plan.resolved &&
GlutenConfig.getConf.enableGluten &&
GlutenConfig.getConf.enableRewriteDateTimestampComparison
plan.resolved && glutenConf.enableGluten && glutenConf.enableRewriteDateTimestampComparison
) {
visitPlan(plan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ import org.apache.spark.sql.types._
// Optimized result is `to_date(stringType)`
class RewriteToDateExpresstionRule(spark: SparkSession) extends Rule[LogicalPlan] with Logging {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (
plan.resolved &&
GlutenConfig.getConf.enableGluten &&
GlutenConfig.getConf.enableCHRewriteDateConversion
) {
if (plan.resolved && glutenConf.enableGluten && glutenConf.enableCHRewriteDateConversion) {
visitPlan(plan)
} else {
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ import org.apache.spark.sql.types._
* @param spark
*/
case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a: Aggregate =>
a.transformExpressions {
case avgExpr @ AggregateExpression(avg: Average, _, _, _, _)
if GlutenConfig.getConf.enableCastAvgAggregateFunction &&
GlutenConfig.getConf.enableColumnarHashAgg &&
if glutenConf.enableCastAvgAggregateFunction &&
glutenConf.enableColumnarHashAgg &&
!avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) =>
AggregateExpression(
avg.copy(child = Cast(avg.child, DoubleType)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,73 +52,73 @@ private object VeloxRuleApi {
// Gluten columnar: Transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(c => FallbackOnANSIMode.apply(c.spark))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.spark))
injector.injectTransform(_ => RewriteSubqueryBroadcast())
injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.session))
injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session))
injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.spark))
injector.injectTransform(c => ArrowScanReplaceRule.apply(c.spark))
injector.injectTransform(_ => RewriteSparkPlanRulesManager())
injector.injectTransform(_ => AddFallbackTagRule())
injector.injectTransform(_ => TransformPreOverrides())
injector.injectTransform(c => PartialProjectRule.apply(c.session))
injector.injectTransform(c => PartialProjectRule.apply(c.spark))
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
injector.injectTransform(c => RewriteTransformer.apply(c.spark))
injector.injectTransform(_ => PushDownFilterToScan)
injector.injectTransform(_ => PushDownInputFileExpression.PostOffload)
injector.injectTransform(_ => EnsureLocalSortRequirements)
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session))
injector.injectTransform(c => FlushableHashAggregateRule.apply(c.spark))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))

// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))

// Gluten columnar: Post rules.
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.injectPost(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => each(c.session)))
.foreach(each => injector.injectPost(c => each(c.spark)))
injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.spark))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.spark))
injector.injectFinal(_ => RemoveFallbackTagRule())
}

def injectRas(injector: RasInjector): Unit = {
// Gluten RAS: Pre rules.
injector.inject(_ => RemoveTransitions)
injector.inject(_ => PushDownInputFileExpression.PreOffload)
injector.inject(c => FallbackOnANSIMode.apply(c.session))
injector.inject(c => FallbackOnANSIMode.apply(c.spark))
injector.inject(_ => RewriteSubqueryBroadcast())
injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.session))
injector.inject(c => ArrowScanReplaceRule.apply(c.session))
injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.spark))
injector.inject(c => ArrowScanReplaceRule.apply(c.spark))

// Gluten RAS: The RAS rule.
injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar))
injector.inject(c => EnumeratedTransform(c.spark, c.outputsColumnar))

// Gluten RAS: Post rules.
injector.inject(_ => RemoveTransitions)
injector.inject(c => PartialProjectRule.apply(c.session))
injector.inject(c => PartialProjectRule.apply(c.spark))
injector.inject(_ => RemoveNativeWriteFilesSortAndProject())
injector.inject(c => RewriteTransformer.apply(c.session))
injector.inject(c => RewriteTransformer.apply(c.spark))
injector.inject(_ => PushDownFilterToScan)
injector.inject(_ => PushDownInputFileExpression.PostOffload)
injector.inject(_ => EnsureLocalSortRequirements)
injector.inject(_ => EliminateLocalSort)
injector.inject(_ => CollapseProjectExecTransformer)
injector.inject(c => FlushableHashAggregateRule.apply(c.session))
injector.inject(c => FlushableHashAggregateRule.apply(c.spark))
injector.inject(c => InsertTransitions(c.outputsColumnar))
injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
injector.inject(c => RemoveTopmostColumnarToRow(c.spark, c.ac.isAdaptiveContext()))
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.inject(c => each(c.session)))
.foreach(each => injector.inject(c => each(c.spark)))
injector.inject(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.spark))
injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.spark))
injector.inject(_ => RemoveFallbackTagRule())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils}
import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
Expand Down Expand Up @@ -134,7 +133,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
}

override protected def doValidateInternal(): ValidationResult = {
if (!GlutenConfig.getConf.enableColumnarPartialProject) {
if (!glutenConf.enableColumnarPartialProject) {
return ValidationResult.failed("Config disable this feature")
}
if (UDFAttrNotExists) {
Expand All @@ -159,11 +158,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
if (!original.projectList.forall(validateExpression(_))) {
return ValidationResult.failed("Contains expression not supported")
}
if (
ExpressionUtils.hasComplexExpressions(
original,
GlutenConfig.getConf.fallbackExpressionsThreshold)
) {
if (ExpressionUtils.hasComplexExpressions(original, glutenConf.fallbackExpressionsThreshold)) {
return ValidationResult.failed("Fallback by complex expression")
}
ValidationResult.succeeded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
Expand Down Expand Up @@ -48,7 +47,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
val convertTime = longMetric("convertTime")
val numRows = GlutenConfig.getConf.maxBatchSize
val numRows = glutenConf.maxBatchSize
// This avoids calling `schema` in the RDD closure, so that we don't need to include the entire
// plan (this) in the closure.
val localSchema = schema
Expand All @@ -68,7 +67,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
val convertTime = longMetric("convertTime")
val numRows = GlutenConfig.getConf.maxBatchSize
val numRows = glutenConf.maxBatchSize
val mode = BroadcastUtils.getBroadcastMode(outputPartitioning)
val relation = child.executeBroadcast()
BroadcastUtils.sparkToVeloxUnsafe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) extends Rule[SparkPlan] {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableNativeBloomFilter) {
if (!glutenConf.enableNativeBloomFilter) {
return plan
}
val out = plan.transformWithSubqueries {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
* To transform regular aggregation to intermediate aggregation that internally enables
* optimizations such as flushing and abandoning.
*/
case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] {
case class FlushableHashAggregateRule(spark: SparkSession) extends Rule[SparkPlan] {
import FlushableHashAggregateRule._

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) {
if (!glutenConf.enableVeloxFlushablePartialAggregation) {
return plan
}
plan.transformUpWithPruning(_.containsPattern(EXCHANGE)) {
Expand Down
Loading

0 comments on commit 7fca07c

Please sign in to comment.