Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(shuffle): Progagate shuffle origin to native exchange exec to make AQE rebalance valid #663

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package org.apache.spark.sql.blaze

import java.io.File
import java.util.UUID

import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
Expand Down Expand Up @@ -98,8 +97,7 @@ import org.apache.spark.sql.execution.blaze.plan._
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider
import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExecProvider
Expand All @@ -111,7 +109,6 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment
import org.blaze.{protobuf => pb}

import com.thoughtworks.enableIf

class ShimsImpl extends Shims with Logging {
Expand Down Expand Up @@ -266,8 +263,9 @@ class ShimsImpl extends Shims with Logging {

override def createNativeShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan): NativeShuffleExchangeBase =
NativeShuffleExchangeExec(outputPartitioning, child)
child: SparkPlan,
shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase =
NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin)

override def createNativeSortExec(
sortOrder: Seq[SortOrder],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.blaze.plan
import scala.collection.mutable
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future

import org.apache.spark._
import org.apache.spark.rdd.MapPartitionsRDD
import org.apache.spark.rdd.RDD
Expand All @@ -37,12 +36,12 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter

import com.thoughtworks.enableIf

case class NativeShuffleExchangeExec(
override val outputPartitioning: Partitioning,
override val child: SparkPlan)
override val child: SparkPlan,
_shuffleOrigin: Option[Any] = None)
extends NativeShuffleExchangeBase(outputPartitioning, child) {

// NOTE: coordinator can be null after serialization/deserialization,
Expand Down Expand Up @@ -175,8 +174,10 @@ case class NativeShuffleExchangeExec(
@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
override def shuffleOrigin =
org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
override def shuffleOrigin = {
import org.apache.spark.sql.execution.exchange.ShuffleOrigin;
_shuffleOrigin.get.asInstanceOf[ShuffleOrigin]
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
Expand Down
5 changes: 5 additions & 0 deletions spark-extension/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,10 @@
<artifactId>scalatest_${scalaVersion}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.thoughtworks.enableIf</groupId>
<artifactId>enableif_${scalaVersion}</artifactId>
<version>1.2.0</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
*/
package org.apache.spark.sql.blaze

import com.thoughtworks.enableIf

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat
import org.apache.spark.SparkEnv
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -267,9 +268,18 @@ object BlazeConverters extends Logging {
}
Shims.get.createNativeShuffleExchangeExec(
outputPartitioning,
addRenameColumnsExec(convertedChild))
addRenameColumnsExec(convertedChild),
getShuffleOrigin(exec))
}

@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin)

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None

def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = {
val (
relation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.apache.spark.sql.blaze

import java.io.File

import org.apache.spark.ShuffleDependency
import org.apache.spark.TaskContext
import org.apache.spark.SparkContext
Expand Down Expand Up @@ -135,7 +134,8 @@ abstract class Shims {

def createNativeShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan): NativeShuffleExchangeBase
child: SparkPlan,
shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase

def createNativeSortExec(
sortOrder: Seq[SortOrder],
Expand Down