@@ -25,6 +25,7 @@ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525import org .apache .spark .sql .catalyst .plans ._
2626import org .apache .spark .sql .catalyst .rules .Rule
2727import org .apache .spark .sql .execution ._
28+ import org .apache .spark .sql .execution .aggregate .HashAggregateExec
2829import org .apache .spark .sql .execution .exchange .{EnsureRequirements , ShuffleExchangeExec }
2930import org .apache .spark .sql .execution .joins .SortMergeJoinExec
3031import org .apache .spark .sql .internal .SQLConf
@@ -130,13 +131,15 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
130131 }
131132 }
132133
133- private def canSplitLeftSide (joinType : JoinType ) = {
134- joinType == Inner || joinType == Cross || joinType == LeftSemi ||
135- joinType == LeftAnti || joinType == LeftOuter
134+ private def canSplitLeftSide (joinType : JoinType , plan : SparkPlan ) = {
135+ (joinType == Inner || joinType == Cross || joinType == LeftSemi ||
136+ joinType == LeftAnti || joinType == LeftOuter ) &&
137+ plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
136138 }
137139
138- private def canSplitRightSide (joinType : JoinType ) = {
139- joinType == Inner || joinType == Cross || joinType == RightOuter
140+ private def canSplitRightSide (joinType : JoinType , plan : SparkPlan ) = {
141+ (joinType == Inner || joinType == Cross || joinType == RightOuter ) &&
142+ plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
140143 }
141144
142145 private def getSizeInfo (medianSize : Long , sizes : Seq [Long ]): String = {
@@ -199,8 +202,8 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
199202
200203 | ${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
201204 """ .stripMargin)
202- val canSplitLeft = canSplitLeftSide(joinType)
203- val canSplitRight = canSplitRightSide(joinType)
205+ val canSplitLeft = canSplitLeftSide(joinType, s1 )
206+ val canSplitRight = canSplitRightSide(joinType, s2 )
204207 // We use the actual partition sizes (may be coalesced) to calculate target size, so that
205208 // the final data distribution is even (coalesced partitions + split partitions).
206209 val leftActualSizes = left.partitionsWithSizes.map(_._2)
0 commit comments