Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Fix case class Exchange may break AQE issue
Browse files Browse the repository at this point in the history
when AQE enabled, it only accept ShuffleExchangeExec and BroadcastExchangeExec, so if we use case class for ColumnarShuffleExchangeExec and ColumnarBroadcastExchangeExec, it will go exception in AQE
To fix this, we add a shadow class who extends ShuffleExchangeExec or BroadcastExchangeExec while actually call ColumnarShuffleExchangeExec and ColumnarBroadcastExchangeExec implementation

Then we will use supportAdaptive check to see if we should instantiate shadow class or case class to both support DPP and AQE

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
  • Loading branch information
xuechendi committed Jan 13, 2021
1 parent 821ceb9 commit d433318
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 26 deletions.
59 changes: 44 additions & 15 deletions core/src/main/scala/com/intel/oap/ColumnarPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf

case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
val columnarConf = ColumnarPluginConfig.getConf(conf)
var isSupportAdaptive: Boolean = true

def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match {
case RowGuard(child: CustomShuffleReaderExec) =>
Expand Down Expand Up @@ -102,8 +103,11 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
if ((child.supportsColumnar || columnarConf.enablePreferColumnar) && columnarConf.enableColumnarShuffle) {
if (SQLConf.get.adaptiveExecutionEnabled) {
ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.canChangeNumPartitions)
if (isSupportAdaptive) {
new ColumnarShuffleExchangeAdaptor(
plan.outputPartitioning,
child,
plan.canChangeNumPartitions)
} else {
CoalesceBatchesExec(
ColumnarShuffleExchangeExec(
Expand Down Expand Up @@ -133,7 +137,10 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
case plan: BroadcastExchangeExec =>
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarBroadcastExchangeExec(plan.mode, child)
if (isSupportAdaptive)
new ColumnarBroadcastExchangeAdaptor(plan.mode, child)
else
ColumnarBroadcastExchangeExec(plan.mode, child)
case plan: BroadcastHashJoinExec =>
if (columnarConf.enableColumnarBroadcastJoin) {
val left = replaceWithColumnarPlan(plan.left)
Expand Down Expand Up @@ -179,17 +186,17 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {

case plan: CustomShuffleReaderExec if columnarConf.enableColumnarShuffle =>
plan.child match {
case shuffle: ColumnarShuffleExchangeExec =>
case shuffle: ColumnarShuffleExchangeAdaptor =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
CoalesceBatchesExec(
ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs, plan.description))
case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec) =>
case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeAdaptor) =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
CoalesceBatchesExec(
ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs, plan.description))
case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) =>
reused match {
case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) =>
case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeAdaptor) =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
CoalesceBatchesExec(
ColumnarCustomShuffleReaderExec(
Expand Down Expand Up @@ -228,22 +235,20 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
plan.withNewChildren(children)

case p =>
// Here we need to make an exception for operators who use BroadcastExchange
// as one side child, while it is not BroadcastHashedJoin
val children = plan.children.map(replaceWithColumnarPlan)
logDebug(s"Columnar Processing for ${p.getClass} is currently not supported.")
p.withNewChildren(children.map(fallBackBroadcastExchangeOrNot))
}

def fallBackBroadcastQueryStage(curPlan: BroadcastQueryStageExec): BroadcastQueryStageExec = {
curPlan.plan match {
case originalBroadcastPlan: ColumnarBroadcastExchangeExec =>
case originalBroadcastPlan: ColumnarBroadcastExchangeAdaptor =>
BroadcastQueryStageExec(
curPlan.id,
BroadcastExchangeExec(
originalBroadcastPlan.mode,
DataToArrowColumnarExec(originalBroadcastPlan, 1)))
case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) =>
case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeAdaptor) =>
BroadcastQueryStageExec(
curPlan.id,
BroadcastExchangeExec(
Expand All @@ -258,11 +263,15 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
case p: ColumnarBroadcastExchangeExec =>
// aqe is disabled
BroadcastExchangeExec(p.mode, DataToArrowColumnarExec(p, 1))
case p: ColumnarBroadcastExchangeAdaptor =>
// aqe is disabled
BroadcastExchangeExec(p.mode, DataToArrowColumnarExec(p, 1))
case p: BroadcastQueryStageExec =>
// ape is enabled
fallBackBroadcastQueryStage(p)
case other => other
}
def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable }

def apply(plan: SparkPlan): SparkPlan = {
replaceWithColumnarPlan(plan)
Expand All @@ -272,18 +281,16 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {

case class ColumnarPostOverrides(conf: SparkConf) extends Rule[SparkPlan] {
val columnarConf = ColumnarPluginConfig.getConf(conf)
var isSupportAdaptive: Boolean = true

def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match {
case plan: RowToColumnarExec =>
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"ColumnarPostOverrides RowToArrowColumnarExec(${child.getClass})")
RowToArrowColumnarExec(child)
case ColumnarToRowExec(child: ColumnarShuffleExchangeExec)
if SQLConf.get.adaptiveExecutionEnabled && columnarConf.enableColumnarShuffle =>
// When AQE enabled, we need to discard ColumnarToRowExec to avoid extra transactions
// if ColumnarShuffleExchangeExec is the last plan of the query stage.
case ColumnarToRowExec(child: ColumnarShuffleExchangeAdaptor) =>
replaceWithColumnarPlan(child)
case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) =>
case ColumnarToRowExec(child: ColumnarBroadcastExchangeAdaptor) =>
replaceWithColumnarPlan(child)
case ColumnarToRowExec(child: CoalesceBatchesExec) =>
plan.withNewChildren(Seq(replaceWithColumnarPlan(child.child)))
Expand All @@ -292,6 +299,8 @@ case class ColumnarPostOverrides(conf: SparkConf) extends Rule[SparkPlan] {
p.withNewChildren(children)
}

def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable }

def apply(plan: SparkPlan): SparkPlan = {
replaceWithColumnarPlan(plan)
}
Expand All @@ -306,9 +315,28 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit
val preOverrides = ColumnarPreOverrides(conf)
val postOverrides = ColumnarPostOverrides(conf)
val collapseOverrides = ColumnarCollapseCodegenStages(conf)
var isSupportAdaptive: Boolean = true

private def supportAdaptive(plan: SparkPlan): Boolean = {
// TODO migrate dynamic-partition-pruning onto adaptive execution.
// Only QueryStage will have Exchange as Leaf Plan
val isLeafPlanExchange = plan match {
case e: Exchange => true
case other => false
}
isLeafPlanExchange || (sanityCheck(plan) &&
!plan.logicalLink.exists(_.isStreaming) &&
!plan.expressions.exists(_.find(_.isInstanceOf[DynamicPruningSubquery]).isDefined) &&
plan.children.forall(supportAdaptive))
}

private def sanityCheck(plan: SparkPlan): Boolean =
plan.logicalLink.isDefined

override def preColumnarTransitions: Rule[SparkPlan] = plan => {
if (columnarEnabled) {
isSupportAdaptive = supportAdaptive(plan)
preOverrides.setAdaptiveSupport(isSupportAdaptive)
preOverrides(rowGuardOverrides(plan))
} else {
plan
Expand All @@ -317,6 +345,7 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit

override def postColumnarTransitions: Rule[SparkPlan] = plan => {
if (columnarEnabled) {
postOverrides.setAdaptiveSupport(isSupportAdaptive)
val tmpPlan = postOverrides(plan)
collapseOverrides(tmpPlan)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
s"ColumnarBroadcastExchange only support HashRelationMode")
}
@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
lazy val promise = Promise[broadcast.Broadcast[Any]]()

@transient
lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
Expand Down Expand Up @@ -211,12 +211,12 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}
}

override protected def doPrepare(): Unit = {
override def doPrepare(): Unit = {
// Materialize the future.
relationFuture
}

override protected def doExecute(): RDD[InternalRow] = {
override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(
"BroadcastExchange does not support the execute() code path.")
}
Expand All @@ -240,3 +240,52 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}

}

class ColumnarBroadcastExchangeAdaptor(mode: BroadcastMode, child: SparkPlan)
extends BroadcastExchangeExec(mode, child) {
val plan: ColumnarBroadcastExchangeExec = new ColumnarBroadcastExchangeExec(mode, child)

override def supportsColumnar = true
override def nodeName: String = plan.nodeName
override def output: Seq[Attribute] = plan.output

private[sql] override val runId: UUID = plan.runId

override def outputPartitioning: Partitioning = plan.outputPartitioning

override def doCanonicalize(): SparkPlan = plan.doCanonicalize()

@transient
private val timeout: Long = SQLConf.get.broadcastTimeout

override lazy val metrics = plan.metrics

val buildKeyExprs: Seq[Expression] = plan.buildKeyExprs

@transient
private lazy val promise = plan.promise

@transient
lazy override val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
plan.completionFuture

@transient
private[sql] override lazy val relationFuture
: java.util.concurrent.Future[broadcast.Broadcast[Any]] =
plan.relationFuture

override protected def doPrepare(): Unit = plan.doPrepare()

override protected def doExecute(): RDD[InternalRow] = plan.doExecute()

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] =
plan.doExecuteBroadcast[T]()

override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarShuffleExchangeAdaptor]

override def equals(other: Any): Boolean = other match {
case that: ColumnarShuffleExchangeAdaptor =>
(that canEqual this) && super.equals(that)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ case class ColumnarShuffleExchangeExec(
canChangeNumPartitions: Boolean = true)
extends Exchange {

private lazy val writeMetrics =
private[sql] lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private[sql] lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
Expand All @@ -92,7 +92,7 @@ case class ColumnarShuffleExchangeExec(
super.stringArgs ++ Iterator(s"[id=#$id]")
//super.stringArgs ++ Iterator(output.map(o => s"${o}#${o.dataType.simpleString}"))

private val serializer: Serializer = new ArrowColumnarBatchSerializer(
val serializer: Serializer = new ArrowColumnarBatchSerializer(
longMetric("avgReadBatchNumRows"),
longMetric("numOutputRows"))

Expand Down Expand Up @@ -129,8 +129,8 @@ case class ColumnarShuffleExchangeExec(
longMetric("compressTime"))
}

private var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
protected override def doExecute(): RDD[InternalRow] = {
var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException()
}
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
Expand All @@ -156,6 +156,110 @@ case class ColumnarShuffleExchangeExec(

}

class ColumnarShuffleExchangeAdaptor(
override val outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean = true)
extends ShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions) {

private[sql] lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private[sql] override lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"),
"computePidTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_computepid"),
"splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_split"),
"spillTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle spill time"),
"compressTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_compress"),
"avgReadBatchNumRows" -> SQLMetrics
.createAverageMetric(sparkContext, "avg read batch num rows"),
"numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics
.createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics

override def nodeName: String = "ColumnarExchange"
override def output: Seq[Attribute] = child.output

override def supportsColumnar: Boolean = true

override def stringArgs =
super.stringArgs ++ Iterator(s"[id=#$id]")
//super.stringArgs ++ Iterator(output.map(o => s"${o}#${o.dataType.simpleString}"))

val serializer: Serializer = new ArrowColumnarBatchSerializer(
longMetric("avgReadBatchNumRows"),
longMetric("numOutputRows"))

@transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar()

// 'mapOutputStatisticsFuture' is only needed when enable AQE.
@transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputColumnarRDD.getNumPartitions == 0) {
Future.successful(null)
} else {
sparkContext.submitMapStage(columnarShuffleDependency)
}
}

