Skip to content
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.reflect.ClassTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -65,6 +65,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
* @param keyOrdering key ordering for RDD's shuffles
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
* @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask
*/
@DeveloperApi
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
Expand All @@ -73,7 +74,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val serializer: Serializer = SparkEnv.get.serializer,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
extends Dependency[Product2[K, V]] {

if (mapSideCombine) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,7 @@ private[spark] class ShuffleMapTask(
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](
dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
dep.shuffleWriterProcessor.writeProcess(rdd, dep, partitionId, context, partition)
}

override def preferredLocations: Seq[TaskLocation] = preferredLocs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus

/**
* The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
* and put it into [[ShuffleDependency]], and executors use it in each ShuffleMapTask.
*/
private[spark] class ShuffleWriteProcessor extends Serializable with Logging {

/**
* Create a [[ShuffleWriteMetricsReporter]] from the task context. As the reporter is a
* per-row operator, here need a careful consideration on performance.
*/
protected def createMetricsReporter(context: TaskContext): ShuffleWriteMetricsReporter = {
context.taskMetrics().shuffleWriteMetrics
}

/**
* The write process for particular partition, it controls the life circle of [[ShuffleWriter]]
* get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for
* this task.
*/
def writeProcess(
rdd: RDD[_],
dep: ShuffleDependency[_, _, _],
partitionId: Int,
context: TaskContext,
partition: Partition): MapStatus = {
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
partitionId,
context,
createMetricsReporter(context))
writer.write(
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
}
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"),

// [SPARK-26139] Implement shuffle write metrics in SQL
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"),

// Data Source V2 API changes
(problem: Problem) => problem match {
case MissingClassProblem(cls) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ import java.util.function.Supplier
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleMetricsReporter}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.MutablePair
Expand All @@ -46,10 +47,13 @@ case class ShuffleExchangeExec(

// NOTE: coordinator can be null after serialization/deserialization,
// e.g. it can be null on the Executor side

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")
) ++ SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
) ++ readMetrics ++ writeMetrics

override def nodeName: String = {
val extraInfo = coordinator match {
Expand Down Expand Up @@ -90,7 +94,11 @@ case class ShuffleExchangeExec(
private[exchange] def prepareShuffleDependency()
: ShuffleDependency[Int, InternalRow, InternalRow] = {
ShuffleExchangeExec.prepareShuffleDependency(
child.execute(), child.output, newPartitioning, serializer)
child.execute(),
child.output,
newPartitioning,
serializer,
writeMetrics)
}

/**
Expand All @@ -109,7 +117,7 @@ case class ShuffleExchangeExec(
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = UnknownPartitioning(indices.length)
}
new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices)
new ShuffledRowRDD(shuffleDependency, readMetrics, specifiedPartitionStartIndices)
}

/**
Expand Down Expand Up @@ -204,7 +212,9 @@ object ShuffleExchangeExec {
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
serializer: Serializer,
writeMetrics: Map[String, SQLMetric])
: ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
Expand Down Expand Up @@ -333,8 +343,22 @@ object ShuffleExchangeExec {
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer)
serializer,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))

dependency
}

/**
* Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter
* with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]].
*/
def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = {
new ShuffleWriteProcessor {
override protected def createMetricsReporter(
context: TaskContext): ShuffleWriteMetricsReporter = {
new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
}
}
}
}
30 changes: 23 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter
import org.apache.spark.sql.execution.metric.{SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter}

/**
* Take the first `limit` elements and collect them to a single partition.
Expand All @@ -38,13 +38,21 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = readMetrics ++ writeMetrics
protected override def doExecute(): RDD[InternalRow] = {
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
locallyLimited, child.output, SinglePartition, serializer),
metrics)
locallyLimited,
child.output,
SinglePartition,
serializer,
writeMetrics),
readMetrics)
shuffled.mapPartitionsInternal(_.take(limit))
}
}
Expand Down Expand Up @@ -154,7 +162,11 @@ case class TakeOrderedAndProjectExec(

private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)

override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = readMetrics ++ writeMetrics

protected override def doExecute(): RDD[InternalRow] = {
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
Expand All @@ -165,8 +177,12 @@ case class TakeOrderedAndProjectExec(
}
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
localTopK, child.output, SinglePartition, serializer),
metrics)
localTopK,
child.output,
SinglePartition,
serializer,
writeMetrics),
readMetrics)
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
if (projectList != child.output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
import java.util.Locale

import scala.concurrent.duration._

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
Expand Down Expand Up @@ -78,6 +80,7 @@ object SQLMetrics {
private val SUM_METRIC = "sum"
private val SIZE_METRIC = "size"
private val TIMING_METRIC = "timing"
private val NS_TIMING_METRIC = "nsTiming"
private val AVERAGE_METRIC = "average"

private val baseForAvgMetric: Int = 10
Expand Down Expand Up @@ -121,6 +124,13 @@ object SQLMetrics {
acc
}

def createNanoTimingMetric(sc: SparkContext, name: String): SQLMetric = {
// Same with createTimingMetric, just normalize the unit of time to millisecond.
val acc = new SQLMetric(NS_TIMING_METRIC, -1)
acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false)
acc
}

/**
* Create a metric to report the average information (including min, med, max) like
* avg hash probe. As average metrics are double values, this kind of metrics should be
Expand Down Expand Up @@ -163,6 +173,8 @@ object SQLMetrics {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw new IllegalStateException("unexpected metrics type: " + metricsType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric

import org.apache.spark.SparkContext
import org.apache.spark.executor.TempShuffleReadMetrics
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter

/**
* A shuffle metrics reporter for SQL exchange operators.
Expand Down Expand Up @@ -95,3 +96,57 @@ private[spark] object SQLShuffleMetricsReporter {
FETCH_WAIT_TIME -> SQLMetrics.createTimingMetric(sc, "fetch wait time"),
RECORDS_READ -> SQLMetrics.createMetric(sc, "records read"))
}

/**
* A shuffle write metrics reporter for SQL exchange operators.
* @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter.
* @param metrics Shuffle write metrics in current SparkPlan.
*/
private[spark] class SQLShuffleWriteMetricsReporter(
metricsReporter: ShuffleWriteMetricsReporter,
metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter {
private[this] val _bytesWritten =
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_BYTES_WRITTEN)
private[this] val _recordsWritten =
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN)
private[this] val _writeTime =
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)

override private[spark] def incBytesWritten(v: Long): Unit = {
metricsReporter.incBytesWritten(v)
_bytesWritten.add(v)
}
override private[spark] def decRecordsWritten(v: Long): Unit = {
metricsReporter.decBytesWritten(v)
_recordsWritten.set(_recordsWritten.value - v)
}
override private[spark] def incRecordsWritten(v: Long): Unit = {
metricsReporter.incRecordsWritten(v)
_recordsWritten.add(v)
}
override private[spark] def incWriteTime(v: Long): Unit = {
metricsReporter.incWriteTime(v)
_writeTime.add(v)
}
override private[spark] def decBytesWritten(v: Long): Unit = {
metricsReporter.decBytesWritten(v)
_bytesWritten.set(_bytesWritten.value - v)
}
}

private[spark] object SQLShuffleWriteMetricsReporter {
val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten"
val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten"
val SHUFFLE_WRITE_TIME = "shuffleWriteTime"

/**
* Create all shuffle write relative metrics and return the Map.
*/
def createShuffleWriteMetrics(sc: SparkContext): Map[String, SQLMetric] = Map(
SHUFFLE_BYTES_WRITTEN ->
SQLMetrics.createSizeMetric(sc, "shuffle bytes written"),
SHUFFLE_RECORDS_WRITTEN ->
SQLMetrics.createMetric(sc, "shuffle records written"),
SHUFFLE_WRITE_TIME ->
SQLMetrics.createNanoTimingMetric(sc, "shuffle write time"))
}
Loading