diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 58539129..fd77ac65 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.blaze import java.io.File import java.util.UUID - import org.apache.commons.lang3.reflect.FieldUtils import org.apache.spark.OneToOneDependency import org.apache.spark.ShuffleDependency @@ -98,8 +97,7 @@ import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase -import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExecProvider @@ -111,7 +109,6 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.FileSegment import org.blaze.{protobuf => pb} - import com.thoughtworks.enableIf class ShimsImpl extends Shims with Logging { @@ -266,8 +263,9 @@ class ShimsImpl extends Shims with Logging { override def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, - child: SparkPlan): NativeShuffleExchangeBase = - NativeShuffleExchangeExec(outputPartitioning, child) + child: SparkPlan, + shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase = + NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) override def createNativeSortExec( sortOrder: Seq[SortOrder], diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala index 143785a5..64284814 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.blaze.plan import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future - import org.apache.spark._ import org.apache.spark.rdd.MapPartitionsRDD import org.apache.spark.rdd.RDD @@ -37,12 +36,12 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter - import com.thoughtworks.enableIf case class NativeShuffleExchangeExec( override val outputPartitioning: Partitioning, - override val child: SparkPlan) + override val child: SparkPlan, + _shuffleOrigin: Option[Any] = None) extends NativeShuffleExchangeBase(outputPartitioning, child) { // NOTE: coordinator can be null after serialization/deserialization, @@ -175,8 +174,10 @@ case class NativeShuffleExchangeExec( @enableIf( Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( System.getProperty("blaze.shim"))) - override def shuffleOrigin = - org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS + override def shuffleOrigin = { + import org.apache.spark.sql.execution.exchange.ShuffleOrigin; + _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] + } @enableIf( Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( diff --git a/spark-extension/pom.xml b/spark-extension/pom.xml index 88763cd3..2616420f 100644 --- a/spark-extension/pom.xml +++ b/spark-extension/pom.xml @@ -69,5 +69,10 @@ scalatest_${scalaVersion} test + + com.thoughtworks.enableIf + enableif_${scalaVersion} + 1.2.0 + diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 4f055933..ad9adf5c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala @@ -15,9 +15,10 @@ */ package org.apache.spark.sql.blaze +import com.thoughtworks.enableIf + import scala.annotation.tailrec import scala.collection.mutable - import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat import org.apache.spark.SparkEnv import org.apache.spark.broadcast.Broadcast @@ -267,9 +268,18 @@ object BlazeConverters extends Logging { } Shims.get.createNativeShuffleExchangeExec( outputPartitioning, - addRenameColumnsExec(convertedChild)) + addRenameColumnsExec(convertedChild), + getShuffleOrigin(exec)) } + @enableIf( + Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( + System.getProperty("blaze.shim"))) + def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin) + + @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None + def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = { val ( relation, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index 883b459d..7c3ec435 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala @@ -16,7 +16,6 @@ package org.apache.spark.sql.blaze import java.io.File - import org.apache.spark.ShuffleDependency import org.apache.spark.TaskContext import org.apache.spark.SparkContext @@ -135,7 +134,8 @@ abstract class Shims { def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, - child: SparkPlan): NativeShuffleExchangeBase + child: SparkPlan, + shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase def createNativeSortExec( sortOrder: Seq[SortOrder],