diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index be2bdf87d11..26f4eef6e7a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -149,6 +149,8 @@ protected Compressor initialValue() { private final boolean dataPushFailureTrackingEnabled; + private final boolean shuffleWriteLimitEnabled; + public static class ReduceFileGroups { public Map> partitionGroups; public Map pushFailedBatches; @@ -211,6 +213,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u } authEnabled = conf.authEnabledOnClient(); dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled(); + shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled(); // init rpc env rpcEnv = @@ -1067,6 +1070,10 @@ public int pushOrMergeData( Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length); + if (shuffleWriteLimitEnabled) { + pushState.addWrittenBytes(body.length); + } + if (doPush) { // check limit limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort()); @@ -1789,6 +1796,8 @@ private void mapEndInternal( long[] bytesPerPartition = pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions); + long bytesWritten = pushState.getBytesWritten(); + MapperEndResponse response = lifecycleManagerRef.askSync( new MapperEnd( @@ -1801,7 +1810,8 @@ private void mapEndInternal( numPartitions, crc32PerPartition, bytesPerPartition, - SerdeVersion.V1), + SerdeVersion.V1, + bytesWritten), rpcMaxRetries, rpcRetryWait, ClassTag$.MODULE$.apply(MapperEndResponse.class)); diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f48f3cd72ba..f288f7911af 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -23,7 +23,7 @@ import java.security.SecureRandom import java.util import java.util.{function, List => JList} import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicInteger, LongAdder} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} import scala.collection.JavaConverters._ @@ -132,6 +132,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private val mockDestroyFailure = conf.testMockDestroySlotsFailure private val authEnabled = conf.authEnabledOnClient private var applicationMeta: ApplicationMeta = _ + + private val shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled + private val shuffleWriteLimitThreshold = conf.shuffleWriteLimitThreshold + private val shuffleTotalWrittenBytes = JavaUtils.newConcurrentHashMap[Int, AtomicLong]() + @VisibleForTesting def workerSnapshots(shuffleId: Int): util.Map[String, ShufflePartitionLocationInfo] = shuffleAllocatedWorkers.get(shuffleId) @@ -439,7 +444,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) => + serdeVersion, + bytesWritten) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) @@ -455,7 +461,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) + serdeVersion, + bytesWritten) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -933,7 +940,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions: Int, crc32PerPartition: Array[Int], bytesWrittenPerPartition: Array[Long], - serdeVersion: SerdeVersion): Unit = { + serdeVersion: SerdeVersion, + bytesWritten: Long): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = commitManager.finishMapperAttempt( @@ -945,6 +953,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends numPartitions = numPartitions, crc32PerPartition = crc32PerPartition, bytesWrittenPerPartition = bytesWrittenPerPartition) + + if (mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) { + handleShuffleWriteLimitCheck(shuffleId, bytesWritten) + logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " + + s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get( + shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold") + } + if (mapperAttemptFinishedSuccess && allMapperFinished) { // last mapper finished. call mapper end logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + @@ -1258,6 +1274,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"[handleUnregisterShuffle] Wait for handleStageEnd complete costs ${cost}ms") } } + + if (shuffleWriteLimitEnabled) { + shuffleTotalWrittenBytes.remove(shuffleId) + } } // add shuffleKey to delay shuffle removal set @@ -2081,4 +2101,25 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } def getShuffleIdMapping = shuffleIdMapping + + private def handleShuffleWriteLimitCheck(shuffleId: Int, writtenBytes: Long): Unit = { + if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return + + if (writtenBytes > 0) { + val totalBytesAccumulator = + shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, (id: Int) => new AtomicLong(0)) + val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes) + + if (currentTotalBytes > shuffleWriteLimitThreshold) { + val reason = + s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes" + logError(reason) + + cancelShuffleCallback match { + case Some(c) => c.accept(shuffleId, reason) + case _ => None + } + } + } + } } diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 46714c4e826..03daf2cc32a 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -40,10 +40,13 @@ public class PushState { private final Map failedBatchMap; + private long bytesWritten; + public PushState(CelebornConf conf) { pushBufferMaxSize = conf.clientPushBufferMaxSize(); inFlightRequestTracker = new InFlightRequestTracker(conf, this); failedBatchMap = JavaUtils.newConcurrentHashMap(); + bytesWritten = 0; } public void cleanup() { @@ -136,4 +139,12 @@ public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset, commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata()); commitMetadata.addDataWithOffsetAndLength(data, offset, length); } + + public void addWrittenBytes(int length) { + bytesWritten += length; + } + + public long getBytesWritten() { + return bytesWritten; + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..2807d3f5464 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -377,6 +377,7 @@ message PbMapperEnd { int32 numPartitions = 7; repeated int32 crc32PerPartition = 8; repeated int64 bytesWrittenPerPartition = 9; + int64 bytesWritten = 10; } message PbLocationPushFailedBatches { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 6b21abed540..949dff6affd 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1670,6 +1670,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def secretRedactionPattern = get(SECRET_REDACTION_PATTERN) def containerInfoProviderClass = get(CONTAINER_INFO_PROVIDER) + + def shuffleWriteLimitEnabled: Boolean = get(SHUFFLE_WRITE_LIMIT_ENABLED) + + def shuffleWriteLimitThreshold: Long = get(SHUFFLE_WRITE_LIMIT_THRESHOLD) } object CelebornConf extends Logging { @@ -6826,4 +6830,19 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val SHUFFLE_WRITE_LIMIT_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.shuffle.write.limit.enabled") + .categories("client") + .doc("Enable shuffle write limit check to prevent cluster resource exhaustion.") + .version("0.7.0") + .booleanConf + .createWithDefault(false) + + val SHUFFLE_WRITE_LIMIT_THRESHOLD: ConfigEntry[Long] = + buildConf("celeborn.client.spark.shuffle.write.limit.threshold") + .categories("client") + .doc("Shuffle write limit threshold, exceed to cancel oversized shuffle.") + .version("0.7.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("5TB") } diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..a56a0b1bad3 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -218,7 +218,8 @@ object ControlMessages extends Logging { numPartitions: Int, crc32PerPartition: Array[Int], bytesWrittenPerPartition: Array[Long], - serdeVersion: SerdeVersion) + serdeVersion: SerdeVersion, + bytesWritten: Long) extends MasterMessage case class ReadReducerPartitionEnd( @@ -737,7 +738,8 @@ object ControlMessages extends Logging { numPartitions, crc32PerPartition, bytesWrittenPerPartition, - serdeVersion) => + serdeVersion, + bytesWritten) => val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v) (k, resultValue) @@ -753,6 +755,7 @@ object ControlMessages extends Logging { .addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava) .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map( java.lang.Long.valueOf).toSeq.asJava) + .setBytesWritten(bytesWritten) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion) @@ -1248,7 +1251,8 @@ object ControlMessages extends Logging { pbMapperEnd.getNumPartitions, crc32Array, bytesWrittenPerPartitionArray, - message.getSerdeVersion) + message.getSerdeVersion, + pbMapperEnd.getBytesWritten) case READ_REDUCER_PARTITION_END_VALUE => val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload) diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 8be472b6447..26d06777e25 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -160,7 +160,8 @@ class UtilsSuite extends CelebornFunSuite { 1, Array.emptyIntArray, Array.emptyLongArray, - SerdeVersion.V1) + SerdeVersion.V1, + 1) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId) @@ -172,6 +173,7 @@ class UtilsSuite extends CelebornFunSuite { assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions) mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition + assert(mapperEnd.bytesWritten == mapperEndTrans.bytesWritten) } test("validate HDFS compatible fs path") { diff --git a/docs/configuration/client.md b/docs/configuration/client.md index fcd28ec2c9c..81f06d8814a 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -137,6 +137,8 @@ license: | | celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | | | celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | | +| celeborn.client.spark.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | | +| celeborn.client.spark.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | | | celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer | | celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure | | celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider | diff --git a/tests/spark-it/pom.xml b/tests/spark-it/pom.xml index a87594f5ce2..733b71ea8eb 100644 --- a/tests/spark-it/pom.xml +++ b/tests/spark-it/pom.xml @@ -187,6 +187,11 @@ minio test + + org.mockito + mockito-core + test + diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala index d34c79419b8..ce79262c717 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala @@ -18,7 +18,10 @@ package org.apache.celeborn.tests.client import java.util +import java.util.Collections +import java.util.function.BiConsumer +import org.mockito.Mockito import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.{interval, timeout} import org.scalatest.time.SpanSugar.convertIntToGrainOfTime @@ -26,7 +29,10 @@ import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.network.protocol.SerdeVersion +import org.apache.celeborn.common.protocol.message.ControlMessages.MapperEnd import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.rpc.RpcCallContext import org.apache.celeborn.service.deploy.MiniClusterFeature class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeature { @@ -126,6 +132,68 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu } } + test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold") { + val conf = celebornConf.clone + conf.set(CelebornConf.SHUFFLE_WRITE_LIMIT_ENABLED.key, "true") + .set(CelebornConf.SHUFFLE_WRITE_LIMIT_THRESHOLD.key, "2000") + val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) + val ctx = Mockito.mock(classOf[RpcCallContext]) + + // Custom BiConsumer callback to track if cancelShuffle is invoked + var isCancelShuffleInvoked = false + val cancelShuffleCallback = new BiConsumer[Integer, String] { + override def accept(shuffleId: Integer, reason: String): Unit = { + isCancelShuffleInvoked = true + } + } + lifecycleManager.registerCancelShuffleCallback(cancelShuffleCallback) + + // Scenario 1: Same mapper with multiple attempts (total bytes exceed threshold but no cancel) + val shuffleId = 0 + val mapId1 = 0 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId1, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId1, + attemptId = 1, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1500)) + assert(!isCancelShuffleInvoked) + + // Scenario 2: Total bytes of mapId1 + mapId2 exceed threshold (trigger cancel) + val mapId2 = 1 + lifecycleManager.receiveAndReply(ctx)(MapperEnd( + shuffleId = shuffleId, + mapId = mapId2, + attemptId = 0, + 2, + 1, + Collections.emptyMap(), + 1, + Array.emptyIntArray, + Array.emptyLongArray, + SerdeVersion.V1, + bytesWritten = 1000)) + assert(isCancelShuffleInvoked) + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster()