-
Notifications
You must be signed in to change notification settings - Fork 418
[CELEBORN-2264] Support cancel shuffle when write bytes exceeds threshold #3601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import java.security.SecureRandom | |
| import java.util | ||
| import java.util.{function, List => JList} | ||
| import java.util.concurrent._ | ||
| import java.util.concurrent.atomic.{AtomicInteger, LongAdder} | ||
| import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder} | ||
| import java.util.function.{BiConsumer, BiFunction, Consumer} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
@@ -132,6 +132,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| private val mockDestroyFailure = conf.testMockDestroySlotsFailure | ||
| private val authEnabled = conf.authEnabledOnClient | ||
| private var applicationMeta: ApplicationMeta = _ | ||
|
|
||
| private val shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled | ||
| private val shuffleWriteLimitThreshold = conf.shuffleWriteLimitThreshold | ||
| private val shuffleTotalWrittenBytes = JavaUtils.newConcurrentHashMap[Int, AtomicLong]() | ||
|
|
||
| @VisibleForTesting | ||
| def workerSnapshots(shuffleId: Int): util.Map[String, ShufflePartitionLocationInfo] = | ||
| shuffleAllocatedWorkers.get(shuffleId) | ||
|
|
@@ -439,7 +444,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| numPartitions, | ||
| crc32PerPartition, | ||
| bytesWrittenPerPartition, | ||
| serdeVersion) => | ||
| serdeVersion, | ||
| bytesWritten) => | ||
| logTrace(s"Received MapperEnd TaskEnd request, " + | ||
| s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") | ||
| val partitionType = getPartitionType(shuffleId) | ||
|
|
@@ -455,7 +461,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| numPartitions, | ||
| crc32PerPartition, | ||
| bytesWrittenPerPartition, | ||
| serdeVersion) | ||
| serdeVersion, | ||
| bytesWritten) | ||
| case PartitionType.MAP => | ||
| handleMapPartitionEnd( | ||
| context, | ||
|
|
@@ -933,7 +940,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| numPartitions: Int, | ||
| crc32PerPartition: Array[Int], | ||
| bytesWrittenPerPartition: Array[Long], | ||
| serdeVersion: SerdeVersion): Unit = { | ||
| serdeVersion: SerdeVersion, | ||
| bytesWritten: Long): Unit = { | ||
|
|
||
| val (mapperAttemptFinishedSuccess, allMapperFinished) = | ||
| commitManager.finishMapperAttempt( | ||
|
|
@@ -945,6 +953,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| numPartitions = numPartitions, | ||
| crc32PerPartition = crc32PerPartition, | ||
| bytesWrittenPerPartition = bytesWrittenPerPartition) | ||
|
|
||
| if (mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider adding a negative test case for this feature? |
||
| handleShuffleWriteLimitCheck(shuffleId, bytesWritten) | ||
| logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " + | ||
| s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get( | ||
| shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") | ||
| } | ||
|
|
||
| if (mapperAttemptFinishedSuccess && allMapperFinished) { | ||
| // last mapper finished. call mapper end | ||
| logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + | ||
|
|
@@ -1258,6 +1274,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| logInfo(s"[handleUnregisterShuffle] Wait for handleStageEnd complete costs ${cost}ms") | ||
| } | ||
| } | ||
|
|
||
| if (shuffleWriteLimitEnabled) { | ||
| shuffleTotalWrittenBytes.remove(shuffleId) | ||
| } | ||
| } | ||
|
|
||
| // add shuffleKey to delay shuffle removal set | ||
|
|
@@ -2081,4 +2101,25 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends | |
| } | ||
|
|
||
| def getShuffleIdMapping = shuffleIdMapping | ||
|
|
||
| private def handleShuffleWriteLimitCheck(shuffleId: Int, writtenBytes: Long): Unit = { | ||
| if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return | ||
|
|
||
| if (writtenBytes > 0) { | ||
| val totalBytesAccumulator = | ||
| shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, (id: Int) => new AtomicLong(0)) | ||
| val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes) | ||
|
|
||
| if (currentTotalBytes > shuffleWriteLimitThreshold) { | ||
| val reason = | ||
| s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" | ||
| logError(reason) | ||
|
|
||
| cancelShuffleCallback match { | ||
| case Some(c) => c.accept(shuffleId, reason) | ||
| case _ => None | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -218,7 +218,8 @@ object ControlMessages extends Logging { | |
| numPartitions: Int, | ||
| crc32PerPartition: Array[Int], | ||
| bytesWrittenPerPartition: Array[Long], | ||
| serdeVersion: SerdeVersion) | ||
| serdeVersion: SerdeVersion, | ||
| bytesWritten: Long) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC this patch is changing the protocol how is this handled ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not a big deal since MapperEnd is only utilized on the engine side. |
||
| extends MasterMessage | ||
|
|
||
| case class ReadReducerPartitionEnd( | ||
|
|
@@ -737,7 +738,8 @@ object ControlMessages extends Logging { | |
| numPartitions, | ||
| crc32PerPartition, | ||
| bytesWrittenPerPartition, | ||
| serdeVersion) => | ||
| serdeVersion, | ||
| bytesWritten) => | ||
| val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => | ||
| val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v) | ||
| (k, resultValue) | ||
|
|
@@ -753,6 +755,7 @@ object ControlMessages extends Logging { | |
| .addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava) | ||
| .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map( | ||
| java.lang.Long.valueOf).toSeq.asJava) | ||
| .setBytesWritten(bytesWritten) | ||
| .build().toByteArray | ||
| new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion) | ||
|
|
||
|
|
@@ -1248,7 +1251,8 @@ object ControlMessages extends Logging { | |
| pbMapperEnd.getNumPartitions, | ||
| crc32Array, | ||
| bytesWrittenPerPartitionArray, | ||
| message.getSerdeVersion) | ||
| message.getSerdeVersion, | ||
| pbMapperEnd.getBytesWritten) | ||
|
|
||
| case READ_REDUCER_PARTITION_END_VALUE => | ||
| val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should clean shuffleId related data when shuffle expires.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes.