diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
index 58539129..fd77ac65 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -17,7 +17,6 @@ package org.apache.spark.sql.blaze
import java.io.File
import java.util.UUID
-
import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
@@ -98,8 +97,7 @@ import org.apache.spark.sql.execution.blaze.plan._
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
-import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
-import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider
import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExecProvider
@@ -111,7 +109,6 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment
import org.blaze.{protobuf => pb}
-
import com.thoughtworks.enableIf
class ShimsImpl extends Shims with Logging {
@@ -266,8 +263,9 @@ class ShimsImpl extends Shims with Logging {
override def createNativeShuffleExchangeExec(
outputPartitioning: Partitioning,
- child: SparkPlan): NativeShuffleExchangeBase =
- NativeShuffleExchangeExec(outputPartitioning, child)
+ child: SparkPlan,
+ shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase =
+ NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin)
override def createNativeSortExec(
sortOrder: Seq[SortOrder],
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
index 143785a5..64284814 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.blaze.plan
import scala.collection.mutable
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
-
import org.apache.spark._
import org.apache.spark.rdd.MapPartitionsRDD
import org.apache.spark.rdd.RDD
@@ -37,12 +36,12 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
-
import com.thoughtworks.enableIf
case class NativeShuffleExchangeExec(
override val outputPartitioning: Partitioning,
- override val child: SparkPlan)
+ override val child: SparkPlan,
+ _shuffleOrigin: Option[Any] = None)
extends NativeShuffleExchangeBase(outputPartitioning, child) {
// NOTE: coordinator can be null after serialization/deserialization,
@@ -175,8 +174,10 @@ case class NativeShuffleExchangeExec(
@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
- override def shuffleOrigin =
- org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
+ override def shuffleOrigin = {
+ import org.apache.spark.sql.execution.exchange.ShuffleOrigin;
+ _shuffleOrigin.get.asInstanceOf[ShuffleOrigin]
+ }
@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
diff --git a/spark-extension/pom.xml b/spark-extension/pom.xml
index 88763cd3..2616420f 100644
--- a/spark-extension/pom.xml
+++ b/spark-extension/pom.xml
@@ -69,5 +69,10 @@
scalatest_${scalaVersion}
test
+
+ com.thoughtworks.enableIf
+ enableif_${scalaVersion}
+ 1.2.0
+
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
index 4f055933..ad9adf5c 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
@@ -15,9 +15,10 @@
*/
package org.apache.spark.sql.blaze
+import com.thoughtworks.enableIf
+
import scala.annotation.tailrec
import scala.collection.mutable
-
import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat
import org.apache.spark.SparkEnv
import org.apache.spark.broadcast.Broadcast
@@ -267,9 +268,18 @@ object BlazeConverters extends Logging {
}
Shims.get.createNativeShuffleExchangeExec(
outputPartitioning,
- addRenameColumnsExec(convertedChild))
+ addRenameColumnsExec(convertedChild),
+ getShuffleOrigin(exec))
}
+ @enableIf(
+ Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
+ System.getProperty("blaze.shim")))
+ def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin)
+
+ @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None
+
def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = {
val (
relation,
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
index 883b459d..7c3ec435 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
@@ -16,7 +16,6 @@
package org.apache.spark.sql.blaze
import java.io.File
-
import org.apache.spark.ShuffleDependency
import org.apache.spark.TaskContext
import org.apache.spark.SparkContext
@@ -135,7 +134,8 @@ abstract class Shims {
def createNativeShuffleExchangeExec(
outputPartitioning: Partitioning,
- child: SparkPlan): NativeShuffleExchangeBase
+ child: SparkPlan,
+ shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase
def createNativeSortExec(
sortOrder: Seq[SortOrder],