/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
@transient
lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
ColumnarShuffleExchangeExec.prepareShuffleDependency(
inputColumnarRDD,
child.output,
outputPartitioning,
serializer,
writeMetrics,
longMetric("dataSize"),
longMetric("bytesSpilled"),
longMetric("numInputRows"),
longMetric("computePidTime"),
longMetric("splitTime"),
longMetric("spillTime"),
longMetric("compressTime"))
}

var cachedShuffleRDD: ShuffledColumnarBatchRDD = _
override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException()
}
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
if (cachedShuffleRDD == null) {
cachedShuffleRDD = new ShuffledColumnarBatchRDD(columnarShuffleDependency, readMetrics)
}
cachedShuffleRDD
}

// 'shuffleDependency' is only needed when enable AQE. Columnar shuffle will use 'columnarShuffleDependency'
@transient
override lazy val shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow] =
new ShuffleDependency[Int, InternalRow, InternalRow](
_rdd = new ColumnarShuffleExchangeExec.DummyPairRDDWithPartitions(
sparkContext,
inputColumnarRDD.getNumPartitions),
partitioner = columnarShuffleDependency.partitioner) {

override val shuffleId: Int = columnarShuffleDependency.shuffleId

override val shuffleHandle: ShuffleHandle = columnarShuffleDependency.shuffleHandle
}

override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarShuffleExchangeAdaptor]

override def equals(other: Any): Boolean = other match {
case that: ColumnarShuffleExchangeAdaptor =>
(that canEqual this) && super.equals(that)
case _ => false
}

}

object ColumnarShuffleExchangeExec extends Logging {

class DummyPairRDDWithPartitions(@transient private val sc: SparkContext, numPartitions: Int)
Expand Down
Loading

0 comments on commit d433318

Please sign in to comment.