From 0a36f67feada49c1fbf2dcbe3d5fafb07fef1b35 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 27 Nov 2024 21:08:39 +0800 Subject: [PATCH 1/3] fix(shuffle): Progagate shuffle origin to native exchange exec to make AQE rebalance valid --- .../scala/org/apache/spark/sql/blaze/ShimsImpl.scala | 10 ++++------ .../blaze/plan/NativeShuffleExchangeExec.scala | 9 ++++----- .../org/apache/spark/sql/blaze/BlazeConverters.scala | 3 ++- .../main/scala/org/apache/spark/sql/blaze/Shims.scala | 6 +++--- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 58539129..488aa442 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.blaze import java.io.File import java.util.UUID - import org.apache.commons.lang3.reflect.FieldUtils import org.apache.spark.OneToOneDependency import org.apache.spark.ShuffleDependency @@ -98,8 +97,7 @@ import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase -import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ENSURE_REQUIREMENTS, ReusedExchangeExec, ShuffleOrigin} import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExecProvider @@ -111,7 +109,6 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.FileSegment import org.blaze.{protobuf => pb} - import com.thoughtworks.enableIf class ShimsImpl extends Shims with Logging { @@ -266,8 +263,9 @@ class ShimsImpl extends Shims with Logging { override def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, - child: SparkPlan): NativeShuffleExchangeBase = - NativeShuffleExchangeExec(outputPartitioning, child) + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS): NativeShuffleExchangeBase = + NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) override def createNativeSortExec( sortOrder: Seq[SortOrder], diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala index 143785a5..fdbe0011 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.blaze.plan import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future - import org.apache.spark._ import org.apache.spark.rdd.MapPartitionsRDD import org.apache.spark.rdd.RDD @@ -37,12 +36,13 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter - import com.thoughtworks.enableIf +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin} case class NativeShuffleExchangeExec( override val outputPartitioning: Partitioning, - override val child: SparkPlan) + override val child: SparkPlan, + _shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) extends NativeShuffleExchangeBase(outputPartitioning, child) { // NOTE: coordinator can be null after serialization/deserialization, @@ -175,8 +175,7 @@ case class NativeShuffleExchangeExec( @enableIf( Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( System.getProperty("blaze.shim"))) - override def shuffleOrigin = - org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS + override def shuffleOrigin = _shuffleOrigin @enableIf( Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 4f055933..7851806d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala @@ -267,7 +267,8 @@ object BlazeConverters extends Logging { } Shims.get.createNativeShuffleExchangeExec( outputPartitioning, - addRenameColumnsExec(convertedChild)) + addRenameColumnsExec(convertedChild), + exec.shuffleOrigin) } def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index 883b459d..941a133c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala @@ -16,7 +16,6 @@ package org.apache.spark.sql.blaze import java.io.File - import org.apache.spark.ShuffleDependency import org.apache.spark.TaskContext import org.apache.spark.SparkContext @@ -34,7 +33,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase -import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ENSURE_REQUIREMENTS, ShuffleOrigin} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Generator @@ -135,7 +134,8 @@ abstract class Shims { def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, - child: SparkPlan): NativeShuffleExchangeBase + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS): NativeShuffleExchangeBase def createNativeSortExec( sortOrder: Seq[SortOrder], From f78991166218aa82ade62f13448246493afab326 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Nov 2024 20:25:02 +0800 Subject: [PATCH 2/3] fix for spark3 --- .../org/apache/spark/sql/blaze/ShimsImpl.scala | 4 ++-- .../blaze/plan/NativeShuffleExchangeExec.scala | 6 +++--- spark-extension/pom.xml | 5 +++++ .../apache/spark/sql/blaze/BlazeConverters.scala | 13 +++++++++++-- .../scala/org/apache/spark/sql/blaze/Shims.scala | 4 ++-- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 488aa442..fd77ac65 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -97,7 +97,7 @@ import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ENSURE_REQUIREMENTS, ReusedExchangeExec, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExecProvider @@ -264,7 +264,7 @@ class ShimsImpl extends Shims with Logging { override def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS): NativeShuffleExchangeBase = + shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase = NativeShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) override def createNativeSortExec( diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala index fdbe0011..e81a7b33 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala @@ -37,12 +37,12 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter import com.thoughtworks.enableIf -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.ShuffleOrigin case class NativeShuffleExchangeExec( override val outputPartitioning: Partitioning, override val child: SparkPlan, - _shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + _shuffleOrigin: Option[Any] = None) extends NativeShuffleExchangeBase(outputPartitioning, child) { // NOTE: coordinator can be null after serialization/deserialization, @@ -175,7 +175,7 @@ case class NativeShuffleExchangeExec( @enableIf( Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( System.getProperty("blaze.shim"))) - override def shuffleOrigin = _shuffleOrigin + override def shuffleOrigin = _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] @enableIf( Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( diff --git a/spark-extension/pom.xml b/spark-extension/pom.xml index 88763cd3..2616420f 100644 --- a/spark-extension/pom.xml +++ b/spark-extension/pom.xml @@ -69,5 +69,10 @@ scalatest_${scalaVersion} test + + com.thoughtworks.enableIf + enableif_${scalaVersion} + 1.2.0 + diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 7851806d..ad9adf5c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala @@ -15,9 +15,10 @@ */ package org.apache.spark.sql.blaze +import com.thoughtworks.enableIf + import scala.annotation.tailrec import scala.collection.mutable - import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat import org.apache.spark.SparkEnv import org.apache.spark.broadcast.Broadcast @@ -268,9 +269,17 @@ object BlazeConverters extends Logging { Shims.get.createNativeShuffleExchangeExec( outputPartitioning, addRenameColumnsExec(convertedChild), - exec.shuffleOrigin) + getShuffleOrigin(exec)) } + @enableIf( + Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( + System.getProperty("blaze.shim"))) + def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin) + + @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None + def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = { val ( relation, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index 941a133c..7c3ec435 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ENSURE_REQUIREMENTS, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Generator @@ -135,7 +135,7 @@ abstract class Shims { def createNativeShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS): NativeShuffleExchangeBase + shuffleOrigin: Option[Any] = None): NativeShuffleExchangeBase def createNativeSortExec( sortOrder: Seq[SortOrder], From e5f0bfc13defc1226ea00a88f8dd319433805cfc Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 28 Nov 2024 20:32:56 +0800 Subject: [PATCH 3/3] fix2 --- .../execution/blaze/plan/NativeShuffleExchangeExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala index e81a7b33..64284814 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter import com.thoughtworks.enableIf -import org.apache.spark.sql.execution.exchange.ShuffleOrigin case class NativeShuffleExchangeExec( override val outputPartitioning: Partitioning, @@ -175,7 +174,10 @@ case class NativeShuffleExchangeExec( @enableIf( Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( System.getProperty("blaze.shim"))) - override def shuffleOrigin = _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] + override def shuffleOrigin = { + import org.apache.spark.sql.execution.exchange.ShuffleOrigin; + _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] + } @enableIf( Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(