Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-537] Increase partition number adaptively for large SHJ stages (#…
Browse files Browse the repository at this point in the history
…538)

Closes #537
  • Loading branch information
zhztheplayer authored Nov 29, 2021
1 parent a5eb496 commit 2a16e7b
Show file tree
Hide file tree
Showing 11 changed files with 608 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class SplitResult {
private final long totalBytesWritten;
private final long totalBytesSpilled;
private final long[] partitionLengths;
private final long[] rawPartitionLengths;

public SplitResult(
long totalComputePidTime,
Expand All @@ -34,14 +35,16 @@ public SplitResult(
long totalCompressTime,
long totalBytesWritten,
long totalBytesSpilled,
long[] partitionLengths) {
long[] partitionLengths,
long[] rawPartitionLengths) {
this.totalComputePidTime = totalComputePidTime;
this.totalWriteTime = totalWriteTime;
this.totalSpillTime = totalSpillTime;
this.totalCompressTime = totalCompressTime;
this.totalBytesWritten = totalBytesWritten;
this.totalBytesSpilled = totalBytesSpilled;
this.partitionLengths = partitionLengths;
this.rawPartitionLengths = rawPartitionLengths;
}

public long getTotalComputePidTime() {
Expand Down Expand Up @@ -71,4 +74,8 @@ public long getTotalBytesSpilled() {
public long[] getPartitionLengths() {
return partitionLengths;
}

public long[] getRawPartitionLengths() {
return rawPartitionLengths;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ package com.intel.oap
import java.util
import java.util.Collections
import java.util.Objects

import scala.language.implicitConversions

import com.intel.oap.GazellePlugin.GAZELLE_SESSION_EXTENSION_NAME
import com.intel.oap.GazellePlugin.SPARK_SESSION_EXTS_KEY
import com.intel.oap.extension.ColumnarOverrides
import com.intel.oap.extension.{OptimizerOverrides, StrategyOverrides}

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.DriverPlugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package com.intel.oap

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

case class GazelleNumaBindingInfo(
enableNumaBinding: Boolean,
Expand Down Expand Up @@ -82,6 +84,15 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
conf.getConfString("spark.oap.sql.columnar.forceshuffledhashjoin", "false").toBoolean &&
enableCpu

val resizeShuffledHashJoinInputPartitions: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffledhashjoin.resizeinputpartitions", "false")
.toBoolean && enableCpu

// build size limit for shj, per task
val shuffledHashJoinBuildSizeLimit: Long =
JavaUtils.byteStringAsBytes(
conf.getConfString("spark.oap.sql.columnar.shuffledhashjoin.buildsizelimit", "100m"))

// enable or disable columnar sortmergejoin
// this should be set with preferSortMergeJoin=false
val enableColumnarSortMergeJoin: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,33 @@
* limitations under the License.
*/

package com.intel.oap
package com.intel.oap.extension

import com.intel.oap.GazellePluginConfig
import com.intel.oap.GazelleSparkExtensionsInjector

import scala.collection.mutable
import com.intel.oap.execution._
import com.intel.oap.extension.LocalWindowExec
import com.intel.oap.extension.columnar.ColumnarGuardRule
import com.intel.oap.extension.columnar.RowGuard
import com.intel.oap.sql.execution.RowToArrowColumnarExec
import org.apache.spark.internal.config._
import org.apache.spark.{MapOutputStatistics, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.LeftAnti
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.ShufflePartitionSpec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.adaptive.{ShuffleStageInfo, _}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
Expand All @@ -37,12 +50,13 @@ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, ColumnarArrowEvalPythonExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.CalendarIntervalType
import org.apache.spark.util.ShufflePartitionUtils

case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val columnarConf: GazellePluginConfig = GazellePluginConfig.getSessionConf
var isSupportAdaptive: Boolean = true


def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match {
case RowGuard(child: CustomShuffleReaderExec) =>
replaceWithColumnarPlan(child)
Expand Down Expand Up @@ -146,15 +160,25 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
plan.withNewChildren(Seq(child))
}
case plan: ShuffledHashJoinExec =>
val left = replaceWithColumnarPlan(plan.left)
val right = replaceWithColumnarPlan(plan.right)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val maybeOptimized = if (
GazellePluginConfig.getSessionConf.resizeShuffledHashJoinInputPartitions &&
ShufflePartitionUtils.withCustomShuffleReaders(plan)) {
// We are on AQE execution. Try repartitioning inputs
// to avoid OOM as ColumnarShuffledHashJoin doesn't spill
// input data.
ShufflePartitionUtils.reoptimizeShuffledHashJoinInput(plan)
} else {
plan
}
val left = replaceWithColumnarPlan(maybeOptimized.left)
val right = replaceWithColumnarPlan(maybeOptimized.right)
logDebug(s"Columnar Processing for ${maybeOptimized.getClass} is currently supported.")
ColumnarShuffledHashJoinExec(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
maybeOptimized.leftKeys,
maybeOptimized.rightKeys,
maybeOptimized.joinType,
maybeOptimized.buildSide,
maybeOptimized.condition,
left,
right)
case plan: BroadcastQueryStageExec =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.shuffle

import java.io.IOException

import com.google.common.annotations.VisibleForTesting
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression.ConverterUtils
Expand All @@ -34,7 +33,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

class ColumnarShuffleWriter[K, V](
shuffleBlockResolver: IndexShuffleBlockResolver,
Expand Down Expand Up @@ -78,6 +77,8 @@ class ColumnarShuffleWriter[K, V](

private var partitionLengths: Array[Long] = _

private var rawPartitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
Expand Down Expand Up @@ -180,6 +181,7 @@ class ColumnarShuffleWriter[K, V](
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)

partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
shuffleBlockResolver.writeIndexFileAndCommit(
dep.shuffleId,
Expand All @@ -191,7 +193,12 @@ class ColumnarShuffleWriter[K, V](
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
}
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

// fixme workaround: to store uncompressed sizes on the rhs of (maybe) compressed sizes
val unionPartitionLengths = ArrayBuffer[Long]()
unionPartitionLengths ++= partitionLengths
unionPartitionLengths ++= rawPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, unionPartitionLengths.toArray, mapId)
}

override def stop(success: Boolean): Option[MapStatus] = {
Expand Down
Loading

0 comments on commit 2a16e7b

Please sign in to comment.