Skip to content

Commit

Permalink
d2
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Oct 27, 2024
1 parent 27e228d commit fa290b7
Show file tree
Hide file tree
Showing 28 changed files with 102 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ private object CHRuleApi {
injector.injectParser(
(spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface))
injector.injectResolutionRule(
spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf))
spark => new RewriteToDateExpresstionRule(spark))
injector.injectResolutionRule(
spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf))
spark => new RewriteDateTimestampComparisonRule(spark))
injector.injectOptimizerRule(
spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf))
spark => new CommonSubexpressionEliminateRule(spark))
injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.injectOptimizerRule(_ => EqualToRewrite)
Expand Down Expand Up @@ -89,7 +89,7 @@ private object CHRuleApi {
injector.injectTransform(
c =>
intercept(
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)))
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)(c.session)))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))

// Gluten columnar: Fallback policies.
Expand All @@ -101,14 +101,14 @@ private object CHRuleApi {
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => intercept(each(c.session))))
injector.injectPost(c => ColumnarCollapseTransformStages(c.conf))
injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.injectTransform(
c =>
intercept(SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)))
intercept(SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.session)))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.injectFinal(_ => RemoveFallbackTagRule())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,19 @@ import scala.collection.mutable
// 2. append two options to spark config
// --conf spark.sql.planChangeLog.level=error
// --conf spark.sql.planChangeLog.batches=all
class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
class CommonSubexpressionEliminateRule(spark: SparkSession)
extends Rule[LogicalPlan]
with Logging {

private val glutenConf = new GlutenConfig(spark)

private var lastPlan: LogicalPlan = null

override def apply(plan: LogicalPlan): LogicalPlan = {
val newPlan =
if (
plan.resolved && GlutenConfig.getConf.enableGluten
&& GlutenConfig.getConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan)
plan.resolved && glutenConf.enableGluten
&& glutenConf.enableCommonSubexpressionEliminate && !plan.fastEquals(lastPlan)
) {
lastPlan = plan
visitPlan(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ import scala.util.control.Breaks.{break, breakable}
// see each other during transformation. In order to prevent BroadcastExec being transformed
// to columnar while BHJ fallbacks, BroadcastExec need to be tagged not transformable when applying
// queryStagePrepRules.
case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extends Rule[SparkPlan] {
case class FallbackBroadcastHashJoinPrepQueryStage(spark: SparkSession) extends Rule[SparkPlan] {

private val glutenConf: GlutenConfig = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
val columnarConf: GlutenConfig = GlutenConfig.getConf
plan.foreach {
case bhj: BroadcastHashJoinExec =>
val buildSidePlan = bhj.buildSide match {
Expand All @@ -53,8 +55,8 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
val isTransformable =
if (
!columnarConf.enableColumnarBroadcastExchange ||
!columnarConf.enableColumnarBroadcastJoin
!glutenConf.enableColumnarBroadcastExchange ||
!glutenConf.enableColumnarBroadcastJoin
) {
ValidationResult.failed(
"columnar broadcast exchange is disabled or " +
Expand Down Expand Up @@ -107,8 +109,8 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
val isTransformable =
if (
!GlutenConfig.getConf.enableColumnarBroadcastExchange ||
!GlutenConfig.getConf.enableColumnarBroadcastJoin
!glutenConf.enableColumnarBroadcastExchange ||
!glutenConf.enableColumnarBroadcastJoin
) {
ValidationResult.failed(
"columnar broadcast exchange is disabled or " +
Expand Down Expand Up @@ -146,13 +148,14 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend
// columnar rules.
case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPlan] {

private val columnarConf: GlutenConfig = new GlutenConfig(session)

private val enableColumnarBroadcastJoin: Boolean =
GlutenConfig.getConf.enableColumnarBroadcastJoin &&
GlutenConfig.getConf.enableColumnarBroadcastExchange
columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange

private val enableColumnarBroadcastNestedLoopJoin: Boolean =
GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
GlutenConfig.getConf.enableColumnarBroadcastExchange
columnarConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
columnarConf.enableColumnarBroadcastExchange

override def apply(plan: SparkPlan): SparkPlan = {
plan.foreachUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
* Note: this rule must be applied before the `PullOutPreProject` rule, because the
* `PullOutPreProject` rule will modify the attributes in some cases.
*/
case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
case class MergeTwoPhasesHashBaseAggregate(spark: SparkSession)
extends Rule[SparkPlan]
with Logging {

val columnarConf: GlutenConfig = GlutenConfig.getConf
val scanOnly: Boolean = columnarConf.enableScanOnly
val enableColumnarHashAgg: Boolean = !scanOnly && columnarConf.enableColumnarHashAgg
val replaceSortAggWithHashAgg: Boolean = GlutenConfig.getConf.forceToUseHashAgg
private val glutenConf: GlutenConfig = new GlutenConfig(spark)

private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = {
// TODO: now it can not support to merge agg which there are the filters in the aggregate exprs.
Expand All @@ -59,7 +56,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!enableColumnarHashAgg) {
if (!(!glutenConf.enableScanOnly && glutenConf.enableColumnarHashAgg)) {
plan
} else {
plan.transformDown {
Expand Down Expand Up @@ -111,7 +108,7 @@ case class MergeTwoPhasesHashBaseAggregate(session: SparkSession)
_,
resultExpressions,
child: SortAggregateExec)
if replaceSortAggWithHashAgg && !isStreaming && isPartialAgg(child, sortAgg) =>
if glutenConf.forceToUseHashAgg && !isStreaming && isPartialAgg(child, sortAgg) =>
// convert to complete mode aggregate expressions
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
sortAgg.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
// This rule try to make the filter condition into integer comparison, which is more efficient.
// The above example will be rewritten into
// select * from table where to_unixtime('2023-11-02', 'yyyy-MM-dd') >= unix_timestamp
class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf)
class RewriteDateTimestampComparisonRule(spark: SparkSession)
extends Rule[LogicalPlan]
with Logging {

Expand All @@ -54,11 +54,11 @@ class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf)
"yyyy"
)

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (
plan.resolved &&
GlutenConfig.getConf.enableGluten &&
GlutenConfig.getConf.enableRewriteDateTimestampComparison
plan.resolved && glutenConf.enableGluten && glutenConf.enableRewriteDateTimestampComparison
) {
visitPlan(plan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ import org.apache.spark.sql.types._
// Under ch backend, the StringType can be directly converted into DateType,
// and the functions `from_unixtime` and `unix_timestamp` can be optimized here.
// Optimized result is `to_date(stringType)`
class RewriteToDateExpresstionRule(session: SparkSession, conf: SQLConf)
class RewriteToDateExpresstionRule(spark: SparkSession)
extends Rule[LogicalPlan]
with Logging {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (
plan.resolved &&
GlutenConfig.getConf.enableGluten &&
GlutenConfig.getConf.enableCHRewriteDateConversion
plan.resolved && glutenConf.enableGluten && glutenConf.enableCHRewriteDateConversion
) {
visitPlan(plan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ import org.apache.spark.sql.types._
* @param spark
*/
case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a: Aggregate =>
a.transformExpressions {
case avgExpr @ AggregateExpression(avg: Average, _, _, _, _)
if GlutenConfig.getConf.enableCastAvgAggregateFunction &&
GlutenConfig.getConf.enableColumnarHashAgg &&
if glutenConf.enableCastAvgAggregateFunction &&
glutenConf.enableColumnarHashAgg &&
!avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) =>
AggregateExpression(
avg.copy(child = Cast(avg.child, DoubleType)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ private object VeloxRuleApi {
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => each(c.session)))
injector.injectPost(c => ColumnarCollapseTransformStages(c.conf))
injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session))
injector.injectFinal(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.injectFinal(_ => RemoveFallbackTagRule())
}

Expand Down Expand Up @@ -116,9 +116,9 @@ private object VeloxRuleApi {
SparkShimLoader.getSparkShims
.getExtendedColumnarPostRules()
.foreach(each => injector.inject(c => each(c.session)))
injector.inject(c => ColumnarCollapseTransformStages(c.conf))
injector.inject(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.inject(c => GlutenFallbackReporter(c.conf, c.session))
injector.inject(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.inject(_ => RemoveFallbackTagRule())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
}

override protected def doValidateInternal(): ValidationResult = {
if (!GlutenConfig.getConf.enableColumnarPartialProject) {
if (!glutenConf.enableColumnarPartialProject) {
return ValidationResult.failed("Config disable this feature")
}
if (UDFAttrNotExists) {
Expand Down Expand Up @@ -159,7 +159,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
if (
ExpressionUtils.hasComplexExpressions(
original,
GlutenConfig.getConf.fallbackExpressionsThreshold)
glutenConf.fallbackExpressionsThreshold)
) {
return ValidationResult.failed("Fallback by complex expression")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
val convertTime = longMetric("convertTime")
val numRows = GlutenConfig.getConf.maxBatchSize
val numRows = glutenConf.maxBatchSize
// This avoids calling `schema` in the RDD closure, so that we don't need to include the entire
// plan (this) in the closure.
val localSchema = schema
Expand All @@ -68,7 +68,7 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
val convertTime = longMetric("convertTime")
val numRows = GlutenConfig.getConf.maxBatchSize
val numRows = glutenConf.maxBatchSize
val mode = BroadcastUtils.getBroadcastMode(outputPartitioning)
val relation = child.executeBroadcast()
BroadcastUtils.sparkToVeloxUnsafe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

case class BloomFilterMightContainJointRewriteRule(spark: SparkSession) extends Rule[SparkPlan] {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableNativeBloomFilter) {
if (!glutenConf.enableNativeBloomFilter) {
return plan
}
val out = plan.transformWithSubqueries {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
* To transform regular aggregation to intermediate aggregation that internally enables
* optimizations such as flushing and abandoning.
*/
case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] {
case class FlushableHashAggregateRule(spark: SparkSession) extends Rule[SparkPlan] {
import FlushableHashAggregateRule._

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: SparkPlan): SparkPlan = {
if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) {
if (!glutenConf.enableVeloxFlushablePartialAggregation) {
return plan
}
plan.transformUpWithPruning(_.containsPattern(EXCHANGE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXP
import org.apache.spark.sql.types._

case class HLLRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {

private val glutenConf = new GlutenConfig(spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) {
case a: Aggregate =>
a.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) {
case aggExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _, _, _)
if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction &&
GlutenConfig.getConf.enableColumnarHashAgg &&
if glutenConf.enableNativeHyperLogLogAggregateFunction &&
glutenConf.enableColumnarHashAgg &&
isSupportedDataType(hll.child.dataType) =>
val hllAdapter = HLLAdapter(
hll.child,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ object ColumnarRuleApplier {
val session: SparkSession,
val ac: AdaptiveContext,
val outputsColumnar: Boolean) {
val conf: GlutenConfig = {
new GlutenConfig(Some(session))
}
val glutenConf: GlutenConfig = new GlutenConfig(session)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class GlutenInjector private[injector] (control: InjectorControl) {
control.disabler().wrapColumnarRule(s => new GlutenColumnarRule(s, applier)))
}

private def applier(session: SparkSession): ColumnarRuleApplier = {
val conf = new GlutenConfig(Some(session))
private def applier(spark: SparkSession): ColumnarRuleApplier = {
val conf = new GlutenConfig(spark)
if (conf.enableRas) {
return ras.createApplier(session)
return ras.createApplier(spark)
}
legacy.createApplier(session)
legacy.createApplier(spark)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
Expand Down Expand Up @@ -52,7 +51,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
val fileFormat: ReadFileFormat

def getRootFilePaths: Seq[String] = {
if (GlutenConfig.getConf.scanFileSchemeValidationEnabled) {
if (glutenConf.scanFileSchemeValidationEnabled) {
getRootPathsInternal
} else {
Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
}

override protected def doValidateInternal(): ValidationResult = {
if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) {
if (!glutenConf.broadcastNestedLoopJoinTransformerTransformerEnabled) {
return ValidationResult.failed(
s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
val sparkConf: SparkConf = sparkContext.getConf
val serializableHadoopConf: SerializableConfiguration = new SerializableConfiguration(
sparkContext.hadoopConfiguration)
val numaBindingInfo: GlutenNumaBindingInfo = GlutenConfig.getConf.numaBindingInfo
val numaBindingInfo: GlutenNumaBindingInfo = glutenConf.numaBindingInfo

@transient
private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None
Expand Down Expand Up @@ -277,7 +277,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
}(
t =>
logOnLevel(
GlutenConfig.getConf.substraitPlanLogLevel,
glutenConf.substraitPlanLogLevel,
s"$nodeName generating the substrait plan took: $t ms."))
val inputRDDs = new ColumnarInputRDDsWrapper(columnarInputRDDs)
// Check if BatchScan exists.
Expand Down
Loading

0 comments on commit fa290b7

Please sign in to comment.